In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
!pip install scanpy --quiet

In [3]:
import random
import torch
import sys
import os
import gc
import collections 
import anndata as ad
import torch.nn as nn 
from argparse import Namespace

config = Namespace(
    # LEARNING_RATE = 0.00002,
    DEVICE = 'cpu',
    BATCH_SIZE = 300,
    NUM_WORKERS = 4,
    N_GENES = 13431,
    N_PEAKS = 116465,
    MAX_SEQ_LEN_GEX = 1500,
    MAX_SEQ_LEN_ATAC = 15000,
)

In [4]:
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/data.py")
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/models.py")

## Load data

In [5]:
index = get_chr_index(ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad"))

In [6]:
gc.collect()

277

In [7]:
batch = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['batch']
batch = list(batch)
train_id = [a for a, l in enumerate(batch) if l not in ['s2d4','s1d1']]
val_id =  [a for a, l in enumerate(batch) if l == 's1d1']
test_id = [a for a, l in enumerate(batch) if l == 's2d4']

cell_type_all = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['cell_type']

csr_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").layers['log_norm']
csr_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad").layers['log_norm']

In [8]:
gc.collect()

240

In [9]:
random.seed(0)

# idx_val = [val_id[i] for i in random.sample(range(0, 6000), 1024)]
idx_val = val_id  # Full dataset
gex_val = csr_gex[idx_val,:]
atac_val = csr_atac[idx_val,:]
cell_type_val = [cell_type_all[j] for j in idx_val]

data_val = get_dataloaders(gex_val, atac_val, cell_type_val)

In [10]:
def cellTypematchingProb2(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type):
        gex_out_0 = nn.functional.normalize(gex_out_0, dim = 1)
        gex_out_1 = nn.functional.normalize(gex_out_1, dim = 1)
        atac_out_0 = nn.functional.normalize(atac_out_0, dim = 1)
        atac_out_1 = nn.functional.normalize(atac_out_1, dim = 1)
        
        score_mat = torch.mm(gex_out_0, atac_out_0.transpose(0,1)) + 0.2 * torch.mm(gex_out_1, atac_out_1.transpose(0,1))

        #cell matching prob
        score_norm_gex = score_mat.softmax(dim = 0)
        score_norm_atac = score_mat.softmax(dim = 1)

        match_probs = 0.5 * (torch.diagonal(score_norm_gex) + torch.diagonal(score_norm_atac))

      # method1: xf
      # Collect list of index list for each cell type
        idx_in_type = collections.defaultdict(list)
        for i, x in enumerate(cell_type):
            idx_in_type[x].append(i)

        sum_score_mat=torch.zeros(len(idx_in_type.values()),len(idx_in_type.values()))
        for i, dx in enumerate(idx_in_type.values()):
          for j, dx2 in enumerate(idx_in_type.values()):
            tem=score_mat[np.ix_(dx, dx2)].sum()
            sum_score_mat[i,j]=tem
        score_mat_norm = 0.5 * (sum_score_mat.softmax(dim = 0) + sum_score_mat.softmax(dim = 1))
       
        return torch.mean(torch.diagonal(score_mat_norm)), torch.mean(match_probs) 
        #first and second output is celltype matching prob, last output is cell matching prob

def inference(model, data_val):

    # Initialize encoder & decoder 
    model.eval()
    model.to(config.DEVICE)
    criterion.to(config.DEVICE)
    
    running_loss = 0.0
    running_loss_triplet = 0.0
    running_ct_prob = running_ct_prob2 = 0.0  
    running_cell_prob = 0.0  
    for iter, data in enumerate(data_val):
      gex_input = data['gex'].to(config.DEVICE)
      atac_input = data['atac'].to(config.DEVICE)
      cell_type_input = data['cell_type']

      ### Forward
      gex_out_0, gex_out_1, atac_out_0, atac_out_1 = model(gex_input, atac_input)

      ### Compute loss
      loss, loss_triplet, loss_cross, ct_match_prob, cell_match_prob = criterion(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type_input)
      ct_match_prob2 = cellTypematchingProb2(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type_input)

      running_loss += loss_cross.item()
      running_loss_triplet += loss_triplet.item()
      running_ct_prob += ct_match_prob.item()
      running_ct_prob2 += ct_match_prob2.item()
      running_cell_prob += cell_match_prob.item()

      del gex_input
      del atac_input
      del cell_type_input
      torch.cuda.empty_cache()

    return running_loss / len(data_val), running_loss_triplet / len(data_val), running_ct_prob / len(data_val), running_ct_prob2 / len(data_val), running_cell_prob / len(data_val)

## Parameter set 1

In [11]:
config.ALPHA = 0.2
config.MARGIN = 0.5
config.N_CHANNELS = 32

In [12]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_30epochs'
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
model.load_state_dict(torch.load(file, map_location=torch.device('cpu')))



<All keys matched successfully>

In [None]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to('cpu')
loss, loss_triplet, loss_cross, ct_match_prob, ct_match_prob2, cell_match_prob = inference(model, data_val)
print('loss = {0}, triplet loss = {1}, cell type match prob = {2},  cell match prob = {3}, XF cell type match prob n = {4}'.format(loss, loss_triplet, ct_match_prob2[0],  ct_match_prob2[1], ct_match_prob))