In [2]:
import pandas as pd
import numpy as np
import re
import glob
from pathlib import Path
from captum.attr import IntegratedGradients
from NegativeClassOptimization.preprocessing import *
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from pathlib import Path
import torch
import torch.nn as nn
import pandas as pd
from captum.attr import IntegratedGradients
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics

In [3]:
from tqdm import tqdm

In [4]:
def compute_binary_metrics(y_test_pred, y_test_true) -> dict:
    acc = metrics.accuracy_score(y_true=y_test_true, y_pred=y_test_pred)
    recall= metrics.recall_score(y_true=y_test_true, y_pred=y_test_pred)
    prec = metrics.precision_score(y_true=y_test_true, y_pred=y_test_pred)
    f1 = metrics.f1_score(y_true=y_test_true, y_pred=y_test_pred)
    return [acc, recall, prec, f1]

In [6]:
from Bio import motifs
from Bio.Seq import Seq
from scipy.stats import entropy
from scipy.spatial.distance import jensenshannon

def get_pwm(slides_1, slides_2):
    # Create a list of Seq objects
    seqs_1 = [Seq(slide) for slide in slides_1]
    seqs_2 = [Seq(slide) for slide in slides_2]

    # Create a motifs instance
    m_1 = motifs.create(seqs_1, alphabet=config.AMINOACID_ALPHABET) # type: ignore
    m_2 = motifs.create(seqs_2, alphabet=config.AMINOACID_ALPHABET) # type: ignore

    # Get the position weight matrix
    pwm_1: np.ndarray = pd.DataFrame(m_1.pwm).values
    pwm_1 += 1e-20  # Avoid log(0)
    pwm_2: np.ndarray = pd.DataFrame(m_2.pwm).values
    pwm_2 += 1e-20  # Avoid log(0)
    return pwm_1, pwm_2

def jensen_shannon_divergence_slides(slides_1, slides_2):

    pwm_1, pwm_2 = get_pwm(slides_1, slides_2)
    return jensenshannon(pwm_1, pwm_2, axis=1, base=2).sum()

In [5]:
task_template = {'1v9': '{pos_ag}__vs__9', 'high_vs_95low': '{pos_ag}_high__vs__{pos_ag}_95low', 'high_vs_looser': '{pos_ag}_high__vs__{pos_ag}_looser'}

In [10]:
shuffl_res = []
for task in tqdm(['1v9', 'high_vs_95low', 'high_vs_looser'], desc='task'):
    for pos_ag in tqdm(config.ANTIGENS, desc='positive antigen'):
        model_path = glob.glob(f"./torch_models/Frozen_MiniAbsolut_ML_shuffled/{task}/seed_0/split_42/{task_template[task].format(pos_ag=pos_ag)}/swa_model/data/model.pth")[0]
        model = torch.load(model_path, map_location=torch.device("cpu"))
        mps_device = torch.device("mps")
        model.to(mps_device)
        dataset = glob.glob(f'./torch_models/Frozen_MiniAbsolut_ML_shuffled/{task}/seed_0/split_42/{task_template[task].format(pos_ag=pos_ag)}/*_test_dataset.tsv')[0] #this part with for loop is repeated a often, probably I just should store these dfs
        df = pd.read_csv(dataset, sep='\t')
        slides_pos = df[df['y'] == 1]['Slide']
        slides_neg = df[df['y'] == 0]['Slide']
        jsd = jensen_shannon_divergence_slides(slides_neg, slides_pos)
        df['X'] = df['Slide'].apply(onehot_encode)
        X = torch.tensor(df['X'].tolist(), dtype=torch.float32).to(mps_device)
        y = torch.tensor(df['y'].tolist(), dtype=torch.float32).to(mps_device)
        y_pred, logits = model(X, return_logits = True)
        y_pred = y_pred.cpu().detach().numpy().reshape(-1).round()
        y = y.cpu().detach().numpy()
        acc, recall, prec, f1 = compute_binary_metrics(y_pred, y)
        shuffl_res.append([acc, recall, prec, f1, task, task, pos_ag, 0, 42, jsd])        

  X = torch.tensor(df['X'].tolist(), dtype=torch.float32).to(mps_device)
positive antigen: 100%|██████████| 10/10 [00:11<00:00,  1.14s/it]
positive antigen: 100%|██████████| 10/10 [00:11<00:00,  1.12s/it]
positive antigen: 100%|██████████| 10/10 [00:10<00:00,  1.07s/it]
task: 100%|██████████| 3/3 [00:33<00:00, 11.11s/it]


In [16]:
shuffled_closed_df = pd.DataFrame(shuffl_res, columns=['acc', 'recall', 'prec', 'f1', 'train_task', 'test_task','pos_ag', 'seed', 'split', 'jsd'])

In [37]:
shuffled_closed_df

Unnamed: 0,acc,recall,prec,f1,train_task,test_task,pos_ag,seed,split,jsd
0,0.505453,0.523,0.505509,0.514106,1v9,1v9,3VRL,0,42,0.4049
1,0.497649,0.4598,0.497727,0.478012,1v9,1v9,1NSN,0,42,0.437076
2,0.497849,0.4898,0.498068,0.493899,1v9,1v9,3RAJ,0,42,0.460353
3,0.498149,0.5134,0.498447,0.505813,1v9,1v9,5E94,0,42,0.434619
4,0.498449,0.487,0.498669,0.492765,1v9,1v9,1H0D,0,42,0.40701
5,0.504052,0.5278,0.504107,0.515681,1v9,1v9,1WEJ,0,42,0.397035
6,0.497149,0.4918,0.497371,0.49457,1v9,1v9,1ADQ,0,42,0.385468
7,0.494847,0.4896,0.495046,0.492308,1v9,1v9,1FBI,0,42,0.396808
8,0.496448,0.493,0.496675,0.494831,1v9,1v9,2YPV,0,42,0.364463
9,0.49955,0.503,0.499801,0.501396,1v9,1v9,1OB1,0,42,0.429553


