In [9]:
import os
from pathlib import Path
import sys
curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))
import torch
import numpy as np
import os
from torcheval.metrics import MultilabelAUPRC, BinaryAUPRC
import pandas as pd
from src.utils.evaluation import EvalMetrics
from src.utils.data import generate_vocabularies

In [2]:
zero_shot_pinf_logits = pd.read_parquet('../outputs/results/test_logits_proteinfer.parquet')
zero_shot_labels = pd.read_parquet('../outputs/results/test_labels_proteinfer.parquet')
embeddings = torch.load('../data/embeddings/frozen_BioGPT_label_embeddings_mean.pt')
embeddings_idx = torch.load('../data/embeddings/frozen_BioGPT_label_embeddings_mean_index.pt')
vocabularies = generate_vocabularies(file_path = '../data/swissprot/proteinfer_splits/random/full_GO.fasta')



In [3]:
zero_shot_pinf_logits.columns = vocabularies['GO_label_vocab']

In [4]:
embedding_mask = embeddings_idx['description_type']=='name'
embeddings_idx = embeddings_idx[embedding_mask].reset_index(drop=True)
embeddings = embeddings[embedding_mask]

train_embeddings_mask = embeddings_idx['id'].isin(vocabularies['GO_label_vocab'])
train_embeddings_idx = embeddings_idx[train_embeddings_mask].reset_index(drop=True)
train_embeddings = embeddings[train_embeddings_mask]

zero_shot_embeddings_mask = embeddings_idx['id'].isin(zero_shot_labels.columns)
zero_shot_embeddings_idx = embeddings_idx[zero_shot_embeddings_mask].reset_index(drop=True)
zero_shot_embeddings = embeddings[zero_shot_embeddings_mask]

In [5]:
label_train_2_zero_shot_similarities = (torch.nn.functional.normalize(zero_shot_embeddings)@torch.nn.functional.normalize(train_embeddings).T)
zero_shot_label_mapping = {zero_shot_embeddings_idx['id'].iloc[zero_shot_label_idx]:train_embeddings_idx['id'].iloc[train_label_idx.item()] for zero_shot_label_idx,train_label_idx in enumerate(label_train_2_zero_shot_similarities.max(dim=-1).indices)}


### Sanity check
Print zero shot label descriptions with their mapped training label description. Keys should be similar to their values

In [6]:
zero_shot_label_descriptions_mapping={embeddings_idx.loc[embeddings_idx['id']==k,'description'].iloc[0]:embeddings_idx.loc[embeddings_idx['id']==v,'description'].iloc[0] for k,v in zero_shot_label_mapping.items()}

#Print only 10 randomly selected key-value pairs
sampled_keys =np.random.choice(list(zero_shot_label_descriptions_mapping.keys()),size=10,replace=False)
{k:v for k,v in zero_shot_label_descriptions_mapping.items() if k in sampled_keys}

{'zinc ion sensor activity': 'copper ion sensor activity',
 'chloride:bicarbonate antiporter activity': 'potassium:chloride symporter activity',
 'chromosome, telomeric repeat region': 'chromosome, telomeric region',
 '4-hydroxybenzoate 3-monooxygenase [NADH] activity': '4-hydroxybenzoate 3-monooxygenase activity',
 'ADP-D-ribose modification-dependent protein binding': 'ADP-D-ribose binding',
 'ATP-dependent chromatin remodeler activity': 'ATP-dependent chromatin remodeling',
 'acquisition of seed longevity': 'acquisition of desiccation tolerance',
 'ferric-chelate reductase (NADH) activity': 'ferric-chelate reductase (NADPH) activity',
 'intracellular auxin homeostasis': 'intracellular auxin transport',
 'patulin biosynthetic process': 'aflatoxin biosynthetic process'}

### Create the zero shot proteinfer-based baseline prediction df

In [7]:
zero_shot_pinf_baseline_logits = zero_shot_pinf_logits[[zero_shot_label_mapping[i] for i in zero_shot_labels.columns]]
zero_shot_pinf_baseline_logits.columns = zero_shot_labels.columns
zero_shot_pinf_baseline_logits

