In [None]:
import torch
import pandas as pd
from transformers import AutoModel, AutoTokenizer

In [None]:

df_metname_smile = pd.read_csv("/home/metabolites_to_SMILES.csv")
# df_metname_smile.rename(columns={"Exact Match to Standard (* = isomer family)": "Met_name"}, inplace=True)
df_metname_smile.shape

(466, 3)

In [None]:
df_metname_smile = df_metname_smile.dropna(subset=["SMILES"]).reset_index(drop=True)

In [None]:
df_metname_smile

Unnamed: 0.1,Unnamed: 0,Exact Match to Standard (* = isomer family),SMILES
0,HILIC-neg_Cluster_0622,"1,2,3,4-tetrahydro-1-methyl-beta-carboline-3-c...",CC1NC(Cc2c1[nH]c3ccccc23)C(O)=O
1,C18-neg_Cluster_0183,"1,2,3,4-tetrahydro-b-carboline-1,3-dicarboxyli...",OC(=O)C1Cc2c([nH]c3ccccc23)C(N1)C(O)=O
2,C18-neg_Cluster_0393,12.13-diHOME,CCCCCC(O)C(O)C\C=C/CCCCCCCC(O)=O
3,HILIC-neg_Cluster_0480,1-3-7-trimethylurate,CN1C(=O)N(C)C2=C(N(C)C(=O)N2)C1=O
4,C18-neg_Cluster_0530,13-docosenoate,CCCCCCCCC=CCCCCCCCCCCCC([O-])=O
...,...,...,...
296,HILIC-pos_Cluster_0116,urocanic acid,OC(=O)/C=C/c1[nH]cnc1
297,HILIC-pos_Cluster_0046,valine,CC(C)[C@H](N)C(O)=O
298,HILIC-neg_Cluster_0066,valine,CC(C)[C@H](N)C(O)=O
299,HILIC-neg_Cluster_0187,xanthine,O=C1NC(=O)c2[nH]cnc2N1


In [None]:
smiles_list = df_metname_smile["SMILES"].tolist()
smiles_list[:20]

['CC1NC(Cc2c1[nH]c3ccccc23)C(O)=O',
 'OC(=O)C1Cc2c([nH]c3ccccc23)C(N1)C(O)=O',
 'CCCCCC(O)C(O)C\\C=C/CCCCCCCC(O)=O',
 'CN1C(=O)N(C)C2=C(N(C)C(=O)N2)C1=O',
 'CCCCCCCCC=CCCCCCCCCCCCC([O-])=O',
 'CN1C(=Nc2nc[nH]c2C1=O)N',
 'CN1CC(=CC=C1)C(N)=O',
 'CCC(N)C(O)=O',
 'Nc1ncnc2n(cnc12)C3CC(O)C(CO)O3',
 'CCC(C)C(O)C(O)=O',
 'OC(CCC([O-])=O)C([O-])=O',
 'CCCCCCCCCCCCCCC(O)C([O-])=O',
 'CCCCCCCCCCCCC(O)C(O)=O',
 'NCCc1ccccc1O',
 'CC(CCC(O)=O)C(O)=O',
 'CC(CCC(O)=O)CC(O)=O.OC(=O)CCCCCC(O)=O',
 'Cn1cncc1C[C@H](N)C(O)=O',
 'CN1C(=O)NC(=O)c2[nH]cnc12',
 'NC(=N)NCCCC(O)=O',
 'CC(=O)c1ccc(O)c(C)c1']

In [None]:
# From MoLFormer website: https://huggingface.co/ibm-research/MoLFormer-XL-both-10pct
model = AutoModel.from_pretrained("ibm/MoLFormer-XL-both-10pct", deterministic_eval=True, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True)

In [None]:
batch_size = 32
all_embeddings = []

for i in range(0, len(smiles_list), batch_size):
    batch_smiles = smiles_list[i : i + batch_size]
    inputs = tokenizer(batch_smiles, padding=True, return_tensors="pt")
    inputs = {k: v for k, v in inputs.items()}

    # also from huggingface website
    with torch.no_grad():
        outputs = model(**inputs)
        batch_emb = outputs.pooler_output  # [batch, hidden_dim]

    all_embeddings.append(batch_emb)