In [39]:
task_rename = {'1v9': 'vs 9', 'high_vs_95low': 'vs Non-binder', 'high_vs_looser': 'vs Weak'}
shuffled_closed_df['train_task'] = shuffled_closed_df['train_task'].apply(lambda x: task_rename[x])
shuffled_closed_df['test_task'] = shuffled_closed_df['test_task'].apply(lambda x: task_rename[x])

shuffled_closed_df

Unnamed: 0,acc,recall,prec,f1,train_task,test_task,pos_ag,seed,split,jsd
0,0.505453,0.523,0.505509,0.514106,vs 9,vs 9,3VRL,0,42,0.4049
1,0.497649,0.4598,0.497727,0.478012,vs 9,vs 9,1NSN,0,42,0.437076
2,0.497849,0.4898,0.498068,0.493899,vs 9,vs 9,3RAJ,0,42,0.460353
3,0.498149,0.5134,0.498447,0.505813,vs 9,vs 9,5E94,0,42,0.434619
4,0.498449,0.487,0.498669,0.492765,vs 9,vs 9,1H0D,0,42,0.40701
5,0.504052,0.5278,0.504107,0.515681,vs 9,vs 9,1WEJ,0,42,0.397035
6,0.497149,0.4918,0.497371,0.49457,vs 9,vs 9,1ADQ,0,42,0.385468
7,0.494847,0.4896,0.495046,0.492308,vs 9,vs 9,1FBI,0,42,0.396808
8,0.496448,0.493,0.496675,0.494831,vs 9,vs 9,2YPV,0,42,0.364463
9,0.49955,0.503,0.499801,0.501396,vs 9,vs 9,1OB1,0,42,0.429553


In [30]:
task_clean_order = ['vs Non-binder', 'vs 1', 'vs 9', 'vs Weak']
ag_order = [
        "1FBI",
        "3VRL",
        "2YPV",
        "5E94",
        "1WEJ",
        "1OB1",
        "1NSN",
        "1H0D",
        "3RAJ",]
       
pellets = {'color_blind_light':['#a2c8ec','#cfcfcf', '#ffbc79'],
        'color_blind_dark':['#5CA7E5','#ababab', '#ff7700'],
        'r_like':['#94669E', '#F2D81D', '#00817A'],
        'chat_gpt': ['#FFC300', '#FF5733', '#00A6ED'],
        'antigens': ['#008080','#FFA07A','#000080','#FFD700','#228B22','#FF69B4','#800080','#FF6347','#00FF00','#FF1493']}

In [63]:
shuffled_closed_df.groupby('train_task').describe()

Unnamed: 0_level_0,acc,acc,acc,acc,acc,acc,acc,acc,recall,recall,...,split,split,jsd,jsd,jsd,jsd,jsd,jsd,jsd,jsd
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,...,75%,max,count,mean,std,min,25%,50%,75%,max
train_task,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
vs 9,10.0,0.498959,0.003313,0.494847,0.497274,0.497999,0.499275,0.505453,10.0,0.49782,...,42.0,42.0,10.0,0.411729,0.02844,0.364463,0.396865,0.405955,0.433353,0.460353
vs Non-binder,10.0,0.49852,0.004741,0.4912,0.4961,0.49725,0.50055,0.5084,10.0,0.49438,...,42.0,42.0,10.0,0.397759,0.025337,0.342429,0.389117,0.396013,0.409228,0.438479
vs Weak,10.0,0.50117,0.003827,0.4951,0.498925,0.50145,0.503075,0.5068,10.0,0.49608,...,42.0,42.0,10.0,0.404984,0.018257,0.380773,0.393023,0.401778,0.41989,0.436097


In [57]:
#calculate correlation between jsd and acc in shuffled_closed_df so that it is one column
shuffled_closed_df.groupby(['train_task']).corr()[['acc','jsd']].reset_index().pivot(index='train_task', columns='level_1', values='acc')

level_1,acc,f1,jsd,prec,recall,seed,split
train_task,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
vs 9,1.0,0.778069,0.018394,0.999798,0.72146,,
vs Non-binder,1.0,0.428264,-0.006249,0.999608,0.159901,,
vs Weak,1.0,0.633859,0.397881,0.999462,0.579934,,


In [61]:
shuffled_r = dict()
for task in shuffled_closed_df['train_task'].unique():
    df_g = shuffled_closed_df[shuffled_closed_df['train_task'] == task]
    r, pval = pearsonr(df_g["jsd"], df_g["acc"])
    shuffled_r[task] = (r, pval)


In [62]:
shuffled_r

{'vs 9': (0.01839405592264354, 0.9597766137434811),
 'vs Non-binder': (-0.006248853686000326, 0.9863311663131851),
 'vs Weak': (0.3978813397420664, 0.25482723824666437)}

In [60]:
pval

0.25482723824666437