### Train with GEX data

In [1]:
import pandas as pd
import numpy as np
from SPICE import Soma

GEX_dict = np.load('GEX_dict.npy', allow_pickle=True).item()
df = pd.read_csv('SOMA_GEX_example_input.csv')
df.head()

Unnamed: 0,index_offset,seq,769P,786O,8MGBA,A172,A375,ACHN,CAL120,COGN278,...,SF126,SKNAS,SNU398,SNU423,SNU449,SNUC4,T47D,TOV21G,U251MG,VMRCRCZ
0,ENSG00000000003.15;TSPAN6;chrX-100632484-10063...,CTTCGACACCGAGCTCGATATGATCGAAGTATTTATTACCATAAAG...,5.855052,6.95565,4.786596,5.066832,4.343257,5.347252,6.247928,5.17289,...,9.967226,4.703436,5.544321,3.755662,9.967226,9.967226,6.139551,9.967226,9.967226,9.967226
1,ENSG00000000003.15;TSPAN6;chrX-100633930-10063...,GCTTCGACACCGAGCTCGTCGAGAACTTATTTGACCTGAAACCAAA...,4.160823,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,...,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226
2,ENSG00000000003.15;TSPAN6;chrX-100635177-10063...,GCTTCGACACCGAGCTCGAGACGACCATTATTTTTTCTTTGACTCC...,,1.069751,-0.060542,-0.147135,-0.50373,1.844684,,1.994607,...,-0.095157,0.997839,1.779734,,-1.2868,1.089583,-0.147135,0.678072,0.617465,0.828326
3,ENSG00000000419.14;DPM1;chr20-50945736-5094576...,TGAGATTGAATCCAGGAAATGAAGCTTCGACACCGAGCTCGTTAGC...,1.252026,1.712718,3.99694,2.492914,-0.233995,2.835563,1.67278,4.786596,...,2.632603,4.829909,1.633412,4.160823,2.093702,2.2488,3.80119,1.163165,0.794913,1.641242
4,ENSG00000000419.14;DPM1;chr20-50948628-5094866...,CTTCGACACCGAGCTCGGTGCAACTATATTTCTATTAAAGTGAGTA...,,,0.280483,,-3.321928,-1.581126,,,...,,-1.116179,,,,,-1.736966,-2.2488,,


In [2]:
GEX_dict