Unnamed: 0,GO:0000514,GO:0000515,GO:0062136,GO:0062142,GO:0062143,GO:0062144,GO:0062145,GO:0062146,GO:0062147,GO:0062148,...,GO:0160044,GO:0160046,GO:0160047,GO:0160049,GO:0160055,GO:0160063,GO:0160064,GO:0160076,GO:0180000,GO:0180015
G2X4G0,-47.541676,-35.941223,-61.878963,-44.495914,-58.970070,-57.377369,-42.836117,-44.455158,-39.196896,-57.509708,...,-53.409294,-66.129837,-63.591774,-60.912781,-55.585606,-62.763256,-56.075291,-67.065620,-77.471016,-45.570148
A0A2H5AIX5,-28.039135,-26.346807,-45.581715,-25.986626,-25.285727,-29.607149,-26.832134,-25.457111,-23.234884,-23.628864,...,-27.532940,-44.751705,-32.070839,-39.413742,-28.595085,-46.748264,-42.851353,-48.469982,-48.511219,-28.387714
I3PB37,-23.530231,-22.563887,-38.289078,-23.179447,-24.087420,-24.551723,-26.981846,-22.930792,-16.681442,-21.751955,...,-26.015902,-37.000771,-26.627262,-37.092922,-27.264112,-43.996407,-33.706532,-33.973145,-42.105080,-25.982750
A0A2H5AIY4,-31.483477,-22.927664,-49.055218,-27.737114,-26.792818,-29.667669,-29.294580,-26.035662,-23.410570,-28.065226,...,-29.459387,-41.298523,-34.491634,-41.447037,-31.212029,-45.291801,-42.489094,-52.149776,-40.589699,-24.922121
P52307,-24.403194,-23.693031,-39.221725,-27.375862,-29.780519,-27.875626,-21.933073,-28.424910,-21.717852,-22.292011,...,-25.500454,-29.814806,-28.343555,-31.969307,-30.354818,-33.852108,-28.755169,-36.302696,-40.499111,-31.416523
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
P36686,-43.435757,-30.996593,-57.310955,-36.687969,-37.442657,-40.926582,-33.965706,-41.469944,-31.772867,-28.116003,...,-38.182838,-44.858574,-47.884708,-45.461933,-51.475319,-47.955284,-41.089451,-63.447369,-60.083084,-36.696529
G5EF53,-30.021080,-18.978807,-36.144989,-16.941919,-27.434803,-22.078762,-23.249517,-31.551897,-21.085211,-20.889723,...,-23.596409,-27.621756,-19.742397,-28.200283,-20.976076,-27.850149,-29.412155,-29.050787,-34.520931,-25.870638
P16153,-23.835443,-28.946733,-31.913568,-23.731842,-34.245251,-33.641441,-24.805691,-31.442871,-32.797462,-31.195648,...,-27.139637,-34.582054,-35.583694,-32.547447,-43.178951,-35.115894,-34.893875,-45.132298,-47.130737,-29.842566
P50157,-26.026873,-23.293116,-28.085409,-24.480492,-26.680840,-22.168390,-20.396111,-23.725313,-24.015881,-24.891655,...,-23.710913,-33.456360,-21.716194,-34.093754,-30.613405,-27.027147,-21.244709,-25.739750,-28.145428,-26.247005


### Measure baseline performance

In [14]:
from pprint import pprint
eval_metrics = EvalMetrics(device='cuda')
mAP_micro = BinaryAUPRC(device='cpu')
mAP_macro = MultilabelAUPRC(device='cpu',num_labels=zero_shot_labels.shape[-1])
metrics = eval_metrics\
        .get_metric_collection_with_regex(pattern='f1_m.*',
                            threshold=0.5,
                            num_labels=zero_shot_labels.shape[-1]
                            )

metrics(torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values,device='cuda')),
              torch.tensor(zero_shot_labels.values,device='cuda'))
mAP_micro.update(torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values)).flatten(),
                               torch.tensor(zero_shot_labels.values).flatten())
mAP_macro.update(torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values)),
                 torch.tensor(zero_shot_labels.values))


metrics = metrics.compute()
metrics.update({
                "map_micro":mAP_micro.compute(),
                "map_macro":mAP_macro.compute()
                })
metrics = {k:v.item() for k,v in metrics.items()}
pprint(metrics)



{'f1_macro': 0.037531327456235886,
 'f1_micro': 0.05510534718632698,
 'map_macro': 0.10055525600910187,
 'map_micro': 0.013812655583024025}


Bad pipe message: %s [b"j\xbaDpG\xbc\x90-q`\xf9\xea\xab\xc0\xe0,\xd2c\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003"]
Bad pipe message: %s [b"\xa5\x1f\x03\x80)`|\xba\x17\xc7@\xbe\x12]\xf76;\t\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0s\xc0w\x00\xc4\x00\xc3\xc0#\xc0'\x00g\x00@\xc0r\xc0v\x00\xbe\x00\xbd\xc0\n\xc0\x14\x009\x008\x00\x88\x00\x87\xc0\t\xc0\x13\x003\x002\x00\x9a\x00\x99\x00E\x00D\xc0\x07\xc0\x11\xc0\x08\xc0\x12\x00\x16\x00\x13\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c"]
Bad pipe message: %s [b'\x03=\xd5#:\x8c\x99O\xde\x90\x97\x1er~\x8dF\xf7i\x00\x00>\xc0\x14\xc0\n\x009\x0