all_embeddings = torch.cat(all_embeddings, dim=0) # [N_metabolites, hidden_dim]
emb_np = all_embeddings.numpy()

emb_np.shape


(301, 768)

In [None]:
len(smiles_list)

301

In [None]:
emb_df = pd.DataFrame(
    emb_np,
    columns=[f"emb_{i}" for i in range(emb_np.shape[1])]
)
emb_df.head()

Unnamed: 0,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
0,0.505307,0.368583,-0.16389,0.557924,-0.69992,-0.661996,-0.216786,0.257202,-0.443332,0.217897,...,-0.682277,-0.032246,-0.241925,0.813374,0.190511,-0.061261,-2.48789,0.327604,0.040142,-0.362552
1,-0.133892,0.371886,0.260327,0.502376,-0.831947,-0.467348,-0.081658,-0.101592,-0.162596,-0.317,...,-0.346604,-0.498758,-0.502389,0.943361,0.42465,0.077521,-2.27077,0.091566,0.039425,-0.083055
2,1.270118,-0.746132,0.660812,-0.218625,-1.001234,-0.362599,0.194166,0.314788,0.016662,-0.636142,...,-0.398302,0.009072,-0.698953,0.039464,0.111021,-0.336268,-2.697375,0.631028,0.212545,-0.276969
3,-0.642621,0.317733,0.051413,0.600625,-0.67149,0.394049,-0.811712,0.481118,-0.312811,0.178391,...,0.026004,-0.118494,-0.923518,0.10561,0.44194,-0.332174,-2.005177,0.065245,-0.077093,-0.101834
4,0.732133,-0.032527,1.018785,-0.060513,0.025252,0.280306,-0.036591,-0.13801,-0.010335,-0.541228,...,-0.407527,-0.930379,-0.736568,-0.084411,-0.285262,0.272547,-2.808429,1.518496,0.133684,-0.614252


In [None]:
emb_df.insert(0, "Exact Match to Standard (* = isomer family)", df_metname_smile["Exact Match to Standard (* = isomer family)"].values)
# emb_df.insert(1, "SMILES", df_metname_smile["SMILES"].values)

emb_df.head()

Unnamed: 0,Exact Match to Standard (* = isomer family),emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
0,"1,2,3,4-tetrahydro-1-methyl-beta-carboline-3-c...",0.505307,0.368583,-0.16389,0.557924,-0.69992,-0.661996,-0.216786,0.257202,-0.443332,...,-0.682277,-0.032246,-0.241925,0.813374,0.190511,-0.061261,-2.48789,0.327604,0.040142,-0.362552
1,"1,2,3,4-tetrahydro-b-carboline-1,3-dicarboxyli...",-0.133892,0.371886,0.260327,0.502376,-0.831947,-0.467348,-0.081658,-0.101592,-0.162596,...,-0.346604,-0.498758,-0.502389,0.943361,0.42465,0.077521,-2.27077,0.091566,0.039425,-0.083055
2,12.13-diHOME,1.270118,-0.746132,0.660812,-0.218625,-1.001234,-0.362599,0.194166,0.314788,0.016662,...,-0.398302,0.009072,-0.698953,0.039464,0.111021,-0.336268,-2.697375,0.631028,0.212545,-0.276969
3,1-3-7-trimethylurate,-0.642621,0.317733,0.051413,0.600625,-0.67149,0.394049,-0.811712,0.481118,-0.312811,...,0.026004,-0.118494,-0.923518,0.10561,0.44194,-0.332174,-2.005177,0.065245,-0.077093,-0.101834
4,13-docosenoate,0.732133,-0.032527,1.018785,-0.060513,0.025252,0.280306,-0.036591,-0.13801,-0.010335,...,-0.407527,-0.930379,-0.736568,-0.084411,-0.285262,0.272547,-2.808429,1.518496,0.133684,-0.614252


In [None]:
emb_df.to_csv("/home/metabolite_embeddings_molformer.csv", index=False)

In [None]:
df_metname_metid = pd.read_excel("/home/41564_2018_306_MOESM3_ESM.xlsx", header=1)
df_metname_metid.head()

