In [1]:
import scanpy as sc
import scvelo as scv
import torch
import torch.nn as nn

In [3]:
num_genes = 2000
embedding_dim = 128

In [5]:
adata = scv.datasets.pancreas()
scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000)
#adata = adata[adata.obs["stage"]=="E8.25"].copy()
scv.pp.moments(adata, n_neighbors=200)

Filtered out 20801 genes that are detected 20 counts (shared).
Normalized count data: X, spliced, unspliced.
Extracted 2000 highly variable genes.
Logarithmized X.
computing neighbors


  log1p(adata)


    finished (0:00:04) --> added 
    'distances' and 'connectivities', weighted adjacency matrices (adata.obsp)
computing moments based on connectivities
    finished (0:00:02) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)


In [8]:
from dataloaders import setup_dataloaders_ranked

train, test, full = setup_dataloaders_ranked(adata, batch_size=12, num_genes=num_genes, embedding_dim=embedding_dim)

In [9]:
tokens, combined, idx = next(iter(full))

In [11]:
tokens.shape

torch.Size([256, 4000, 128])

In [12]:
combined.shape

torch.Size([256, 4000])

In [13]:
idx.shape

torch.Size([256])

In [5]:
embeddings = nn.Embedding(2*num_genes, embedding_dim)  # 2*num_genes for unspliced and spliced
pos_embeddings = nn.Parameter(torch.randn(2*num_genes, embedding_dim))  # Learned positional embeddings

In [6]:
unspliced = torch.tensor(adata.layers["Mu"][0], dtype=torch.float32)
spliced = torch.tensor(adata.layers["Ms"][0], dtype=torch.float32)
combined = torch.cat([unspliced, spliced])


In [7]:
# Rank genes based on expression and get indices
ranked_indices = torch.argsort(combined, descending=True)
ranked_indices

tensor([ 865, 2295, 3060,  ..., 1774, 2286, 1768])

In [10]:
combined[ranked_indices]

tensor([80.4759, 78.5706, 73.4623,  ...,  0.0000,  0.0000,  0.0000])

In [12]:
ranked_indices.shape

torch.Size([4000])

In [13]:
# Tokenize features using embedding layer
tokens = embeddings(ranked_indices)
tokens.shape

torch.Size([4000, 10])

In [15]:
pos_embeddings.shape

torch.Size([4000, 10])

In [None]:
# Add positional embeddings
tokens += pos_embeddings[:2*num_genes]
tokens