In [1]:
import os
from pathlib import Path
import sys
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))
from collections import Counter
import pandas as pd
from src.utils.data import read_fasta
from src.data.datasets import ProteinDataset
import numpy as np
from src.utils.data import read_pickle, save_to_pickle,read_json
from src.utils.evaluation import metrics_per_label_df
from torchmetrics.classification import AveragePrecision,Specificity
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from src.utils.losses import FocalLoss
from src.utils.evaluation import EvalMetrics


In [2]:
def fasta_list_to_df(fasta_list):
    df = pd.DataFrame([(seq,id," ".join(labs)) for seq,id,labs in fasta_list],columns=['sequence','id','labels'])
    return df

# Wha can change between 2019 and 2024?

In [3]:
annotations = pd.read_pickle('../data/annotations/go_annotations_may_2024.pkl')
annotations_old = pd.read_pickle('../data/annotations/go_annotations_2019_07_01.pkl')
p2024 = read_json('../data/vocabularies/parenthood_may_2024.json')
p2019 = read_json('../data/vocabularies/parenthood_2019.json')

In [4]:
annotations['name'] = annotations['name'].str.lower()
annotations['label'] = annotations['name'].str.lower()
annotations_old['name'] = annotations_old['name'].str.lower()
annotations_old['label'] = annotations_old['name'].str.lower()

merged_on_names = annotations_old[['name','label']].reset_index().merge(annotations[['name','label']].reset_index(),how='outer',suffixes=('_old','_new'),on=['name'],indicator=True)
merged_on_ids = annotations_old[['name','label']].reset_index().merge(annotations[['name','label']].reset_index(),how='outer',suffixes=('_old','_new'),on=['index'],indicator=True)

In [5]:
print('Added GO Terms: ',len(merged_on_ids.query('_merge=="right_only"')))
print('Removed GO Terms: ',len(merged_on_ids.query('_merge=="left_only"')))
temp = merged_on_ids.query('_merge=="both" and name_old != name_new')
print('Same ID, different definitions',len(temp[temp['name_old'].apply(lambda x: 'obsolete' not in x)&temp['name_new'].apply(lambda x: 'obsolete' not in x)]))
temp = merged_on_names.query('_merge=="both" and index_old!=index_new')
print('Different IDs, same definitions',len(temp[temp['name'].apply(lambda x: 'obsolete' not in x)]))

Added GO Terms:  1490
Removed GO Terms:  1066
Same ID, different definitions 1767
Different IDs, same definitions 17


In [6]:
test_2019 = fasta_list_to_df(read_fasta('../data/swissprot/proteinfer_splits/random/test_GO.fasta'))
test_2024 = fasta_list_to_df(read_fasta('../data/swissprot/proteinfer_splits/random/test_GO_may_2024.fasta'))
test_2024_pinf_labels = fasta_list_to_df(read_fasta('../data/swissprot/proteinfer_splits/random/test_GO_may_2024_pinf_labels.fasta'))
test_2019['labels'] = test_2019['labels'].apply(lambda x: sorted(x.strip().split(' ')))
test_2024['labels'] = test_2024['labels'].apply(lambda x: sorted(x.strip().split(' ')))
test_2024_pinf_labels['labels'] = test_2024_pinf_labels['labels'].apply(lambda x: sorted(x.strip().split(' ')))

FileNotFoundError: [Errno 2] No such file or directory: '../data/swissprot/proteinfer_splits/random/test_GO_may_2024_pinf_labels.fasta'

In [23]:
test_2019.shape,test_2024.shape

((51751, 3), (51616, 3))

In [24]:
len(set([ j for i in test_2024['labels'] for j in i]))

22015

In [25]:
test_merged = test_2019.merge(test_2024,how='inner',on='id',suffixes=('_old','_new'))

In [26]:
def get_pct_added_terms(x):
    old = set(x['labels_old'])
    new = set(x['labels_new'])
    return len( new - old )*100/len(new)
def get_pct_removed_terms(x):
    old = set(x['labels_old'])
    new = set(x['labels_new'])
    return len( old - new)*100/len(old)

def iou(x):
    old = set(x['labels_old'])
    new = set(x['labels_new'])
    return len( old & new)*100/len( old.union(new))

test_merged['pct_added_terms'] = test_merged.apply(get_pct_added_terms,axis=1)
test_merged['pct_removed_terms'] = test_merged.apply(get_pct_removed_terms,axis=1)
test_merged['pct_iou'] = test_merged.apply(iou,axis=1)

In [27]:
test_merged[['pct_added_terms','pct_removed_terms','pct_iou']].describe()