Unnamed: 0,Metabolomic Feature,Retention Time,m/z,Cluster (if DA),Putative Chemical Class,Exact Match to Standard (* = isomer family),Adduct
0,HILIC-neg_Cluster_0622,3.809685,229.097974,5.0,Harmala alkaloids,"1,2,3,4-tetrahydro-1-methyl-beta-carboline-3-c...",[M-H]-
1,C18-neg_Cluster_0183,1.282739,259.072196,2.0,,"1,2,3,4-tetrahydro-b-carboline-1,3-dicarboxyli...",[M-H]-
2,C18-neg_Cluster_0393,10.162391,313.238314,,Long-chain fatty acids,12.13-diHOME,[M-H]-
3,HILIC-neg_Cluster_0480,3.83278,209.067885,,Imidazopyrimidines,1-3-7-trimethylurate,[M-H]-
4,C18-neg_Cluster_0530,16.986972,337.310869,,Very long-chain fatty acids,13-docosenoate,[M-H]-


In [None]:
metid_metname_smiles = df_metname_metid.merge(
    df_metname_smile,
    on="Exact Match to Standard (* = isomer family)",
    how="inner"       # only retain what are the same in both dfs
)
metid_metname_smiles = metid_metname_smiles.dropna(subset=["SMILES"]).drop_duplicates(subset=["Metabolomic Feature"]).reset_index(drop=True)

metid_metname_smiles_emb = metid_metname_smiles.merge(
    emb_df,
    on="Exact Match to Standard (* = isomer family)",
    how="inner"
)

metid_metname_smiles_emb = (
    metid_metname_smiles_emb
    .drop_duplicates(subset=["Metabolomic Feature"])
    .reset_index(drop=True)
)
metid_metname_smiles_emb.head()
# 256 rows × 776 columns

Unnamed: 0.1,Metabolomic Feature,Retention Time,m/z,Cluster (if DA),Putative Chemical Class,Exact Match to Standard (* = isomer family),Adduct,Unnamed: 0,SMILES,emb_0,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
0,HILIC-neg_Cluster_0622,3.809685,229.097974,5.0,Harmala alkaloids,"1,2,3,4-tetrahydro-1-methyl-beta-carboline-3-c...",[M-H]-,HILIC-neg_Cluster_0622,CC1NC(Cc2c1[nH]c3ccccc23)C(O)=O,0.505307,...,-0.682277,-0.032246,-0.241925,0.813374,0.190511,-0.061261,-2.48789,0.327604,0.040142,-0.362552
1,C18-neg_Cluster_0183,1.282739,259.072196,2.0,,"1,2,3,4-tetrahydro-b-carboline-1,3-dicarboxyli...",[M-H]-,C18-neg_Cluster_0183,OC(=O)C1Cc2c([nH]c3ccccc23)C(N1)C(O)=O,-0.133892,...,-0.346604,-0.498758,-0.502389,0.943361,0.42465,0.077521,-2.27077,0.091566,0.039425,-0.083055
2,C18-neg_Cluster_0393,10.162391,313.238314,,Long-chain fatty acids,12.13-diHOME,[M-H]-,C18-neg_Cluster_0393,CCCCCC(O)C(O)C\C=C/CCCCCCCC(O)=O,1.270118,...,-0.398302,0.009072,-0.698953,0.039464,0.111021,-0.336268,-2.697375,0.631028,0.212545,-0.276969
3,HILIC-neg_Cluster_0480,3.83278,209.067885,,Imidazopyrimidines,1-3-7-trimethylurate,[M-H]-,HILIC-neg_Cluster_0480,CN1C(=O)N(C)C2=C(N(C)C(=O)N2)C1=O,-0.642621,...,0.026004,-0.118494,-0.923518,0.10561,0.44194,-0.332174,-2.005177,0.065245,-0.077093,-0.101834
4,C18-neg_Cluster_0530,16.986972,337.310869,,Very long-chain fatty acids,13-docosenoate,[M-H]-,C18-neg_Cluster_0530,CCCCCCCCC=CCCCCCCCCCCCC([O-])=O,0.732133,...,-0.407527,-0.930379,-0.736568,-0.084411,-0.285262,0.272547,-2.808429,1.518496,0.133684,-0.614252


