In [2]:
import numpy as np
import chainer
import chainer.links as L
import chainer.functions as F
import anndata as ad
from chainer import Variable, optimizers
from sklearn.preprocessing import LabelEncoder

--------------------------------------------------------------------------------
CuPy (cupy-cuda101) version 9.6.0 may not be compatible with this version of Chainer.
Please consider installing the supported version by running:
  $ pip install 'cupy-cuda101>=7.7.0,<8.0.0'

See the following page for more details:
  https://docs.cupy.dev/en/latest/install.html
--------------------------------------------------------------------------------



In [3]:
# 用于生成正交矩阵，作为基因嵌入模型的初始参数
def _orthogonal_matrix(shape):
    M1 = np.random.randn(shape[0], shape[0])
    M2 = np.random.randn(shape[1], shape[1])
    Q1, R1 = np.linalg.qr(M1)
    Q2, R2 = np.linalg.qr(M2)
    Q1 = Q1 * np.sign(np.diag(R1))
    Q2 = Q2 * np.sign(np.diag(R2))
    n_min = min(shape[0], shape[1])
    return np.dot(Q1[:, :n_min], Q2[:n_min, :])

class GeneEmbedMixture(chainer.Chain):
    """
    将基因表示为多个潜在的功能组/生物通路的加权组合
    每个基因对应一个加权向量，表示它与每个功能组的关联强度。
    
    Args:
        n_genes (int): 基因总数
        n_topics (int): 潜在功能组或生物通路的数量
        n_dim (int): 每个功能组向量的维度（可以与词向量的维度相同）
    
    Attributes:
        weights : chainer.links.EmbedID
            基因到潜在功能组的未归一化权重 (即基因对应每个功能组的权重)
        factors : chainer.links.Parameter
            功能组/通路向量矩阵
    """
    def __init__(self, n_genes, n_topics, n_dim, dropout_ratio=0.2, temperature=1.0):
        self.n_genes = n_genes
        self.n_topics = n_topics
        self.n_dim = n_dim
        self.dropout_ratio = dropout_ratio
        factors = _orthogonal_matrix((n_topics, n_dim)).astype('float32')
        factors /= np.sqrt(n_topics + n_dim)
        super(GeneEmbedMixture, self).__init__(
            weights=L.EmbedID(n_genes, n_topics),
            factors=L.Parameter(factors)
        )
        self.temperature = temperature
        self.weights.W.data[...] /= np.sqrt(n_genes + n_topics)

    def __call__(self, gene_ids, update_only_genes=False):
        """
        给定基因 ID 数组，返回每个基因的嵌入向量。
        
        Args:
            gene_ids : chainer.Variable
                基因 ID 的一维数组
        
        Returns:
            gene_vector : chainer.Variable
                每个基因的嵌入向量，表示为多个功能组的加权组合。
        """
        proportions = self.proportions(gene_ids, softmax=True)
        factors = F.dropout(self.factors(), ratio=self.dropout_ratio)
        if update_only_genes:
            factors.unchain_backward()
        gene_embeddings = F.matmul(proportions, factors)
        return gene_embeddings

    def proportions(self, gene_ids, softmax=False):
        """
        给定基因 ID 数组，返回每个基因的未归一化的功能组权重。
        
        Returns:
            gene_weights : chainer.Variable
                每个基因的功能组权重矩阵。
        """
        w = self.weights(gene_ids)
        if softmax:
            size = w.data.shape
            mask = self.xp.random.random_integers(0, 1, size=size)
            y = (F.softmax(w * self.temperature) * Variable(mask.astype('float32')))
            norm, y = F.broadcast(F.expand_dims(F.sum(y, axis=1), 1), y)
            return y / (norm + 1e-7)
        else:
            return w


In [20]:
from scipy.sparse import coo_matrix

def cosine_similarity_loss(x, y):
    """
    计算两个张量之间的余弦相似性损失。
    
    参数:
        x: 第一个张量
        y: 第二个张量
    
    返回:
        余弦相似性损失
    """
    # 计算 L2 范数
    norm_x = F.sqrt(F.sum(F.square(x), axis=1, keepdims=True))
    norm_y = F.sqrt(F.sum(F.square(y), axis=1, keepdims=True))
    
    # 计算点积
    dot_product = F.sum(x * y, axis=1, keepdims=True)
    
    # 计算余弦相似性
    cosine_sim = dot_product / (norm_x * norm_y + 1e-8)  # 添加小值以避免除以零
    
    # 余弦相似性损失 = 1 - cosine similarity
    return 1 - F.mean(cosine_sim)


# 读取 .h5ad 文件数据
def load_h5ad_data(filepath):
    adata = ad.read_h5ad(filepath)
    
    # 提取基因表达矩阵 (可以是稀疏矩阵)
    gene_expression_matrix = adata.X
    
    # 提取基因名
    genes = adata.var_names
    
    # 提取样本（细胞）的名称或 ID
    samples = adata.obs_names
    
    return gene_expression_matrix, genes, samples

# 准备模型输入
def prepare_data(gene_expression_matrix, genes):
    # 如果模型输入是基因名，我们使用 LabelEncoder 将基因名编码为 ID
    gene_encoder = LabelEncoder()
    gene_ids = gene_encoder.fit_transform(genes)  # 每个基因对应一个唯一的 ID
    return gene_ids, gene_encoder

def train_gene_embedding_model(gene_ids, gene_expression_matrix, n_topics, n_dim, epochs, lr):
    # 获取矩阵的维度（文档数/样本数）
    n_documents = gene_expression_matrix.shape[1]
    
    # 检查 gene_ids 是否在有效范围内，如果超过了则调整 n_documents
    max_gene_id = np.max(gene_ids)
    if max_gene_id >= n_documents:
        print(f"警告: gene_ids 中的最大索引 {max_gene_id} 超过了基因表达矩阵的行数 {n_documents}")
        print("调整 n_documents 以匹配 gene_ids 中的最大索引...")
        n_documents = max_gene_id + 1

    # 初始化模型
    model = GeneEmbedMixture(n_documents, n_topics, n_dim)
    optimizer = chainer.optimizers.Adam(lr)
    optimizer.setup(model)

    # 训练循环
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}/{epochs}')
        
        # 随机打乱样本的索引，保证每次训练顺序不同
        indices = np.random.permutation(n_documents)

        # 选择全量数据
        full_batch_gene_ids = gene_ids[indices]
        
        # 选择对应的基因表达矩阵行
        full_batch_gene_expression = gene_expression_matrix[indices]
        
        # 将稀疏矩阵转换为密集矩阵
        full_batch_gene_expression_dense = full_batch_gene_expression.toarray()
        full_batch_gene_expression_var = Variable(full_batch_gene_expression_dense)

        # 前向传播计算 gene embedding
        gene_ids_var = Variable(full_batch_gene_ids)
        gene_embeddings = model(gene_ids_var)

        # 反向传播并优化
        model.cleargrads()
        optimizer.update()

    return model


# 使用模型生成基因嵌入
def generate_gene_embeddings(model, gene_ids, gene_encoder):
    # 获取基因嵌入
    gene_ids_var = Variable(gene_ids)
    gene_embeddings = model(gene_ids_var).data  # 生成嵌入向量
    
    # 将基因 ID 映射回基因名
    gene_names = gene_encoder.inverse_transform(gene_ids)
    
    # 构建一个字典，将基因名映射到它们的嵌入向量
    gene_embedding_dict = {gene: embedding for gene, embedding in zip(gene_names, gene_embeddings)}
    
    return gene_embedding_dict