In [48]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [49]:
!pip install scanpy --quiet

In [50]:
import torch
from torch import nn
from torch.autograd import Variable
import anndata as ad
import numpy as np
import os
import collections 
from argparse import Namespace
from torch.utils.data import Dataset, DataLoader
config = Namespace(
    DEVICE = 'cuda',
    BATCH_SIZE = 64,
    NUM_WORKERS = 4,
    N_GENES = 13431,
    N_PEAKS = 116465,
    N_CHANNELS = 128,
    MAX_SEQ_LEN_GEX = 1500,
    MAX_SEQ_LEN_ATAC = 15000,
    LEARNING_RATE = 0.0005
)

In [51]:
# adata_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad")

In [52]:
# adata_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad")

In [53]:
def get_chr_index(adata_atac):
  r"""
  Output row indices for each chromosome for each chromosome
  Parameters
  ----------
  adata_atac
      annData for ATAC
  Returns
  -------
  chr_index
      Dictionary of indices for each chromosome
  """
  row_name = adata_atac.var.index
  chr_name = [c.split("-")[0] for c in row_name]
  lst = np.unique(chr_name) # names for chromosome

  chr_index = dict()
  for i in range(len(lst)):
    index = [a for a, l in enumerate(chr_name) if l == lst[i]]
    if lst[i] not in chr_index:
      chr_index[lst[i]]=index

  return chr_index

In [54]:
## Write cnn modules for gex modalities
class gexCNN(nn.Module):
    """customized  module"""
    #argument index is the poisition for each choromosome
    def __init__(self, kernel_size):
        super(gexCNN, self).__init__()

        # Conv layer
        self.in_channels = 1 
        self.out_channels = config.N_CHANNELS
        self.kernel_size = kernel_size   
        self.stride = 50 # TO CHANGE 
        self.padding = 25 # TO CHANGE
        self.pool_size = 2
        self.pool_stride = 1
        self.convs = nn.Sequential(
            nn.Conv1d(in_channels = self.in_channels, 
                      out_channels = self.out_channels, 
                      kernel_size = self.kernel_size,
                      stride = self.stride,
                      padding = self.padding),
            nn.LeakyReLU(),
            nn.MaxPool1d(kernel_size = self.pool_size,
                         stride = self.pool_stride)
        )

        # # FC layer
        # self.conv_out_features = int((config.N_GENES + 2*self.padding - self.kernel_size) / self.stride + 1)
        # self.fc_in_features = int((self.conv_out_features - self.pool_size) / self.pool_stride + 1) * self.out_channels
        # self.fc_out_feature = 300
        # self.fc = nn.Linear(in_features = self.fc_in_features, out_features = self.fc_out_feature) 

    def forward(self, x):
        r"""  
        Generate GEX embeddings
        
        Parameters
        ----------
        x
            Pre-processed GEX data (batch_size x 1 x N_GENES)
        
        Returns
        -------
        gex_embed
            GEX embeddings of a batch (batch_size x seq_len x dim_size)
        """
        gex_embed = self.convs(x.float())
        # gex_embed = torch.flatten(gex_embed, 1)
        # gex_embed = self.fc(gex_embed)
        return gex_embed.transpose(1,2).to(config.DEVICE)

In [55]:
# # Test for gexCNN()
# x = torch.tensor(np.asarray(csr_gex[:5].todense())).unsqueeze(1) # 5 cells
# print(x.size())
# model = gexCNN(kernel_size = 100)
# print(model(x).size())

