In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
def get_chr_index(adata_atac):
  r"""
  Output row indices for each chromosome for each chromosome
  Parameters
  ----------
  adata_atac
      annData for ATAC
  Returns
  -------
  chr_index
      Dictionary of indices for each chromosome
  """
  row_name = adata_atac.var.index
  chr_name = [c.split("-")[0] for c in row_name]
  lst = np.unique(chr_name) # names for chromosome

  chr_index = dict()
  for i in range(len(lst)):
    index = [a for a, l in enumerate(chr_name) if l == lst[i]]
    if lst[i] not in chr_index:
      chr_index[lst[i]]=index

  return chr_index

In [3]:
!pip install scanpy --quiet

[K     |████████████████████████████████| 2.0 MB 4.9 MB/s 
[K     |████████████████████████████████| 9.4 MB 40.0 MB/s 
[K     |████████████████████████████████| 88 kB 6.2 MB/s 
[K     |████████████████████████████████| 96 kB 4.9 MB/s 
[K     |████████████████████████████████| 295 kB 70.5 MB/s 
[K     |████████████████████████████████| 965 kB 59.9 MB/s 
[K     |████████████████████████████████| 1.1 MB 58.1 MB/s 
[K     |████████████████████████████████| 63 kB 1.6 MB/s 
[?25h  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Building wheel for pynndescent (setup.py) ... [?25l[?25hdone
  Building wheel for session-info (setup.py) ... [?25l[?25hdone


In [54]:
import torch
from torch import nn
from torch.autograd import Variable
import anndata as ad
import numpy as np
import os
from argparse import Namespace
config = Namespace(
    N_GENES = 13431,
    N_PEAKS = 116465,
    N_CHANNELS = 32
)

In [62]:
## Write cnn modules for gex modalities
class GEX_CNN(nn.Module):
    """customized  module"""
    #argument index is the poisition for each choromosome
    def __init__(self, kernel_size):
        super(GEX_CNN, self).__init__()

        # Conv layer
        self.in_channels = 1 
        self.out_channels = config.N_CHANNELS
        self.kernel_size = kernel_size   
        self.stride = 10 # TO CHANGE 
        self.padding = 10 # TO CHANGE
        self.pool_size = 2
        self.pool_stride = 1
        self.convs = nn.Sequential(
            nn.Conv1d(in_channels = self.in_channels, 
                      out_channels = self.out_channels, 
                      kernel_size = self.kernel_size,
                      stride = self.stride,
                      padding = self.padding),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size = self.pool_size,
                         stride = self.pool_stride)
        )

        # # FC layer
        # self.conv_out_features = int((config.N_GENES + 2*self.padding - self.kernel_size) / self.stride + 1)
        # self.fc_in_features = int((self.conv_out_features - self.pool_size) / self.pool_stride + 1) * self.out_channels
        # self.fc_out_feature = 300
        # self.fc = nn.Linear(in_features = self.fc_in_features, out_features = self.fc_out_feature) 

    def forward(self, x):
      x = self.convs(x)
      # x = torch.flatten(x, 1)
      # x = self.fc(x)
      return x.transpose(1,2)

In [59]:
# Test for GEX_CNN()
adata_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad")

In [63]:
x = torch.tensor(np.asarray(adata_gex.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells
print(x.size())
model = GEX_CNN(kernel_size = 10)
print(model(x).size())

torch.Size([5, 1, 13431])
torch.Size([5, 1344, 32])


In [64]:
# Write cnn modules for atac modalities
class ATAC_CNN(nn.Module):
    """customized  module"""
    #argument index is the poisition for each choromosome
    def __init__(self, index, kernel_size):
        super(ATAC_CNN, self).__init__()
        self.index = index
        
        # Conv layer
        self.in_channels = 1 
        self.out_channels = config.N_CHANNELS
        self.kernel_size = kernel_size   
        self.stride = 10 # TO CHANGE 
        self.padding = 10 # TO CHANGE
        self.pool_size = 2
        self.pool_stride = 1
        self.convs = nn.Sequential(
            nn.Conv1d(in_channels = self.in_channels, 
                      out_channels = self.out_channels, 
                      kernel_size = self.kernel_size,
                      stride = self.stride,
                      padding = self.padding),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size = self.pool_size,
                         stride = self.pool_stride)
        )


    def forward(self, x):
      out = []
      for chr in self.index.keys(): 
          idx = self.index[chr]
          x_chr = x[:,:,idx]
          x_chr = self.convs(x_chr.float())
          out.append(x_chr)
      out = torch.cat(out, dim = 2)
      return out.transpose(1,2)

In [6]:
# Test for ATAC_CNN()
adata_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad")

In [66]:
x = torch.tensor(np.asarray(adata_atac.layers['log_norm'][:5].todense())).unsqueeze(1) # 5 cells
print(x.size())
index = get_chr_index(adata_atac)
model = ATAC_CNN(kernel_size = 10, index = index)
print(model(x).size())

torch.Size([5, 1, 116465])
torch.Size([5, 11658, 32])


In [None]:
class GEX_attention(nn.Module):
    """customized  module"""
    #argument index is the poisition for each choromosome
    def __init__(self):
        super(GEX_attention, self).__init__()

    def forward(self, x):
        pass

In [None]:
class ATAC_attention(nn.Module):
    """customized  module"""
    #argument index is the poisition for each choromosome
    def __init__(self):
        super(ATAC_attention, self).__init__()

    def forward(self, x):
        pass

In [None]:
class cross_attention(nn.Module):
    """customized  module"""
    #argument index is the poisition for each choromosome
    def __init__(self):
        super(cross_attention, self).__init__()

    def forward(self, x):
        pass