# Setup

In [2]:
#@title Dependencies and Imports
from google.colab import userdata, drive
import os
from pathlib import Path

#@title Environment Variables
github_pat = userdata.get("GITHUB_PAT")
wandb_key = userdata.get("WANDB_API_KEY")
os.environ["WANDB_API_KEY"] = wandb_key
drive.mount("/content/drive")

%load_ext autoreload
%autoreload 2
%pip install -q lightning click transformers goatools toml wget fastobo pydantic loguru obonet


#@title Clone and cd
if os.getcwd() != "/content/contempro/work":
  if not Path("/content/contempro").exists():
    !git clone https://{github_pat}@github.com/boun-tabi-lifelu/contempro.git
  %cd /content/contempro/work

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Mounted at /content/drive
Cloning into 'contempro'...
remote: Enumerating objects: 364, done.[K
remote: Counting objects: 100% (364/364), done.[K
remote: Compressing objects: 100% (228/228), done.[K
remote: Total 364 (delta 186), reused 291 (delta 116), pack-reused 0 (from 0)[K
Receiving objects: 100% (364/364), 6.61 MiB | 13.29 MiB/s, done.
Resolving deltas: 100% (186/186), done.
/content/contempro/work


In [3]:
#@title Setup data
!mkdir -p datasets
!cp /content/drive/MyDrive/research/contempro/work-pfresgo-data.zip ./datasets
!cd datasets && unzip -o work-pfresgo-data.zip

Archive:  work-pfresgo-data.zip
   creating: pfresgo/
  inflating: __MACOSX/._pfresgo      
  inflating: pfresgo/.DS_Store       
  inflating: __MACOSX/pfresgo/._.DS_Store  
  inflating: pfresgo/annot.tsv       
  inflating: __MACOSX/pfresgo/._annot.tsv  
  inflating: pfresgo/nrPDB-GO_2019.06.18_test.csv  
  inflating: __MACOSX/pfresgo/._nrPDB-GO_2019.06.18_test.csv  
  inflating: pfresgo/train.txt       
  inflating: __MACOSX/pfresgo/._train.txt  
  inflating: pfresgo/go.obo          
  inflating: __MACOSX/pfresgo/._go.obo  
  inflating: pfresgo/ontology.embeddings.npy  
  inflating: __MACOSX/pfresgo/._ontology.embeddings.npy  
  inflating: pfresgo/valid.txt       
  inflating: __MACOSX/pfresgo/._valid.txt  
  inflating: pfresgo/nrPDB-GO_2019.06.18_sequences.fasta  
  inflating: __MACOSX/pfresgo/._nrPDB-GO_2019.06.18_sequences.fasta  
  inflating: pfresgo/test.txt        
  inflating: __MACOSX/pfresgo/._test.txt  


In [4]:
!cp /content/drive/MyDrive/research/per_residue_embeddings.h5 ./datasets/pfresgo

# Evaluate


In [None]:
%run bin/evaluate.py --config_file configs/ordered_encdec_medium.toml --go_release 2024 --subontology biological_process --use_wandb

In [10]:
from typing import Literal
import torch
from model import TrainingModel, get_model_cls
from config import from_toml
from pathlib import Path
from data.datamodule import PFresGODataModule
from torch.nn.functional import sigmoid
import os

subontology = "biological_process" # @param {type:"string"}
config_file = "configs/ordered_encdec_medium.toml" #@param {type:"string"}
go_release = "2024" #@param ["2020", "2024"]
use_wandb = True #@param {type:"boolean"}

model_type = config_file.split("/")[-1].split(".")[0].replace("_", "-")
subontology_short = ''.join([word[0] for word in subontology.split("_")])
model_name = f"contempro-{subontology_short}-{go_release}-{model_type}"

config = from_toml(config_file)

data_root_dir = Path(config.train.data_dir)

def load_model(ontology: Literal["molecular_function", "biological_process", "cellular_component"]):
  subontology = ontology
  subontology_short = ''.join([word[0] for word in subontology.split("_")])

  config.train.subontology = subontology # override subontology

  model_name = f"contempro-{subontology_short}-{go_release}-{model_type}"

  if use_wandb:
    import wandb
    run = wandb.init(project="contempro", name=model_name+"-eval", job_type="eval")
    wandb.config.update(config)
    wandb.config.update({"model_name": model_name})
    wandb.config.update({"subontology": subontology})
    artifact = run.use_artifact(f"{model_name}:latest")
    os.makedirs("trained_models", exist_ok=True)
    path = artifact.download(root="trained_models")
    print(path)
    model = get_model_cls(config.model.name)(config.model)
    module = TrainingModel.load_from_checkpoint(
      "trained_models/"+artifact.files()[0].name,
      model=model,
      training_config=config.train
    )
  else:
    module = TrainingModel.load_from_checkpoint(f"trained_models/{model_name}.ckpt", model=model, training_config=config.train)

  return module

In [7]:
#@title Load Data

dm = PFresGODataModule(
  data_dir=data_root_dir,
  batch_size=32,
  num_workers=config.train.dm_num_workers,
  ontology=subontology,
  order_go_terms=config.train.order_go_terms,
  go_release=go_release,
)
dm.setup("test")


In [10]:
#@title Inference
from tqdm import tqdm

model = load_model(subontology)
model.eval()
with torch.no_grad():
    batches_list = []

    for batch in tqdm(dm.test_dataloader()):
        result = sigmoid(model({
            "embeddings": batch["embeddings"].cuda(),
            "attention_mask": batch["attention_mask"].cuda(),
            "go_embeddings": batch["go_embeddings"].cuda()
        }))
        batches_list.append(result)

    result = torch.cat(batches_list, dim=0)


torch.save(result, f"{model_name}_test_preds.pt")
if use_wandb:
    import wandb
    art = wandb.Artifact(model_name, type="predictions").add_file(f"{model_name}_test_preds.pt")
    wandb.log_artifact(art)
    wandb.finish()


100%|██████████| 107/107 [46:21<00:00, 26.00s/it]


## PFresGO Metrics

In [11]:
# Cell 1: Imports and Setup
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import pickle
from pathlib import Path
from data.datamodule import PFresGODataModule
from pfresgo_eval import Method, load_test_prots, protein_centric_aupr_curves

if use_wandb:
  wandb.init(project="contempro", name=model_name+"-eval-metrics", job_type="eval")
# Configuration
predictions_file = f"{model_name}_test_preds.pt"


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
# Cell 2: Load Data
# Initialize and setup datamodule
dm.setup("test")

# Load predictions
test_preds = torch.load(predictions_file, map_location='cpu')
print(f"Predictions shape: {test_preds.shape}")
print(f"Number of test samples: {len(dm.test_dataset)}")
print(f"Number of GO terms: {len(dm.test_dataset.go_term_list)}")

Predictions shape: torch.Size([3416, 30735])
Number of test samples: 3416
Number of GO terms: 30735


  test_preds = torch.load(predictions_file, map_location='cpu')


In [13]:
# Cell 3: Prepare Evaluation Data
all_annots = []
for batch in dm.test_dataloader():
    all_annots.append(batch["annotations"])

all_annots = torch.cat(all_annots, dim=0)

eval_data = {
    'Y_true': all_annots.cpu().numpy(),
    'Y_pred': test_preds.cpu().numpy(),
    'goterms': dm.test_dataset.go_term_list,
    'proteins': dm.test_dataset.protein_ids
}

# Save evaluation data
with open('eval_results.pckl', 'wb') as f:
    pickle.dump(eval_data, f)

In [None]:
# Cell 4: Create Method Object and Calculate Metrics
# Create evaluation method
method = Method('Contempro', 'eval_results.pckl', subontology_short)

# Load test protein indices
test_prots, seqid_mtrx = load_test_prots('./datasets/pfresgo/nrPDB-GO_2019.06.18_test.csv')
prot_idx = np.where(seqid_mtrx[:, 4] == 1)[0]

# Calculate metrics
micro_aupr, macro_aupr, _ = method._function_centric_aupr(keep_pidx=prot_idx)
auc = method.AUC(keep_pidx=prot_idx)
fmax = method.fmax(keep_pidx=prot_idx)

# Print results
results = {
    "Micro AUPR": micro_aupr,
    "Macro AUPR": macro_aupr,
    "AUC": auc,
    "Fmax": fmax
}

for metric, value in results.items():
    print(f"{metric}: {value:.3f}")

### Number of functions =1907


In [None]:
model_name

In [None]:
#@title Save metrics to JSON
import json
results["model"] = model_name
results = {k: float(v) if type(v) != str else v for k, v in results.items()}  # Convert numpy types to native Python types

with open(f"{model_name}_metrics.json", "w") as f:
    json.dump(results, f, indent=2)  # Added indent for better readability
if use_wandb:
    wandb.log(results)
    wandb.finish()

print(f"Metrics saved to {model_name}_metrics.json")

In [None]:
!cp {model_name}_metrics.json /content/drive/MyDrive/research/