In [None]:
df_abund_samples = pd.read_excel("/home/41564_2018_306_MOESM4_ESM.xlsx", header=1)
df_abund_samples = df_abund_samples.set_index("# Feature / Sample").T
df_abund_samples

# Feature / Sample,Age,Diagnosis,Fecal.Calprotectin,antibiotic,immunosuppressant,mesalamine,steroids,C18-neg_Cluster_0001,C18-neg_Cluster_0002,C18-neg_Cluster_0003,...,HILIC-pos_Cluster_2367,HILIC-pos_Cluster_2368,HILIC-pos_Cluster_2369,HILIC-pos_Cluster_2370,HILIC-pos_Cluster_2371,HILIC-pos_Cluster_2372,HILIC-pos_Cluster_2373,HILIC-pos_Cluster_2374,HILIC-pos_Cluster_2375,HILIC-pos_Cluster_2376
PRISM|7122,38,CD,207.484429,No,Yes,No,No,0,6391.01,288.808,...,9.01813,0,119.404,1272.3,722.609,10.4174,0,0,14.3306,0
PRISM|7147,50,CD,,No,No,Yes,No,1635.54,27.4461,59.2412,...,15.9256,0,6.32188,115.12,38.2105,27.6128,0,5.32097,40.4456,40.1677
PRISM|7150,41,CD,218.334517,No,Yes,No,No,0,8265.9,7708.63,...,0,0,18.5229,37.7271,0,0,0,20.8014,41.2349,0
PRISM|7153,51,CD,,No,No,Yes,No,203.783,14.2666,57.3647,...,11.3461,145.987,35.9724,5428.45,2868.94,9.98243,0,0,9.87679,77.0562
PRISM|7184,68,CD,20.167951,No,No,No,No,0,332.206,42.5518,...,13.3375,0,0,5664.39,2112.37,21.7531,6.01921,7.29409,9.92948,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Validation|UMCGIBD00593,21,UC,40,No,No,Yes,No,2944.38,293.233,86.8866,...,12.8184,0.016195,0,69.3297,69.7259,11.3143,0,8.39139,34.2163,0
Validation|UMCGIBD00233,32,CD,45,No,Yes,No,No,0,1017.67,71.9505,...,0,0,0,550.638,224.916,0,0,11.0293,14.818,0
Validation|UMCGIBD00238,38,CD,305,No,Yes,No,No,0,411.05,2631.76,...,7.99238,0,13.1122,140.88,234.404,0,10.1884,40.7231,10.3148,0
Validation|UMCGIBD00027,51,CD,44,No,Yes,No,No,0,453.672,24.7567,...,0,49.5064,5.40646,1009.55,604.698,0,0,0,27.8203,0


In [None]:
features_in_abund = set(df_abund_samples.columns)
meta_smiles_in_abund = metid_metname_smiles_emb[metid_metname_smiles_emb["Metabolomic Feature"].isin(features_in_abund)].reset_index(drop=True) # 494 in both abundance and meta_simles
feature_order = meta_smiles_in_abund["Metabolomic Feature"].tolist()

emb_cols = [c for c in meta_smiles_in_abund.columns if c.startswith("emb_")]

feat_emb_df = meta_smiles_in_abund.set_index("Metabolomic Feature")[emb_cols]
feat_emb_df.head()
# feat_emb_df.shape # (256, 768)

Unnamed: 0_level_0,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
Metabolomic Feature,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
HILIC-neg_Cluster_0622,0.505307,0.368583,-0.16389,0.557924,-0.69992,-0.661996,-0.216786,0.257202,-0.443332,0.217897,...,-0.682277,-0.032246,-0.241925,0.813374,0.190511,-0.061261,-2.48789,0.327604,0.040142,-0.362552
C18-neg_Cluster_0183,-0.133892,0.371886,0.260327,0.502376,-0.831947,-0.467348,-0.081658,-0.101592,-0.162596,-0.317,...,-0.346604,-0.498758,-0.502389,0.943361,0.42465,0.077521,-2.27077,0.091566,0.039425,-0.083055
C18-neg_Cluster_0393,1.270118,-0.746132,0.660812,-0.218625,-1.001234,-0.362599,0.194166,0.314788,0.016662,-0.636142,...,-0.398302,0.009072,-0.698953,0.039464,0.111021,-0.336268,-2.697375,0.631028,0.212545,-0.276969
HILIC-neg_Cluster_0480,-0.642621,0.317733,0.051413,0.600625,-0.67149,0.394049,-0.811712,0.481118,-0.312811,0.178391,...,0.026004,-0.118494,-0.923518,0.10561,0.44194,-0.332174,-2.005177,0.065245,-0.077093,-0.101834
C18-neg_Cluster_0530,0.732133,-0.032527,1.018785,-0.060513,0.025252,0.280306,-0.036591,-0.13801,-0.010335,-0.541228,...,-0.407527,-0.930379,-0.736568,-0.084411,-0.285262,0.272547,-2.808429,1.518496,0.133684,-0.614252


