## Train SOMA teachers

In [None]:
from SPICE import Soma
from SPICE import Melange
import torch
import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
df = pd.read_csv('Melange_example_input.csv')
df.head()

Unnamed: 0,index_offset,psi,seq
0,ENSG00000000003.15;TSPAN6;chrX-100632484-10063...,7.496831,CTTCGACACCGAGCTCGATATGATCGAAGTATTTATTACCATAAAG...
1,ENSG00000000003.15;TSPAN6;chrX-100633930-10063...,9.633673,GCTTCGACACCGAGCTCGTCGAGAACTTATTTGACCTGAAACCAAA...
2,ENSG00000000003.15;TSPAN6;chrX-100635177-10063...,1.012797,GCTTCGACACCGAGCTCGAGACGACCATTATTTTTTCTTTGACTCC...
3,ENSG00000000419.14;DPM1;chr20-50945736-5094576...,2.55393,TGAGATTGAATCCAGGAAATGAAGCTTCGACACCGAGCTCGTTAGC...
4,ENSG00000000419.14;DPM1;chr20-50948628-5094866...,-2.114327,CTTCGACACCGAGCTCGGTGCAACTATATTTCTATTAAAGTGAGTA...


## Train SOMA teachers

In [3]:
num_teachers = 5
Soma.train(df, device=device, epochs=20, batch_size=512, learning_rate=1e-4, num_seeds=num_teachers)

Training: 100%|██████████| 20/20 [03:36<00:00, 10.83s/epoch, train_loss=13.6947]


Training completed for seed 0. Saving model...


Training: 100%|██████████| 20/20 [03:36<00:00, 10.82s/epoch, train_loss=13.9914]


Training completed for seed 1. Saving model...


Training: 100%|██████████| 20/20 [03:35<00:00, 10.76s/epoch, train_loss=13.5696]


Training completed for seed 2. Saving model...


Training: 100%|██████████| 20/20 [03:37<00:00, 10.86s/epoch, train_loss=13.6409]


Training completed for seed 3. Saving model...


Training: 100%|██████████| 20/20 [03:38<00:00, 10.91s/epoch, train_loss=13.4340]

Training completed for seed 4. Saving model...





## Train Melange Student

In [3]:
Melange.train(
    df,
    device=device,
    max_len=250,
    lambda_cls=5.0,
    epochs=1,
    clf_ids=[0,1,2,3],
    mode='PSI0to1',
    save_path='Melange_params.pth'
)

Training: 100%|██████████| 1/1 [01:11<00:00, 71.44s/epoch, loss=271.4958, recon=1.1808, cls=54.0626, kl=0.4414]

Training completed.





In [4]:
Melange.reconstruct_sequence(
    df.iloc[0]['seq'],
    max_len=250,
    device=device,
    params='Melange_params.pth'
)

Original     : CTTCGACACCGAGCTCGATATGATCGAAGTATTTATTACCATAAAGAAAAGCACAGGCTGCTTGTGCTGTATTTAATCTTTGTTTTTTTCCTCCCATTAGGGTTGTTTTATAAAGGTGATGACCATTATAGAGTCAGAAATGGGAGTCGTTGCAGGAATTTCCTTTGGAGTTGCTTGCTTCCAAGTAAGTTTTTGTAGTTACTTAGGAAATATTTCATCCCTCTTGTAGGTGTGCAGCCATCTAAGTTTC
Reconstructed: AAAATAAAAAAAAAGAAAAAAGAAAGAAAAAAAAATAAAAAAAAATAAATATAATATTATTAATAATTATATATAATTTTTTAGATAGAAGGAGGAGGAGGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGGAGAAGAAGAGAGAGGAGAAGAGGAGAAGAAGAGAAGAAGAAGAAGAAGAGGTGAAGGTAAGTATAGATTGGTTGTTTTTTTTTTTTAGTAATTTTTTTTTGTAAAATATTTTAATATT
Match        : ✗✗✗✗✗✓✗✓✗✗✗✓✗✗✗✗✗✓✗✓✗✓✓✗✗✓✓✓✗✗✓✗✗✗✓✓✗✓✗✗✓✗✓✓✓✗✓✓✓✗✗✗✓✗✗✗✗✗✗✗✗✗✗✗✗✗✗✓✗✓✓✓✗✓✓✓✓✗✓✓✓✗✗✗✗✓✗✗✗✗✗✗✗✗✗✓✗✗✓✓✓✗✗✗✗✗✗✗✗✗✗✓✗✓✗✓✗✗✗✗✗✗✗✗✗✗✗✗✗✓✓✗✗✗✗✓✓✓✓✗✗✓✗✓✓✓✗✗✓✗✗✓✗✗✓✗✗✓✗✗✗✗✗✗✗✗✗✓✓✗✗✗✗✗✗✗✓✗✗✗✗✗✗✗✓✗✓✗✓✓✗✗✗✓✗✓✓✓✗✓✗✗✗✓✗✓✗✗✗✗✓✗✓✓✓✗✗✓✗✗✗✗✗✓✓✗✓✗✗✗✓✗✗✗✗✓✗✗✗✓✓✗✓✗✓✗✓✗✓✗


In [5]:
df_test = df[df['psi']<-3.17].sample(500, random_state=0).reset_index(drop=True)
gen_seq_psi_pred, org_seq_psi_pred = Melange.evaluate_reconstructions(
    df_test,
    clf_id=4,
    device=device,
    max_len=250,
    params='Melange_params.pth'
)

100%|██████████| 500/500 [00:06<00:00, 77.61it/s]

Mean prediction of original sequences from the classifier: -0.2656
Mean prediction of generated sequences from the classifier: 9.7886



