## Train SOMA teachers

In [1]:
from SPICE import Soma
from SPICE import Melange
import torch
import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
df = pd.read_csv('Melange_example_input.csv')
df.head()

Unnamed: 0,index_offset,psi,seq
0,ENSG00000000003.15;TSPAN6;chrX-100632484-10063...,7.496831,CTTCGACACCGAGCTCGATATGATCGAAGTATTTATTACCATAAAG...
1,ENSG00000000003.15;TSPAN6;chrX-100633930-10063...,9.633673,GCTTCGACACCGAGCTCGTCGAGAACTTATTTGACCTGAAACCAAA...
2,ENSG00000000003.15;TSPAN6;chrX-100635177-10063...,1.012797,GCTTCGACACCGAGCTCGAGACGACCATTATTTTTTCTTTGACTCC...
3,ENSG00000000419.14;DPM1;chr20-50945736-5094576...,2.55393,TGAGATTGAATCCAGGAAATGAAGCTTCGACACCGAGCTCGTTAGC...
4,ENSG00000000419.14;DPM1;chr20-50948628-5094866...,-2.114327,CTTCGACACCGAGCTCGGTGCAACTATATTTCTATTAAAGTGAGTA...


## Train SOMA teachers

In [3]:
num_teachers = 5
Soma.train(df, device='cuda', epochs=5, batch_size=512, learning_rate=1e-4, num_seeds=num_teachers)

Training: 100%|██████████| 5/5 [00:55<00:00, 11.11s/epoch, train_loss=20.3511]


Training completed for seed 0. Saving model...


Training: 100%|██████████| 5/5 [00:54<00:00, 11.00s/epoch, train_loss=20.2293]


Training completed for seed 1. Saving model...


Training: 100%|██████████| 5/5 [00:54<00:00, 10.84s/epoch, train_loss=20.4927]


Training completed for seed 2. Saving model...


Training: 100%|██████████| 5/5 [00:53<00:00, 10.68s/epoch, train_loss=20.3390]


Training completed for seed 3. Saving model...


Training: 100%|██████████| 5/5 [00:53<00:00, 10.79s/epoch, train_loss=20.1764]

Training completed for seed 4. Saving model...





## Train Melange Student

In [9]:
Melange.train(
    df,
    device=device,
    max_len=250,
    lambda_cls=5.0,
    epochs=10,
    clf_ids=[0,1,2,3],
    mode='PSI1to0',
    save_path='Melange_params.pth'
)

Training: 100%|██████████| 10/10 [10:58<00:00, 65.86s/epoch, loss=6.7258, recon=3.7845, cls=0.5876, kl=0.6251]

Training completed.





In [10]:
Melange.reconstruct_sequence(
    df.iloc[0]['seq'],
    max_len=250,
    device=device,
    params='Melange_params.pth'
)

Original     : CTTCGACACCGAGCTCGATATGATCGAAGTATTTATTACCATAAAGAAAAGCACAGGCTGCTTGTGCTGTATTTAATCTTTGTTTTTTTCCTCCCATTAGGGTTGTTTTATAAAGGTGATGACCATTATAGAGTCAGAAATGGGAGTCGTTGCAGGAATTTCCTTTGGAGTTGCTTGCTTCCAAGTAAGTTTTTGTAGTTACTTAGGAAATATTTCATCCCTCTTGTAGGTGTGCAGCCATCTAAGTTTC
Reconstructed: GGTGGGAGGGGGGGTTGGGGTAGTGGAAGGGGGGGGTGAAAGAGGGGGGGGGAAGAAAGGGTGGGGGTGAAGGGGAAAGGGGGGGGGATGAGGAGAGGAGAGGGGGGGGGGGAGGGGGAGGGATTGGGAAGGGGGGGGGAGGAGAGGGGGGGCTGAAAGGGAAGGGGGAGGAGGGGGAGGGGGGGTAGGGGAATGGAGCGACTGGGGAGGGGGGGGGGGGGGGGGGGGGGGGTGGGGGGGGATTCGGGGG
Match        : ✗✗✓✗✓✗✗✗✗✗✓✗✓✗✓✗✓✗✗✗✓✗✗✓✗✓✓✓✓✗✗✗✗✗✗✗✓✗✗✗✓✗✓✗✗✓✗✗✗✗✓✗✓✗✗✗✗✗✗✓✗✓✗✓✗✓✗✓✓✗✓✗✗✗✗✓✗✗✗✗✗✓✗✗✗✗✗✗✓✗✗✗✗✗✗✓✗✗✓✓✗✓✗✗✓✗✗✗✗✗✗✗✓✗✓✓✗✓✓✗✓✗✗✗✗✗✗✗✗✓✓✗✓✗✗✗✓✗✗✓✗✓✗✓✓✓✗✗✓✗✗✓✓✗✓✗✓✓✗✗✗✗✗✗✗✗✓✓✓✓✗✗✓✗✗✗✓✗✗✗✗✗✗✗✓✓✓✗✓✗✗✗✗✓✓✗✓✓✗✗✓✓✓✗✗✓✓✓✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✗✓✗✗✓✓✗✓✓✓✗✗✓✗✗✗✗✗✓✗✗✓✗✗✗✗


In [11]:
df_test = df[df['psi']>3.17].sample(1000, random_state=0).reset_index(drop=True)
gen_seq_psi_pred, org_seq_psi_pred = Melange.evaluate_reconstructions(
    df_test,
    clf_id=4,
    device=device,
    max_len=250,
)

Mean prediction of original sequences from the classifier: 4.3359
Mean prediction of generated sequences from the classifier: -1.6584
