In [120]:
import os
import json
import pickle
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import HeteroData
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict

In [121]:
# ========== 路径配置 ==========
data_dir = r"C:\Users\Keqin\Projects\CellLineKG\data_process\tahoe-100M"
output_dir = os.path.join(data_dir, "outputs", "final_graph")
os.makedirs(output_dir, exist_ok=True)

# 边表
cell_line2protein_path = os.path.join(data_dir, "cell_line2protein.csv")
drug_protein_path = os.path.join(data_dir, "primeKG_drug_protein.csv")
protein2protein_path = os.path.join(data_dir, "protein2protein.csv")
drug_synergy_path = os.path.join(data_dir, "tahoe_cellline_drug_synergy.csv")

# ESM protein embedding
esm_emb_path = r"C:\Users\Keqin\xwechat_files\wxid_0nn3oeq70kpq12_3a1f\msg\file\2025-10\esm2_t33_650M_UR50D_embeddings_wildtype.npy"
esm_gene_path = r"C:\Users\Keqin\xwechat_files\wxid_0nn3oeq70kpq12_3a1f\msg\file\2025-10\esm2_t33_650M_UR50D_embeddings_wildtype_gene_list_processed.csv"
ensg_map_path = os.path.join(data_dir, "ensg_mapping_geneSymbol_NCBI.json")

# Cell line expression (for cell embedding)
cell_line_gene_path = os.path.join(data_dir, "cell_line_gene.json")  # {"cell": {"gene_symbol": expr, ...}}

# ChemBERTa model
chemberta_path = os.path.join(data_dir, "models", "ChemBERTa-77M-MTR")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


In [122]:
# Cell 2: 读取原始数据
df_cell2prot = pd.read_csv(cell_line2protein_path)
df_drug_prot = pd.read_csv(drug_protein_path)
df_ppi = pd.read_csv(protein2protein_path)
df_synergy = pd.read_csv(drug_synergy_path)
with open(cell_line_gene_path, "r") as f:
    cell_line_gene = json.load(f)

print("原始数据形状:")
print("cell_line2protein:", df_cell2prot.shape)
print("drug_protein:", df_drug_prot.shape)
print("protein2protein:", df_ppi.shape)
print("drug_synergy:", df_synergy.shape)

原始数据形状:
cell_line2protein: (50, 2)
drug_protein: (11170, 3)
protein2protein: (6084266, 3)
drug_synergy: (84921, 6)


In [123]:
# Cell 3: 构建蛋白集合
# 1. cell_line 覆盖的蛋白（Entrez ID）
cell2proteins = {}
all_proteins_cell = set()
for _, row in df_cell2prot.iterrows():
    cell = row["cell_line"].strip()
    entrez_str = str(row.get("entrez_list", ""))
    prots = [p.strip() for p in entrez_str.split(";") if p.strip()]
    cell2proteins[cell] = set(prots)
    all_proteins_cell.update(prots)

print("Cell-line 覆盖蛋白数:", len(all_proteins_cell))

Cell-line 覆盖蛋白数: 10211


In [124]:

# 2. ESM embedding 覆盖的蛋白（通过 gene symbol → Entrez）
esm_emb = np.load(esm_emb_path)  # (N, 1280)
df_esm_genes = pd.read_csv(esm_gene_path)
with open(ensg_map_path, "r") as f:
    ensg_map = json.load(f)

symbol_to_entrez = {
    v["symbol"]: str(v["entrez"])
    for v in ensg_map.values()
    if "symbol" in v and "entrez" in v
}

esm_proteins_entrez = set()
esm_gene_to_idx = {}
for i, row in df_esm_genes.iterrows():
    symbol = str(row["genes"])
    if symbol in symbol_to_entrez:
        entrez = symbol_to_entrez[symbol]
        esm_proteins_entrez.add(entrez)
        esm_gene_to_idx[entrez] = i

print("ESM 覆盖蛋白数:", len(esm_proteins_entrez))

ESM 覆盖蛋白数: 8554


In [125]:

# 3. 蛋白交集：必须同时在 cell_line 表达谱 AND ESM embedding 中
protein_intersection = all_proteins_cell & esm_proteins_entrez
print("蛋白交集大小:", len(protein_intersection))