{'769P': array([2.82984956, 0.        , 6.35314683, ..., 0.        , 0.        ,
        0.        ]),
 '786O': array([2.44890095, 0.        , 7.18596563, ..., 0.        , 0.02856915,
        0.        ]),
 '8MGBA': array([4.37364821, 0.        , 6.96266469, ..., 0.07038933, 0.        ,
        0.02856915]),
 'A172': array([0.11103131, 0.        , 6.71383326, ..., 0.        , 0.04264434,
        0.36737107]),
 'A375': array([3.38543104, 0.        , 7.66341578, ..., 0.        , 0.        ,
        0.        ]),
 'ACHN': array([3.55213111, 0.        , 6.03584382, ..., 0.02856915, 0.02856915,
        0.        ]),
 'CAL120': array([4.34766566, 0.        , 6.99672769, ..., 0.        , 0.        ,
        0.        ]),
 'COGN278': array([5.16791987, 0.        , 6.55458885, ..., 0.        , 0.        ,
        0.31034012]),
 'COLO783': array([4.1268077 , 0.02856915, 6.92837032, ..., 0.        , 0.05658353,
        0.        ]),
 'DAOY': array([4.51032902, 0.        , 6.3554392 , ..., 0.     

## Held-out sequences

In [3]:
# split into train and test
df_train = df.sample(frac=0.8, random_state=42)
df_test = df.drop(df_train.index)

In [None]:
Soma.fintune_with_gex(
    df_train,
    GEX_dict,
    device='cuda',
    pretrained_params='SOMA_params_seed_0.pth',
    epochs=1,
    batch_size=256,
)

Epoch 1/1:  10%|█         | 578/5715 [01:16<10:15,  8.34batch/s, loss=14.0022]

In [None]:
# spearman
from scipy.stats import spearmanr


pred_psi, org_psi = Soma.predict_with_gex(
    df_test,
    GEX_dict,
    device='cuda',
    batch_size=512,
    pretrained_params='SOMA_params_seed_0.pth',
    pretrained_GEX_params='SOMA_with_GEX_params.pth',
)

rho, pval = spearmanr(org_psi, pred_psi)
print(f"Spearman ρ={rho:.4f}, p-value={pval:.2e}")

100%|██████████| 72/72 [00:09<00:00,  7.48it/s]

Spearman ρ=0.6671, p-value=0.00e+00





## Held-out cell lines

In [7]:
df

Unnamed: 0,index_offset,seq,769P,786O,8MGBA,A172,A375,ACHN,CAL120,COGN278,...,SF126,SKNAS,SNU398,SNU423,SNU449,SNUC4,T47D,TOV21G,U251MG,VMRCRCZ
0,ENSG00000000003.15;TSPAN6;chrX-100632484-10063...,CTTCGACACCGAGCTCGATATGATCGAAGTATTTATTACCATAAAG...,5.855052,6.955650,4.786596,5.066832,4.343257,5.347252,6.247928,5.172890,...,9.967226,4.703436,5.544321,3.755662,9.967226,9.967226,6.139551,9.967226,9.967226,9.967226
1,ENSG00000000003.15;TSPAN6;chrX-100633930-10063...,GCTTCGACACCGAGCTCGTCGAGAACTTATTTGACCTGAAACCAAA...,4.160823,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,...,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226,9.967226
2,ENSG00000000003.15;TSPAN6;chrX-100635177-10063...,GCTTCGACACCGAGCTCGAGACGACCATTATTTTTTCTTTGACTCC...,,1.069751,-0.060542,-0.147135,-0.503730,1.844684,,1.994607,...,-0.095157,0.997839,1.779734,,-1.286800,1.089583,-0.147135,0.678072,0.617465,0.828326
3,ENSG00000000419.14;DPM1;chr20-50945736-5094576...,TGAGATTGAATCCAGGAAATGAAGCTTCGACACCGAGCTCGTTAGC...,1.252026,1.712718,3.996940,2.492914,-0.233995,2.835563,1.672780,4.786596,...,2.632603,4.829909,1.633412,4.160823,2.093702,2.248800,3.801190,1.163165,0.794913,1.641242
4,ENSG00000000419.14;DPM1;chr20-50948628-5094866...,CTTCGACACCGAGCTCGGTGCAACTATATTTCTATTAAAGTGAGTA...,,,0.280483,,-3.321928,-1.581126,,,...,,-1.116179,,,,,-1.736966,-2.248800,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
43778,ENSG00000288710.1;RP11-386G11.12;chr12-4900546...,CTTCGACACCGAGCTCGACCCAACATGCCCAAACACTGTTCTTTTT...,-0.684164,-0.135578,-0.474069,-2.756957,-2.731449,0.118249,-0.714717,0.503730,...,-0.895303,-0.819826,0.014413,-1.203909,-2.862496,-0.794913,-2.822237,-2.822237,-0.971986,-0.129800
43779,ENSG00000288710.1;RP11-386G11.12;chr12-4902227...,CTTCGACACCGAGCTCGACTGCCTGAGTCTCCTACCTGATCCCACA...,-0.539463,0.895303,0.095157,-0.391524,-1.238212,1.116179,-0.828326,1.192349,...,1.328998,-0.060542,1.581126,-0.379787,-1.422312,0.462233,-0.901645,-0.491853,0.690262,0.456320
43780,ENSG00000288717.1;RP11-852E15.4;chr3-46000912-...,GCTTCGACACCGAGCTCGAGATGAAGGCAAGGTTAGGGGTATCCGT...,-0.309611,0.545434,0.426815,1.036911,-1.541097,0.521577,-0.228193,1.176697,...,0.665905,-0.158698,0.303781,-0.112475,-0.474069,0.770115,0.257222,0.794913,0.344648,2.142811
43781,ENSG00000288720.1;RP11-852E15.3;chr3-45812038-...,GCTTCGACACCGAGCTCGACTTAAATTGAAAAAGAAATCCAGCTTC...,0.708594,3.824428,1.328998,0.037475,1.844684,3.689610,1.300808,2.706269,...,1.688686,2.160257,1.939508,1.321928,2.141044,1.836238,-0.757757,2.131497,2.504197,2.341037


In [15]:
# sample 20% columns as held-out cell lines
cell_lines = list(GEX_dict.keys())
np.random.seed(42)
held_out = np.random.choice(cell_lines, size=int(0.2*len(cell_lines)), replace=False).tolist()
held_in = [c for c in cell_lines if c not in held_out]
df_train = df[held_in]
df_test = df[held_out]