In [3]:
import os
from pathlib import Path
import sys
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))
import torch
import numpy as np
import os
from torcheval.metrics import MultilabelAUPRC, BinaryAUPRC
import pandas as pd
from src.utils.evaluation import EvalMetrics
from src.utils.data import generate_vocabularies

In [4]:
label_embeddings = '2024_E5_multiling_inst_frozen_label_embeddings_mean'
annotation_type = 'GO'

zero_shot_pinf_logits = pd.read_parquet(f'../outputs/results/test_logits_GO_unseen_proteinfer_ABCD.parquet')
zero_shot_labels = pd.read_parquet(f'../outputs/results/test_1_labels_normal_test_label_aug_v4.parquet')
embeddings = torch.load(f'../data/embeddings/{label_embeddings}.pt')
embeddings_idx = torch.load(f'../data/embeddings/{label_embeddings}_index.pt')
vocabularies = generate_vocabularies(file_path = f'../data/swissprot/proteinfer_splits/random/full_{annotation_type}.fasta')
zero_shot_pinf_logits.columns = vocabularies['label_vocab']


In [5]:
embedding_mask = embeddings_idx['description_type']=='name'
embeddings_idx = embeddings_idx[embedding_mask].reset_index(drop=True)
embeddings = embeddings[embedding_mask]

train_embeddings_mask = embeddings_idx['id'].isin(vocabularies['label_vocab'])
train_embeddings_idx = embeddings_idx[train_embeddings_mask].reset_index(drop=True)
train_embeddings = embeddings[train_embeddings_mask]

zero_shot_embeddings_mask = embeddings_idx['id'].isin(zero_shot_labels.columns)
zero_shot_embeddings_idx = embeddings_idx[zero_shot_embeddings_mask].reset_index(drop=True)
zero_shot_embeddings = embeddings[zero_shot_embeddings_mask]

In [6]:
label_train_2_zero_shot_similarities = (torch.nn.functional.normalize(zero_shot_embeddings)@torch.nn.functional.normalize(train_embeddings).T)
zero_shot_label_mapping = {zero_shot_embeddings_idx['id'].iloc[zero_shot_label_idx]:train_embeddings_idx['id'].iloc[train_label_idx.item()] for zero_shot_label_idx,train_label_idx in enumerate(label_train_2_zero_shot_similarities.max(dim=-1).indices)}


### Sanity check
Print zero shot label descriptions with their mapped training label description. Keys should be similar to their values

In [7]:
zero_shot_label_descriptions_mapping={embeddings_idx.loc[embeddings_idx['id']==k,'description'].iloc[0]:embeddings_idx.loc[embeddings_idx['id']==v,'description'].iloc[0] for k,v in zero_shot_label_mapping.items()}

#Print only 10 randomly selected key-value pairs
sampled_keys =np.random.choice(list(zero_shot_label_descriptions_mapping.keys()),size=10,replace=False)
{k:v for k,v in zero_shot_label_descriptions_mapping.items() if k in sampled_keys}

{"Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: D-arabinose 1-dehydrogenase (NADP+) activity": "Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: D-arabinose 1-dehydrogenase (NAD+) activity",
 "Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: acetylgalactosaminyl-O-glycosyl-seryl-glycoprotein beta-1,6-N-acetylglucosaminyltransferase activity": "Instruct: Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function\nQuery: acetylgalactosaminyl-O-glycosyl-glycoprotein beta-1,6-N-acetylglucosaminyltransferase activity",
 "Instruct: Identify the main categories, themes,

### Create the zero shot proteinfer-based baseline prediction df

In [8]:
zero_shot_pinf_baseline_logits = zero_shot_pinf_logits[[zero_shot_label_mapping[i] for i in zero_shot_labels.columns]]
zero_shot_pinf_baseline_logits.columns = zero_shot_labels.columns

In [9]:
zero_shot_pinf_baseline_logits.shape

(52874, 596)

### Measure baseline performance

In [10]:
from pprint import pprint
eval_metrics = EvalMetrics(device='cuda')
mAP_micro = BinaryAUPRC(device='cpu')
mAP_macro = MultilabelAUPRC(device='cpu',num_labels=zero_shot_labels.shape[-1])
metrics = eval_metrics\
        .get_metric_collection_with_regex(pattern='f1_m.*',
                            threshold=0.5,
                            num_labels=zero_shot_labels.shape[-1]
                            )

metrics(torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values,device='cuda')),
              torch.tensor(zero_shot_labels.values,device='cuda'))
mAP_micro.update(torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values)).flatten(),
                               torch.tensor(zero_shot_labels.values).flatten())
mAP_macro.update(torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values)),
                 torch.tensor(zero_shot_labels.values))


metrics = metrics.compute()
metrics.update({
                "map_micro":mAP_micro.compute(),
                "map_macro":mAP_macro.compute()
                })
metrics = {k:v.item() for k,v in metrics.items()}
pprint(metrics)



{'f1_macro': 0.10838527977466583,
 'f1_micro': 0.08086512237787247,
 'map_macro': 0.12346171587705612,
 'map_micro': 0.01908101886510849}