Unnamed: 0,pct_added_terms,pct_removed_terms,pct_iou
count,51616.0,51616.0,51616.0
mean,14.826461,25.77554,65.808443
std,16.775159,16.547986,17.871252
min,0.0,0.0,0.0
25%,3.703704,14.583333,58.139535
50%,8.77193,23.728814,68.085106
75%,20.0,33.333333,78.571429
max,100.0,100.0,100.0


In [None]:
unseen_test = read_fasta('../data/zero_shot/SwissProt_2023_unseen_sequences_and_labels.fasta')
full_go = read_fasta('../data/swissprot/proteinfer_splits/random/full_GO.fasta')

unseen_terms = set([j for i in unseen_test for j in i[-1]])
unseen_seqs = set([i[0] for i in unseen_test])
unseen_seq_ids = set([i[1] for i in unseen_test])

full_go_terms = set([j for i in full_go for j in i[-1]])
new =  set(annotations.index) - set(annotations_old.index)

In [None]:
print('New annotations from new terms:', len(unseen_terms))
print('New annotations:', len(new))
print('New predicted labels all swissprot',len(set(labels.columns)))


New annotations from new terms: 228
New annotations: 1260
New predicted labels all swissprot 544


In [5]:
all_sp = read_fasta('../data/swissprot/swissprot_2023.fasta')
id2seq = {id:seq for seq,id,_ in all_sp}

for df in [labels,logits]:
    df['sequence'] = list(df.index.map(id2seq))

    #Identify sequences based on proteinfer's split
    seq2split = {
    **{seq:'train' for seq,_,_ in read_fasta('../data/swissprot/proteinfer_splits/random/train_GO.fasta')},
    **{seq:'val' for seq,_,_ in read_fasta('../data/swissprot/proteinfer_splits/random/dev_GO.fasta')},
    **{seq:'test' for seq,_,_ in read_fasta('../data/swissprot/proteinfer_splits/random/test_GO.fasta')}
    }

    df['split'] = df['sequence'].map(seq2split).fillna('new')

    df.set_index(['sequence','split'],append=True,inplace=True)

There are some proteins in proteinfer dataset that don't appear in all swissprot data. Less than 0.3% of proteinfer sequences. It's just proteins that have "INLL" appended at the end

In [6]:
from collections import defaultdict
del_seqs = set([i[0] for i in full_go]) - set(id2seq.values())
full_go_seq2ids = defaultdict(list)
for seq,id,_ in full_go:
    full_go_seq2ids[seq].append(id)



In [7]:
len(del_seqs)*100/len(full_go)

0.2965899806164097

In [8]:
[full_go_seq2ids[i] for i in del_seqs][:10]

[['Q9GLP6'],
 ['Q11098'],
 ['P17019'],
 ['Q90632'],
 ['Q96SE0'],
 ['Q5H9K5'],
 ['P63034'],
 ['Q96ME1'],
 ['A8WRP9'],
 ['Q09M05']]

unseen_test

In [6]:
unseen_test_df = fasta_list_to_df(unseen_test)
all_sp_df = fasta_list_to_df(all_sp)

all_sp_df['labels'] = all_sp_df['labels'].apply(lambda x: sorted(x.replace('GO:0003674','').replace('  ',' ').strip().split(' ')))
unseen_test_df['labels'] = unseen_test_df['labels'].apply(lambda x: sorted(x.strip().split(' ')))


NameError: name 'all_sp' is not defined

Performance

In [11]:
from torcheval.metrics import MultilabelAUPRC, BinaryAUPRC
eval_metrics = EvalMetrics(device='cpu')


for split in [
              #'val',
              'test',
              #'new'
              ]:
    
    label_mask = labels.query('split == @split').sum(axis=0)>=0
    label_mask = list(label_mask[label_mask].index)

    sequence_mask = labels.query('split == @split').sum(axis=1)>=0
    sequence_mask = list(sequence_mask[sequence_mask].index)


    mAP_micro = BinaryAUPRC(device='cpu')
    mAP_macro = MultilabelAUPRC(device='cpu',
                                num_labels=len(label_mask))
    metrics=eval_metrics.get_metric_collection_with_regex(pattern="(f1_m.*)|(precision_macro)|(recall_macro)",
                                                                    threshold=0.3,
                                                            num_labels=len(label_mask)
                                                            )
    
    probabilities = torch.sigmoid(torch.tensor(logits.query('split == @split').loc[sequence_mask][label_mask].values))
    y = torch.tensor(labels.query('split == @split').loc[sequence_mask][label_mask].values)


    print(mAP_micro.update(probabilities.flatten(), y.flatten()).compute(),mAP_macro.update(probabilities, y).compute())
    print(metrics(probabilities, y))
    metrics.reset()
    print(metrics(probabilities, 1-y))