In [56]:
# Write cnn modules for atac modalities
class atacCNN(nn.Module):
    #argument index is the poisition for each choromosome
    def __init__(self, index, kernel_size_1, kernel_size_2):
        super(atacCNN, self).__init__()
        self.index = index
        
        # Conv layer
        self.in_channels_1 = 1 
        self.out_channels_1 = int(config.N_CHANNELS / 2)
        self.kernel_size_1 = kernel_size_1
        self.stride_1 = 10 # TO CHANGE 
        self.padding_1 = 5 # TO CHANGE

        self.in_channels_2 = int(config.N_CHANNELS / 2)
        self.out_channels_2 = config.N_CHANNELS 
        self.kernel_size_2 = kernel_size_2
        self.stride_2 = 5 # TO CHANGE 
        self.padding_2 = 3 # TO CHANGE

        self.convs = nn.Sequential(
            nn.Conv1d(in_channels = self.in_channels_1, 
                      out_channels = self.out_channels_1, 
                      kernel_size = self.kernel_size_1,
                      stride = self.stride_1,
                      padding = self.padding_1),
            nn.LeakyReLU(),
            nn.MaxPool1d(kernel_size = 5, stride = 2),

            nn.Conv1d(in_channels = self.in_channels_2, 
                      out_channels = self.out_channels_2, 
                      kernel_size = self.kernel_size_2,
                      stride = self.stride_2,
                      padding = self.padding_2),
            nn.LeakyReLU(),
            nn.MaxPool1d(kernel_size = 2, stride = 1)             
        )



    def forward(self, x):
        r"""  
        Generate ATAC embeddings
        
        Parameters
        ----------
        x
            Pre-processed ATAC data (batch_size x 1 x N_PEAKS)
        
        Returns
        -------
        atac_embed
            ATAC embeddings of a batch (batch_size x seq_len x dim_size)
        """
        atac_embed = []
        for chr in self.index.keys(): 
            idx = self.index[chr]
            x_chr = x[:,:,idx]
            x_chr = self.convs(x_chr.float())
            atac_embed.append(x_chr)
        atac_embed = torch.cat(atac_embed, dim = 2)
        return atac_embed.transpose(1,2).to(config.DEVICE)

In [57]:
# # Test for ATAC_CNN()
# x = torch.tensor(np.asarray(csr_atac[:5].todense())).unsqueeze(1) # 5 cells
# print(x.size())
# # index = get_chr_index(adata_atac)
# model = atacCNN(kernel_size_1 = 50, kernel_size_2 = 10, index = index)
# print(model(x).size())

In [58]:
class MultimodalAttention(nn.Module):
    def __init__(self):
        super(MultimodalAttention, self).__init__()
        self.nhead_gex = 1
        self.nhead_atac = 4
        self.nhead_multi = 4
        self.nlayer_gex = 1
        self.nlayer_atac = 1
        self.nlayer_multi = 1

        self.encoder_layer_gex = nn.TransformerEncoderLayer(d_model = config.N_CHANNELS, nhead = self.nhead_gex)
        self.transformer_encoder_gex = nn.TransformerEncoder(self.encoder_layer_gex, num_layers = self.nlayer_gex)
        self.linear_gex_0 = nn.LazyLinear(out_features = 1)

        self.encoder_layer_atac = nn.TransformerEncoderLayer(d_model = config.N_CHANNELS, nhead = self.nhead_atac)
        self.transformer_encoder_atac = nn.TransformerEncoder(self.encoder_layer_atac, num_layers = self.nlayer_atac)
        self.linear_atac_0 = nn.LazyLinear(out_features = 1)

        self.encoder_layer_multi = nn.TransformerEncoderLayer(d_model = config.N_CHANNELS, nhead = self.nhead_multi)
        self.transformer_encoder_multi = nn.TransformerEncoder(self.encoder_layer_multi, num_layers = self.nlayer_multi)
        self.linear_gex_1 = nn.LazyLinear(out_features = 1)
        self.linear_atac_1 = nn.LazyLinear(out_features = 1)
    

    def forward(self, gex_embed, atac_embed):
      r"""  
      Incorporate two self-attention and one cross-attention module

      Parameters
      ----------
      gex_embed
          GEX embeddings of a batch (batch_size x seq_len_gex x dim_size)
      atac_embed
          ATAC embeddings of a batch (batch_size x seq_len_atac x dim_size)

      Returns
      -------
      ## TO FILL
      """
      seq_len_gex = gex_embed.size()[1]
      seq_len_atac = atac_embed.size()[1]
      # print(gex_embed.size())
      # print(atac_embed.size())

      gex_context = self.transformer_encoder_gex(gex_embed)
      atac_context = self.transformer_encoder_atac(atac_embed)

      # # Average self-attention fragment representation
      # gex_out_0 = gex_context.mean(dim = 1)
      # atac_out_0 = atac_context.mean(dim = 1)
      gex_out_0 = self.linear_gex_0(gex_context.permute(0,2,1)).squeeze(2)
      atac_out_0 = self.linear_atac_0(atac_context.permute(0,2,1)).squeeze(2)

      multi_embed = torch.cat((gex_context, atac_context), dim = 1)
      multi_context = self.transformer_encoder_multi(multi_embed)
      # print(multi_context.size())
      
      multi_context_gex = multi_context[:, :seq_len_gex, :]
      multi_context_atac = multi_context[:, seq_len_gex:, :]

      # # Average cross-attention fragment representation
      # gex_out_1 = multi_context_gex.mean(dim = 1)
      # atac_out_1 = multi_context_atac.mean(dim = 1)
      gex_out_1 = self.linear_gex_1(multi_context_gex.permute(0,2,1)).squeeze(2)
      atac_out_1 = self.linear_atac_1(multi_context_atac.permute(0,2,1)).squeeze(2)

      return gex_out_0.to(config.DEVICE), gex_out_1.to(config.DEVICE), atac_out_0.to(config.DEVICE), atac_out_1.to(config.DEVICE)

