In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
!pip install scanpy --quiet

In [3]:
import torch
import sys
import os
import gc
import collections 
import anndata as ad
from argparse import Namespace

config = Namespace(
    LEARNING_RATE = 0.00002,
    DEVICE = 'cuda',
    BATCH_SIZE = 100,
    NUM_WORKERS = 4,
    N_GENES = 13431,
    N_PEAKS = 116465,
    MAX_SEQ_LEN_GEX = 1500,
    MAX_SEQ_LEN_ATAC = 15000,
)

In [4]:
print(sys.path)

['/content', '/env/python', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '', '/usr/local/lib/python3.8/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.8/dist-packages/IPython/extensions', '/root/.ipython']


In [5]:
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/data.py")
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/models.py")

## Import data

In [6]:
index = get_chr_index(ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad"))

In [7]:
gc.collect()

261

In [8]:
batch = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['batch']
batch = list(batch)
train_id = [a for a, l in enumerate(batch) if l not in ['s2d4','s1d1']]
val_id =  [a for a, l in enumerate(batch) if l == 's1d1']
test_id = [a for a, l in enumerate(batch) if l == 's2d4']

cell_type_all = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['cell_type']

csr_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").layers['log_norm']
csr_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad").layers['log_norm']

In [9]:
gc.collect()

240

In [10]:
import random
random.seed(0)

# idx_train = [train_id[i] for i in random.sample(range(0, 45000), 1024)]
idx_train = train_id # Full dataset
gex_train = csr_gex[idx_train,:]
atac_train = csr_atac[idx_train,:]
cell_type_train = [cell_type_all[j] for j in idx_train]

# idx_val = [val_id[i] for i in random.sample(range(0, 6000), 1024)]
idx_val = val_id  # Full dataset
gex_val = csr_gex[idx_val,:]
atac_val = csr_atac[idx_val,:]
cell_type_val = [cell_type_all[j] for j in idx_val]

data_train, data_val = get_dataloaders(gex_train, atac_train, cell_type_train, gex_val, atac_val, cell_type_val)

In [11]:
def train(model, criterion, optimizer, data_train, epochs, loss_type):

  model.train()

  for e in range(epochs):
    running_loss = 0.0
    running_loss_cross = 0.0
    running_loss_triplet = 0.0
    running_ct_prob = 0.0
    for iter, data in enumerate(data_train):
      gex_input = data['gex'].to(config.DEVICE)
      atac_input = data['atac'].to(config.DEVICE)
      cell_type_input = data['cell_type']
      # print(cell_type_input)

      model.zero_grad()
      optimizer.zero_grad()

      ### Forward
      gex_out_0, gex_out_1, atac_out_0, atac_out_1 = model(gex_input, atac_input)

      ### Compute loss
      loss, loss_triplet, loss_cross, ct_match_prob = criterion(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type_input)
      
      ### Propagate loss
      if loss_type == "both":
        running_loss += loss.item()
        loss.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  
        
      elif loss_type == "entropy":
        running_loss += loss_cross.item()
        loss_cross.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_loss_triplet += loss_triplet.item()       
        running_loss_cross += loss_cross.item()  

      elif loss_type == "triplet":
        running_loss += loss_triplet.item()
        loss_triplet.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  

      else:
        break

      ### update parameters
      optimizer.step()

      del gex_input
      del atac_input
      del cell_type_input
      torch.cuda.empty_cache()
      if (iter + 1) % 50 == 0: 
        print('Within-epoch iter', iter + 1, ': cross_loss = ', loss_cross.item(), '; triplet_loss = ', loss_triplet.item(), '; ct_match_prob = ', ct_match_prob.item())

    # print('cross_loss = ', loss_cross.item(), '; triplet_loss = ', loss_triplet.item(), '; ct_match_prob = ', ct_match_prob.item())
    if (e+1) % 1 == 0: 
      print('Epoch-{0}: lr = {1}, loss = {2}, entropy_loss = {3}, triplet loss = {4}, cell type match prob = {5}'.format(
          e+1, 
          optimizer.param_groups[0]['lr'], 
          running_loss / len(data_train), 
          running_loss_cross / len(data_train), 
          running_loss_triplet / len(data_train), 
          running_ct_prob / len(data_train)
          )
      )

    # scheduler.step()

## Select hyperparameters

### Hyperparam set 1

In [12]:
config.ALPHA = 0.2
config.MARGIN = 0.5
config.N_CHANNELS = 32

In [13]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [20], gamma = 0.1)



In [None]:
train(model, criterion, optimizer, data_train, epochs = 30, loss_type = "triplet") 

Within-epoch iter 50 : cross_loss =  4.603342533111572 ; triplet_loss =  0.8728950023651123 ; ct_match_prob =  0.052921898663043976
Within-epoch iter 100 : cross_loss =  4.6072468757629395 ; triplet_loss =  0.7201617360115051 ; ct_match_prob =  0.053178150206804276
Within-epoch iter 150 : cross_loss =  4.637065887451172 ; triplet_loss =  0.6258715987205505 ; ct_match_prob =  0.05323915183544159
Within-epoch iter 200 : cross_loss =  4.6693806648254395 ; triplet_loss =  0.6141918301582336 ; ct_match_prob =  0.05683872103691101
Within-epoch iter 250 : cross_loss =  4.662415504455566 ; triplet_loss =  0.5955432057380676 ; ct_match_prob =  0.05970562994480133
Within-epoch iter 300 : cross_loss =  4.690706253051758 ; triplet_loss =  0.6222420930862427 ; ct_match_prob =  0.05978010594844818
Within-epoch iter 350 : cross_loss =  4.707892894744873 ; triplet_loss =  0.6346340179443359 ; ct_match_prob =  0.05375109612941742
Within-epoch iter 400 : cross_loss =  4.702280044555664 ; triplet_loss = 

In [None]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_30epochs'
torch.save(model.state_dict(), file)

### Hyperparam set 2

In [None]:
config.ALPHA = 0.2
config.MARGIN = 0.5
config.N_CHANNELS = 64

### Hyperparam set 3

In [None]:
config.ALPHA = 0.2
config.MARGIN = 1
config.N_CHANNELS = 32

### Hyperparam set 4

In [None]:
config.ALPHA = 0.2
config.MARGIN = 1
config.N_CHANNELS = 64

### Hyperparam set 5

In [None]:
config.ALPHA = 0.8
config.MARGIN = 0.5
config.N_CHANNELS = 32

### Hyperparam set 6

In [None]:
config.ALPHA = 0.8
config.MARGIN = 0.5
config.N_CHANNELS = 64

### Hyperparam set 7

In [None]:
config.ALPHA = 0.8
config.MARGIN = 1
config.N_CHANNELS = 32

### Hyperparam set 8

In [None]:
config.ALPHA = 0.8
config.MARGIN = 1
config.N_CHANNELS = 64