In [None]:
import sys
sys.path.append('../../apps/SONIA')
sys.path.append('../../apps/OLGA')
sys.path.append('../../apps/soNNia')

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics
from os.path import exists

%matplotlib inline

In [None]:
from sonnia.sonnia import SoNNia

In [None]:
# settings
pep = "NLVPMVATV"

In [None]:
# load and prepare emerson sequences for SONIA
train_seqs_df = pd.read_csv('./train_train_WithoutDuplicates_aligned_20.csv.gz')
t_seqs = train_seqs_df['amino_acid'].to_list()
t_v = train_seqs_df['v_gene'].to_list()
t_j = train_seqs_df['j_gene'].to_list()
sonia_input_emerson = [list(a) for a in zip(t_seqs, t_v, t_j)]

# select subset of 10^6 seqs
n_max = 1000000
raninds = np.arange(len(sonia_input_emerson))
rng = np.random.default_rng(2021)
rng.shuffle(raninds)
raninds = raninds[:n_max]
sonia_input_emerson_1e6 = list(np.array(sonia_input_emerson)[raninds])

In [None]:
# load and prepare peptide-specific seqs for SONIA
vdgdb_df = pd.read_csv('./' + pep + '/VDJdb_' + pep + '_WithAligned20.csv')
vdgdb_df = vdgdb_df.drop_duplicates().reset_index(drop=True)
t_seqs = vdgdb_df['CDR3_beta'].to_list()
t_v = vdgdb_df['TRBV_gene'].to_list()
t_j = vdgdb_df['TRBJ_gene'].to_list()
sonia_input_vdgdb = [list(a) for a in zip(t_seqs, t_v, t_j)]

In [None]:
# load and prepare second set of Emerson seqs for SONIA (to be used as negative)
filename_cdr3raw = './train_data_1.txt' 
inds_non_overlap = np.loadtxt('./1_inds_nonoverlap_0.txt').astype(np.int16)
t_seq0 = []
with open(filename_cdr3raw) as f:
    for line in f:
        linesplit = line.strip().split('\n')
        t_seq0.append(linesplit[0])

t_seq = [x.split('\t') for x in np.array(t_seq0)[inds_non_overlap]]

In [None]:
len(sonia_input_emerson_1e6) / int(80*len(sonia_input_vdgdb)/100)

274.8763056624519

In [None]:
# settings2 
l2_fin = 0
epo = 30
bs = 10000

In [None]:
# main computation: AUROCs

if exists('./' + pep + '/AUROCs.txt'):
    AUROCs = np.loadtxt('./' + pep + '/AUROCs.txt')
else:
    AUROCs = []
