In [2]:
import os
from pathlib import Path
import sys
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))
import torch
import numpy as np
import os
from torcheval.metrics import MultilabelAUPRC, BinaryAUPRC
import pandas as pd
from src.utils.evaluation import EvalMetrics
from src.utils.data import generate_vocabularies

In [1]:
annotation_type = 'GO'

if annotation_type =='GO':
    label_embeddings = '2024_E5_multiling_inst_frozen_label_embeddings_mean' 
elif annotation_type=='EC':
    label_embeddings = 'ecv1_E5_multiling_inst_frozen_label_embeddings_mean'

zero_shot_pinf_logits = pd.read_parquet(f'../outputs/results/test_logits_{annotation_type}_unseen_proteinfer_ABCD.parquet')
zero_shot_labels = pd.read_parquet(f'../outputs/results/test_1_labels_{annotation_type+"_" if annotation_type == "EC" else ""}normal_test_label_aug_v4.parquet')
embeddings = torch.load(f'../data/embeddings/{label_embeddings}.pt')
embeddings_idx = torch.load(f'../data/embeddings/{label_embeddings}_index.pt')
vocabularies = generate_vocabularies(file_path = f'../data/swissprot/proteinfer_splits/random/full_GO.fasta')
zero_shot_pinf_logits.columns = vocabularies['label_vocab']


NameError: name 'pd' is not defined

In [62]:
embeddings_idx['description'].iloc[0]

"Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: mitochondrion inheritance"

In [63]:
zero_shot_labels.shape,zero_shot_pinf_logits.shape

((832, 228), (52874, 32102))

In [64]:
logits_unseen = pd.read_parquet("../outputs/results/unseen_zero_shot_logits.parquet")
labels_unseen = pd.read_parquet("../outputs/results/unseen_zero_shot_labels.parquet")
mask = set(zero_shot_labels.index.get_level_values(0))&set(labels_unseen.index)
cols = set(zero_shot_labels.columns)&set(labels_unseen.columns) - set(['GO:0106314'])

In [65]:
embedding_mask = embeddings_idx['description_type']=='name'
embeddings_idx = embeddings_idx[embedding_mask].reset_index(drop=True)
embeddings = embeddings[embedding_mask]

train_embeddings_mask = embeddings_idx['id'].isin(vocabularies['label_vocab'])
train_embeddings_idx = embeddings_idx[train_embeddings_mask].reset_index(drop=True)
train_embeddings = embeddings[train_embeddings_mask]

zero_shot_embeddings_mask = embeddings_idx['id'].isin(zero_shot_labels.columns)
zero_shot_embeddings_idx = embeddings_idx[zero_shot_embeddings_mask].reset_index(drop=True)
zero_shot_embeddings = embeddings[zero_shot_embeddings_mask]

In [66]:
label_train_2_zero_shot_similarities = (torch.nn.functional.normalize(zero_shot_embeddings)@torch.nn.functional.normalize(train_embeddings).T)
zero_shot_label_mapping = {zero_shot_embeddings_idx['id'].iloc[zero_shot_label_idx]:train_embeddings_idx['id'].iloc[train_label_idx.item()] for zero_shot_label_idx,train_label_idx in enumerate(label_train_2_zero_shot_similarities.max(dim=-1).indices)}


### Sanity check
Print zero shot label descriptions with their mapped training label description. Keys should be similar to their values

In [67]:
zero_shot_label_descriptions_mapping={embeddings_idx.loc[embeddings_idx['id']==k,'description'].iloc[0]:embeddings_idx.loc[embeddings_idx['id']==v,'description'].iloc[0] for k,v in zero_shot_label_mapping.items()}

#Print only 10 randomly selected key-value pairs
sampled_keys =np.random.choice(list(zero_shot_label_descriptions_mapping.keys()),size=10,replace=False)
{k:v for k,v in zero_shot_label_descriptions_mapping.items() if k in sampled_keys}

{"Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: mitochondrial proliferation": "Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: mitochondrial gene expression",
 "Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: zinc ion sensor activity": "Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: copper ion sensor activity",
 "Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: Fc receptor-mediated immune complex endocytosis": "Instruct

### Create the zero shot proteinfer-based baseline prediction df

In [68]:
zero_shot_pinf_baseline_logits = zero_shot_pinf_logits[[zero_shot_label_mapping[i] for i in zero_shot_labels.columns]]
zero_shot_pinf_baseline_logits.columns = zero_shot_labels.columns

In [69]:
zero_shot_pinf_baseline_logits.shape

(52874, 228)

### Measure baseline performance

In [70]:
zero_shot_pinf_baseline_logits.shape

(52874, 228)

In [72]:
zero_shot_labels.shape

(815, 227)

In [71]:
from pprint import pprint

zero_shot_pinf_baseline_logits = zero_shot_pinf_baseline_logits.loc[zero_shot_pinf_baseline_logits.index.isin(mask),cols]
zero_shot_labels = zero_shot_labels.loc[zero_shot_labels.index.isin(zero_shot_pinf_baseline_logits.index),zero_shot_pinf_baseline_logits.columns]


eval_metrics = EvalMetrics(device='cuda')
mAP_micro = BinaryAUPRC(device='cpu')
mAP_macro = MultilabelAUPRC(device='cpu',num_labels=zero_shot_labels.shape[-1])
metrics = eval_metrics\
        .get_metric_collection_with_regex(pattern='f1_m.*',
                            threshold=0.5,
                            num_labels=zero_shot_labels.shape[-1]
                            )

metrics(torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values,device='cuda')),
              torch.tensor(zero_shot_labels.values,device='cuda'))