tensor(0.0117) tensor(0.0619)
{'f1_macro': tensor(0.0393), 'f1_micro': tensor(0.0251), 'precision_macro': tensor(0.0341), 'recall_macro': tensor(0.0925)}
{'f1_macro': tensor(0.0020), 'f1_micro': tensor(0.0022), 'precision_macro': tensor(0.7317), 'recall_macro': tensor(0.0011)}


In [61]:
probabilities_df = pd.DataFrame(probabilities,
                                columns = logits.query('split == @split').loc[sequence_mask][label_mask].columns,
                                index = logits.query('split == @split').loc[sequence_mask][label_mask].index)

In [77]:
m=(probabilities_df>0.9)&(labels.query('split == @split').loc[sequence_mask][label_mask]==0)

In [26]:
from sklearn.metrics import confusion_matrix

In [40]:
torch.mean(y[probabilities>0.99].float())

tensor(0.0352)

In [60]:
torch.where((probabilities>0.99)&(y==0))

(tensor([  225,   840,  1002,  2066,  2071,  4416,  5082,  5286,  5286,  9112,
          9113,  9653,  9653,  9783,  9784,  9785,  9953,  9958,  9962, 10231,
         10247, 11231, 11474, 11476, 11477, 11481, 11487, 11489, 11493, 11499,
         11503, 11504, 11505, 11508, 11510, 12198, 12826, 13434, 13564, 13583,
         16123, 16123, 16900, 16942, 16942, 16945, 16945, 16946, 16946, 17633,
         17922, 18176, 18940, 18941, 19515, 19516, 20691, 21503, 21650, 23392,
         23396, 23396, 24572, 24573, 24574, 24575, 25507, 25507, 25881, 25886,
         25887, 26006, 26006, 26007, 26008, 26009, 26010, 26010, 26011, 26012,
         26013, 26014, 26015, 26016, 26017, 26018, 26019, 26020, 26748, 27477,
         27480, 27482, 27491, 27505, 27544, 27545, 28216, 28216, 28230, 28390,
         28392, 28394, 28395, 28396, 28794, 28795, 29020, 29097, 29456, 29457,
         29458, 29459, 29467, 29468, 29469, 29470, 29471, 29472, 29473, 29474,
         29475, 29478, 29480, 29492, 29493, 29494, 2

In [49]:
y[probabilities>0.].shape

torch.Size([19598])

In [53]:
y.sum()

tensor(1804)

In [51]:
51683*730

37728590

In [54]:
1804/37728590

4.7815197970557606e-05

In [52]:
19598/37728590

0.0005194469234074213

In [55]:
tn,fp,fn,tp = confusion_matrix(y.flatten().numpy(),(probabilities.flatten().numpy()>.3)*1).ravel()

In [56]:
tn,fp,fn,tp

(37685570, 41216, 1257, 547)

In [10]:
from torcheval.metrics import MultilabelAUPRC, BinaryAUPRC


In [44]:
tn,fp,fn,tp

(37726428, 358, 1800, 4)

In [24]:
eval_metrics = EvalMetrics(device='cpu')


for split in [
              #'val',
              'test',
              #'new'
              ]:
    
    label_mask = labels.query('split == @split').sum(axis=0)>0
    label_mask = list(label_mask[label_mask].index)

    sequence_mask = labels.query('split == @split').sum(axis=1)>=0
    sequence_mask = list(sequence_mask[sequence_mask].index)


    mAP_micro = BinaryAUPRC(device='cpu')
    mAP_macro = MultilabelAUPRC(device='cpu',
                                num_labels=len(label_mask))
    metrics=eval_metrics.get_metric_collection_with_regex(pattern="(f1_m.*)|(precision_macro)|(recall_macro)",
                                                                    threshold=0.3,
                                                            num_labels=len(label_mask)
                                                            )
    
    probabilities = torch.sigmoid(torch.tensor(logits.query('split == @split').loc[sequence_mask][label_mask].values))
    y = torch.tensor(labels.query('split == @split').loc[sequence_mask][label_mask].values)


    print(mAP_micro.update(probabilities.flatten(), y.flatten()).compute(),mAP_macro.update(probabilities, y).compute())
    print(metrics(probabilities, y))
    metrics.reset()
    print(metrics(probabilities, 1-y))

tensor(0.0477) tensor(0.1429)
{'f1_macro': tensor(0.0938), 'f1_micro': tensor(0.0763), 'precision_macro': tensor(0.0823), 'recall_macro': tensor(0.2208)}
{'f1_macro': tensor(0.0023), 'f1_micro': tensor(0.0024), 'precision_macro': tensor(0.6941), 'recall_macro': tensor(0.0012)}


In [1]:
labels.query('split == @split').loc[sequence_mask][label_mask].shape

NameError: name 'labels' is not defined

In [13]:
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO

In [None]:
seq2split

In [14]:
test_new_labels = []
test_new_labels_ = []
for i in all_sp:
    if i[0] in seq2split:
        if seq2split[i[0]]=='test':
            record = SeqRecord(Seq(i[0]),id=i[1],description=" ".join(i[2]))
            test_new_labels.append(record)
            test_new_labels_.append(i)


In [15]:
len(test_new_labels)

51687

In [25]:
seqs = set([i[0] for i in test_new_labels_])
labs=labels.query('sequence in @seqs')
lags=logits.query('sequence in @seqs')

In [52]:
logs = pd.read_parquet("../outputs/results/test_1_logits_zero_shot_pinf_test.parquet").drop('GO:0003674',axis=1)
labs = pd.read_parquet("../outputs/results/test_1_labels_zero_shot_pinf_test.parquet").drop('GO:0003674',axis=1)
mAP_micro = BinaryAUPRC(device='cpu')
mAP_macro = MultilabelAUPRC(device='cpu',
                            num_labels=labs.shape[-1])

probabilities = torch.sigmoid(torch.tensor(logs.values))
y = torch.tensor(labs.values)

print(mAP_micro.update(probabilities.flatten(), y.flatten()).compute(),mAP_macro.update(probabilities, y).compute())

tensor(0.0478) tensor(0.1399)


In [65]:

label_mask = labels.query('split == @split').sum(axis=0)>0
label_mask = list(label_mask[label_mask].index)

sequence_mask = labels.query('split == @split').sum(axis=1)>0
sequence_mask = list(sequence_mask[sequence_mask].index)

In [90]:
split = 'new'
label_mask = labels.query('split == @split').sum(axis=0)>0
label_mask = list(label_mask.index) #list(unseen_terms&set(labels.query('split == @split').columns)) 

mAP_micro = BinaryAUPRC(device='cpu')
mAP_macro = MultilabelAUPRC(device='cpu',
                            num_labels=len(label_mask))

probabilities = torch.sigmoid(torch.tensor(logits.query('split == @split')[label_mask].values))
y = torch.tensor(labels.query('split == @split')[label_mask].values)


print(mAP_micro.update(probabilities.flatten(), y.flatten()).compute(),mAP_macro.update(probabilities, y).compute())

tensor(0.0212) tensor(0.0431)


In [89]:
(labels.query('split == @split')[label_mask].sum(axis=1)>0).sum()

987

In [8]:
import pandas as pd
logits_unseen = pd.read_parquet("../outputs/results/unseen_zero_shot_logits.parquet")
labels_unseen = pd.read_parquet("../outputs/results/unseen_zero_shot_labels.parquet")

In [13]:
probabilities.shape

torch.Size([815, 227])

In [14]:


mask = set(labels.index.get_level_values(0))#&set(labels_unseen.index)
cols = set(labels.columns) #set(labels_unseen.columns)

mAP_micro = BinaryAUPRC(device='cpu')
mAP_macro = MultilabelAUPRC(device='cpu',
                            num_labels=len(cols))

probabilities = torch.sigmoid(torch.tensor(logits.loc[mask,cols].values))
y = torch.tensor(labels.loc[mask,cols].values)
print(y.sum(),probabilities.sum())
print(mAP_micro.update(probabilities.flatten(), y.flatten()).compute(),mAP_macro.update(probabilities, y).compute())

  probabilities = torch.sigmoid(torch.tensor(logits.loc[mask,cols].values))
  y = torch.tensor(labels.loc[mask,cols].values)


tensor(69730) tensor(202205.6875)
tensor(0.0948) tensor(0.1447)


In [15]:
labels.sum(axis=0).sort_values()

GO:0120257        0
GO:0120258        0
GO:0140928        1
GO:0140796        1
GO:0140795        1
              ...  
GO:0140678     1404
GO:0140535     1693
GO:0140657     2466
GO:0140640     5604
GO:0110165    47658
Length: 544, dtype: int64

In [60]:
from sklearn.metrics import confusion_matrix
tn,fp,fn,tp = confusion_matrix(y.flatten().numpy(),(probabilities.flatten().numpy()>.3)*1).ravel()

In [61]:
tn,fp,fn,tp

(11604433, 13629, 51221, 1468)

In [None]:
test_merged



mask = set(labels.index.get_level_values(0))#&set(labels_unseen.index)
cols = set(labels.columns) #set(labels_unseen.columns)

mAP_micro = BinaryAUPRC(device='cpu')
mAP_macro = MultilabelAUPRC(device='cpu',
                            num_labels=len(cols))

probabilities = torch.sigmoid(torch.tensor(logits.loc[mask,cols].values))
y = torch.tensor(labels.loc[mask,cols].values)
print(y.sum(),probabilities.sum())
print(mAP_micro.update(probabilities.flatten(), y.flatten()).compute(),mAP_macro.update(probabilities, y).compute())