In [None]:
metabolite_cols = feat_emb_df.index.intersection(df_abund_samples.columns)

abund_feat_df = df_abund_samples[metabolite_cols].astype(float)
abund_feat_df.head()

# feat_emb_aligned_df = feat_emb_df.loc[metabolite_cols]
# feat_emb_aligned_df.head()
# feat_emb_aligned_df.shape #(256, 768)

# Feature / Sample,HILIC-neg_Cluster_0622,C18-neg_Cluster_0183,C18-neg_Cluster_0393,HILIC-neg_Cluster_0480,C18-neg_Cluster_0530,HILIC-pos_Cluster_0245,HILIC-pos_Cluster_0110,HILIC-neg_Cluster_0032,HILIC-pos_Cluster_0728,HILIC-neg_Cluster_0113,...,C18-neg_Cluster_0079,HILIC-neg_Cluster_0049,HILIC-neg_Cluster_0254,HILIC-neg_Cluster_0706,C18-neg_Cluster_2021,HILIC-pos_Cluster_0116,HILIC-pos_Cluster_0046,HILIC-neg_Cluster_0066,HILIC-neg_Cluster_0187,HILIC-neg_Cluster_0176
PRISM|7122,210.752,63.0097,507.955,873.165,311.94,17.1701,454.239,356.396,20.7759,295.788,...,203.581,1143.26,983.554,43.4791,708.803,15665.7,9555.28,29184.9,10964.2,85.3097
PRISM|7147,23.7645,62.714,162.325,690.634,0.0,0.027799,103.944,12.3401,0.0,26.1544,...,167.526,65.0007,479.314,779.622,1459.86,230.937,385.239,894.347,1056.03,0.870127
PRISM|7150,359.982,308.907,123.417,686.706,31.5585,0.020499,0.0,791.694,21.0533,5721.94,...,737.486,5492.72,3666.58,99.0469,1817.8,3355.72,6625.12,30172.0,12451.7,105.129
PRISM|7153,190.391,28.3744,1027.72,348.698,48.5136,0.0,0.0,19.8677,0.0,57.3873,...,84.4193,167.95,640.691,271.17,56.7944,323.066,2194.0,1047.39,404.082,0.0
PRISM|7184,45.9292,50.5728,566.753,798.689,49.7607,63.307,6.17649,265.094,38.4573,111.748,...,274.743,7881.59,458.678,611.721,21.1268,9744.98,4961.99,15037.5,14820.1,116.503


In [None]:
import numpy as np
# log(1 + x) transformation
log_abund_df = np.log1p(abund_feat_df)

row_sums = log_abund_df.sum(axis=1)
row_sums = row_sums.replace(0, np.nan)

weights_df = log_abund_df.div(row_sums, axis=0).fillna(0.0)
weights_df.head()
# weights_df.shape #(220, 256)

