In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [2]:
!pip install scanpy --quiet

[K     |████████████████████████████████| 2.0 MB 39.7 MB/s 
[K     |████████████████████████████████| 96 kB 7.4 MB/s 
[K     |████████████████████████████████| 88 kB 10.7 MB/s 
[K     |████████████████████████████████| 9.4 MB 58.8 MB/s 
[K     |████████████████████████████████| 295 kB 81.0 MB/s 
[K     |████████████████████████████████| 965 kB 69.6 MB/s 
[K     |████████████████████████████████| 1.1 MB 64.4 MB/s 
[K     |████████████████████████████████| 63 kB 1.7 MB/s 
[?25h  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Building wheel for pynndescent (setup.py) ... [?25l[?25hdone
  Building wheel for session-info (setup.py) ... [?25l[?25hdone


In [3]:
import sys
import os
import gc
import collections 
import anndata as ad
from argparse import Namespace

config = Namespace(
    DEVICE = 'cuda',
    BATCH_SIZE = 128,
    NUM_WORKERS = 4,
    N_GENES = 13431,
    N_PEAKS = 116465,
    MAX_SEQ_LEN_GEX = 1500,
    MAX_SEQ_LEN_ATAC = 15000,
)

In [5]:
print(sys.path)

['/content', 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources', '/env/python', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '', '/usr/local/lib/python3.8/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.8/dist-packages/IPython/extensions', '/root/.ipython']


In [10]:
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/data.py")
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/models.py")

## Import data

In [11]:
index = get_chr_index(ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad"))

In [12]:
gc.collect()

320

In [None]:
batch = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['batch']
batch = list(batch)
train_id = [a for a, l in enumerate(batch) if l not in ['s2d4','s1d1']]
val_id =  [a for a, l in enumerate(batch) if l == 's1d1']
test_id = [a for a, l in enumerate(batch) if l == 's2d4']

cell_type_all = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['cell_type']

csr_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").layers['log_norm']
csr_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad").layers['log_norm']

In [None]:
gc.collect()

In [None]:
import random
random.seed(0)

idx_train = [train_id[i] for i in random.sample(range(0, 45000), 10240)]
gex_train = csr_gex[idx_train,:]
atac_train = csr_atac[idx_train,:]
cell_type_train = [cell_type_all[j] for j in idx_train]

idx_val = [val_id[i] for i in random.sample(range(0, 6000), 1024)]
gex_val = csr_gex[idx_val,:]
atac_val = csr_atac[idx_val,:]
cell_type_val = [cell_type_all[j] for j in idx_val]

data_train, data_val = get_dataloaders(gex_train, atac_train, cell_type_train, gex_val, atac_val, cell_type_val)

In [None]:
def train(model, criterion, optimizer, data_train, epochs, loss_type):

  model.train()

  for e in range(epochs):
    running_loss = 0.0
    running_loss_cross = 0.0
    running_loss_triplet = 0.0
    running_ct_prob = 0.0
    for iter, data in enumerate(data_train):
      gex_input = data['gex'].to(config.DEVICE)
      atac_input = data['atac'].to(config.DEVICE)
      cell_type_input = data['cell_type']
      # print(cell_type_input)

      model.zero_grad()
      optimizer.zero_grad()

      ### Forward
      gex_out_0, gex_out_1, atac_out_0, atac_out_1 = model(gex_input, atac_input)

      ### Compute loss
      loss, loss_triplet, loss_cross, ct_match_prob = criterion(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type_input)
      
      ### Propagate loss
      if loss_type == "both":
        running_loss += loss.item()
        loss.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  
        
      elif loss_type == "entropy":
        running_loss += loss_cross.item()
        loss_cross.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_loss_triplet += loss_triplet.item()       
        running_loss_cross += loss_cross.item()  

      elif loss_type == "triplet":
        running_loss += loss_triplet.item()
        loss_triplet.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  

      else:
        break

      ### update parameters
      optimizer.step()

      del gex_input
      del atac_input
      del cell_type_input
      torch.cuda.empty_cache()
      # print('Within epoch: cross_loss = ', loss_cross.item(), '; triplet_loss = ', loss_triplet.item(), '; ct_match_prob = ', ct_match_prob.item())

    print('cross_loss = ', loss_cross.item(), '; triplet_loss = ', loss_triplet.item(), '; ct_match_prob = ', ct_match_prob.item())
    if (e+1) % 10 == 0: 
      print('Epoch-{0}: lr = {1}, loss = {2}, entropy_loss = {3}, triplet loss = {4}, cell type match prob = {5}'.format(
          e+1, 
          optimizer.param_groups[0]['lr'], 
          running_loss / len(data_train), 
          running_loss_cross / len(data_train), 
          running_loss_triplet / len(data_train), 
          running_ct_prob / len(data_train)
          )
      )

    # scheduler.step()

## Select hyperparameters

In [None]:
config = Namespace(
    ALPHA = 0.8,
    MARGIN = 0.5,
    N_CHANNELS = 64,
    LEARNING_RATE = 0.00002,
)

In [None]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [100], gamma = 0.1)



In [None]:
train(model, criterion, optimizer, data_train, epochs = 100, loss_type = "triplet") 

cross_loss =  4.900207042694092 ; triplet_loss =  0.6797603368759155 ; ct_match_prob =  0.05097425729036331
cross_loss =  4.881697654724121 ; triplet_loss =  0.6276385188102722 ; ct_match_prob =  0.05448698252439499
cross_loss =  4.866706848144531 ; triplet_loss =  0.6524909734725952 ; ct_match_prob =  0.05606905370950699
cross_loss =  4.746600151062012 ; triplet_loss =  0.5818232893943787 ; ct_match_prob =  0.06194540858268738
cross_loss =  4.537942886352539 ; triplet_loss =  0.5772849917411804 ; ct_match_prob =  0.06884630024433136
cross_loss =  4.448002815246582 ; triplet_loss =  0.5775848627090454 ; ct_match_prob =  0.0724879801273346


KeyboardInterrupt: ignored

In [None]:
torch.save(model.state_dict(), 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_10240cells_entropy_100epochs')

In [None]:
train(model, data_train, epochs = 100, loss_type = "triplet") 