In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
!pip install scanpy --quiet

[K     |████████████████████████████████| 2.0 MB 14.4 MB/s 
[K     |████████████████████████████████| 88 kB 8.4 MB/s 
[K     |████████████████████████████████| 96 kB 5.7 MB/s 
[K     |████████████████████████████████| 9.4 MB 51.2 MB/s 
[K     |████████████████████████████████| 295 kB 81.9 MB/s 
[K     |████████████████████████████████| 965 kB 83.8 MB/s 
[K     |████████████████████████████████| 1.1 MB 64.2 MB/s 
[K     |████████████████████████████████| 63 kB 2.1 MB/s 
[?25h  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Building wheel for pynndescent (setup.py) ... [?25l[?25hdone
  Building wheel for session-info (setup.py) ... [?25l[?25hdone


In [None]:
import torch
from torch import nn
from torch.autograd import Variable
import anndata as ad
import numpy as np
import os
from argparse import Namespace
config = Namespace(
    NUM_WORKERS = 4,
    N_GENES = 13431,
    N_PEAKS = 116465,
    N_CHANNELS = 32,
    MAX_SEQ_LEN_GEX = 1500,
    MAX_SEQ_LEN_ATAC = 15000
)

In [None]:
adata_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad")

In [None]:
adata_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad")

In [None]:
def get_chr_index(adata_atac):
  r"""
  Output row indices for each chromosome for each chromosome
  Parameters
  ----------
  adata_atac
      annData for ATAC
  Returns
  -------
  chr_index
      Dictionary of indices for each chromosome
  """
  row_name = adata_atac.var.index
  chr_name = [c.split("-")[0] for c in row_name]
  lst = np.unique(chr_name) # names for chromosome

  chr_index = dict()
  for i in range(len(lst)):
    index = [a for a, l in enumerate(chr_name) if l == lst[i]]
    if lst[i] not in chr_index:
      chr_index[lst[i]]=index

  return chr_index

In [None]:
## Write cnn modules for gex modalities
class gexCNN(nn.Module):
    """customized  module"""
    #argument index is the poisition for each choromosome
    def __init__(self, kernel_size):
        super(gexCNN, self).__init__()

        # Conv layer
        self.in_channels = 1 
        self.out_channels = config.N_CHANNELS
        self.kernel_size = kernel_size   
        self.stride = 10 # TO CHANGE 
        self.padding = 10 # TO CHANGE
        self.pool_size = 2
        self.pool_stride = 1
        self.convs = nn.Sequential(
            nn.Conv1d(in_channels = self.in_channels, 
                      out_channels = self.out_channels, 
                      kernel_size = self.kernel_size,
                      stride = self.stride,
                      padding = self.padding),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size = self.pool_size,
                         stride = self.pool_stride)
        )

        # # FC layer
        # self.conv_out_features = int((config.N_GENES + 2*self.padding - self.kernel_size) / self.stride + 1)
        # self.fc_in_features = int((self.conv_out_features - self.pool_size) / self.pool_stride + 1) * self.out_channels
        # self.fc_out_feature = 300
        # self.fc = nn.Linear(in_features = self.fc_in_features, out_features = self.fc_out_feature) 

    def forward(self, x):
        r"""  
        Generate GEX embeddings
        
        Parameters
        ----------
        x
            Pre-processed GEX data (batch_size x 1 x N_GENES)
        
        Returns
        -------
        gex_embed
            GEX embeddings of a batch (batch_size x seq_len x dim_size)
        """
        gex_embed = self.convs(x)
        # gex_embed = torch.flatten(gex_embed, 1)
        # gex_embed = self.fc(gex_embed)
        return gex_embed.transpose(1,2)

In [None]:
# Test for gexCNN()
x = torch.tensor(np.asarray(adata_gex.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells
print(x.size())
model = gexCNN(kernel_size = 10)
print(model(x).size())

torch.Size([5, 1, 13431])
torch.Size([5, 1344, 32])


In [None]:
# Write cnn modules for atac modalities
class atacCNN(nn.Module):
    #argument index is the poisition for each choromosome
    def __init__(self, index, kernel_size):
        super(atacCNN, self).__init__()
        self.index = index
        
        # Conv layer
        self.in_channels = 1 
        self.out_channels = config.N_CHANNELS
        self.kernel_size = kernel_size   
        self.stride = 10 # TO CHANGE 
        self.padding = 10 # TO CHANGE
        self.pool_size = 2
        self.pool_stride = 1
        self.convs = nn.Sequential(
            nn.Conv1d(in_channels = self.in_channels, 
                      out_channels = self.out_channels, 
                      kernel_size = self.kernel_size,
                      stride = self.stride,
                      padding = self.padding),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size = self.pool_size,
                         stride = self.pool_stride)
        )


    def forward(self, x):
        r"""  
        Generate ATAC embeddings
        
        Parameters
        ----------
        x
            Pre-processed ATAC data (batch_size x 1 x N_PEAKS)
        
        Returns
        -------
        atac_embed
            ATAC embeddings of a batch (batch_size x seq_len x dim_size)
        """
        atac_embed = []
        for chr in self.index.keys(): 
            idx = self.index[chr]
            x_chr = x[:,:,idx]
            x_chr = self.convs(x_chr.float())
            atac_embed.append(x_chr)
        atac_embed = torch.cat(atac_embed, dim = 2)
        return atac_embed.transpose(1,2)

In [None]:
# Test for ATAC_CNN()
x = torch.tensor(np.asarray(adata_atac.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells
print(x.size())
index = get_chr_index(adata_atac)
model = atacCNN(kernel_size = 50, index = index)
print(model(x).size())

torch.Size([5, 1, 116465])
torch.Size([5, 11566, 32])


In [None]:
class MultimodalAttention(nn.Module):
    def __init__(self):
        super(MultimodalAttention, self).__init__()
        self.nhead_gex = 1
        self.nhead_atac = 4
        self.nhead_multi = 4
        self.nlayer_gex = 1
        self.nlayer_atac = 1
        self.nlayer_multi = 1

        self.encoder_layer_gex = nn.TransformerEncoderLayer(d_model = config.N_CHANNELS, nhead = self.nhead_gex)
        self.transformer_encoder_gex = nn.TransformerEncoder(self.encoder_layer_gex, num_layers = self.nlayer_gex)
        # self.linear_gex = nn.Linear(in_features = config.MAX_SEQ_LEN_GEX, out_features = 1)

        self.encoder_layer_atac = nn.TransformerEncoderLayer(d_model = config.N_CHANNELS, nhead = self.nhead_atac)
        self.transformer_encoder_atac = nn.TransformerEncoder(self.encoder_layer_atac, num_layers = self.nlayer_atac)
        # self.linear_atac = nn.Linear(in_features = config.MAX_SEQ_LEN_ATAC, out_features = 1)

        self.encoder_layer_multi = nn.TransformerEncoderLayer(d_model = config.N_CHANNELS, nhead = self.nhead_multi)
        self.transformer_encoder_multi = nn.TransformerEncoder(self.encoder_layer_multi, num_layers = self.nlayer_multi)
    

    def forward(self, gex_embed, atac_embed):
      r"""  
      Incorporate two self-attention and one cross-attention module

      Parameters
      ----------
      gex_embed
          GEX embeddings of a batch (batch_size x seq_len_gex x dim_size)
      atac_embed
          ATAC embeddings of a batch (batch_size x seq_len_atac x dim_size)

      Returns
      -------
      ## TO FILL
      """
      seq_len_gex = gex_embed.size()[1]
      seq_len_atac = atac_embed.size()[1]

      gex_context = self.transformer_encoder_gex(gex_embed)
      atac_context = self.transformer_encoder_atac(atac_embed)
      print(gex_context.size())
      print(atac_context.size())

      # Average self-attention fragment representation
      gex_out_0 = gex_context.mean(dim = 1)
      atac_out_0 = atac_context.mean(dim = 1)

      multi_embed = torch.cat((gex_context, atac_context), dim = 1)
      multi_context = self.transformer_encoder_multi(multi_embed)
      print(multi_context.size())
      
      multi_context_gex = multi_context[:, :seq_len_gex, :]
      multi_context_atac = multi_context[:, seq_len_gex:, :]

      # Average cross-attention fragment representation
      gex_out_1 = multi_context_gex.mean(dim = 1)
      atac_out_1 = multi_context_atac.mean(dim = 1)

      return gex_out_0, gex_out_1, atac_out_0, atac_out_1

In [None]:
index = get_chr_index(adata_atac)

x_gex = torch.tensor(np.asarray(adata_gex.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells
x_atac = torch.tensor(np.asarray(adata_atac.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells

gex_cnn = gexCNN(kernel_size = 10)
atac_cnn = atacCNN(kernel_size = 20, index = index)
multi_attention = MultimodalAttention()

gex_embed = gex_cnn(x_gex)
atac_embed = atac_cnn(x_atac)

gex_out_0, gex_out_1, atac_out_0, atac_out_1 = multi_attention(gex_embed, atac_embed)


torch.Size([5, 1344, 32])
torch.Size([5, 11635, 32])
torch.Size([5, 12979, 32])


In [None]:
class Custom_MSE(nn.Module):
  def __init__(self):
    super(Custom_MSE, self).__init__();

  def forward(self, predictions, target):
    square_difference = torch.square(predictions - target)
    loss_value = torch.mean(square_difference)
    return loss_value
  
class bidirectTripletLoss(nn.Module):
    r"""
    
    Output bidirectional triplet loss for two pairs of gex and atac
    ----------
    gex_0_mat: Matrix of GEX embeddings from self-attention (batch_size x embedding_size_0)

    Returns
    -------
    loss
    """
    def __init__(self, alpha, margin):
        super(bidirectTripletLoss, self).__init__()
        self.alpha = alpha
        self.margin = margin
        self.loss = nn.TripletMarginWithDistanceLoss(distance_function = self.similarityScore, margin = self.margin)

    def similarityScore(self, gex_vec, atac_vec):
        r"""
        Output similarity scores for two pairs of gex and atac
        ----------
        Vector of gex and atac embeddings 
        gex_vec: (1 x 2*config.N_CHANNELS)
        atac_vec: (1 x 2*config.N_CHANNELS)

        Returns
        -------
        similarity score between two modalities
        """
        score = torch.dot(gex_vec[:config.N_CHANNELS], atac_vec[:config.N_CHANNELS]) + self.alpha * torch.dot(gex_vec[config.N_CHANNELS:], atac_vec[config.N_CHANNELS:])
        return  score

    def forward(self, gex_mat, atac_mat):
        r"""
        Output bi-directional triplet ranking scores for two pairs of gex and atac
        ----------
        Matrix of gex and atac embeddings 
        gex_mat: (batch_size x 2*config.N_CHANNELS)
        atac_mat: (batch_size x 2*config.N_CHANNELS)

        Returns
        -------
        Bi-directional triplet ranking scores between two modalities
        """
        gex_mat_0, gex_mat_1 = torch.split(gex_mat, 1)
        atac_mat_0, atac_mat_1 = torch.split(atac_mat, 1)

        similarity_matrix = torch.matmul(gex_mat_0, atac_mat_0) + self.alpha * torch.matmul(gex_mat_1, atac_mat_1)
        

        output = loss(anchor, positive, negative) + loss(anchor, positive2, negative2)

In [None]:
embedding = nn.Embedding(1000, 128)
anchor_ids = torch.randint(0, 1000, (1,))
positive_ids = torch.randint(0, 1000, (1,))
negative_ids = torch.randint(0, 1000, (1,))
anchor = embedding(anchor_ids)
positive = embedding(positive_ids)
negative = embedding(negative_ids)

In [None]:
print(anchor.size())
print(positive.size())
print(negative.size())

torch.Size([1, 128])
torch.Size([1, 128])
torch.Size([1, 128])


In [None]:
import torch
import torch.nn as nn
  
class bidirectTripletLoss(nn.Module):
    r"""
    
    Output bidirectional triplet loss for two pairs of gex and atac
    ----------
    gex_0_mat: Matrix of GEX embeddings from self-attention (batch_size x embedding_size_0)

    Returns
    -------
    loss
    """
    def __init__(self, alpha, margin):
        super(bidirectTripletLoss, self).__init__()

        self.alpha = alpha
        self.margin = margin

    def similarityScore(self, gex_mat, atac_mat):
        r"""
        Output similarity scores for two pairs of gex and atac
        ----------
        list: of gex and atac embeddings 
        gex_mat:  batch_size * (embed_size * 2)
        atac_mat: batch_size * (embed_size * 2)

        Returns
        -------
        score: batch_size * batch_size
        similarity score between two modalities
        """
        gex_mat0, gex_mat1 = torch.split(gex_mat, config.N_CHANNELS, dim = 1)
        atac_mat0, atac_mat1 = torch.split(atac_mat,config.N_CHANNELS, dim = 1)

        score = torch.mm(gex_mat0,atac_mat0.transpose(0,1)) + self.alpha * torch.mm(gex_mat1,atac_mat1.transpose(0,1))

        return  score


    def forward(self, gex_mat, atac_mat):
      
        batch_size = gex_mat.size()[0]
        score_mat = self.similarityScore(gex_mat, atac_mat)#; print("score_mat:\n", score_mat)
        true_score = torch.diagonal(score_mat)#; print("true_score:\n", true_score)
        
        reduced_score_mat = score_mat - torch.diag(true_score) # set the diagnoal score to be zero
        
        neg_index_1 = torch.argmax(reduced_score_mat, dim = 1) # indices of hard negatives for GEX
        neg_index_2 = torch.argmax(reduced_score_mat, dim = 0) # indices of hard negatives for ATAC
        # print("neg_index_1:\n", neg_index_1); print("neg_index_2:\n", neg_index_2)
        neg_1 = score_mat[[range(batch_size), neg_index_1]] # hard negatives for GEX 
        neg_2 = score_mat[[neg_index_2, range(batch_size)]] # hard negatives for ATAC
        # print("neg_1:\n", neg_1); print("neg_2:\n", neg_2)
        
        loss_1 = torch.max(self.margin - true_score + neg_1, torch.zeros(1, batch_size))
        loss_2 = torch.max(self.margin - true_score + neg_2, torch.zeros(1, batch_size))
 
        loss = loss_1 + loss_2

        return torch.mean(loss) 

In [None]:
import random
random.seed(0)
gex_mat=torch.randn([5,64])
# print(gex_mat[0,])
atac_mat=torch.randn([5,64])
# print(atac_mat[0,])
# print(torch.dot(gex_mat[0,:32],atac_mat[0,:32]))
loss=bidirectTripletLoss(alpha=0.2,margin=2)
res=loss(gex_mat,atac_mat)
res