In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [2]:
!pip install scanpy --quiet

[K     |████████████████████████████████| 2.0 MB 15.3 MB/s 
[K     |████████████████████████████████| 9.4 MB 84.9 MB/s 
[K     |████████████████████████████████| 96 kB 6.8 MB/s 
[K     |████████████████████████████████| 88 kB 10.0 MB/s 
[K     |████████████████████████████████| 295 kB 76.3 MB/s 
[K     |████████████████████████████████| 965 kB 81.2 MB/s 
[K     |████████████████████████████████| 1.1 MB 85.1 MB/s 
[K     |████████████████████████████████| 63 kB 2.6 MB/s 
[?25h  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Building wheel for pynndescent (setup.py) ... [?25l[?25hdone
  Building wheel for session-info (setup.py) ... [?25l[?25hdone


In [3]:
import random
import torch
import sys
import os
import gc
import collections 
import anndata as ad
from argparse import Namespace

config = Namespace(
    LEARNING_RATE = 0.00002,
    DEVICE = 'cuda',
    BATCH_SIZE = 100,
    NUM_WORKERS = 4,
    N_GENES = 13431,
    N_PEAKS = 116465,
    MAX_SEQ_LEN_GEX = 1500,
    MAX_SEQ_LEN_ATAC = 15000,
)

In [4]:
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/data.py")
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/models.py")

## Import data

In [5]:
index = get_chr_index(ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad"))

In [6]:
gc.collect()

242

In [7]:
batch = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['batch']
batch = list(batch)
train_id = [a for a, l in enumerate(batch) if l not in ['s2d4','s1d1']]
val_id =  [a for a, l in enumerate(batch) if l == 's1d1']
test_id = [a for a, l in enumerate(batch) if l == 's2d4']

cell_type_all = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['cell_type']

csr_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").layers['log_norm']
csr_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad").layers['log_norm']

In [8]:
gc.collect()

240

In [9]:
random.seed(0)

# idx_train = [train_id[i] for i in random.sample(range(0, 45000), 1024)]
idx_train = train_id # Full dataset
gex_train = csr_gex[idx_train,:]
atac_train = csr_atac[idx_train,:]
cell_type_train = [cell_type_all[j] for j in idx_train]

data_train = get_dataloaders(gex_train, atac_train, cell_type_train)

In [10]:
def train(model, criterion, optimizer, data_train, epochs, loss_type):

  model.train()

  for e in range(epochs):
    running_loss = 0.0
    running_loss_cross = 0.0
    running_loss_triplet = 0.0
    running_ct_prob = 0.0
    running_cell_prob = 0.0
    for iter, data in enumerate(data_train):
      gex_input = data['gex'].to(config.DEVICE)
      atac_input = data['atac'].to(config.DEVICE)
      cell_type_input = data['cell_type']
      # print(cell_type_input)

      model.zero_grad()
      optimizer.zero_grad()

      ### Forward
      gex_out_0, gex_out_1, atac_out_0, atac_out_1 = model(gex_input, atac_input)

      ### Compute loss
      loss, loss_triplet, loss_cross, ct_match_prob, cell_match_prob = criterion(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type_input)
      
      ### Propagate loss
      if loss_type == "both":
        running_loss += loss.item()
        loss.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_cell_prob += cell_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  
        
      elif loss_type == "entropy":
        running_loss += loss_cross.item()
        loss_cross.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_cell_prob += cell_match_prob.item()
        running_loss_triplet += loss_triplet.item()       
        running_loss_cross += loss_cross.item()  

      elif loss_type == "triplet":
        running_loss += loss_triplet.item()
        loss_triplet.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_cell_prob += cell_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  

      else:
        break

      ### update parameters
      optimizer.step()

      del gex_input
      del atac_input
      del cell_type_input
      torch.cuda.empty_cache()
      if (iter + 1) % 50 == 0: 
        print('Within-epoch iter', iter + 1, ': cross_loss =', loss_cross.item(), '; triplet_loss =', loss_triplet.item(), '; ct_match =', ct_match_prob.item(), '; cell_match =', cell_match_prob.item())

    # print('cross_loss = ', loss_cross.item(), '; triplet_loss = ', loss_triplet.item(), '; ct_match_prob = ', ct_match_prob.item())
    if (e+1) % 1 == 0: 
      print('Epoch-{0}: lr = {1}, loss = {2}, entropy_loss = {3}, triplet loss = {4}, cell type match prob = {5}, cell_match = {6}'.format(
          e+1, 
          optimizer.param_groups[0]['lr'], 
          running_loss / len(data_train), 
          running_loss_cross / len(data_train), 
          running_loss_triplet / len(data_train), 
          running_ct_prob / len(data_train),
          running_cell_prob / len(data_train)
          )
      )

    # scheduler.step()

## Select hyperparameters

### Hyperparam set 1

In [11]:
config.ALPHA = 0.2
config.MARGIN = 0.5
config.N_CHANNELS = 32

In [12]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [10], gamma = 0.1)



In [None]:
train(model, criterion, optimizer, data_train, epochs = 15, loss_type = "triplet") # shuffle = True

Within-epoch iter 50 : cross_loss = 3.9264414310455322 ; triplet_loss = 0.5119875073432922 ; ct_match = 0.10754494369029999 ; cell_match = 0.020508617162704468
Within-epoch iter 100 : cross_loss = 3.8863515853881836 ; triplet_loss = 0.5277578830718994 ; ct_match = 0.126285120844841 ; cell_match = 0.021259082481265068
Within-epoch iter 150 : cross_loss = 3.8958559036254883 ; triplet_loss = 0.4983513653278351 ; ct_match = 0.10487990081310272 ; cell_match = 0.02086810953915119
Within-epoch iter 200 : cross_loss = 3.974040985107422 ; triplet_loss = 0.5269578099250793 ; ct_match = 0.14636145532131195 ; cell_match = 0.0194865670055151
Within-epoch iter 250 : cross_loss = 3.9022748470306396 ; triplet_loss = 0.5279420018196106 ; ct_match = 0.12857791781425476 ; cell_match = 0.02071729116141796
Within-epoch iter 300 : cross_loss = 3.8046927452087402 ; triplet_loss = 0.5396959185600281 ; ct_match = 0.11438201367855072 ; cell_match = 0.022598857060074806
Within-epoch iter 350 : cross_loss = 3.866

In [None]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_15epochs'
torch.save(model.state_dict(), file)

### Hyperparam set 2

In [15]:
config.ALPHA = 0.2
config.MARGIN = 0.5
config.N_CHANNELS = 64

In [16]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [10], gamma = 0.1)

In [17]:
train(model, criterion, optimizer, data_train, epochs = 15, loss_type = "triplet") 

Within-epoch iter 50 : cross_loss = 4.61014461517334 ; triplet_loss = 0.7371470332145691 ; ct_match = 0.050371572375297546 ; cell_match = 0.010103763081133366
Within-epoch iter 100 : cross_loss = 4.658210754394531 ; triplet_loss = 0.6405066251754761 ; ct_match = 0.04822942987084389 ; cell_match = 0.010131610557436943
Within-epoch iter 150 : cross_loss = 4.684309005737305 ; triplet_loss = 0.6615475416183472 ; ct_match = 0.05279282480478287 ; cell_match = 0.010117468424141407
Within-epoch iter 200 : cross_loss = 4.700479030609131 ; triplet_loss = 0.6821862459182739 ; ct_match = 0.05956846848130226 ; cell_match = 0.010200748220086098
Within-epoch iter 250 : cross_loss = 4.715002536773682 ; triplet_loss = 0.6737810969352722 ; ct_match = 0.0561494454741478 ; cell_match = 0.009980382397770882
Within-epoch iter 300 : cross_loss = 4.675390243530273 ; triplet_loss = 0.6730854511260986 ; ct_match = 0.05434828996658325 ; cell_match = 0.010522532276809216
Within-epoch iter 350 : cross_loss = 4.665

In [18]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_15epochs'
torch.save(model.state_dict(), file)

### Hyperparam set 3

In [19]:
config.ALPHA = 0.2
config.MARGIN = 1
config.N_CHANNELS = 32

In [20]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [10], gamma = 0.1)

In [21]:
train(model, criterion, optimizer, data_train, epochs = 15, loss_type = "triplet") 

Within-epoch iter 50 : cross_loss = 4.600675582885742 ; triplet_loss = 1.8901630640029907 ; ct_match = 0.06265173107385635 ; cell_match = 0.010071961209177971
Within-epoch iter 100 : cross_loss = 4.603088855743408 ; triplet_loss = 1.7736588716506958 ; ct_match = 0.04804017394781113 ; cell_match = 0.010113043710589409
Within-epoch iter 150 : cross_loss = 4.647706031799316 ; triplet_loss = 1.5440195798873901 ; ct_match = 0.05908103287220001 ; cell_match = 0.010039076209068298
Within-epoch iter 200 : cross_loss = 4.697261810302734 ; triplet_loss = 1.4478212594985962 ; ct_match = 0.05857866629958153 ; cell_match = 0.010084299370646477
Within-epoch iter 250 : cross_loss = 4.718192100524902 ; triplet_loss = 1.375315546989441 ; ct_match = 0.05859018489718437 ; cell_match = 0.010075383819639683
Within-epoch iter 300 : cross_loss = 4.705885887145996 ; triplet_loss = 1.326669692993164 ; ct_match = 0.05704361945390701 ; cell_match = 0.010344154201447964
Within-epoch iter 350 : cross_loss = 4.7238

In [22]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_15epochs'
torch.save(model.state_dict(), file)

### Hyperparam set 4

In [None]:
config.ALPHA = 0.2
config.MARGIN = 1
config.N_CHANNELS = 64

In [None]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [10], gamma = 0.1)

In [None]:
train(model, criterion, optimizer, data_train, epochs = 15, loss_type = "triplet") 

In [None]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_15epochs'
torch.save(model.state_dict(), file)

### Hyperparam set 5

In [None]:
config.ALPHA = 0.8
config.MARGIN = 0.5
config.N_CHANNELS = 32

### Hyperparam set 6

In [None]:
config.ALPHA = 0.8
config.MARGIN = 0.5
config.N_CHANNELS = 64

### Hyperparam set 7

In [None]:
config.ALPHA = 0.8
config.MARGIN = 1
config.N_CHANNELS = 32

### Hyperparam set 8

In [None]:
config.ALPHA = 0.8
config.MARGIN = 1
config.N_CHANNELS = 64