for i in range(50):
    repl = i
    
    ## prepare positives (train, test) ##
    path_o ='./' + pep + '/indices/index_permutation_repl' + str(repl) + '.txt'
    full_intR = (np.loadtxt(path_o)).astype(np.int16)
    data = [sonia_input_vdgdb[t] for t in full_intR]
    train_data = data[:int(80*len(data)/100)]
    val_data = data[int(80*len(data)/100):]

    ## prepare negatives (test) ##
    path_o ='./' + pep + '/indices/index_permutationN_repl' + str(repl) + '.txt'
    full_intR = (np.loadtxt(path_o)).astype(np.int16)
    val_dataN0 = [t_seq[t] for t in full_intR]
    val_dataN = val_dataN0[:len(val_data)]
    
    ## train model ##
    qm = SoNNia(data_seqs = train_data, 
            gen_seqs = sonia_input_emerson_1e6,
            l2_reg = l2_fin,
            deep=False, include_joint_genes=True, include_indep_genes=False,
            )
    qm.infer_selection(epochs = epo, batch_size = bs, validation_split=0.01, verbose=0)
    
    ## check for nans ##
    t_min = np.min(qm.likelihood_train)
    if np.isnan(t_min):
        print("ERROR: nan obtained! Try to have a larger minibatch to prevent this...")
        AUROCs = np.append(AUROCs, append(-1))
    
    ## compute AUROC ##
    LR_vdgdbn = [qm.find_seq_features(x) for x in val_data]
    LR_emerson = [qm.find_seq_features(x) for x in val_dataN]
    scores_positive = - qm.compute_energy(LR_vdgdbn)
    scores_negative = - qm.compute_energy(LR_emerson)    
    labels = np.hstack((np.zeros((len(scores_negative))), np.ones((len(scores_positive))))) 
    scores = np.hstack((scores_negative, scores_positive))
    fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
    AUROCs = np.append(AUROCs, metrics.auc(fpr, tpr))
    
    # save resulting AUROC file (this rewrites the file completely, but it is a short file so no problem...)
    np.savetxt('./' + pep + '/AUROCs.txt', AUROCs)
    
    # save resulting positives_scores
    np.savetxt('./' + pep + '/scores_positives_' + str(i) + '.txt', scores_positive)
    
    # save resulting negatives_scores
    np.savetxt('./' + pep + '/scores_negatives_' + str(i) + '.txt', scores_negative)

 44%|████▍     | 1601/3638 [00:00<00:00, 16007.57it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16035.85it/s]
  0%|          | 2693/1000000 [00:00<01:14, 13459.21it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:21<00:00, 12340.89it/s]
 18%|█▊        | 654/3638 [00:00<00:00, 6535.24it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 9913.21it/s] 
  0%|          | 998/1000000 [00:00<01:40, 9973.68it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:30<00:00, 11088.95it/s]
 45%|████▌     | 1639/3638 [00:00<00:00, 16379.63it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16248.08it/s]
  0%|          | 1314/1000000 [00:00<01:16, 13131.56it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:15<00:00, 13265.83it/s]
 67%|██████▋   | 2444/3638 [00:00<00:00, 12202.81it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 12164.82it/s]
  0%|          | 2036/1000000 [00:00<01:37, 10186.66it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:38<00:00, 10106.20it/s]
 45%|████▌     | 1654/3638 [00:00<00:00, 16531.07it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16453.55it/s]
  0%|          | 2670/1000000 [00:00<01:14, 13358.89it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:19<00:00, 12596.28it/s]
 13%|█▎        | 469/3638 [00:00<00:00, 4686.08it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 11022.33it/s]
  0%|          | 2630/1000000 [00:00<01:16, 13087.06it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:18<00:00, 12692.43it/s]
 20%|██        | 743/3638 [00:00<00:00, 7424.33it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 10347.68it/s]
  0%|          | 2011/1000000 [00:00<01:39, 10055.62it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:15<00:00, 13305.59it/s]
 21%|██        | 749/3638 [00:00<00:00, 7481.82it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 10107.84it/s]
  0%|          | 1341/1000000 [00:00<01:14, 13403.05it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:22<00:00, 12050.48it/s]
 17%|█▋        | 630/3638 [00:00<00:00, 6294.57it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 9475.77it/s]
  0%|          | 613/1000000 [00:00<02:43, 6124.30it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:30<00:00, 11087.54it/s]
 34%|███▍      | 1239/3638 [00:00<00:00, 12384.34it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 11997.82it/s]
  0%|          | 2654/1000000 [00:00<01:15, 13267.16it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:15<00:00, 13251.47it/s]
 44%|████▍     | 1594/3638 [00:00<00:00, 15933.94it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16097.34it/s]
  0%|          | 1331/1000000 [00:00<01:15, 13305.29it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:14<00:00, 13347.71it/s]
 18%|█▊        | 650/3638 [00:00<00:00, 6498.94it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 9890.33it/s] 
  0%|          | 1010/1000000 [00:00<01:38, 10099.31it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:25<00:00, 11725.53it/s]
 42%|████▏     | 1524/3638 [00:00<00:00, 15237.87it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 15804.28it/s]
  0%|          | 2694/1000000 [00:00<01:14, 13449.18it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:35<00:00, 10420.60it/s]
 43%|████▎     | 1550/3638 [00:00<00:00, 15497.10it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 15923.29it/s]
  0%|          | 2636/1000000 [00:00<01:15, 13205.19it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:20<00:00, 12426.41it/s]
 72%|███████▏  | 2603/3638 [00:00<00:00, 13395.56it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 13722.10it/s]
  0%|          | 2666/1000000 [00:00<01:14, 13333.24it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:22<00:00, 12095.96it/s]
 20%|██        | 737/3638 [00:00<00:00, 7364.00it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 10486.96it/s]
  0%|          | 2103/1000000 [00:00<01:34, 10577.84it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:19<00:00, 12623.08it/s]
 21%|██        | 746/3638 [00:00<00:00, 7455.33it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 8524.56it/s]
  0%|          | 2013/1000000 [00:00<01:39, 10061.20it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:21<00:00, 12260.59it/s]
 44%|████▍     | 1610/3638 [00:00<00:00, 16090.50it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16009.40it/s]
  0%|          | 1331/1000000 [00:00<01:15, 13304.08it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:15<00:00, 13290.68it/s]
 45%|████▌     | 1649/3638 [00:00<00:00, 16486.05it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16205.92it/s]
  0%|          | 1339/1000000 [00:00<01:14, 13387.01it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:15<00:00, 13165.51it/s]
 20%|██        | 735/3638 [00:00<00:00, 7343.48it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 10396.29it/s]
  0%|          | 2042/1000000 [00:00<01:38, 10173.28it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:14<00:00, 13418.92it/s]
 45%|████▌     | 1639/3638 [00:00<00:00, 16383.88it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16268.90it/s]
  0%|          | 2663/1000000 [00:00<01:14, 13308.85it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:21<00:00, 12262.91it/s]
 38%|███▊      | 1400/3638 [00:00<00:00, 7463.91it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 10225.40it/s]
  0%|          | 2567/1000000 [00:00<01:17, 12892.73it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:15<00:00, 13254.03it/s]
 45%|████▍     | 1624/3638 [00:00<00:00, 16234.52it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16096.54it/s]
  0%|          | 2671/1000000 [00:00<01:14, 13362.85it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:23<00:00, 11941.67it/s]
 77%|███████▋  | 2797/3638 [00:00<00:00, 14360.99it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 14408.50it/s]
  0%|          | 1327/1000000 [00:00<01:15, 13263.59it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:21<00:00, 12341.25it/s]
 44%|████▎     | 1583/3638 [00:00<00:00, 15829.41it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 15974.54it/s]
  0%|          | 1359/1000000 [00:00<01:13, 13584.80it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:18<00:00, 12766.49it/s]
 21%|██        | 764/3638 [00:00<00:00, 7633.60it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 8926.18it/s]
  0%|          | 1005/1000000 [00:00<01:39, 10040.70it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:45<00:00, 9444.66it/s]
 19%|█▊        | 681/3638 [00:00<00:00, 6805.46it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 9450.78it/s]
  0%|          | 1009/1000000 [00:00<01:39, 10082.87it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:19<00:00, 12534.48it/s]
 45%|████▍     | 1635/3638 [00:00<00:00, 16344.48it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16122.34it/s]
  0%|          | 2638/1000000 [00:00<01:15, 13133.75it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:21<00:00, 12297.20it/s]
 21%|██        | 746/3638 [00:00<00:00, 7452.15it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 10424.88it/s]
  0%|          | 1010/1000000 [00:00<01:38, 10099.26it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:15<00:00, 13250.39it/s]
 88%|████████▊ | 3194/3638 [00:00<00:00, 15955.66it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 15910.04it/s]
  0%|          | 2637/1000000 [00:00<01:15, 13186.84it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:20<00:00, 12417.33it/s]
 30%|██▉       | 1075/3638 [00:00<00:00, 10748.37it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 13928.27it/s]
  0%|          | 1352/1000000 [00:00<01:13, 13514.28it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:14<00:00, 13360.14it/s]
 45%|████▍     | 1632/3638 [00:00<00:00, 16310.99it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16212.43it/s]
  0%|          | 1354/1000000 [00:00<01:13, 13537.88it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:14<00:00, 13344.67it/s]
 45%|████▌     | 1653/3638 [00:00<00:00, 16526.55it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 14506.71it/s]
  0%|          | 1360/1000000 [00:00<01:13, 13595.28it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:14<00:00, 13490.86it/s]
 45%|████▍     | 1636/3638 [00:00<00:00, 16350.66it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16317.51it/s]
  0%|          | 2696/1000000 [00:00<01:13, 13489.43it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:21<00:00, 12263.37it/s]
 44%|████▍     | 1616/3638 [00:00<00:00, 16154.82it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 16111.31it/s]
  0%|          | 1334/1000000 [00:00<01:14, 13334.93it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:16<00:00, 12989.80it/s]
 84%|████████▍ | 3055/3638 [00:00<00:00, 15425.46it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 15384.26it/s]
  0%|          | 2718/1000000 [00:00<01:13, 13566.05it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:14<00:00, 13422.00it/s]
 20%|██        | 733/3638 [00:00<00:00, 7328.89it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 10326.07it/s]
  0%|          | 1982/1000000 [00:00<01:40, 9918.35it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:18<00:00, 12659.54it/s]
 33%|███▎      | 1209/3638 [00:00<00:00, 12087.10it/s]

Encode data.


100%|██████████| 3638/3638 [00:00<00:00, 14431.19it/s]
  0%|          | 1354/1000000 [00:00<01:13, 13538.43it/s]

Encode gen.


100%|██████████| 1000000/1000000 [01:26<00:00, 11625.54it/s]
