In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os

from logging import getLogger

import pandas as pd
import numpy as np
import torch

from evaluate import load
from torch.utils.data import DataLoader, Dataset
from transformers import HfArgumentParser


from txplm.evaluate.framework.core import run_evaluation
from txplm.evaluate.framework.args import EvalArgs

from txplm.data.dataset import (
    AASeqDataset,
    AASeqTextUnifiedDataset,
)

from txplm.training.training_args_IT import (
    DataArgs,
    ModelArgs,
    postprocess_args,
)

from txplm.data.data_utils import (
    DATA_DIR,
    HOME_DIR,
)


In [4]:
checkpoint_dir = "/om2/vast/kellislab/shared/PLM/model_outputs/pretrain/2024-01-17_12:36_LLAMA2_z2_all/checkpoint-87500"

# test_dir will be used to write config files and outputs from evaluations (cached embeddings, metrics CSV, plots)
test_dir = "/your/favorite/test_dir"

inputs_dir = os.path.join(test_dir, "inputs")
outputs_dir = os.path.join(test_dir, "inputs")

os.makedirs(inputs_dir, exist_ok=True)

# Make configs

## Create protein subset for testing

In [None]:
ALL_PROTEINS_FILE = os.path.join(
    DATA_DIR, "integrated_data/v1/protein/protein_info_filtered.pkl"
)

In [None]:
all_proteins = pd.read_pickle(ALL_PROTEINS_FILE)
subset = all_proteins.sample(n=100)
subset.head()