In [59]:
# # index = get_chr_index(adata_atac)

# x_gex = torch.tensor(np.asarray(csr_gex[:5].todense())).unsqueeze(1).to(config.DEVICE) # 5 cells
# x_atac = torch.tensor(np.asarray(csr_atac[:5].todense())).unsqueeze(1).to(config.DEVICE) # 5 cells

# gex_cnn = gexCNN(kernel_size = 100).to(config.DEVICE)
# atac_cnn = atacCNN(kernel_size_1 = 50, kernel_size_2 = 10, index = index).to(config.DEVICE)
# multi_attention = MultimodalAttention().to(config.DEVICE)

# gex_embed = gex_cnn(x_gex).to(config.DEVICE)
# atac_embed = atac_cnn(x_atac).to(config.DEVICE)

# gex_out_0, gex_out_1, atac_out_0, atac_out_1 = multi_attention(gex_embed, atac_embed)
# print(atac_out_0.size())

In [60]:
# embedding = nn.Embedding(1000, 128)
# anchor_ids = torch.randint(0, 1000, (1,))
# positive_ids = torch.randint(0, 1000, (1,))
# negative_ids = torch.randint(0, 1000, (1,))
# anchor = embedding(anchor_ids)
# positive = embedding(positive_ids)
# negative = embedding(negative_ids)

In [61]:
# print(anchor.size())
# print(positive.size())
# print(negative.size())

In [62]:
class Encoder(nn.Module):
    def __init__(self, kernel_size_gex, kernel_size_atac_1, kernel_size_atac_2, index):
        super(Encoder, self).__init__()

        self.kernel_size_gex = kernel_size_gex
        self.kernel_size_atac_1 = kernel_size_atac_1
        self.kernel_size_atac_2 = kernel_size_atac_2
        self.index = index

        self.gex_cnn = gexCNN(kernel_size = self.kernel_size_gex)
        self.atac_cnn = atacCNN(kernel_size_1 = self.kernel_size_atac_1, kernel_size_2 = self.kernel_size_atac_2, index = self.index)
        self.multi_attention = MultimodalAttention()

        
    def forward(self, x_gex, x_atac):

        gex_embed = self.gex_cnn(x_gex)
        atac_embed = self.atac_cnn(x_atac)
        gex_out_0, gex_out_1, atac_out_0, atac_out_1 = self.multi_attention(gex_embed, atac_embed)

        return gex_out_0, gex_out_1, atac_out_0, atac_out_1


In [63]:
# gex_test = torch.tensor(np.asarray(adata_gex.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells
# atac_test = torch.tensor(np.asarray(adata_atac.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells

# index = get_chr_index(adata_atac)


In [64]:
# encoder = Encoder(kernel_size = 32, index = index )
# gex_out_0, gex_out_1, atac_out_0, atac_out_1 = encoder(gex_test, atac_test)