# Feature / Sample,HILIC-neg_Cluster_0622,C18-neg_Cluster_0183,C18-neg_Cluster_0393,HILIC-neg_Cluster_0480,C18-neg_Cluster_0530,HILIC-pos_Cluster_0245,HILIC-pos_Cluster_0110,HILIC-neg_Cluster_0032,HILIC-pos_Cluster_0728,HILIC-neg_Cluster_0113,...,C18-neg_Cluster_0079,HILIC-neg_Cluster_0049,HILIC-neg_Cluster_0254,HILIC-neg_Cluster_0706,C18-neg_Cluster_2021,HILIC-pos_Cluster_0116,HILIC-pos_Cluster_0046,HILIC-neg_Cluster_0066,HILIC-neg_Cluster_0187,HILIC-neg_Cluster_0176
PRISM|7122,0.003276,0.002544,0.003812,0.004143,0.003514,0.001774,0.003744,0.003596,0.001884,0.003482,...,0.003255,0.004307,0.004216,0.002321,0.004015,0.005908,0.005606,0.006289,0.00569,0.002727
PRISM|7147,0.002408,0.003117,0.003823,0.004906,0.0,2.1e-05,0.003491,0.001944,0.0,0.002477,...,0.003847,0.003143,0.004632,0.004997,0.005467,0.004086,0.004469,0.0051,0.005224,0.00047
PRISM|7150,0.003923,0.003821,0.003213,0.004352,0.00232,1.4e-05,0.0,0.004447,0.002061,0.005764,...,0.0044,0.005736,0.005467,0.003068,0.005,0.005408,0.005861,0.006871,0.006282,0.003107
PRISM|7153,0.003681,0.002368,0.004859,0.004103,0.002734,0.0,0.0,0.002128,0.0,0.002849,...,0.003116,0.003594,0.004529,0.003928,0.002842,0.00405,0.00539,0.004873,0.004206,0.0
PRISM|7184,0.002408,0.002467,0.003969,0.004183,0.002458,0.002606,0.001233,0.003494,0.0023,0.002957,...,0.003517,0.005615,0.003836,0.004016,0.001938,0.005748,0.005325,0.006019,0.00601,0.002983


In [None]:
# 5. [n_samples, n_features] · [n_features, emb_dim] → [n_samples, emb_dim]
sample_emb_df = weights_df.dot(feat_emb_df)

sample_emb_df.head()

Unnamed: 0,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
PRISM|7122,0.503727,0.071121,0.26221,0.208107,-0.750364,-0.162051,-0.231517,-0.044882,0.066673,-0.186569,...,-0.41247,-0.263849,-0.299218,0.275538,0.09952,-0.097561,-2.345192,0.402813,-0.203259,-0.167115
PRISM|7147,0.484642,0.089444,0.267869,0.230851,-0.757682,-0.170431,-0.255503,-0.029994,0.06492,-0.170134,...,-0.429335,-0.250167,-0.280226,0.294098,0.124096,-0.13024,-2.366274,0.384989,-0.20441,-0.120609
PRISM|7150,0.481692,0.075952,0.247671,0.231047,-0.770289,-0.181725,-0.251004,-0.064253,0.048593,-0.170349,...,-0.417299,-0.260405,-0.290031,0.304461,0.098469,-0.096384,-2.318566,0.387278,-0.216182,-0.139882
PRISM|7153,0.509081,0.053121,0.271979,0.215051,-0.742012,-0.153504,-0.227382,-0.00788,0.091546,-0.177813,...,-0.422414,-0.256328,-0.301447,0.26876,0.109144,-0.119116,-2.376013,0.39984,-0.204853,-0.154715
PRISM|7184,0.495793,0.069895,0.260924,0.212337,-0.760238,-0.166215,-0.250357,-0.034988,0.055714,-0.178719,...,-0.422994,-0.256692,-0.28935,0.282124,0.112781,-0.118279,-2.369301,0.406119,-0.207982,-0.134576


In [None]:
sample_emb_df.to_csv("/home/sample_emb_df.csv")
df_abund_samples.to_csv("/home/df_abund_samples.csv")

# Try classification model

In [None]:
is_prism = df_abund_samples.index.str.startswith("PRISM")
is_valid = df_abund_samples.index.str.startswith("Validation")

is_prism.sum(), is_valid.sum()

(np.int64(155), np.int64(65))

In [None]:
X_prism = sample_emb_df.loc[is_prism].values
X_valid = sample_emb_df.loc[is_valid].values

y_prism_multi  = df_abund_samples.loc[is_prism, "Diagnosis"].values
y_valid_multi  = df_abund_samples.loc[is_valid, "Diagnosis"].values

y_prism_binary = (y_prism_multi != "Control").astype(int)
y_valid_binary = (y_valid_multi != "Control").astype(int)

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report

scaler_multi = StandardScaler()
X_prism_scaled = scaler_multi.fit_transform(X_prism)
X_valid_scaled = scaler_multi.transform(X_valid)