mAP_micro.update(torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values)).flatten(),
                               torch.tensor(zero_shot_labels.values).flatten())
mAP_macro.update(torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values)),
                 torch.tensor(zero_shot_labels.values))


metrics = metrics.compute()
metrics.update({
                "map_micro":mAP_micro.compute(),
                "map_macro":mAP_macro.compute()
                })
metrics = {k:v.item() for k,v in metrics.items()}
pprint(metrics)



{'f1_macro': 0.08511335402727127,
 'f1_micro': 0.20441989600658417,
 'map_macro': 0.18999364972114563,
 'map_micro': 0.08505809307098389}


  zero_shot_pinf_baseline_logits = zero_shot_pinf_baseline_logits.loc[zero_shot_pinf_baseline_logits.index.isin(mask),cols]


In [38]:
logits_unseen = pd.read_parquet("../outputs/results/unseen_zero_shot_logits.parquet")
labels_unseen = pd.read_parquet("../outputs/results/unseen_zero_shot_labels.parquet")
mask = set(zero_shot_labels.index.get_level_values(0))&set(labels_unseen.index)
cols = set(zero_shot_labels.columns)&set(labels_unseen.columns)

In [39]:
logits_unseen.shape

(832, 228)

In [17]:
y.shape

(815, 227)

In [15]:
from pprint import pprint
eval_metrics = EvalMetrics(device='cuda')
mAP_micro = BinaryAUPRC(device='cpu')
mAP_macro = MultilabelAUPRC(device='cpu',num_labels=len(cols))
metrics = eval_metrics\
        .get_metric_collection_with_regex(pattern='f1_m.*',
                            threshold=0.5,
                            num_labels=len(cols)
                            )

pinf_logits = zero_shot_pinf_baseline_logits.loc[mask,cols]
y = zero_shot_labels.loc[mask,cols]
metrics(torch.sigmoid(torch.tensor(pinf_logits.values,device='cuda')),
              torch.tensor(y.values,device='cuda'))
mAP_micro.update(torch.sigmoid(torch.tensor(pinf_logits.values)).flatten(),
                               torch.tensor(y.values).flatten())
mAP_macro.update(torch.sigmoid(torch.tensor(pinf_logits.values)),
                 torch.tensor(y.values))


metrics = metrics.compute()
metrics.update({
                "map_micro":mAP_micro.compute(),
                "map_macro":mAP_macro.compute()
                })
metrics = {k:v.item() for k,v in metrics.items()}
pprint(metrics)



{'f1_macro': 0.07445017248392105,
 'f1_micro': 0.13045518100261688,
 'map_macro': 0.18683546781539917,
 'map_micro': 0.07307962328195572}


  pinf_logits = zero_shot_pinf_baseline_logits.loc[mask,cols]
  y = zero_shot_labels.loc[mask,cols]


# EC 

In [2]:
from src.utils.data import read_fasta,get_vocab_mappings

In [3]:
train_EC = read_fasta('../data/zero_shot/train_EC.fasta')
val_EC = read_fasta('../data/zero_shot/train_EC.fasta')
test_EC = read_fasta('../data/zero_shot/test_EC.fasta')

In [9]:
ec_vocabs = generate_vocabularies('../data/swissprot/proteinfer_splits/random/full_EC.fasta')

In [10]:
label2int, int2label = get_vocab_mappings(ec_vocabs['label_vocab'])


def process_labels(labels):
    # One-hot encode the labels for use in the loss function (not a model input, so should not be impacted by augmentation)
    labels_ints = torch.tensor(
        [label2int[label] for label in labels], dtype=torch.long
    )

    label_multihots = torch.nn.functional.one_hot(
        labels_ints, num_classes=len(ec_vocabs['label_vocab'])
    ).sum(dim=0)

    return label_multihots

In [None]:
from tqdm import tqdm
d = []
for record in tqdm(test_EC):
    d.append(process_labels(record[-1]).numpy())


In [16]:
df = pd.DataFrame(d[:10])

In [17]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,5124,5125,5126,5127,5128,5129,5130,5131,5132,5133
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
6,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
7,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
8,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [5]:
test_EC = [(i,j," ".join(k)) for i,j,k in test_EC]


In [6]:
test_EC[0]

('MLKNDLFLRALKRQPCSRTPIWVMRQAGRYLPEYRAVREKTDFLTLCKTPELATEVTIQPVELVGVDAAIIFSDILVVNEAMGQEVNIIETKGIKLAPPIRSQADIDKLIVPDIDEKLGYVLDALRMTKKELDNRVPLIGFSGAAWTLFTYAVEGGGSKNYAYAKQMMYREPQMAHSLLSKISQTITAYTLKQIEAGADAIQIFDSWASALSEDDYREYALPYIKDTVQAIKAKHPETPVIVFSKDCNTILSDIADTGCDAVGLGWGIDISKARTELNDRVALQGNLDPTVLYGTQERIKIEAGKILKSFGQHNHHSGHVFNLGHGILPDMDPDNLRCLVEFVKEESAKYH',
 'Q3AUB7',
 'EC:4.-.-.- EC:4.1.-.- EC:4.1.1.- EC:4.1.1.37')