蛋白交集大小: 5242


In [126]:
# Cell 4: 筛选边表

# 1. PPI: 两端都在 protein_intersection
mask_ppi = df_ppi.apply(
    lambda r: (str(r["entrez1"]) in protein_intersection) and (str(r["entrez2"]) in protein_intersection),
    axis=1
)
df_ppi_filtered = df_ppi[mask_ppi].reset_index(drop=True)
print("PPI 筛选后:", len(df_ppi_filtered))

# 2. drug-protein: target ∈ protein_intersection
mask_dp = df_drug_prot["protein"].astype(str).isin(protein_intersection)
df_drug_prot_filtered = df_drug_prot[mask_dp].reset_index(drop=True)
print("drug-protein 筛选后:", len(df_drug_prot_filtered))

# 3. synergy drugs
synergy_drugs = set(
    df_synergy["drug_a_drugbank_id"].astype(str).tolist() +
    df_synergy["drug_b_drugbank_id"].astype(str).tolist()
)
drugs_with_target = set(df_drug_prot_filtered["drug"].astype(str))
candidate_drugs = synergy_drugs & drugs_with_target

# 4. 过滤 synergy
df_synergy_filtered = df_synergy[
    (df_synergy["drug_a_drugbank_id"].isin(candidate_drugs)) &
    (df_synergy["drug_b_drugbank_id"].isin(candidate_drugs)) &
    (df_synergy["cell_line"].isin(cell2proteins.keys()))
].reset_index(drop=True)
print("Synergy 筛选后:", len(df_synergy_filtered))

PPI 筛选后: 2104550
drug-protein 筛选后: 8376
Synergy 筛选后: 21613


In [127]:
print(len(df_synergy_filtered[
    (df_synergy["synergy"]>0)
]))
print(len(df_synergy_filtered[
    (df_synergy["synergy"]<0)
]))
print(len(df_synergy_filtered[
    (df_synergy["synergy"]>30)
]))
print(len(df_synergy_filtered[
    (df_synergy["synergy"]<30)
]))