clf_multi = LogisticRegression(multi_class="multinomial",max_iter=1000, C=0.1, penalty='l2')
clf_multi.fit(X_prism_scaled, y_prism_multi)

# on validation
y_valid_pred = clf_multi.predict(X_valid_scaled)
acc_valid = accuracy_score(y_valid_multi, y_valid_pred)

print("Validation Multiclass Accuracy (UC vs CD vs Control):", acc_valid)
print(classification_report(y_valid_multi, y_valid_pred))


Validation Multiclass Accuracy (UC vs CD vs Control): 0.5384615384615384
              precision    recall  f1-score   support

          CD       0.60      0.60      0.60        20
     Control       0.49      0.77      0.60        22
          UC       0.60      0.26      0.36        23

    accuracy                           0.54        65
   macro avg       0.56      0.54      0.52        65
weighted avg       0.56      0.54      0.52        65





In [None]:
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)

acc_list = []
for train_idx, test_idx in skf.split(X_prism, y_prism_multi):
    X_tr, X_te = X_prism[train_idx], X_prism[test_idx]
    y_tr, y_te = y_prism_multi[train_idx], y_prism_multi[test_idx]

    scaler = StandardScaler()
    X_tr_s = scaler.fit_transform(X_tr)
    X_te_s = scaler.transform(X_te)

    clf = LogisticRegression(multi_class="multinomial", max_iter=1000)
    clf.fit(X_tr_s, y_tr)
    y_te_pred = clf.predict(X_te_s)
    acc_list.append(accuracy_score(y_te, y_te_pred))

print("within PRISM 5-fold multi-classification Accuracy: {:.2f} ± {:.2f}".format(
    np.mean(acc_list), np.std(acc_list)
))



PRISM 内部 5-fold 多分类 Accuracy: 0.51 ± 0.09




In [None]:
from sklearn.metrics import roc_auc_score

scaler_bin = StandardScaler()
X_prism_scaled_bin = scaler_bin.fit_transform(X_prism)
X_valid_scaled_bin = scaler_bin.transform(X_valid)

clf_bin = LogisticRegression(max_iter=1000)
clf_bin.fit(X_prism_scaled_bin, y_prism_binary)

# on validation
y_valid_proba = clf_bin.predict_proba(X_valid_scaled_bin)[:, 1]
auc_valid = roc_auc_score(y_valid_binary, y_valid_proba)

print("Validation Binary AUC (IBD vs Control):", auc_valid)


Validation Binary AUC (IBD vs Control): 0.693446088794926


# mix validation & prism

In [None]:
y_multi = df_abund_samples["Diagnosis"].values
y_binary = (y_multi != "Control").astype(int)
X = sample_emb_df.values


In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# acc_list = []

# for fold, (train_idx, test_idx) in enumerate(skf.split(X, y_multi), start=1):
#     X_train, X_test = X[train_idx], X[test_idx]
#     y_train, y_test = y_multi[train_idx], y_multi[test_idx]

#     scaler = StandardScaler()
#     X_train_scaled = scaler.fit_transform(X_train)
#     X_test_scaled  = scaler.transform(X_test)

#     clf = LogisticRegression(multi_class="multinomial",max_iter=1000)
#     clf.fit(X_train_scaled, y_train)

#     y_pred = clf.predict(X_test_scaled)
#     acc = accuracy_score(y_test, y_pred)
#     acc_list.append(acc)

# acc_array = np.array(acc_list)
# print("Multiclass Accuracy mean ± std: {:.2f} ± {:.2f}".format(acc_array.mean(), acc_array.std()))

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import cross_val_score
le = LabelEncoder()
y_multi = le.fit_transform(y_multi)
print("Classes:", le.classes_)

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)

clf_multi = LogisticRegression(
    penalty="l2",
    C=0.1,
    max_iter=1000,
    multi_class="multinomial"
)

acc_multi = cross_val_score(clf_multi, X, y_multi, cv=cv, scoring="accuracy")
print("Multiclass accuracy:", acc_multi.mean(), "±", acc_multi.std())

Classes: ['CD' 'Control' 'UC']




Multiclass accuracy: 0.4818181818181818 ± 0.033402132856134234