In [65]:
from numpy.lib.shape_base import row_stack
class bidirectTripletLoss(nn.Module):
    r"""
    
    Output bidirectional triplet loss for two pairs of gex and atac
    ----------
    gex_0_mat: Matrix of GEX embeddings from self-attention (batch_size x embedding_size_0)

    Returns
    -------
    loss
    """
    def __init__(self, alpha, margin):
        super(bidirectTripletLoss, self).__init__()

        self.alpha = alpha
        self.margin = margin
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def similarityScore(self, gex_out_0, gex_out_1, atac_out_0, atac_out_1):
        r"""
        Output similarity scores for two pairs of gex and atac
        ----------


        Returns
        -------
        score: batch_size * batch_size
        similarity score between two modalities
        """ 

        # print(gex_mat.size())
        # print(atac_mat.size())

        # gex_mat0, gex_mat1 = torch.split(gex_mat, config.N_CHANNELS, dim = 1)
        # atac_mat0, atac_mat1 = torch.split(atac_mat,config.N_CHANNELS, dim = 1)

        gex_out_0 = nn.functional.normalize(gex_out_0, dim = 1)
        gex_out_1 = nn.functional.normalize(gex_out_1, dim = 1)
        atac_out_0 = nn.functional.normalize(atac_out_0, dim = 1)
        atac_out_1 = nn.functional.normalize(atac_out_1, dim = 1)
        
        score = torch.mm(gex_out_0, atac_out_0.transpose(0,1)) + self.alpha * torch.mm(gex_out_1, atac_out_1.transpose(0,1))
        return  score.to(config.DEVICE)

    def triplet(self, score_mat):

        true_score = torch.diagonal(score_mat)
        # print("true_score:\n", true_score)
        # print("true_score dimension:\n", true_score.size())
        
        reduced_score_mat = score_mat - torch.diag(true_score) # set the diagnoal score to be zero
        
        neg_index_1 = torch.argmax(reduced_score_mat, dim = 1) # indices of hard negatives for GEX
        neg_index_2 = torch.argmax(reduced_score_mat, dim = 0) # indices of hard negatives for ATAC
        # print("neg_index_1:\n", neg_index_1); print("neg_index_2:\n", neg_index_2)
        # print("neg_index_1 dimension:\n", neg_index_1.size()); print("neg_index_2 dimension:\n", neg_index_2.size())
        neg_1 = score_mat[[range(config.BATCH_SIZE), neg_index_1]] # hard negatives for GEX 
        neg_2 = score_mat[[neg_index_2, range(config.BATCH_SIZE)]] # hard negatives for ATAC
        # print("neg_1:\n", neg_1); print("neg_2:\n", neg_2)
        
        loss_1 = torch.max(self.margin - true_score + neg_1, torch.zeros(1, config.BATCH_SIZE).to(config.DEVICE))
        loss_2 = torch.max(self.margin - true_score + neg_2, torch.zeros(1, config.BATCH_SIZE).to(config.DEVICE))

        return torch.mean(loss_1 + loss_2)

    def crossEntropy(self, score_mat):

        batch_size = score_mat.size()[0]
        target = torch.arange(batch_size)

        loss_1 = self.cross_entropy_loss(score_mat, target.to(config.DEVICE))
        loss_2 = self.cross_entropy_loss(score_mat.T, target.to(config.DEVICE))

        return 0.5 * (loss_1 + loss_2)

    def cellTypeMatchingProbRow(self, score_mat, cell_type):

        # Collect list of index list for each cell type
        idx_in_type = collections.defaultdict(list)
        for i, x in enumerate(cell_type):
            idx_in_type[x].append(i)

        # Compute matching probs for each cell type
        score_mat_norm = score_mat.softmax(dim = 0)
        probs = []
        for idx in idx_in_type.values():
            prob_type = 0
            for i in idx:
                prob_type += score_mat_norm[i, idx].sum()
            probs.append(prob_type / len(idx)) 

        # Take average of matching prob from cell types
        ct_match_prob = torch.tensor(probs).mean()

        return ct_match_prob

    def cellTypeMatchingProb(self, score_mat, cell_type):
        row = self.cellTypeMatchingProbRow(score_mat, cell_type) # Softmax on rows (normalize GEX)
        col = self.cellTypeMatchingProbRow(score_mat.T, cell_type) # Softmax on cols (normalize ATAC) 
        
        return 0.5 * (row + col)

    def forward(self, gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type):
      
        score_mat = self.similarityScore(gex_out_0, gex_out_1, atac_out_0, atac_out_1)#; print("score_mat:\n", score_mat)

        loss_triplet = self.triplet(score_mat)
        loss_cross = self.crossEntropy(score_mat)
        loss = loss_triplet + loss_cross
        # print(score_mat); print(cell_type)
        ct_match_prob = self.cellTypeMatchingProb(score_mat, cell_type)

        return loss.to(config.DEVICE), loss_triplet, loss_cross, ct_match_prob