7568
13996
615
20978


  print(len(df_synergy_filtered[
  print(len(df_synergy_filtered[
  print(len(df_synergy_filtered[
  print(len(df_synergy_filtered[


In [128]:
# Cell 5: 构建节点 ID 映射
cell_list = sorted(cell2proteins.keys())
drug_list = sorted(candidate_drugs)
prot_list = sorted(protein_intersection)

cell2id = {c: i for i, c in enumerate(cell_list)}
drug2id = {d: i for i, d in enumerate(drug_list)}
prot2id = {p: i for i, p in enumerate(prot_list)}

print("节点数:", len(cell2id), len(drug2id), len(prot2id))

节点数: 50 78 5242


In [129]:
# Cell 6: 构建 HeteroData（边去重 + 无向化）
# -------------------------------
# 工具函数部分
# -------------------------------
def add_undirected_edges(edge_list):
    """仅用于同类型节点之间的边（去自环 + 无向化 + 去重）"""
    edges = torch.tensor(edge_list, dtype=torch.long).t()
    src, dst = edges
    # 去自环
    mask = src != dst
    src, dst = src[mask], dst[mask]
    # 无向标准化
    min_nodes = torch.min(src, dst)
    max_nodes = torch.max(src, dst)
    undir_edges = torch.stack([min_nodes, max_nodes], dim=0)
    # 去重
    unique_edges = torch.unique(undir_edges, dim=1)
    return unique_edges


def make_bidirectional(edge_list):
    """用于异构图（不同节点类型之间）：添加反向边"""
    edges = torch.tensor(edge_list, dtype=torch.long).t()
    rev_edges = edges.flip(0)
    combined = torch.cat([edges, rev_edges], dim=1)
    unique_edges = torch.unique(combined, dim=1)
    return unique_edges


def add_edges(data, src_type, rel_type, dst_type, edge_list):
    """
    通用函数：
    - 同类型边（drug-drug, protein-protein）：无向 + 去重
    - 跨类型边（cell_line-protein, drug-protein）：双向 + 去重
    """
    if src_type == dst_type:
        # 无向关系
        edges = add_undirected_edges(edge_list)
        data[(src_type, rel_type, dst_type)].edge_index = edges
    else:
        # 异构关系：正向 + 反向
        edges = make_bidirectional(edge_list)
        mid = edges.size(1) // 2
        data[(src_type, rel_type, dst_type)].edge_index = edges[:, :mid]
        data[(dst_type, f"rev_{rel_type}", src_type)].edge_index = edges[:, mid:]
    return data


# -------------------------------
# 图构建部分
# -------------------------------
data = HeteroData()

# 添加节点数量
data["cell_line"].num_nodes = len(cell2id)
data["drug"].num_nodes = len(drug2id)
data["protein"].num_nodes = len(prot2id)

# 1️⃣ cell_line - expresses - protein
edge_cl_p = []
for cell, prots in cell2proteins.items():
    if cell in cell2id:
        cid = cell2id[cell]
        for p in prots:
            if p in prot2id:
                edge_cl_p.append([cid, prot2id[p]])
data = add_edges(data, "cell_line", "expresses", "protein", edge_cl_p)

# 2️⃣ drug - targets - protein
edge_d_p = []
for _, row in df_drug_prot_filtered.iterrows():
    d, p = str(row["drug"]), str(row["protein"])
    if d in drug2id and p in prot2id:
        edge_d_p.append([drug2id[d], prot2id[p]])
data = add_edges(data, "drug", "targets", "protein", edge_d_p)

# 3️⃣ protein - ppi - protein
edge_p_p = []
for _, row in df_ppi_filtered.iterrows():
    a, b = str(row["entrez1"]), str(row["entrez2"])
    if a in prot2id and b in prot2id:
        edge_p_p.append([prot2id[a], prot2id[b]])
data = add_edges(data, "protein", "ppi", "protein", edge_p_p)

# 4️⃣ drug - interacts - drug
edge_d_d = []
for _, row in df_synergy_filtered.iterrows():
    a, b = str(row["drug_a_drugbank_id"]), str(row["drug_b_drugbank_id"])
    if a in drug2id and b in drug2id:
        edge_d_d.append([drug2id[a], drug2id[b]])
data = add_edges(data, "drug", "interacts", "drug", edge_d_d)

print("图构建完成:", data)

图构建完成: HeteroData(
  cell_line={ num_nodes=50 },
  drug={ num_nodes=78 },
  protein={ num_nodes=5242 },
  (cell_line, expresses, protein)={ edge_index=[2, 128556] },
  (protein, rev_expresses, cell_line)={ edge_index=[2, 128557] },
  (drug, targets, protein)={ edge_index=[2, 314] },
  (protein, rev_targets, drug)={ edge_index=[2, 315] },
  (protein, ppi, protein)={ edge_index=[2, 1033904] },
  (drug, interacts, drug)={ edge_index=[2, 1638] }
)


In [130]:
# Cell 7: Protein Embedding
prot_emb = torch.zeros(len(prot2id), esm_emb.shape[1])
for entrez, idx in prot2id.items():
    emb_idx = esm_gene_to_idx[entrez]
    prot_emb[idx] = torch.from_numpy(esm_emb[emb_idx]).float()
data["protein"].x = prot_emb
print("Protein embedding shape:", data["protein"].x.shape)

Protein embedding shape: torch.Size([5242, 1280])


In [131]:
# Cell 8: Drug Embedding
# 提取 SMILES
drug_smiles = {}
for _, row in df_synergy_filtered.iterrows():
    for prefix in ["drug_a", "drug_b"]:
        did = str(row[f"{prefix}_drugbank_id"])
        smi = str(row.get(f"{prefix}_isosmiles", "")).strip()
        if did in drug2id and did not in drug_smiles and smi != "":
            drug_smiles[did] = smi

# 加载模型
tokenizer = AutoTokenizer.from_pretrained(chemberta_path)
model = AutoModel.from_pretrained(chemberta_path).to(device).eval()

# 生成 embedding
drug_emb = torch.zeros(len(drug2id), 384)
batch_size = 16
for i in tqdm(range(0, len(drug_smiles), batch_size)):
    batch = list(drug_smiles.items())[i:i+batch_size]
    dids, smiles = zip(*batch)
    inputs = tokenizer(list(smiles), return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)
    with torch.no_grad():
        out = model(**inputs).last_hidden_state.mean(dim=1).cpu()
    for j, did in enumerate(dids):
        drug_emb[drug2id[did]] = out[j]

data["drug"].x = drug_emb
print("Drug embedding shape:", data["drug"].x.shape)

Some weights of the model checkpoint at C:\Users\Keqin\Projects\CellLineKG\data_process\tahoe-100M\models\ChemBERTa-77M-MTR were not used when initializing RobertaModel: ['regression.out_proj.bias', 'regression.dense.weight', 'regression.out_proj.weight', 'norm_std', 'regression.dense.bias', 'norm_mean']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at C:\Users\Keqin\Projects\CellLineKG\data_process\tahoe-100M\models\ChemBERTa-77M-MTR and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler

Drug embedding shape: torch.Size([78, 384])





In [132]:
import pandas as pd
import ast

# ========== 路径 ==========
expr_path = r"C:\Users\Keqin\Downloads\OmicsExpressionTPMLogp1HumanProteinCodingGenes.csv"
map_path = r"C:\Users\Keqin\Projects\CellLineKG\data_process\drug_synergy\cellosaurus_mapped.csv"

# ========== 1) 加载表达矩阵 ==========
expr_df = pd.read_csv(expr_path)
expr_df = expr_df[expr_df["IsDefaultEntryForModel"] == "Yes"]
print(expr_df.columns)
print(expr_df.head())

In [143]:
import re

# 假设 GeneList_entrez 已经定义，例如：
# GeneList_entrez = [183, 4005, 57147, ...]

# ========== 1. 保留基础列 ==========
base_cols = ['ModelID']

# ========== 2. 从 expr_df.columns 中找到匹配的基因列 ==========
expr_cols = expr_df.columns.tolist()

# 用于按顺序存放筛选出的列
selected_gene_cols = []
GeneList_entrez = [29974,
 10006,
 25,
 27,
 2181,
 23305,
 90,
 92,
 4299,
 3899,
 27125,
 10142,
 207,
 208,
 10000,
 217,
 238,
 139285,
 286,
 9582,
 369,
 23092,
 394,
 9639,
 55160,
 23365,
 8289,
 57492,
 196528,
 405,
 79058,
 171023,
 55252,
 471,
 472,
 476,
 492,
 546,
 8312,
 8313,
 567,
 8314,
 580,
 581,
 11177,
 8915,
 53335,
 64919,
 596,
 83596,
 602,
 604,
 605,
 607,
 283149,
 9774,
 54880,
 63035,
 613,
 330,
 57448,
 641,
 653,
 657,
 673,
 672,
 8019,
 23476,
 695,
 776,
 811,
 23261,
 124583,
 84433,
 836,
 841,
 842,
 863,
 865,
 867,
 868,
 23624,
 8030,
 57820,
 892,
 595,
 894,
 896,
 898,
 1236,
 30835,
 29126,
 940,
 972,
 973,
 974,
 79577,
 1008,
 1009,
 1015,
 51755,
 1019,
 1021,
 1027,
 1029,
 1031,
 1045,
 84902,
 79145,
 1106,
 1108,
 11200,
 26511,
 50515,
 23152,
 4261,
 6249,
 10978,
 1213,
 8218,
 168975,
 7555,
 4849,
 26047,
 11064,
 1277,
 1280,
 1281,
 1345,
 22849,
 1385,
 90993,
 64764,
 1387,
 64109,
 51340,
 23373,
 64784,
 1436,
 1441,
 114788,
 10664,
 1496,
 1499,
 1500,
 1501,
 8452,
 1523,
 7852,
 1540,
 1558,
 57105,
 1616,
 340578,
 1630,
 1643,
 1649,
 4921,
 1662,
 1654,
 1655,
 1656,
 7913,
 54487,
 23405,
 3337,
 1785,
 1788,
 29102,
 1879,
 345930,
 8726,
 1956,
 1964,
 3646,
 1974,
 1999,
 2000,
 2005,
 8178,
 27436,
 2033,
 2034,
 2042,
 2045,
 2060,
 2064,
 2065,
 2066,
 23085,
 2068,
 2071,
 2072,
 2073,
 2078,
 2099,
 55500,
 2115,
 2118,
 2119,
 2120,
 2130,
 2131,
 2132,
 2146,
 7430,
 9715,
 51059,
 442444,
 2175,
 2176,
 2177,
 2178,
 2188,
 2189,
 2195,
 120114,
 79633,
 2199,
 80204,
 55294,
 2213,
 83417,
 2242,
 54738,
 2260,
 2263,
 2261,
 2264,
 2271,
 2272,
 81608,
 2313,
 2316,
 2322,
 2324,
 23048,
 3169,
 668,
 2308,
 2309,
 4303,
 27086,
 283150,
 10272,
 8880,
 2521,
 2623,
 2624,
 2625,
 2735,
 2767,
 2776,
 2778,
 9950,
 57120,
 2719,
 2262,
 10243,
 2903,
 2913,
 9709,
 23462,
 3092,
 3105,
 3159,
 8091,
 6927,
 3181,
 84376,
 3207,
 3209,
 3205,
 3227,
 3229,
 3239,
 3265,
 3320,
 3326,
 3399,
 3417,
 3418,
 10644,
 3551,
 10320,
 3558,
 50615,
 3572,
 3575,
 3662,
 8471,
 91464,
 3685,
 3702,
 3716,
 3717,
 3718,
 221895,
 3725,
 7994,
 23522,
 11143,
 3762,
 5927,
 8242,
 7403,
 3791,
 2531,
 9817,
 57670,
 3799,
 3815,
 9314,
 1316,
 8085,
 90417,
 3845,
 3895,
 23185,
 3927,
 9113,
 26524,
 3932,
 51176,
 23484,
 10186,
 3977,
 4000,
 4004,
 4005,
 4026,
 121227,
 53353,
 26065,
 4066,
 8216,
 346389,
 4094,
 9935,
 10892,
 84441,
 5604,
 5605,
 6416,
 4214,
 9175,
 5594,
 4149,
 151963,
 4193,
 4194,
 2122,
 9968,
 4221,
 4233,
 4255,
 4286,
 4291,
 4292,
 4298,
 8028,
 10962,
 4300,
 4302,
 4330,
 3110,
 4352,
 4436,
 2956,
 124540,
 4478,
 2475,
 4582,
 94025,
 4585,
 4595,
 4602,
 4609,
 4610,
 4613,
 4615,
 4629,
 4627,
 4644,
 4654,
 55728,
 4665,
 4666,
 26960,
 51517,
 8648,
 10499,
 9611,
 9612,
 10397,
 4763,
 4771,
 4773,
 4780,
 4781,
 4791,
 4794,
 51199,
 4841,
 4851,
 4853,
 4869,
 8013,
 4893,
 3084,
 64324,
 22978,
 4913,
 4914,
 4916,
 4926,
 8021,
 4928,
 10215,
 4958,
 286530,
 5049,
 79728,
 23598,
 5077,
 5079,
 5081,
 7849,
 55193,
 5087,
 5093,
 80380,
 9659,
 5155,
 5156,
 5159,
 84295,
 8929,
 8301,
 5290,
 5291,
 5295,
 5324,
 5335,
 5371,
 5378,
 5395,
 5424,
 5426,
 5428,
 25913,
 5450,
 5460,
 5468,
 8496,
 8493,
 5518,
 5537,
 5546,
 639,
 63976,
 7799,
 80243,
 5551,
 5566,
 5573,
 5579,
 25766,
 5396,
 11168,
 5727,
 5728,
 5753,
 5781,
 5783,
 5777,
 5787,
 5788,
 5789,
 5796,
 11122,
 114825,
 9444,
 9135,
 5884,
 5885,
 5890,
 5900,
 5903,
 5910,
 5914,
 5925,
 8241,
 64783,
 9401,
 5966,
 5979,
 55159,
 6000,
 387,
 399,
 116028,
 57674,
 54894,
 6092,
 6098,
 6146,
 6125,
 6184,
 340419,
 861,
 862,
 6278,
 57167,
 51119,
 6385,
 6389,
 6390,
 6391,
 6418,
 26040,
 23067,
 29072,
 9869,
 23451,
 6421,
 6424,
 6446,
 10019,
 6455,
 140885,
 6495,
 10736,
 6497,
 10568,
 85414,
 4087,
 4088,
 4089,
 6597,
 6598,
 6602,
 6605,
 8243,
 27044,
 92017,
 8651,
 11166,
 23013,
 8405,
 6714,
 9901,
 6427,
 6760,
 26039,
 6756,
 10735,
 6774,
 6777,
 6778,
 6491,
 6794,
 6801,
 51684,
 23512,
 6850,
 8148,
 6886,
 6887,
 79718,
 6926,
 6938,
 6934,
 8115,
 7006,
 54855,
 7015,
 80312,
 54790,
 7030,
 7942,
 10342,
 29844,
 7037,
 7048,
 9967,
 3195,
 30012,
 55654,
 7113,
 3371,
 7128,
 8764,
 608,
 7150,
 7157,
 8626,
 7170,
 7171,
 7175,
 84231,
 8805,
 5987,
 51592,
 9321,
 8295,
 7248,
 7249,
 7253,
 7307,
 51366,
 84101,
 9098,
 9101,
 7409,
 7428,
 143187,
 7454,
 11197,
 65268,
 7486,
 7490,
 25937,
 7507,
 7508,
 7514,
 7531,
 7704,
 55596,
 6935,
 463,
 7750,
 9203,
 55422,
 171017,
 353088,
 90827,
 25925,
 84133,
 8233]

for entrez in GeneList_entrez:
    pattern = fr'\({entrez}\)$'  # 匹配结尾为 (entrez) 的列名
    matched = [col for col in expr_cols if re.search(pattern, col)]
    if matched:
        selected_gene_cols.extend(matched)
    else:
        print(f"⚠️ 没找到Entrez {entrez} 对应的列")

# ========== 3. 拼接最终 DataFrame ==========
new_expr_df = expr_df[base_cols + selected_gene_cols]

print("新矩阵形状:", new_expr_df.shape)
print("前几列:", new_expr_df.columns[:10].tolist())


新矩阵形状: (1699, 641)
前几列: ['ModelID', 'A1CF (29974)', 'ABI1 (10006)', 'ABL1 (25)', 'ABL2 (27)', 'ACSL3 (2181)', 'ACSL6 (23305)', 'ACVR1 (90)', 'ACVR2A (92)', 'AFF1 (4299)']


In [144]:
expr_df = new_expr_df

In [145]:
non_gene_cols = [
    "Unnamed: 0",
    "SequencingID",
    "ModelID",
    "IsDefaultEntryForModel",
    "ModelConditionID",
    "IsDefaultEntryForMC",
    "ModelName",
    "DepMap_ID",
]
gene_cols = [c for c in expr_df.columns if c not in non_gene_cols]
print("Columns after cleaning:", expr_df.columns[:10])
print("Final matrix shape:", expr_df.shape)

# ---- 分离出 cell line ID 与基因特征矩阵 ----
cell_ids = expr_df["ModelID"].values if "ModelID" in expr_df.columns else expr_df.iloc[:, 0].values
gene_features = expr_df.drop(columns=["ModelID"], errors="ignore").values
print(f"Loaded expression matrix: {len(cell_ids)} cell lines × {gene_features.shape[1]} genes")


Columns after cleaning: Index(['ModelID', 'A1CF (29974)', 'ABI1 (10006)', 'ABL1 (25)', 'ABL2 (27)',
       'ACSL3 (2181)', 'ACSL6 (23305)', 'ACVR1 (90)', 'ACVR2A (92)',
       'AFF1 (4299)'],
      dtype='object')
Final matrix shape: (1699, 641)
Loaded expression matrix: 1699 cell lines × 640 genes


In [146]:
# ========================
# 2) 读取 Cellosaurus 映射表
# ========================
map_df = pd.read_csv(map_path)
map_df["accession"] = map_df["accession"].str.strip()
map_df["depmap"] = map_df["depmap"].str.strip()  # 这里 depmap 是 DepMap ID，例如 "ACH-000001"

# ========================
# 3) 构建匹配函数（DepMap直接匹配）
# ========================
def match_expression_row(accession):
    row = map_df[map_df["accession"] == accession]
    if row.empty or pd.isna(row.iloc[0]["depmap"]):
        return None
    depmap_id = row.iloc[0]["depmap"].strip().upper()
    hit = expr_df[expr_df["ModelID"].str.upper() == depmap_id]
    if not hit.empty:
        return hit.drop(columns=["ModelID"]).values.squeeze()
    return None


In [147]:
# ========================
# 4) 构建 cell_line.x —— 保存原始表达（640 维）
# ========================
missing_cells = []
expr_dim = gene_features.shape[1]  # 19220
cell_emb = torch.zeros(len(cell2id), expr_dim)

for cell, idx in cell2id.items():
    expr_vec = match_expression_row(cell)
    if expr_vec is not None:
        expr_vec = pd.to_numeric(pd.Series(expr_vec), errors="coerce").fillna(0).values.astype(np.float32)
        cell_emb[idx] = torch.tensor(expr_vec, dtype=torch.float)
    else:
        missing_cells.append(cell)
        cell_emb[idx] = torch.zeros(expr_dim)

print(f"Matched {len(cell2id) - len(missing_cells)} / {len(cell2id)} cell lines")
if missing_cells:
    print("Missing cell lines:", missing_cells[:10], "...")

# ⚠️ 关键修改：直接保存原始表达，不做 Linear 投影！
data["cell_line"].x = cell_emb  # [50, 19220]
print("Cell line embedding shape:", data["cell_line"].x.shape)

Matched 47 / 50 cell lines
Missing cell lines: ['CVCL_0218', 'CVCL_1098', 'CVCL_C466'] ...
Cell line embedding shape: torch.Size([50, 640])


In [148]:
data

HeteroData(
  cell_line={
    num_nodes=50,
    x=[50, 640],
  },
  drug={
    num_nodes=78,
    x=[78, 384],
  },
  protein={
    num_nodes=5242,
    x=[5242, 1280],
  },
  (cell_line, expresses, protein)={ edge_index=[2, 128556] },
  (protein, rev_expresses, cell_line)={ edge_index=[2, 128557] },
  (drug, targets, protein)={ edge_index=[2, 314] },
  (protein, rev_targets, drug)={ edge_index=[2, 315] },
  (protein, ppi, protein)={ edge_index=[2, 1033904] },
  (drug, interacts, drug)={ edge_index=[2, 1638] }
)

In [149]:
# Cell 10: 构建 & 划分 Triples
triples = []
for _, row in df_synergy_filtered.iterrows():
    a = str(row["drug_a_drugbank_id"])
    b = str(row["drug_b_drugbank_id"])
    c = str(row["cell_line"])
    label = float(row.get("synergy", row.iloc[-1]))
    if a in drug2id and b in drug2id and c in cell2id:
        triples.append((drug2id[a], drug2id[b], cell2id[c], label))
    else:
        print(a, b, c)

# 按 drug 划分 70/30
import random
random.seed(42)
all_drugs = list(drug2id.keys())
random.shuffle(all_drugs)
split = int(0.7 * len(all_drugs))
train_drug_set = set(all_drugs[:split])
test_drug_set = set(all_drugs[split:])

train_triples = [t for t in triples if (list(drug2id.keys())[t[0]] in train_drug_set) and (list(drug2id.keys())[t[1]] in train_drug_set)]
test_triples = [t for t in triples if (list(drug2id.keys())[t[0]] in test_drug_set) or (list(drug2id.keys())[t[1]] in test_drug_set)]

print(f"Train: {len(train_triples)}, Test: {len(test_triples)}")

Train: 11482, Test: 10131


In [150]:
# Cell 11: 保存结果
torch.save(data, os.path.join(output_dir, "hetero_graph_filtered.pt"))

metadata = {
    "cell2id": cell2id,
    "drug2id": drug2id,
    "prot2id": prot2id,
    "train_triples": train_triples,
    "test_triples": test_triples,
}
with open(os.path.join(output_dir, "graph_metadata_filtered.pkl"), "wb") as f:
    pickle.dump(metadata, f)

print("✅ 保存完成!")

✅ 保存完成!