Unnamed: 0,index,protein_id,name,entry,comments
8699,8699,P20916,MAG,MAG_HUMAN,[FUNCTION: Adhesion molecule that mediates int...
2887,2887,Q9UGN4,CD300A,CLM8_HUMAN,[FUNCTION: Inhibitory receptor which may contr...
7867,7867,A6PVL3,KNCN,KNCN_HUMAN,[FUNCTION: May play a role in stabilizing dens...
515,515,O75969,AKAP3,AKAP3_HUMAN,[FUNCTION: May function as a regulator of both...
3256,3256,P04632,CAPNS1,CPNS1_HUMAN,[FUNCTION: Regulatory subunit of the calcium-r...


In [None]:
subset.to_pickle(os.path.join(inputs_dir, "test_protein_subset.pkl"))

## Output config files

In [None]:
# One of ["test", "protst", "blast_ppi", "multitask"]
#  "test": run with only TxPLM text -> protein retrieval evaluation
#  "protst": run with TxPLM and ProtST for text -> protein retrieval evaluation
#  "blast_ppi": run with TxPLM and BLAST for protein -> protein retrieval evaluation
#  "multitask": run with TxPLM only, evaluating across all tasks
run_type = "test"

In [None]:
if run_type == "test":
  data_yml = f"""it_datasets:
    train:
      # GO:
      - aaseq_type: protein
        text_type: go
        relations: [process, function, component]
        tasks: [qa, caption, retrieval]
        splits: [CL_train]
    testing:
      # DrugBank:
      - aaseq_type: protein
        text_type: drugbank
        relations: [drug_carrier]
        tasks: [retrieval]
        splits: [CL_train]
      - aaseq_type: protein
        text_type: drugbank
        relations: [drug_carrier]
        tasks: [retrieval]
        splits: [CL_train]
        key_suffix: subset
        dataset_args:
          num_instruction_examples: 3
          text_variant_type: moa_only
        eval_args:
          target_subset: {os.path.join(inputs_dir, 'test_protein_subset.pkl')}
      - aaseq_type: protein
        text_type: drugbank
        relations: [drug_transporter]
        tasks: [retrieval]
        splits: [CL_train]
        dataset_args:
          num_instruction_examples: 3
          text_variant_type: moa_only"""
elif run_type == "blast_ppi":
  data_yml = """it_datasets:
    testing:
      - aaseq_type: protein
        text_type: protein
        relations: [homology]
        tasks: [retrieval]
        splits: [CL_train]
        dataset_args:
          ppi_store_reverse_edges: True"""
elif run_type == "protst":
  data_yml = f"""it_datasets:
  testing:
    # DrugBank:
    - aaseq_type: protein
      text_type: drugbank
      relations: [drug_carrier]
      tasks: [retrieval]
      splits: [eval_zero_shot, eval_pt_ft]"""
elif run_type == "multitask":
  data_yml = f"""it_datasets:
  testing:
    # DrugBank:
    - aaseq_type: protein
      text_type: drugbank
      relations: [drug_carrier]
      tasks: [retrieval, qa, caption]
      splits: [eval_zero_shot, eval_pt_ft]"""

with open(os.path.join(inputs_dir, "tmp_eval_dataset_config.yml"), "w") as fh:
    fh.write(data_yml)

In [None]:
if run_type == "blast_ppi":
  model_yml = f"""models:
      - model_name: BLAST
        args:
          max_ev: 10
      - model_name: TxPLM
        args:
          checkpoint_dir: {checkpoint_dir}"""
elif run_type == "protst":
  model_yml = f"""models:
    - model_name: ProtST
      args:
        max_prompt_len: 128
    - model_name: TxPLM
      args:
        checkpoint_dir: {checkpoint_dir}"""
else:
  model_yml = f"""models:
      - model_name: TxPLM
        args:
          checkpoint_dir: {checkpoint_dir}"""

with open(os.path.join(out_dir, "model_config.yml"), "w") as fh:
    fh.write(model_yml)
with open(os.path.join(inputs_dir, "tmp_eval_model_config.yml"), "w") as fh:
    if blast_test:
      fh.write(models_yml_blast)
    else:
      fh.write(models_yml)

In [None]:
example_config = f"""# --------------------------------------   EVAL ARGUMENTS   --------------------------------------
# Data config
it_data_config_yml: {os.path.join(inputs_dir, "tmp_eval_dataset_config.yml")}
models_config_yml: {os.path.join(inputs_dir, "tmp_eval_model_config.yml")}

retrieval_use_cached_target_embeddings: True
retrieval_eval_all_proteins: True
retrieval_top_k_vals: [10, 20, 100]

model_args_from_checkpoint: {checkpoint_dir}
data_args_from_checkpoint: {checkpoint_dir}

output_dir: {outputs_dir}

qa_num_samples: 5

filter_training_pairs: True
model_args_from_checkpoint: {checkpoint_dir}
data_args_from_checkpoint: {checkpoint_dir}

# --------------------------------------   DATA ARGUMENTS   --------------------------------------
# General:
use_caption: False

# Splitting:
go_split_method: "sample_aware_ontology_go_centric"
val_split_type: "pt_ft"

# Dataset-specific attributes:
go_def_col: "standard"

# Negative sampling:
num_neg_samples_qa: 1
negative_sampling_strategy_qa: 'aaseq_only'
negative_sampling_strategy_retrieval: 'in_batch'
# --------------------------------------  MODEL ARGUMENTS   --------------------------------------
protein_encoder_num_params: '35m'
freeze_protein_encoder: "all"
use_aaseq_embeddings: False
freeze_aaseq_embeddings: False

# Text encoder:
use_text_embeddings: False
freeze_text_embeddings: False
text_encoder_fname: "biogpt"
max_text_len: 512
#freeze_text_encoder: "all"

# Modeling-specific:
ret_token_access: "last"
train_qa_full_lm: False
train_retrieval_lm: False
roll_num: 1"""

with open(os.path.join(inputs_dir, "tmp_eval_config.yml"), "w") as fh:
    fh.write(example_config)

# In-notebook test

In [12]:
parser = HfArgumentParser((EvalArgs, DataArgs, ModelArgs))

eval_args, data_args, model_args = parser.parse_yaml_file(os.path.join(inputs_dir, "tmp_eval_config.yml"))
_, data_args, model_args = postprocess_args(None, data_args, model_args)

In [14]:
metrics = run_evaluation(eval_args, data_args, model_args)

Loading ModelArgs from TxPLM checkpoint: /om2/vast/kellislab/shared/PLM/model_outputs/pretrain/2024-01-17_12:36_LLAMA2_z2_all/checkpoint-87500


Received training datasets for task qa, will not be used for evaluation (check data config)
Received training datasets for task caption, will not be used for evaluation (check data config)
Received training datasets for task retrieval, will not be used for evaluation (check data config)


task: retrieval dataset: protein_drugbank_drug_carrier N: 486
task: retrieval dataset: protein_drugbank_drug_carrier_subset N: 486
task: retrieval dataset: protein_drugbank_drug_transporter N: 1812
retrieval: evaluating on 3 datasets
updating model args DATA_DIR from /n/holystore01/LABS/mzitnik_lab/Lab/PLM -> /om2/vast/kellislab/shared/PLM/
updating stale DATA_DIR for model arg: go_embeddings_path
updating stale DATA_DIR for model arg: pfam_embeddings_path
updating stale DATA_DIR for model arg: drugbank_embeddings_path
updating stale DATA_DIR for model arg: reactome_embeddings_path
updating stale DATA_DIR for model arg: omim_embeddings_path
updating stale DATA_DIR for model arg: ec_embeddings_path
updating stale DATA_DIR for model arg: protein_seq_embeddings_path
updating stale DATA_DIR for model arg: protein_struct_embeddings_path
updating stale DATA_DIR for model arg: protein_embeddings_idmap_path
updating stale DATA_DIR for model arg: drug_struct_embeddings_path
updating stale DATA_

Loading checkpoint shards: 100%|██████████| 2/2 [00:33<00:00, 16.63s/it]
Using sep_token, but it is not set yet.
Using pad_token, but it is not set yet.


model.embed_tokens.weight True
model.layers.0.self_attn.q_proj.weight True
model.layers.0.self_attn.k_proj.weight True
model.layers.0.self_attn.v_proj.weight True
model.layers.0.self_attn.o_proj.weight True
model.layers.0.mlp.gate_proj.weight True
model.layers.0.mlp.up_proj.weight True
model.layers.0.mlp.down_proj.weight True
model.layers.0.input_layernorm.weight True
model.layers.0.post_attention_layernorm.weight True
model.layers.1.self_attn.q_proj.weight True
model.layers.1.self_attn.k_proj.weight True
model.layers.1.self_attn.v_proj.weight True
model.layers.1.self_attn.o_proj.weight True
model.layers.1.mlp.gate_proj.weight True
model.layers.1.mlp.up_proj.weight True
model.layers.1.mlp.down_proj.weight True
model.layers.1.input_layernorm.weight True
model.layers.1.post_attention_layernorm.weight True
model.layers.2.self_attn.q_proj.weight True
model.layers.2.self_attn.k_proj.weight True
model.layers.2.self_attn.v_proj.weight True
model.layers.2.self_attn.o_proj.weight True
model.lay

100%|██████████| 31/31 [01:46<00:00,  3.44s/it]


loading cached target embeddings


  F_max = 2 * precision * recall  / (precision + recall)


retrieval: evaluating model TxPLM on dataset protein_drugbank_drug_carrier_subset , num_targets=100


100%|██████████| 31/31 [01:41<00:00,  3.27s/it]
  F_max = 2 * precision * recall  / (precision + recall)


loading cached target embeddings
retrieval: evaluating model TxPLM on dataset protein_drugbank_drug_transporter , num_targets=18174


100%|██████████| 114/114 [06:20<00:00,  3.34s/it]


loading cached target embeddings


In [None]:
metrics

# Run via CLI

To run via CLI, navigate to `test_dir` and run the following command
```
#! /bin/bash

# Perform any setup for your environment
#  source ~/.bashrc 
#  conda activate txplm

# Run script using path to TxPLM repo
#  python /path/to/TxPLM/scripts/run_eval_framework.py --from_yaml ./tmp_eval_config.yml 2>&1 | tee log.txt
````

# Scratch space for debugging

In [20]:
from txplm.evaluate.eval_v2 import (
    load_and_validate_model_args,
    load_datasets_for_eval,
    get_target_set,
    get_retrieval_target_proteins_loader,
    prep_for_retrieval_eval,
)

In [15]:
# Check if we want to override ModelArgs using a TxPLM checkpoint
if eval_args.model_args_from_checkpoint != "":
    checkpoint_dir = eval_args.model_args_from_checkpoint
    print(f"Loading ModelArgs from TxPLM checkpoint: {checkpoint_dir}")
    model_args = torch.load(os.path.join(checkpoint_dir, "model_args.pt"))

# Parse model specifications.
models = load_and_validate_model_args(eval_args.models_config_yml)

# Load datasets
datasets, collators, dataset_eval_args = load_datasets_for_eval(data_args, model_args)
for task, train_datasets in datasets["train"].items():
    if len(train_datasets) != 0:
        logger.warning(
            f"Received training datasets for task {task}, will not be used for evaluation (check data config)"
        )
# Package datasets and collators into data loaders.
data_loaders = {}
datasets = datasets["testing"]
collators = collators["testing"]
for task, task_datasets in datasets.items():
    task_loaders = {}
    for dataset_key, dataset in task_datasets.items():
        print(f"task: {task} dataset: {dataset_key} N: {len(dataset)}")
        task_loaders[dataset_key] = DataLoader(
            dataset,
            batch_size=eval_args.batch_size,
            collate_fn=collators[task][dataset_key],
            num_workers=eval_args.num_workers,
            pin_memory=True,
            drop_last=False,
        )
    data_loaders[task] = task_loaders

Loading ModelArgs from TxPLM checkpoint: /om2/vast/kellislab/shared/PLM/model_outputs/pretrain/2024-01-17_12:36_LLAMA2_z2_all/checkpoint-87500


Received training datasets for task qa, will not be used for evaluation (check data config)
Received training datasets for task caption, will not be used for evaluation (check data config)
Received training datasets for task retrieval, will not be used for evaluation (check data config)


task: retrieval dataset: protein_drugbank_drug_carrier_subset N: 486


In [17]:
data_loader = data_loaders["retrieval"]["protein_drugbank_drug_carrier_subset"]
this_dataset_eval_args = dataset_eval_args["protein_drugbank_drug_carrier_subset"]

In [18]:
targets = get_target_set(
    data_loader.dataset,
    this_dataset_eval_args,
    eval_args,
)

In [21]:
target_loader = get_retrieval_target_proteins_loader(
    targets,
    eval_args.batch_size,
)
labels, query_order, target_order = prep_for_retrieval_eval(
    data_loader.dataset,
    targets,
)



In [23]:
model_inputs = next(iter(target_loader))