In [66]:
# TEST
# import random
# random.seed(0)
# gex_mat=torch.randn([5,64])
# # print(gex_mat[0,])
# atac_mat=torch.randn([5,64])
# # print(atac_mat[0,])
# # print(torch.dot(gex_mat[0,:32],atac_mat[0,:32]))
# loss=bidirectTripletLoss(alpha=0.2,margin=1)
# res=loss(gex_mat,atac_mat)
# res

In [67]:
# def cellTypeMatchingProbRow(score_mat, cell_type):

#     # Collect list of index list for each cell type
#     idx_in_type = collections.defaultdict(list)
#     for i, x in enumerate(cell_type):
#         idx_in_type[x].append(i)

#     # Compute matching probs for each cell type
#     score_mat_norm = score_mat.softmax(dim = 0)
#     print('score_mat_norm: \n', score_mat_norm)
#     probs = []
#     for idx in idx_in_type.values():
#         print('idx: ', idx)
#         prob_type = 0
#         for i in idx:
#             prob_type += score_mat_norm[i, idx].sum()
#         print('prob_type: ', prob_type)
#         probs.append(prob_type / len(idx)) 
#     print('probs: ', probs, "\n")

#     # Take average of matching prob from cell types
#     ct_match_prob = torch.tensor(probs).mean()

#     return ct_match_prob

# def cellTypeMatchingProb(score_mat, cell_type):
#     row = cellTypeMatchingProbRow(score_mat, cell_type) # Softmax on rows (normalize GEX)
#     col = cellTypeMatchingProbRow(score_mat.T, cell_type) # Softmax on cols (normalize ATAC) 
    
#     return 0.5 * (row + col)

# random.seed(1)
# mat=torch.randn([5,5])
# cell_type=['a','a','b','a','b']
# print(mat)
# print(cellTypeMatchingProb(mat, cell_type))

In [68]:
class MultiomeDataset(Dataset):
    def __init__(
        self, csr_gex, csr_atac, cell_type
    ):
        super().__init__()
        
        self.csr_gex = csr_gex
        self.csr_atac = csr_atac
    
    def __len__(self):
        return self.csr_gex.shape[0]
    
    def __getitem__(self, index: int):
        x_gex = torch.tensor(self.csr_gex[index,:].todense())
        x_atac = torch.tensor(self.csr_atac[index,:].todense())
        return {'gex':x_gex, 'atac':x_atac, 'cell_type':cell_type[index]}
  
def get_dataloaders(gex_train, atac_train, cell_type_train,  gex_val, atac_val, cell_type_val):
    
    # mod2_train = mod2_train.iloc[sol_train.values.argmax(1)]
    # mod2_test = mod2_test.iloc[sol_test.values.argmax(1)]
    
    dataset_train = MultiomeDataset(gex_train, atac_train, cell_type_train)
    data_train = DataLoader(dataset_train, config.BATCH_SIZE, shuffle = True, num_workers = config.NUM_WORKERS)
    
    dataset_val = MultiomeDataset(gex_val, atac_val, cell_type_val)
    data_val = DataLoader(dataset_val, config.BATCH_SIZE, shuffle = False, num_workers = config.NUM_WORKERS)
    
    return data_train, data_val

In [69]:
index = get_chr_index(ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad"))

In [70]:
batch = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['batch']
batch = list(batch)
train_id = [a for a, l in enumerate(batch) if l not in ['s2d4','s1d1']]
val_id =  [a for a, l in enumerate(batch) if l == 's1d1']
test_id = [a for a, l in enumerate(batch) if l == 's2d4']

cell_type = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['cell_type']

csr_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").layers['log_norm']
csr_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad").layers['log_norm']

In [71]:
import random
random.seed(0)

idx_train = [train_id[i] for i in random.sample(range(0, 45000), 10240)]
gex_train = csr_gex[idx_train,:]
atac_train = csr_atac[idx_train,:]
cell_type_train = [cell_type[j] for j in idx_train]

idx_val = [train_id[i] for i in random.sample(range(0, 6000), 1024)]
gex_val = csr_gex[idx_val,:]
atac_val = csr_atac[idx_val,:]
cell_type_val = [cell_type[j] for j in idx_val]

data_train, data_val = get_dataloaders(gex_train, atac_train, cell_type_train, gex_val, atac_val, cell_type_val)

In [72]:
criterion = bidirectTripletLoss(alpha = 0.2, margin = 0.5).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 50, kernel_size_atac_2 = 10, index = index).to(config.DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [100, 300], gamma = 0.1)

