In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [3]:
!pip install scanpy --quiet

[K     |████████████████████████████████| 2.0 MB 39.5 MB/s 
[K     |████████████████████████████████| 9.4 MB 81.6 MB/s 
[K     |████████████████████████████████| 96 kB 6.9 MB/s 
[K     |████████████████████████████████| 88 kB 9.5 MB/s 
[K     |████████████████████████████████| 295 kB 77.5 MB/s 
[K     |████████████████████████████████| 965 kB 65.2 MB/s 
[K     |████████████████████████████████| 1.1 MB 82.7 MB/s 
[K     |████████████████████████████████| 63 kB 2.7 MB/s 
[?25h  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Building wheel for pynndescent (setup.py) ... [?25l[?25hdone
  Building wheel for session-info (setup.py) ... [?25l[?25hdone


In [39]:
import torch
from torch import nn
from torch.autograd import Variable
import anndata as ad
import numpy as np
import os
import collections 
from argparse import Namespace
from torch.utils.data import Dataset, DataLoader
config = Namespace(
    DEVICE = 'cuda',
    BATCH_SIZE = 128,
    NUM_WORKERS = 4,
    N_GENES = 13431,
    N_PEAKS = 116465,
    N_CHANNELS = 32,
    MAX_SEQ_LEN_GEX = 1500,
    MAX_SEQ_LEN_ATAC = 15000,
    LEARNING_RATE = 0.0005
)

In [5]:
# adata_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad")

In [6]:
# adata_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad")

In [7]:
def get_chr_index(adata_atac):
  r"""
  Output row indices for each chromosome for each chromosome
  Parameters
  ----------
  adata_atac
      annData for ATAC
  Returns
  -------
  chr_index
      Dictionary of indices for each chromosome
  """
  row_name = adata_atac.var.index
  chr_name = [c.split("-")[0] for c in row_name]
  lst = np.unique(chr_name) # names for chromosome

  chr_index = dict()
  for i in range(len(lst)):
    index = [a for a, l in enumerate(chr_name) if l == lst[i]]
    if lst[i] not in chr_index:
      chr_index[lst[i]]=index

  return chr_index

In [8]:
## Write cnn modules for gex modalities
class gexCNN(nn.Module):
    """customized  module"""
    #argument index is the poisition for each choromosome
    def __init__(self, kernel_size):
        super(gexCNN, self).__init__()

        # Conv layer
        self.in_channels = 1 
        self.out_channels = config.N_CHANNELS
        self.kernel_size = kernel_size   
        self.stride = 50 # TO CHANGE 
        self.padding = 25 # TO CHANGE
        self.pool_size = 2
        self.pool_stride = 1
        self.convs = nn.Sequential(
            nn.Conv1d(in_channels = self.in_channels, 
                      out_channels = self.out_channels, 
                      kernel_size = self.kernel_size,
                      stride = self.stride,
                      padding = self.padding),
            nn.LeakyReLU(),
            nn.MaxPool1d(kernel_size = self.pool_size,
                         stride = self.pool_stride)
        )

        # # FC layer
        # self.conv_out_features = int((config.N_GENES + 2*self.padding - self.kernel_size) / self.stride + 1)
        # self.fc_in_features = int((self.conv_out_features - self.pool_size) / self.pool_stride + 1) * self.out_channels
        # self.fc_out_feature = 300
        # self.fc = nn.Linear(in_features = self.fc_in_features, out_features = self.fc_out_feature) 

    def forward(self, x):
        r"""  
        Generate GEX embeddings
        
        Parameters
        ----------
        x
            Pre-processed GEX data (batch_size x 1 x N_GENES)
        
        Returns
        -------
        gex_embed
            GEX embeddings of a batch (batch_size x seq_len x dim_size)
        """
        gex_embed = self.convs(x.float())
        # gex_embed = torch.flatten(gex_embed, 1)
        # gex_embed = self.fc(gex_embed)
        return gex_embed.transpose(1,2).to(config.DEVICE)

In [9]:
# # Test for gexCNN()
# x = torch.tensor(np.asarray(csr_gex[:5].todense())).unsqueeze(1) # 5 cells
# print(x.size())
# model = gexCNN(kernel_size = 100)
# print(model(x).size())

In [10]:
# Write cnn modules for atac modalities
class atacCNN(nn.Module):
    #argument index is the poisition for each choromosome
    def __init__(self, index, kernel_size_1, kernel_size_2):
        super(atacCNN, self).__init__()
        self.index = index
        
        # Conv layer
        self.in_channels_1 = 1 
        self.out_channels_1 = int(config.N_CHANNELS / 2)
        self.kernel_size_1 = kernel_size_1
        self.stride_1 = 50 # TO CHANGE 
        self.padding_1 = 25 # TO CHANGE

        self.in_channels_2 = int(config.N_CHANNELS / 2)
        self.out_channels_2 = config.N_CHANNELS 
        self.kernel_size_2 = kernel_size_2
        self.stride_2 = 5 # TO CHANGE 
        self.padding_2 = 3 # TO CHANGE

        self.convs = nn.Sequential(
            nn.Conv1d(in_channels = self.in_channels_1, 
                      out_channels = self.out_channels_1, 
                      kernel_size = self.kernel_size_1,
                      stride = self.stride_1,
                      padding = self.padding_1),
            nn.LeakyReLU(),
            nn.MaxPool1d(kernel_size = 5, stride = 2),

            nn.Conv1d(in_channels = self.in_channels_2, 
                      out_channels = self.out_channels_2, 
                      kernel_size = self.kernel_size_2,
                      stride = self.stride_2,
                      padding = self.padding_2),
            nn.LeakyReLU(),
            nn.MaxPool1d(kernel_size = 2, stride = 1)             
        )



    def forward(self, x):
        r"""  
        Generate ATAC embeddings
        
        Parameters
        ----------
        x
            Pre-processed ATAC data (batch_size x 1 x N_PEAKS)
        
        Returns
        -------
        atac_embed
            ATAC embeddings of a batch (batch_size x seq_len x dim_size)
        """
        atac_embed = []
        for chr in self.index.keys(): 
            idx = self.index[chr]
            x_chr = x[:,:,idx]
            x_chr = self.convs(x_chr.float())
            atac_embed.append(x_chr)
        atac_embed = torch.cat(atac_embed, dim = 2)
        return atac_embed.transpose(1,2).to(config.DEVICE)

In [11]:
# # Test for ATAC_CNN()
# x = torch.tensor(np.asarray(csr_atac[:5].todense())).unsqueeze(1) # 5 cells
# print(x.size())
# # index = get_chr_index(adata_atac)
# model = atacCNN(kernel_size_1 = 50, kernel_size_2 = 10, index = index)
# print(model(x).size())

In [20]:
class MultimodalAttention(nn.Module):
    def __init__(self):
        super(MultimodalAttention, self).__init__()
        self.nhead_gex = 1
        self.nhead_atac = 4
        self.nhead_multi = 4
        self.nlayer_gex = 1
        self.nlayer_atac = 1
        self.nlayer_multi = 1

        # self.encoder_layer_gex = nn.TransformerEncoderLayer(d_model = config.N_CHANNELS, nhead = self.nhead_gex)
        # self.transformer_encoder_gex = nn.TransformerEncoder(self.encoder_layer_gex, num_layers = self.nlayer_gex)
        # # self.linear_gex_0 = nn.LazyLinear(out_features = 1)
        # self.linear1_gex_0 = nn.LazyLinear(out_features = 10)
        # self.linear2_gex_0 = nn.LazyLinear(out_features = 300)


        # self.encoder_layer_atac = nn.TransformerEncoderLayer(d_model = config.N_CHANNELS, nhead = self.nhead_atac)
        # self.transformer_encoder_atac = nn.TransformerEncoder(self.encoder_layer_atac, num_layers = self.nlayer_atac)
        # # self.linear_atac_0 = nn.LazyLinear(out_features = 1)
        # self.linear1_atac_0 = nn.LazyLinear(out_features = 10)
        # self.linear2_atac_0 = nn.LazyLinear(out_features = 300)

        self.encoder_layer_multi = nn.TransformerEncoderLayer(d_model = config.N_CHANNELS, nhead = self.nhead_multi)
        self.transformer_encoder_multi = nn.TransformerEncoder(self.encoder_layer_multi, num_layers = self.nlayer_multi)
        # self.linear_gex_1 = nn.LazyLinear(out_features = 1)
        # self.linear_atac_1 = nn.LazyLinear(out_features = 1)
        self.linear1_gex_1 = nn.LazyLinear(out_features = 10)
        self.linear2_gex_1 = nn.LazyLinear(out_features = 300)
        self.linear1_atac_1 = nn.LazyLinear(out_features = 10)
        self.linear2_atac_1 = nn.LazyLinear(out_features = 300)
    

    def forward(self, gex_embed, atac_embed):
      r"""  
      Incorporate two self-attention and one cross-attention module

      Parameters
      ----------
      gex_embed
          GEX embeddings of a batch (batch_size x seq_len_gex x dim_size)
      atac_embed
          ATAC embeddings of a batch (batch_size x seq_len_atac x dim_size)

      Returns
      -------
      ## TO FILL
      """
      seq_len_gex = gex_embed.size()[1]
      seq_len_atac = atac_embed.size()[1]

      # # self-attention fragment representation
      # gex_context = self.transformer_encoder_gex(gex_embed)
      # atac_context = self.transformer_encoder_atac(atac_embed)
      # # gex_out_0 = self.linear_gex_0(gex_context.permute(0,2,1)).squeeze(2)
      # # atac_out_0 = self.linear_atac_0(atac_context.permute(0,2,1)).squeeze(2)
      # gex_out_0  = self.linear1_gex_0(gex_context.permute(0,2,1)).squeeze(2)
      # # print(gex_out_0.size())
      # gex_out_0  = self.linear2_gex_0(gex_out_0.flatten(start_dim=1))
      # # print(gex_out_0.size())
      # atac_out_0 = self.linear1_atac_0(atac_context.permute(0,2,1)).squeeze(2)
      # # print(atac_out_0.size())
      # atac_out_0 = self.linear2_atac_0(atac_out_0.flatten(start_dim=1))
      # # print(atac_out_0.size())

      multi_embed = torch.cat((gex_embed, atac_embed), dim = 1) # ignore self-attention module
      # multi_embed = torch.cat((gex_context, atac_context), dim = 1)
      multi_context = self.transformer_encoder_multi(multi_embed)
      
      multi_context_gex = multi_context[:, :seq_len_gex, :]
      multi_context_atac = multi_context[:, seq_len_gex:, :]
      
      # # Average cross-attention fragment representation
      # gex_out_1 = multi_context_gex.mean(dim = 1)
      # atac_out_1 = multi_context_atac.mean(dim = 1)
      gex_out_1  = self.linear1_gex_1(multi_context_gex.permute(0,2,1)).squeeze(2)
      gex_out_1  = self.linear2_gex_1(gex_out_1.flatten(start_dim=1))
      atac_out_1 = self.linear1_atac_1(multi_context_atac.permute(0,2,1)).squeeze(2)
      atac_out_1 = self.linear2_atac_1(atac_out_1.flatten(start_dim=1))

      return gex_out_1.to(config.DEVICE), atac_out_1.to(config.DEVICE)
      # return gex_out_0.to(config.DEVICE), gex_out_1.to(config.DEVICE), atac_out_0.to(config.DEVICE), atac_out_1.to(config.DEVICE)

In [21]:
# # index = get_chr_index(adata_atac)

x_gex = torch.tensor(np.asarray(csr_gex[:5].todense())).unsqueeze(1).to(config.DEVICE) # 5 cells
x_atac = torch.tensor(np.asarray(csr_atac[:5].todense())).unsqueeze(1).to(config.DEVICE) # 5 cells

gex_cnn = gexCNN(kernel_size = 100).to(config.DEVICE)
atac_cnn = atacCNN(kernel_size_1 = 50, kernel_size_2 = 10, index = index).to(config.DEVICE)
multi_attention = MultimodalAttention().to(config.DEVICE)

gex_embed = gex_cnn(x_gex).to(config.DEVICE)
atac_embed = atac_cnn(x_atac).to(config.DEVICE)
print(gex_embed.size()); print(atac_embed.size())
gex_out_1, atac_out_1 = multi_attention(gex_embed, atac_embed)
# gex_out_0, gex_out_1, atac_out_0, atac_out_1 = multi_attention(gex_embed, atac_embed)
print(atac_out_1.size())

torch.Size([5, 267, 32])
torch.Size([5, 196, 32])
torch.Size([5, 300])


In [None]:
# embedding = nn.Embedding(1000, 128)
# anchor_ids = torch.randint(0, 1000, (1,))
# positive_ids = torch.randint(0, 1000, (1,))
# negative_ids = torch.randint(0, 1000, (1,))
# anchor = embedding(anchor_ids)
# positive = embedding(positive_ids)
# negative = embedding(negative_ids)

In [None]:
# print(anchor.size())
# print(positive.size())
# print(negative.size())

In [22]:
class Encoder(nn.Module):
    def __init__(self, kernel_size_gex, kernel_size_atac_1, kernel_size_atac_2, index):
        super(Encoder, self).__init__()

        self.kernel_size_gex = kernel_size_gex
        self.kernel_size_atac_1 = kernel_size_atac_1
        self.kernel_size_atac_2 = kernel_size_atac_2
        self.index = index

        self.gex_cnn = gexCNN(kernel_size = self.kernel_size_gex)
        self.atac_cnn = atacCNN(kernel_size_1 = self.kernel_size_atac_1, kernel_size_2 = self.kernel_size_atac_2, index = self.index)
        self.multi_attention = MultimodalAttention()

        
    def forward(self, x_gex, x_atac):

        gex_embed = self.gex_cnn(x_gex)
        atac_embed = self.atac_cnn(x_atac)
        gex_out_1, atac_out_1 = self.multi_attention(gex_embed, atac_embed)
        # gex_out_0, gex_out_1, atac_out_0, atac_out_1 = self.multi_attention(gex_embed, atac_embed)

        return gex_out_1, atac_out_1
        # return gex_out_0, gex_out_1, atac_out_0, atac_out_1


In [23]:
# gex_test = torch.tensor(np.asarray(adata_gex.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells
# atac_test = torch.tensor(np.asarray(adata_atac.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells

# index = get_chr_index(adata_atac)


In [24]:
# encoder = Encoder(kernel_size = 32, index = index )
# gex_out_0, gex_out_1, atac_out_0, atac_out_1 = encoder(gex_test, atac_test)

In [31]:
# from numpy.lib.shape_base import row_stack
class bidirectTripletLoss(nn.Module):
    r"""
    
    Output bidirectional triplet loss for two pairs of gex and atac
    ----------
    gex_0_mat: Matrix of GEX embeddings from self-attention (batch_size x embedding_size_0)

    Returns
    -------
    loss
    """
    def __init__(self, alpha, margin):
        super(bidirectTripletLoss, self).__init__()

        self.alpha = alpha
        self.margin = margin
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.triplet_loss = nn.TripletMarginLoss(margin=self.margin, p=2)

    def similarityScore(self, gex_out_1, atac_out_1):
    # def similarityScore(self, gex_out_0, gex_out_1, atac_out_0, atac_out_1):
        r"""
        Output similarity scores for two pairs of gex and atac
        ----------


        Returns
        -------
        score: batch_size * batch_size
        similarity score between two modalities
        """ 

        # print(gex_mat.size())
        # print(atac_mat.size())

        # gex_mat0, gex_mat1 = torch.split(gex_mat, config.N_CHANNELS, dim = 1)
        # atac_mat0, atac_mat1 = torch.split(atac_mat,config.N_CHANNELS, dim = 1)

        # gex_out_0 = nn.functional.normalize(gex_out_0, dim = 1)
        gex_out_1 = nn.functional.normalize(gex_out_1, dim = 1)
        # atac_out_0 = nn.functional.normalize(atac_out_0, dim = 1)
        atac_out_1 = nn.functional.normalize(atac_out_1, dim = 1)
        
        score = torch.mm(gex_out_1, atac_out_1.transpose(0,1))
        # score = torch.mm(gex_out_0, atac_out_0.transpose(0,1)) + self.alpha * torch.mm(gex_out_1, atac_out_1.transpose(0,1))
        return  score.to(config.DEVICE)
        # return score.to('cpu')

    def triplet(self, gex_out_1, atac_out_1):
    # def triplet(self, gex_out_0, gex_out_1, atac_out_0, atac_out_1):

        score_mat = self.similarityScore(gex_out_1, atac_out_1)
        # score_mat = self.similarityScore(gex_out_0, gex_out_1, atac_out_0, atac_out_1)
        true_score = torch.diagonal(score_mat)
        gex_out = gex_out_1
        # gex_out = torch.cat((gex_out_0, gex_out_1), dim=1)
        atac_out = atac_out_1
        # atac_out = torch.cat((atac_out_0, atac_out_1), dim=1)
        # print(gex_out.size())

        gex_out = nn.functional.normalize(gex_out, dim = 1)
        atac_out = nn.functional.normalize(atac_out, dim = 1)
        # print(gex_out.size())
        # print("true_score:\n", true_score)
        # # print("true_score dimension:\n", true_score.size())
        
        reduced_score_mat = score_mat - torch.diag(true_score) # set the diagnoal score to be zero
        
        neg_index_1 = torch.argmax(reduced_score_mat, dim = 1) # indices of hard negatives for GEX
        neg_index_2 = torch.argmax(reduced_score_mat, dim = 0) # indices of hard negatives for ATAC
        # print("neg_index_1:\n", neg_index_1); print("neg_index_2:\n", neg_index_2)
        # print("neg_index_1 dimension:\n", neg_index_1.size()); print("neg_index_2 dimension:\n", neg_index_2.size())
        # neg_1 = score_mat[[range(config.BATCH_SIZE), neg_index_1]] # hard negatives for GEX 
        # neg_1 = score_mat[np.ix_(np.arange(config.BATCH_SIZE),neg_index_1)]
        # neg_2 = score_mat[np.ix_(neg_index_2,np.arange(config.BATCH_SIZE))]

        #for each gex sample, the most hard negitve atac sample
        neg_1_samples = atac_out[neg_index_1,:]

        #for each atac sample, the most hard negitve gex sample
        neg_2_samples = gex_out[neg_index_2,:]

        # neg_2 = score_mat[[neg_index_2, range(config.BATCH_SIZE)]] # hard negatives for ATAC
        # print("neg_1:\n", neg_1); print("neg_2:\n", neg_2)
        
        loss_11 = self.triplet_loss(anchor=gex_out, positive=atac_out, negative=neg_1_samples)
        loss_22 = self.triplet_loss(anchor=atac_out, positive=gex_out, negative=neg_2_samples)
        # loss_1 = torch.max(self.margin - true_score + neg_1, torch.zeros(1, config.BATCH_SIZE).to(config.DEVICE))
        # loss_2 = torch.max(self.margin - true_score + neg_2, torch.zeros(1, config.BATCH_SIZE).to(config.DEVICE))

        return torch.mean(loss_11 + loss_22)

    def crossEntropy(self, score_mat):

        batch_size = score_mat.size()[0]
        target = torch.arange(batch_size)

        loss_1 = self.cross_entropy_loss(score_mat, target.to(config.DEVICE))
        loss_2 = self.cross_entropy_loss(score_mat.T, target.to(config.DEVICE))

        return 0.5 * (loss_1 + loss_2)

    def cellTypeMatchingProbRow(self, score_mat, cell_type):

    # Collect list of index list for each cell type
      idx_in_type = collections.defaultdict(list)
      for i, x in enumerate(cell_type):
          idx_in_type[x].append(i)

    # Compute matching probs for each cell type
      score_mat_norm = score_mat.softmax(dim = 0)
      # print('score_mat_norm: \n', score_mat_norm)
      probs = []
      for idx in idx_in_type.values():
        # print('idx: ', idx)
        prob_type = 0
        for i in idx:
            prob_type += score_mat_norm[i, idx].sum()
        # print('prob_type: ', prob_type)
        probs.append(prob_type / len(idx)) 
      # print('probs: ', probs, "\n")

    # Take average of matching prob from cell types
      ct_match_prob = torch.tensor(probs).mean()

      return ct_match_prob

    def cellTypeMatchingProb(self, score_mat, cell_type):
      row = self.cellTypeMatchingProbRow(score_mat, cell_type) # Softmax on rows (normalize GEX)
      col = self.cellTypeMatchingProbRow(score_mat.T, cell_type) # Softmax on cols (normalize ATAC) 
    
      return 0.5 * (row + col)

    # def cellTypematchingProb(self, score_mat, cell_type):
    #   # Collect list of index list for each cell type
    #     idx_in_type = collections.defaultdict(list)
    #     for i, x in enumerate(cell_type):
    #         idx_in_type[x].append(i)

    #     sum_score_mat=torch.zeros(len(idx_in_type.values()),len(idx_in_type.values()))
    #     for i, dx in enumerate(idx_in_type.values()):
    #       for j, dx2 in enumerate(idx_in_type.values()):
    #         tem=score_mat[np.ix_(dx, dx2)].sum()
    #       # print(tem)
    #         sum_score_mat[i,j]=tem
    
    #     score_mat_norm = 0.5 * (sum_score_mat.softmax(dim = 0) + sum_score_mat.softmax(dim = 1))
 
    #     return torch.mean(torch.diagonal(score_mat_norm))

    def forward(self, gex_out_1, atac_out_1,cell_type):
    # def forward(self, gex_out_0, gex_out_1, atac_out_0, atac_out_1,cell_type):
      
        score_mat = self.similarityScore(gex_out_1, atac_out_1)
        # score_mat = self.similarityScore(gex_out_0, gex_out_1, atac_out_0, atac_out_1)

        loss_triplet = self.triplet(gex_out_1, atac_out_1)
        # loss_triplet = self.triplet(gex_out_0,gex_out_1,atac_out_0,atac_out_1)
        loss_cross = self.crossEntropy(score_mat)
        loss = loss_triplet + loss_cross
    
        ct_match_prob = self.cellTypeMatchingProb(score_mat, cell_type)
        # ct_match_prob = self.cellTypematchingProb(score_mat, cell_type)

        return loss.to(config.DEVICE), loss_triplet, loss_cross, ct_match_prob

In [None]:
# sum_score_mat=torch.zeros(len(idx_in_type.values()),len(idx_in_type.values()))
#     for i, dx in enumerate(idx_in_type.values()):
#       for j, dx2 in enumerate(idx_in_type.values()):
#         tem=score_mat[np.ix_(dx, dx2)].sum()
#         print(tem)
#         sum_score_mat[i,j]=tem
#     print("score_mat before softmax")
#     print(sum_score_mat)
#     score_mat_norm = 0.5 * (sum_score_mat.softmax(dim = 0) + sum_score_mat.softmax(dim = 1))
#     print(sum_score_mat)
#     #return torch.tensor(probs).mean()
#     return np.sum(np.diagonal(score_mat_norm))

In [None]:
# TEST
# import random
# random.seed(0)
# gex_mat=torch.randn([5,64])
# # print(gex_mat[0,])
# atac_mat=torch.randn([5,64])
# # print(atac_mat[0,])
# # print(torch.dot(gex_mat[0,:32],atac_mat[0,:32]))
# loss=bidirectTripletLoss(alpha=0.2,margin=1)
# res=loss(gex_mat,atac_mat)
# res

In [None]:
# def cellTypeMatchingProbRow(score_mat, cell_type):

#     # Collect list of index list for each cell type
#     idx_in_type = collections.defaultdict(list)
#     for i, x in enumerate(cell_type):
#         idx_in_type[x].append(i)

#     # Compute matching probs for each cell type
#     score_mat_norm = score_mat.softmax(dim = 0)
#     print('score_mat_norm: \n', score_mat_norm)
#     probs = []
#     for idx in idx_in_type.values():
#         print('idx: ', idx)
#         prob_type = 0
#         for i in idx:
#             prob_type += score_mat_norm[i, idx].sum()
#         print('prob_type: ', prob_type)
#         probs.append(prob_type / len(idx)) 
#     print('probs: ', probs, "\n")

#     # Take average of matching prob from cell types
#     ct_match_prob = torch.tensor(probs).mean()

#     return ct_match_prob

# def cellTypeMatchingProb(score_mat, cell_type):
#     row = cellTypeMatchingProbRow(score_mat, cell_type) # Softmax on rows (normalize GEX)
#     col = cellTypeMatchingProbRow(score_mat.T, cell_type) # Softmax on cols (normalize ATAC) 
    
#     return 0.5 * (row + col)

# random.seed(1)
# mat=torch.randn([5,5])
# cell_type=['a','a','b','a','b']
# print(mat)
# print(cellTypeMatchingProb(mat, cell_type))

In [26]:
class MultiomeDataset(Dataset):
    def __init__(
        self, csr_gex, csr_atac, cell_type
    ):
        super().__init__()
        
        self.csr_gex = csr_gex
        self.csr_atac = csr_atac
    
    def __len__(self):
        return self.csr_gex.shape[0]
    
    def __getitem__(self, index: int):
        x_gex = torch.tensor(self.csr_gex[index,:].todense())
        x_atac = torch.tensor(self.csr_atac[index,:].todense())
        return {'gex':x_gex, 'atac':x_atac, 'cell_type':cell_type[index]}
  
def get_dataloaders(gex_train, atac_train, cell_type_train,  gex_val, atac_val, cell_type_val):
    
    # mod2_train = mod2_train.iloc[sol_train.values.argmax(1)]
    # mod2_test = mod2_test.iloc[sol_test.values.argmax(1)]
    
    dataset_train = MultiomeDataset(gex_train, atac_train, cell_type_train)
    data_train = DataLoader(dataset_train, config.BATCH_SIZE, shuffle = True, num_workers = config.NUM_WORKERS)
    
    dataset_val = MultiomeDataset(gex_val, atac_val, cell_type_val)
    data_val = DataLoader(dataset_val, config.BATCH_SIZE, shuffle = False, num_workers = config.NUM_WORKERS)
    
    return data_train, data_val

In [14]:
index = get_chr_index(ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad"))

In [15]:
batch = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['batch']
batch = list(batch)
train_id = [a for a, l in enumerate(batch) if l not in ['s2d4','s1d1']]
val_id =  [a for a, l in enumerate(batch) if l == 's1d1']
test_id = [a for a, l in enumerate(batch) if l == 's2d4']

cell_type = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['cell_type']

csr_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").layers['log_norm']
csr_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad").layers['log_norm']

In [34]:
import random
random.seed(0)

idx_train = [train_id[i] for i in random.sample(range(0, 45000), 10240)]
gex_train = csr_gex[idx_train,:]
atac_train = csr_atac[idx_train,:]
cell_type_train = [cell_type[j] for j in idx_train]

idx_val = [train_id[i] for i in random.sample(range(0, 6000), 1024)]
gex_val = csr_gex[idx_val,:]
atac_val = csr_atac[idx_val,:]
cell_type_val = [cell_type[j] for j in idx_val]

data_train, data_val = get_dataloaders(gex_train, atac_train, cell_type_train, gex_val, atac_val, cell_type_val)

In [40]:
criterion = bidirectTripletLoss(alpha = 0.2, margin = 0.5).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 50, kernel_size_atac_2 = 10, index = index).to(config.DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [100], gamma = 0.1)

In [41]:
def train(model, data_train, epochs, loss_type):

  model.train()

  for e in range(epochs):
    running_loss = 0.0
    running_loss_cross = 0.0
    running_loss_triplet = 0.0
    running_ct_prob = 0.0
    for iter, data in enumerate(data_train):
      gex_input = data['gex'].to(config.DEVICE)
      atac_input = data['atac'].to(config.DEVICE)
      cell_type_input = data['cell_type']
      # print(cell_type_input)

      model.zero_grad()
      optimizer.zero_grad()

      ### Forward
      gex_out_1, atac_out_1 = model(gex_input, atac_input)
      # gex_out_0, gex_out_1, atac_out_0, atac_out_1 = model(gex_input, atac_input)

      ### Compute loss
      loss, loss_triplet, loss_cross, ct_match_prob = criterion(gex_out_1, atac_out_1, cell_type_input)
      # loss, loss_triplet, loss_cross, ct_match_prob = criterion(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type_input)
      
      ### Propagate loss
      if loss_type == "both":
        running_loss += loss.item()
        loss.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  
        
      elif loss_type == "entropy":
        running_loss += loss_cross.item()
        loss_cross.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_loss_triplet += loss_triplet.item()       
        running_loss_cross += loss_cross.item()  

      elif loss_type == "triplet":
        running_loss += loss_triplet.item()
        loss_triplet.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  

      else:
        break

      ### update parameters
      optimizer.step()

      del gex_input
      del atac_input
      del cell_type_input
      torch.cuda.empty_cache()

    print('triplet_loss = ', loss_triplet.item(), '; cross_loss = ', loss_cross.item())
    if (e+1) % 10 == 0: 
      print('Epoch-{0}: lr = {1}, loss = {2}, entropy_loss = {3}, triplet loss = {4}, cell type match prob = {5}'.format(
          e+1, 
          optimizer.param_groups[0]['lr'], 
          running_loss / len(data_train), 
          running_loss_cross / len(data_train), 
          running_loss_triplet / len(data_train), 
          running_ct_prob / len(data_train)
          )
      )

    # scheduler.step()

In [42]:
train(model, data_train, epochs = 100, loss_type = "entropy") 

triplet_loss =  1.4923887252807617 ; cross_loss =  4.4708452224731445
triplet_loss =  1.4691064357757568 ; cross_loss =  4.358969688415527
triplet_loss =  1.4725902080535889 ; cross_loss =  4.275668144226074
triplet_loss =  1.458810806274414 ; cross_loss =  4.248224258422852
triplet_loss =  1.4263865947723389 ; cross_loss =  4.201499938964844
triplet_loss =  1.4231605529785156 ; cross_loss =  4.186094284057617
triplet_loss =  1.4690570831298828 ; cross_loss =  4.2050018310546875
triplet_loss =  1.3978073596954346 ; cross_loss =  4.162193298339844
triplet_loss =  1.4489340782165527 ; cross_loss =  4.18262243270874
triplet_loss =  1.448136329650879 ; cross_loss =  4.165668964385986
Epoch-10: lr = 0.0005, loss = 4.15670992732048, entropy_loss = 4.15670992732048, triplet loss = 1.389281901717186, cell type match prob = 0.06525408234447241
triplet_loss =  1.3621864318847656 ; cross_loss =  4.1512346267700195
triplet_loss =  1.4274544715881348 ; cross_loss =  4.166068077087402
triplet_loss =

In [43]:
torch.save(model.state_dict(), 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_crossAttentionOnly_'+str(len(idx_train))+'cells_entropy_100epochs')

In [None]:
train(model, data_train, epochs = 100, loss_type = "entropy")

In [None]:
torch.save(model.state_dict(), 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_crossAttentionOnly_'+str(len(idx_train))+'cells_entropy_200epochs')

In [45]:
train(model, data_train, epochs = 100, loss_type = "entropy")

triplet_loss =  1.0065240859985352 ; cross_loss =  4.028818130493164
triplet_loss =  1.0158454179763794 ; cross_loss =  4.031327724456787
triplet_loss =  0.9647536277770996 ; cross_loss =  4.031772613525391
triplet_loss =  0.9743040800094604 ; cross_loss =  4.031388759613037
triplet_loss =  0.9873414039611816 ; cross_loss =  4.028667449951172
triplet_loss =  0.9479917287826538 ; cross_loss =  4.027399063110352
triplet_loss =  0.976985514163971 ; cross_loss =  4.029253005981445
triplet_loss =  1.007913589477539 ; cross_loss =  4.03354024887085
triplet_loss =  1.04244065284729 ; cross_loss =  4.043071746826172
triplet_loss =  1.0016757249832153 ; cross_loss =  4.029843330383301
Epoch-10: lr = 0.0005, loss = 4.031217533349991, entropy_loss = 4.031217533349991, triplet loss = 0.9925471894443035, cell type match prob = 0.06751514300704002
triplet_loss =  0.9777954816818237 ; cross_loss =  4.0275115966796875
triplet_loss =  1.023401141166687 ; cross_loss =  4.034598350524902
triplet_loss =  

In [46]:
torch.save(model.state_dict(), 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_crossAttentionOnly_'+str(len(idx_train))+'cells_entropy_300epochs')

In [None]:
# model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 50, kernel_size_atac_2 = 10, index = index).to(config.DEVICE)
# model.load_state_dict(torch.load('drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_crossAttentionOnly_'+str(len(idx_train))+'cells_entropy_300epochs'))

<All keys matched successfully>

In [47]:
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE / 10)
train(model, data_train, epochs = 100, loss_type = "entropy")

triplet_loss =  0.938861608505249 ; cross_loss =  4.02773380279541
triplet_loss =  0.9537396430969238 ; cross_loss =  4.029812812805176
triplet_loss =  0.947595477104187 ; cross_loss =  4.024697303771973
triplet_loss =  0.9416443705558777 ; cross_loss =  4.023224830627441
triplet_loss =  0.9194592237472534 ; cross_loss =  4.023124694824219
triplet_loss =  0.9464056491851807 ; cross_loss =  4.02627420425415
triplet_loss =  0.8969321250915527 ; cross_loss =  4.023590087890625
triplet_loss =  0.929202139377594 ; cross_loss =  4.025856018066406
triplet_loss =  0.9222766160964966 ; cross_loss =  4.023134231567383
triplet_loss =  0.9089263677597046 ; cross_loss =  4.025506019592285
Epoch-10: lr = 5e-05, loss = 4.0254981994628904, entropy_loss = 4.0254981994628904, triplet loss = 0.9219556309282779, cell type match prob = 0.06718997908756137
triplet_loss =  0.9222042560577393 ; cross_loss =  4.029117107391357
triplet_loss =  0.9137256741523743 ; cross_loss =  4.027220249176025
triplet_loss = 

In [48]:
torch.save(model.state_dict(), 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_crossAttentionOnly_'+str(len(idx_train))+'cells_entropy_400epochs')

In [None]:
train(model, data_train, epochs = 100, loss_type = "entropy")

In [None]:
torch.save(model.state_dict(), 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_crossAttentionOnly_'+str(len(idx_train))+'cells_entropy_500epochs')

In [None]:
train(model, data_train, epochs = 100, loss_type = "entropy")

In [None]:
torch.save(model.state_dict(), 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_crossAttentionOnly_'+str(len(idx_train))+'cells_entropy_600epochs')

In [None]:
from numpy.lib.shape_base import row_stack
class bidirectTripletLoss2(nn.Module):
    r"""
    
    Output bidirectional triplet loss for two pairs of gex and atac
    ----------
    gex_0_mat: Matrix of GEX embeddings from self-attention (batch_size x embedding_size_0)

    Returns
    -------
    loss
    """
    def __init__(self, alpha, margin):
        super(bidirectTripletLoss2, self).__init__()

        self.alpha = alpha
        self.margin = margin
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def similarityScore(self, gex_out_0, gex_out_1, atac_out_0, atac_out_1):
        r"""
        Output similarity scores for two pairs of gex and atac
        ----------


        Returns
        -------
        score: batch_size * batch_size
        similarity score between two modalities
        """ 

        # print(gex_mat.size())
        # print(atac_mat.size())

        # gex_mat0, gex_mat1 = torch.split(gex_mat, config.N_CHANNELS, dim = 1)
        # atac_mat0, atac_mat1 = torch.split(atac_mat,config.N_CHANNELS, dim = 1)

        gex_out_0 = nn.functional.normalize(gex_out_0, dim = 1)
        gex_out_1 = nn.functional.normalize(gex_out_1, dim = 1)
        atac_out_0 = nn.functional.normalize(atac_out_0, dim = 1)
        atac_out_1 = nn.functional.normalize(atac_out_1, dim = 1)
        
        score = torch.mm(gex_out_0, atac_out_0.transpose(0,1)) + self.alpha * torch.mm(gex_out_1, atac_out_1.transpose(0,1))
        return  score.to(config.DEVICE)

    def triplet(self, score_mat):

        true_score = torch.diagonal(score_mat)
        print("true_score:\n", true_score)
        # print("true_score dimension:\n", true_score.size())
        
        reduced_score_mat = score_mat - torch.diag(true_score) # set the diagnoal score to be zero
        
        neg_index_1 = torch.argmax(reduced_score_mat, dim = 1) # indices of hard negatives for GEX
        neg_index_2 = torch.argmax(reduced_score_mat, dim = 0) # indices of hard negatives for ATAC
        # print("neg_index_1:\n", neg_index_1); print("neg_index_2:\n", neg_index_2)
        # print("neg_index_1 dimension:\n", neg_index_1.size()); print("neg_index_2 dimension:\n", neg_index_2.size())
        neg_1 = score_mat[[range(config.BATCH_SIZE), neg_index_1]] # hard negatives for GEX 
        neg_2 = score_mat[[neg_index_2, range(config.BATCH_SIZE)]] # hard negatives for ATAC
        # print("neg_1:\n", neg_1); print("neg_2:\n", neg_2)
        
        loss_1 = torch.max(self.margin - true_score + neg_1, torch.zeros(1, config.BATCH_SIZE).to(config.DEVICE))
        loss_2 = torch.max(self.margin - true_score + neg_2, torch.zeros(1, config.BATCH_SIZE).to(config.DEVICE))

        return torch.mean(loss_1 + loss_2)

    def crossEntropy(self, score_mat):

        batch_size = score_mat.size()[0]
        target = torch.arange(batch_size)

        loss_1 = self.cross_entropy_loss(score_mat, target.to(config.DEVICE))
        loss_2 = self.cross_entropy_loss(score_mat.T, target.to(config.DEVICE))

        return 0.5 * (loss_1 + loss_2)

    def cellTypeMatchingProbRow(self, score_mat, cell_type):

        # Collect list of index list for each cell type
        idx_in_type = collections.defaultdict(list)
        for i, x in enumerate(cell_type):
            idx_in_type[x].append(i)

        # Compute matching probs for each cell type
        score_mat_norm = score_mat.softmax(dim = 0)
        probs = []
        for idx in idx_in_type.values():
            prob_type = 0
            for i in idx:
                prob_type += score_mat_norm[i, idx].sum()
            probs.append(prob_type / len(idx)) 

        # Take average of matching prob from cell types
        ct_match_prob = torch.tensor(probs).mean()

        return ct_match_prob

    def cellTypeMatchingProb(self, score_mat, cell_type):
        row = self.cellTypeMatchingProbRow(score_mat, cell_type) # Softmax on rows (normalize GEX)
        col = self.cellTypeMatchingProbRow(score_mat.T, cell_type) # Softmax on cols (normalize ATAC) 
        
        return 0.5 * (row + col)

    def forward(self, gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type):
      
        score_mat = self.similarityScore(gex_out_0, gex_out_1, atac_out_0, atac_out_1)#; print("score_mat:\n", score_mat)

        loss_triplet = self.triplet(score_mat)
        loss_cross = self.crossEntropy(score_mat)
        loss = loss_triplet + loss_cross
        # print(score_mat); print(cell_type)
        ct_match_prob = self.cellTypeMatchingProb(score_mat, cell_type)

        return loss.to(config.DEVICE), loss_triplet, loss_cross, ct_match_prob

In [None]:
criterion2 =  bidirectTripletLoss2(alpha = 0.2, margin = 0.5).to(config.DEVICE)

In [None]:
def inference(model, data_val):

    # Initialize encoder & decoder 
    model.eval()
    model.to(config.DEVICE)
    criterion2.to(config.DEVICE)
    
    running_loss = 0.0
    running_loss_triplet = 0.0
    running_ct_prob = 0.0    
    for iter, data in enumerate(data_val):
      gex_input = data['gex'].to(config.DEVICE)
      atac_input = data['atac'].to(config.DEVICE)
      cell_type_input = data['cell_type']

      ### Forward
      gex_out_0, gex_out_1, atac_out_0, atac_out_1 = model(gex_input, atac_input)

      ### Compute loss
      loss, loss_triplet, loss_cross, ct_match_prob = criterion2(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type_input)
      
      running_loss += loss_cross.item()
      running_loss_triplet += loss_triplet.item()
      running_ct_prob += ct_match_prob.item()

      del gex_input
      del atac_input
      del cell_type_input
      torch.cuda.empty_cache()

    del criterion2
    del model
    torch.cuda.empty_cache()
    return running_loss / len(data_val), running_loss_triplet / len(data_val), running_ct_prob / len(data_val)

In [None]:
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 50, kernel_size_atac_2 = 10, index = index).to(config.DEVICE)
model.load_state_dict(torch.load('drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_20480cells_entropy_100epochs'))

In [None]:
loss, loss_triplet, ct_match_prob = inference(model, data_val)
print('loss = {0}, triplet loss = {1}, cell type match prob = {2}'.format(loss, loss_triplet, ct_match_prob))

In [None]:
# model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 50, kernel_size_atac_2 = 10, index = index).to(config.DEVICE)
# model.load_state_dict(torch.load('drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_20480cells_entropy_200epochs'))



<All keys matched successfully>

In [None]:
# loss, loss_triplet, ct_match_prob = inference(model, data_val)
# print('loss = {0}, triplet loss = {1}, cell type match prob = {2}'.format(loss, loss_triplet, ct_match_prob))

true_score:
 tensor([1.1968, 1.1975, 1.1788, 1.0766, 1.1927, 1.0490, 1.1898, 1.1149, 1.1503,
        1.1836, 1.1979, 0.8274, 1.1933, 1.0821, 1.0032, 0.8467, 0.8696, 1.1873,
        1.1955, 0.5229, 1.1950, 1.1912, 1.0142, 0.8488, 1.1869, 1.1733, 0.9110,
        1.1491, 1.1952, 0.1134, 1.1834, 1.1447, 1.1933, 1.1045, 1.0960, 1.1966,
        0.6721, 0.9657, 0.7136, 1.1979, 1.1980, 1.1261, 0.6802, 1.1790, 1.1054,
        0.9145, 1.1803, 0.7301, 1.1383, 1.1890, 0.9857, 1.1959, 1.0625, 1.1930,
        1.0289, 0.9559, 0.7853, 1.1971, 1.1984, 1.0375, 0.9430, 1.1472, 1.1848,
        0.9206, 1.0730, 0.9288, 1.0737, 0.3035, 1.1099, 0.9366, 0.0568, 1.1854,
        1.0634, 0.9947, 1.1829, 1.1681, 1.1962, 1.0149, 1.1407, 1.0748, 1.1396,
        1.1840, 0.6322, 1.0293, 1.1723, 0.9474, 1.1955, 0.4931, 1.1488, 1.1931,
        1.1990, 1.1841, 1.1841, 1.1679, 1.1368, 0.9413, 0.5911, 1.1852, 1.1456,
        1.1758, 0.4013, 1.1901, 1.1826, 1.1193, 1.1469, 1.1721, 1.1958, 0.8836,
        0.6018, 1.1980, 1.1

In [None]:
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 50, kernel_size_atac_2 = 10, index = index).to(config.DEVICE)
model.load_state_dict(torch.load('drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_20480cells_entropy_300epochs'))

In [None]:
loss, loss_triplet, ct_match_prob = inference(model, data_val)
print('loss = {0}, triplet loss = {1}, cell type match prob = {2}'.format(loss, loss_triplet, ct_match_prob))

In [None]:
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 50, kernel_size_atac_2 = 10, index = index).to(config.DEVICE)
model.load_state_dict(torch.load('drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_20480cells_entropy_400epochs'))

In [None]:
loss, loss_triplet, ct_match_prob = inference(model, data_val)
print('loss = {0}, triplet loss = {1}, cell type match prob = {2}'.format(loss, loss_triplet, ct_match_prob))

In [None]:
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 50, kernel_size_atac_2 = 10, index = index).to(config.DEVICE)
model.load_state_dict(torch.load('drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_20480cells_entropy_500epochs'))

In [None]:
loss, loss_triplet, ct_match_prob = inference(model, data_val)
print('loss = {0}, triplet loss = {1}, cell type match prob = {2}'.format(loss, loss_triplet, ct_match_prob))

In [None]:
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 50, kernel_size_atac_2 = 10, index = index).to(config.DEVICE)
model.load_state_dict(torch.load('drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_20480cells_entropy_600epochs'))

In [None]:
loss, loss_triplet, ct_match_prob = inference(model, data_val)
print('loss = {0}, triplet loss = {1}, cell type match prob = {2}'.format(loss, loss_triplet, ct_match_prob))