In [1]:
!pip install scanpy
!pip install einops

Collecting scanpy
  Downloading scanpy-1.9.4-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting anndata>=0.7.4 (from scanpy)
  Downloading anndata-0.9.2-py3-none-any.whl (104 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.2/104.2 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
Collecting umap-learn>=0.3.10 (from scanpy)
  Downloading umap-learn-0.5.3.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.2/88.2 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting session-info (from scanpy)
  Downloading session_info-1.0.0.tar.gz (24 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pynndescent>=0.5 (from umap-learn>=0.3.10->scanpy)
  Downloading pynndescent-0.5.10.tar.gz (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1

In [2]:
import random
import scanpy as sc
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
import math

In [3]:
# hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 8 # The number of cells for training in one epoch.
tr_ratio= 0.7 # train:valid
masking_ratio = 0.15 # geneformer中loss降得最低的
max_gs = 300 # The max of pathway/token number.即cell embedding的输出维度
n_embd = 48 # cellembedding numbers = weight matrix numbers
n_head = 4
n_layer = 2 # block number
dropout = 0.2
epochs = 1000 # 总训练次数。因学习率降低需要增加训练次数
eval_interval = 100 # 执行评估loss的训练次数迭代间隔
eval_iters = 100 # 对loss求平均的验证次数
learning_rate = 3e-4 # attention不能承受太大的学习率

In [4]:
# 固定随机种子
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

set_seed(1464)

# input data process

In [5]:
# 读取原始单细胞数据
adata = sc.read('drive/MyDrive/data/demo_train.h5ad')
adata = adata[:,adata.var_names]
print(adata)
print(adata.obs.Celltype.value_counts())

View of AnnData object with n_obs × n_vars = 10600 × 3000
    obs: 'Celltype'
    var: 'Gene Symbol'
alpha          3136
beta           2966
ductal         1290
acinar         1144
delta           793
PSC             524
PP              356
endothelial     273
macrophage       52
mast             25
epsilon          21
schwann          13
t_cell            7
Name: Celltype, dtype: int64


In [6]:
def toExp(adata):
  def todense(adata):
    """
    转换为表达矩阵，行是基因，列是细胞
    """
    import scipy
    if isinstance(adata.X, scipy.sparse.csr_matrix) or isinstance(adata.X, scipy.sparse.csc_matrix):
        return adata.X.todense()
    else:
        return adata.X
  el_data = pd.DataFrame(todense(adata),index=np.array(adata.obs_names).tolist(), columns=np.array(adata.var_names).tolist())
  el_data['Celltype'] = adata.obs['Celltype'].astype('str') #最后一列添加cell type
  genes = el_data.columns.values[:-1]
  return el_data, genes

def balance_populations(data):
  """
  让每种celltype所含细胞数相等
  """
  ct_names = set(data.iloc[:,-1])
  ct_counts = pd.value_counts(data.iloc[:,-1])
  max_val = min(ct_counts.max(),np.int32(2000000/len(ct_counts)))
  balanced_data=pd.DataFrame(index=range(1), columns=genes)
  for ct in ct_names:
      tmp = data.loc[data.Celltype == ct]
      idx = np.random.choice(range(len(tmp)), max_val)
      tmp_X = tmp.iloc[idx,:]
      balanced_data = pd.concat([balanced_data,tmp_X])
  return balanced_data.drop(balanced_data.index[0])

In [7]:
# 创建表达矩阵，行是细胞，列是基因
el_data, genes = toExp(adata)
el_data = balance_populations(data = el_data)
el_data = np.array(el_data.iloc[:,:-1])
n_genes = len(genes)
print(el_data)
print(el_data.shape)
print(genes)
print(genes.shape)

[[0.        0.        0.        ... 0.        0.        0.       ]
 [0.        0.        1.9713649 ... 0.        0.        0.       ]
 [0.        0.        0.        ... 0.        0.        0.       ]
 ...
 [0.        0.        0.        ... 0.        0.        0.       ]
 [0.        0.        0.        ... 0.9642614 0.        0.       ]
 [0.        0.        0.        ... 0.        0.        0.       ]]
(40768, 3000)
['COL1A1' 'COL1A2' 'PPY' ... 'C9orf135' 'GRIN2D' 'HERC5']
(3000,)


# train and valid splits

In [8]:
# Train and valid splits
train_size = int(len(el_data) * tr_ratio)
train_dataset, valid_dataset = torch.utils.data.random_split(el_data, [train_size,len(el_data)-train_size])
train_dataset = torch.from_numpy(np.array(train_dataset)[:,:n_genes].astype(np.float32))
valid_dataset = torch.from_numpy(np.array(valid_dataset)[:,:n_genes].astype(np.float32))

print(train_dataset.shape)
print(train_dataset)
print(valid_dataset.shape)
print(valid_dataset)

torch.Size([28537, 3000])
tensor([[1.1370, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
torch.Size([12231, 3000])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 6.7095,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 1.5748,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 6.1754,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])


In [9]:
# data loading
def get_batch(split):
    """
    generate a small batch of data of inputs x and targets y
    """
    data = train_dataset if split == 'train' else valid_dataset
    ix = torch.randint(len(data), (batch_size,))
    x = data[ix]
    # 生成随机mask矩阵，0表示掩盖的位置，1表示要保留的位置
    # mask = np.random.choice([0, 1], size=y.shape, p=[masking_ratio, 1 - masking_ratio])
    # x = y * mask
    x = x.to(device)
    return x # (batch_size, n_genes)

# model structure

## cell embedding

In [10]:
# 自定义权重矩阵的前向传播和反向传播计算
class CustomizedLinearFunction(torch.autograd.Function):
    """
    autograd function, update while training
    """
    @staticmethod # 通过类名直接调用
    def forward(ctx, input, weight):
        output = input.mm(weight.t()) # (batch_size,n_genes) @ (n_genes,max_gs)➡️(batch_size,max_gs)
        ctx.save_for_backward(input, weight)
        return output # gene set token (batch_size,max_gs)

    @staticmethod
    def backward(ctx, grad_output):
        input, weight= ctx.saved_tensors
        grad_input = grad_weight = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)

        return grad_input, grad_weight # 和input，weight形状相同。传播梯度，更新参数

# 创建全连接线性层，即用于处理表达矩阵并可随训练更新的权重矩阵W(k,n)
class CustomizedLinear(nn.Module):
    def __init__(self):
        super(CustomizedLinear, self).__init__()
        self.input_features = n_genes
        self.output_features = max_gs
        # 随机初始化对应大小的权重矩阵，让pytorch知道跟踪该张量的梯度
        self.weight = nn.Parameter(torch.Tensor(self.output_features, self.input_features))
        self.reset_parameters()

    def reset_parameters(self):
      """
      初始化为小范围内的随机值
      """
      stdv = 1. / math.sqrt(self.weight.size(1))
      self.weight.data.uniform_(-stdv, stdv)

    def reset_params_pos(self):
        """ 初始化为正值"""
        stdv = 1./math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(0,stdv)

    def forward(self, input):
        input = input.float()
        return CustomizedLinearFunction.apply(input, self.weight)

# 使用梯度检查来验证自定义线性层的梯度计算是否准确
if __name__ == 'check grad':
    from torch.autograd import gradcheck
    customlinear = CustomizedLinearFunction.apply

    input = (torch.randn(20,20,dtype=torch.double,requires_grad=True),
         torch.randn(30,20,dtype=torch.double,requires_grad=True),)
    test = gradcheck(customlinear, input, eps=1e-6, atol=1e-4)
    print(test)

class FeatureEmbed(nn.Module):
  # 基因通过可学习的权重矩阵映射到gene set token (batch_size,max_gs,n_embd)
    def __init__(self):
        super().__init__()
        self.fe = CustomizedLinear()
    def forward(self, x):
      embed_outputs = []
      weights = []
      for i in range(n_embd):
        output = rearrange(self.fe(x), 'h (w c) -> h c w ', c=max_gs) # (batch_size,max_gs,1)
        embed_outputs.append(output)
        # 获得n_embd个权重矩阵的平均
        weight = self.fe.weight
        weights.append(weight)
      weights = torch.stack(weights)
      final_output = torch.cat(embed_outputs, dim=-1)
      return final_output, weights

## Encoder--multi-head self-attention

In [11]:
class Head(nn.Module):
    """ one head of self-attention """
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape # gene set token(batch_size,max_gs,n_embd)
        k = self.key(x)   # (B,T,head_size)
        q = self.query(x) # (B,T,head_size)
        v = self.value(x) # (B,T,head_size)
        a = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        a = F.softmax(a, dim=-1) # (B, T, T)
        # weights = a
        a = self.dropout(a)
        out = a @ v # (B,T,T) @ (B,T,hs) -> (B,T,hs)
        return out
        # return out,weights

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """
    def __init__(self, head_size):
        super().__init__()
        self.head_size=head_size
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # out,weights = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        # weights = weights.view(batch_size, self.head_size, max_gs, max_gs)
        return out
        # return out,weights

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),)

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation
      Attention + FeedForward """
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # out,weights = self.sa(self.ln1(x))
        # x = x + out
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# Decoder

In [12]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(n_embd, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, n_embd)  # 输出层
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        output = self.fc3(x)
        return output

## Transformer

In [15]:
class Transformer(nn.Module):

    def __init__(self):
        super().__init__()
        # embed (batch_size, n_genes)➡️(batch_size,max_gs,n_embd)
        self.feature_embed = FeatureEmbed()
        # encoder
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        # decoder
        self.decoder = Decoder()
        self.apply(self._init_weights)

    # 初始化权重和偏置参数
    def _init_weights(self, module):
        if isinstance(module, nn.Linear): # 检查module是否是线性层
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) # 正态分布随机初始化权重
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias) # 偏置初始化为0
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        set_target, w = self.feature_embed(idx) # (batch_size,max_gs,n_embd), (n_embd,max_gs,n_genes)
        x = self.blocks(set_target)
        x = self.ln_f(x)
        logits = self.decoder(x)
        loss = F.mse_loss(logits, set_target)

        return loss, w

# train the model

In [17]:
@torch.no_grad()
def estimate_loss():
  """
  取验证次数loss的平均，减少噪音
  """
  out = {}
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X= get_batch(split)
      loss, w = model(X)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out

In [18]:
model = Transformer()
m = model.to(device)
# print model parameters
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(epochs):
    # evaluate
    if iter % eval_interval == 0 or iter == epochs - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb = get_batch('train')

    # evaluate the loss
    loss, w = model(xb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

1.007984 M parameters
step 0: train loss 0.1030, val loss 0.1036
step 100: train loss 0.0245, val loss 0.0251
step 200: train loss 0.0185, val loss 0.0192
step 300: train loss 0.0140, val loss 0.0153
step 400: train loss 0.0123, val loss 0.0126
step 500: train loss 0.0109, val loss 0.0115
step 600: train loss 0.0093, val loss 0.0100
step 700: train loss 0.0081, val loss 0.0091
step 800: train loss 0.0076, val loss 0.0077
step 900: train loss 0.0070, val loss 0.0070
step 999: train loss 0.0062, val loss 0.0064