In [None]:
from sklearn.metrics import roc_auc_score

# auc_list = []

# for fold, (train_idx, test_idx) in enumerate(skf.split(X, y_binary), start=1):
#     X_train, X_test = X[train_idx], X[test_idx]
#     y_train, y_test = y_binary[train_idx], y_binary[test_idx]

#     scaler = StandardScaler()
#     X_train_scaled = scaler.fit_transform(X_train)
#     X_test_scaled  = scaler.transform(X_test)

#     clf = LogisticRegression(max_iter=1000)
#     clf.fit(X_train_scaled, y_train)

#     y_proba = clf.predict_proba(X_test_scaled)[:, 1]
#     auc = roc_auc_score(y_test, y_proba)
#     auc_list.append(auc)

# auc_array = np.array(auc_list)
# print("Binary AUC mean ± std: {:.2f} ± {:.2f}".format(auc_array.mean(), auc_array.std()))

clf_bin = LogisticRegression(
    penalty="l2",
    C=0.1,
    max_iter=1000
)

auc_bin = cross_val_score(clf_bin, X, y_binary, cv=cv, scoring="roc_auc")
print("Binary AUC:", auc_bin.mean(), "±", auc_bin.std())

Binary AUC: 0.8003572658402204 ± 0.056546249700863895


In [None]:
sample_emb_df

Unnamed: 0,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
PRISM|7122,0.503727,0.071121,0.262210,0.208107,-0.750364,-0.162051,-0.231517,-0.044882,0.066673,-0.186569,...,-0.412470,-0.263849,-0.299218,0.275538,0.099520,-0.097561,-2.345192,0.402813,-0.203259,-0.167115
PRISM|7147,0.484642,0.089444,0.267869,0.230851,-0.757682,-0.170431,-0.255503,-0.029994,0.064920,-0.170134,...,-0.429335,-0.250167,-0.280226,0.294098,0.124096,-0.130240,-2.366274,0.384989,-0.204410,-0.120609
PRISM|7150,0.481692,0.075952,0.247671,0.231047,-0.770289,-0.181725,-0.251004,-0.064253,0.048593,-0.170349,...,-0.417299,-0.260405,-0.290031,0.304461,0.098469,-0.096384,-2.318566,0.387278,-0.216182,-0.139882
PRISM|7153,0.509081,0.053121,0.271979,0.215051,-0.742012,-0.153504,-0.227382,-0.007880,0.091546,-0.177813,...,-0.422414,-0.256328,-0.301447,0.268760,0.109144,-0.119116,-2.376013,0.399840,-0.204853,-0.154715
PRISM|7184,0.495793,0.069895,0.260924,0.212337,-0.760238,-0.166215,-0.250357,-0.034988,0.055714,-0.178719,...,-0.422994,-0.256692,-0.289350,0.282124,0.112781,-0.118279,-2.369301,0.406119,-0.207982,-0.134576
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Validation|UMCGIBD00593,0.502575,0.076741,0.275264,0.219681,-0.759462,-0.178421,-0.238460,-0.050786,0.047305,-0.196069,...,-0.420964,-0.257025,-0.290962,0.289705,0.105019,-0.112961,-2.350529,0.396957,-0.213745,-0.147743
Validation|UMCGIBD00233,0.482082,0.077421,0.253952,0.230179,-0.769524,-0.182850,-0.245883,-0.053507,0.053351,-0.186225,...,-0.425728,-0.264030,-0.296729,0.298296,0.105425,-0.101481,-2.332845,0.379583,-0.217140,-0.137309
Validation|UMCGIBD00238,0.499768,0.064344,0.253247,0.211145,-0.771870,-0.186956,-0.244779,-0.017791,0.068735,-0.170680,...,-0.426691,-0.254302,-0.288724,0.278089,0.109093,-0.120586,-2.348456,0.390866,-0.216875,-0.137928
Validation|UMCGIBD00027,0.473860,0.091007,0.230490,0.233488,-0.772370,-0.186495,-0.248090,-0.046487,0.049006,-0.187249,...,-0.435052,-0.258990,-0.278011,0.299739,0.106933,-0.106292,-2.339369,0.361394,-0.215707,-0.130801