In [None]:
def train(model, data_train, epochs):

  model.train()

  for e in range(epochs):
    running_loss = 0.0
    for iter, data in enumerate(data_train):
      gex_input = data['gex'].to(config.DEVICE)
      atac_input = data['atac'].to(config.DEVICE)
      cell_type_input = data['cell_type']
      # print(cell_type_input)

      model.zero_grad()
      optimizer.zero_grad()

      ### Forward
      gex_out_0, gex_out_1, atac_out_0, atac_out_1 = model(gex_input, atac_input)

      ### Compute loss
      loss, loss_triplet, loss_cross, ct_match_prob = criterion(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type_input)
      
      ### Propagate loss
      # running_loss += loss.item()
      # loss.backward()
      running_loss += loss_cross.item()
      loss_cross.backward()

      ### update parameters
      optimizer.step()
    
    # scheduler.step()
    
    print('triplet_loss = ', loss_triplet); print('cross_loss = ', loss_cross); print('cell_type_match_prob = ', ct_match_prob)
    if (e+1) % 10 == 0: 
      print('Epoch-{0} lr: {1}'.format(e+1, optimizer.param_groups[0]['lr']))
      print('epoch[%d] = %.8f' % (e+1, running_loss / len(data_train)))


train(model, data_train, epochs = 100)

triplet_loss =  tensor(1.4120, device='cuda:0', grad_fn=<MeanBackward0>)
cross_loss =  tensor(3.6292, device='cuda:0', grad_fn=<MulBackward0>)
cell_type_match_prob =  tensor(0.0872)
triplet_loss =  tensor(1.2230, device='cuda:0', grad_fn=<MeanBackward0>)
cross_loss =  tensor(3.4763, device='cuda:0', grad_fn=<MulBackward0>)
cell_type_match_prob =  tensor(0.0817)
triplet_loss =  tensor(1.1990, device='cuda:0', grad_fn=<MeanBackward0>)
cross_loss =  tensor(3.4155, device='cuda:0', grad_fn=<MulBackward0>)
cell_type_match_prob =  tensor(0.0922)
triplet_loss =  tensor(1.2382, device='cuda:0', grad_fn=<MeanBackward0>)
cross_loss =  tensor(3.4329, device='cuda:0', grad_fn=<MulBackward0>)
cell_type_match_prob =  tensor(0.0744)
triplet_loss =  tensor(1.2268, device='cuda:0', grad_fn=<MeanBackward0>)
cross_loss =  tensor(3.4182, device='cuda:0', grad_fn=<MulBackward0>)
cell_type_match_prob =  tensor(0.0779)
triplet_loss =  tensor(1.2065, device='cuda:0', grad_fn=<MeanBackward0>)
cross_loss =  ten

In [None]:
# def train(model, data_train, epochs):

#   model.train()

#   for e in range(epochs):
#     running_loss = 0.0
#     for iter, data in enumerate(data_train):
#       gex_input = data['gex'].to(config.DEVICE)
#       atac_input = data['atac'].to(config.DEVICE)

#       model.zero_grad()
#       optimizer.zero_grad()

#       # forward
#       gex_out_0, gex_out_1, atac_out_0, atac_out_1 = model(gex_input, atac_input)
#       # print(gex_out_0); print(gex_out_1); print(atac_out_0); print(atac_out_1); 

#       loss, loss_triplet, loss_cross = criterion(gex_out_0, gex_out_1, atac_out_0, atac_out_1)
      
#       # running_loss += loss.item()
#       # loss.backward()
#       running_loss += loss_triplet.item()
#       loss_triplet.backward()

#       # update parameters
#       optimizer.step()
    
#     # scheduler.step()
    
#     print('triplet_loss = ', loss_triplet); print('cross_loss = ', loss_cross)
#     if (e+1) % 10 == 0: 
#       # print('Epoch-{0} lr: {1}'.format(e+1, optimizer.param_groups[0]['lr']))
#       print('epoch[%d] = %.8f' % (e+1, running_loss / len(data_train)))


# train(model, data_train, epochs = 100)

In [None]:
# train(model, data_train, epochs = 100)

In [None]:
# torch.save(model.state_dict(), 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_10240cells')