In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

from logging import getLogger

import pandas as pd
import numpy as np
import torch

from torch.utils.data import DataLoader, Dataset
from transformers import HfArgumentParser


from txplm.evaluate.framework.core import run_evaluation
from txplm.evaluate.framework.args import EvalArgs


from txplm.training.training_args_IT import (
    DataArgs,
    ModelArgs,
    postprocess_args,
)

from txplm.data.data_utils import (
    DATA_DIR,
    HOME_DIR,
)


  from .autonotebook import tqdm as notebook_tqdm
2024-04-17 23:30:36.532817: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-17 23:30:36.614676: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[2024-04-17 23:30:45,091] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
#checkpoint_dir = "/om2/vast/kellislab/shared/PLM/model_outputs/pretrain/2024-01-17_12:36_LLAMA2_z2_all/checkpoint-87500"
checkpoint_dir = "/om2/vast/kellislab/shared/PLM/model_outputs/pretrain/2024-03-25_04:40_txllm_ALL_split_v1_R9/checkpoint-3440000-OM"
test_dir = os.getcwd()

In [None]:
subset.to_pickle(os.path.join(test_dir, "test_protein_subset.pkl"))

# Output config files

In [5]:
data_yml = f"""it_datasets:
  testing:
    # DrugBank:
    - aaseq_type: protein
      text_type: drugbank
      relations: [drug_carrier]
      tasks: [caption]
      splits: [eval_zero_shot, eval_pt_ft]"""

with open(os.path.join(test_dir, "tmp_eval_dataset_config.yml"), "w") as fh:
    fh.write(data_yml)

In [6]:
models_yml = f"""models:
    - model_name: TxPLM
      args:
        checkpoint_dir: {checkpoint_dir}
    - model_name: UniformRandom
      args:
        sample_from: full_dataset
    - model_name: WeightedRandom
      args:
        sample_from: split"""

with open(os.path.join(test_dir, "tmp_eval_model_config.yml"), "w") as fh:
    fh.write(models_yml)

In [7]:
example_config = f"""# --------------------------------------   EVAL ARGUMENTS   --------------------------------------
# Data config
it_data_config_yml: {os.path.join(test_dir, "tmp_eval_dataset_config.yml")}
models_config_yml: {os.path.join(test_dir, "tmp_eval_model_config.yml")}

retrieval_use_cached_target_embeddings: False
retrieval_eval_all_proteins: False
retrieval_top_k_vals: [10, 20, 100]
retrieval_balanced_metrics_num_samples: 5
retrieval_balanced_metrics_neg_per_pos: 10

batch_size: 4

qa_num_samples: 5

filter_training_pairs: False
model_args_from_checkpoint: {checkpoint_dir}
data_args_from_checkpoint: {checkpoint_dir}

output_dir: {test_dir}

use_cached_results: False

# --------------------------------------   DATA ARGUMENTS   --------------------------------------
# General:
use_caption: False

# Splitting:
go_split_method: "sample_aware_ontology_go_centric"
val_split_type: "pt_ft"

# Dataset-specific attributes:
go_def_col: "standard"

# Negative sampling:
num_neg_samples_qa: 1
negative_sampling_strategy_qa: 'aaseq_only'
negative_sampling_strategy_retrieval: 'in_batch'
# --------------------------------------  MODEL ARGUMENTS   --------------------------------------
protein_encoder_num_params: '35m'
freeze_protein_encoder: "all"
use_aaseq_embeddings: False
freeze_aaseq_embeddings: False

# Text encoder:
use_text_embeddings: False
freeze_text_embeddings: False
text_encoder_fname: "biogpt"
max_text_len: 512
#freeze_text_encoder: "all"

# Modeling-specific:
ret_token_access: "last"
train_qa_full_lm: False
train_retrieval_lm: False
roll_num: 1"""

with open(os.path.join(test_dir, "tmp_eval_config.yml"), "w") as fh:
    fh.write(example_config)

In [8]:
parser = HfArgumentParser((EvalArgs, DataArgs, ModelArgs))
# train_args, data_args, model_args = parser.parse_args_into_dataclasses()

eval_args, data_args, model_args = parser.parse_yaml_file("./tmp_eval_config.yml")
_, data_args, model_args = postprocess_args(None, data_args, model_args)

# Scratch space for debugging

In [9]:
from txplm.evaluate.framework.utils import (
    compare_and_warn_model_args,
    load_and_validate_model_args,
    load_datasets_for_eval,
    move_inputs_to_device,
)
from txplm.evaluate.framework.retrieval import (
    calc_retrieval_metrics,
    get_retrieval_target_set,
    get_retrieval_target_proteins_loader,
    prep_for_retrieval_eval,
)

from txplm.training.training_args_IT import (
    update_data_args_data_dir,
    update_model_args_data_dir
)

In [10]:
from txplm.evaluate.eval_utils import precision_recall_topk

In [11]:
# Check if we want to override ModelArgs using a TxPLM checkpoint
if eval_args.model_args_from_checkpoint != "":
    checkpoint_dir = eval_args.model_args_from_checkpoint
    print(f"Loading ModelArgs from TxPLM checkpoint: {checkpoint_dir}")
    model_args = torch.load(os.path.join(checkpoint_dir, "model_args.pt"))


if eval_args.data_args_from_checkpoint != "":
    checkpoint_dir = eval_args.data_args_from_checkpoint
    print(f"Loading DataArgs from TxPLM checkpoint: {checkpoint_dir}")
    loaded_data_args = torch.load(os.path.join(checkpoint_dir, "data_args.pt"))
    update_data_args_data_dir(loaded_data_args)

    # Prefer to use the data config specified in data_args passed into this function
    # over one specified in the serialized data config.
    if data_args.it_data_config_yml is not None:
        loaded_data_args.it_data_config_yml = data_args.it_data_config_yml
    data_args = loaded_data_args


# Parse model specifications.
models = load_and_validate_model_args(eval_args.models_config_yml)

# Load datasets
datasets, collators, dataset_eval_args = load_datasets_for_eval(
    data_args,
    model_args,
    eval_args.separate_splits,
)
for task, train_datasets in datasets["train"].items():
    if len(train_datasets) != 0:
        print(
            f"Received training datasets for task {task}, will not be used for evaluation (check data config)"
        )
# Package datasets and collators into data loaders.
data_loaders = {}
datasets = datasets["testing"]
collators = collators["testing"]
for task, task_datasets in datasets.items():
    task_loaders = {}
    for dataset_key, dataset in task_datasets.items():
        print(f"task: {task} dataset: {dataset_key} N: {len(dataset)}")
        task_loaders[dataset_key] = DataLoader(
            dataset,
            batch_size=eval_args.batch_size,
            collate_fn=collators[task][dataset_key],
            num_workers=eval_args.num_workers,
            pin_memory=True,
            drop_last=False,
        )
    data_loaders[task] = task_loaders

Loading ModelArgs from TxPLM checkpoint: /om2/vast/kellislab/shared/PLM/model_outputs/pretrain/2024-03-25_04:40_txllm_ALL_split_v1_R9/checkpoint-3440000-OM
Loading DataArgs from TxPLM checkpoint: /om2/vast/kellislab/shared/PLM/model_outputs/pretrain/2024-03-25_04:40_txllm_ALL_split_v1_R9/checkpoint-3440000-OM
updating data args DATA_DIR from /n/holystore01/LABS/mzitnik_lab/Lab/PLM/ -> /om2/vast/kellislab/shared/PLM/
task: caption dataset: protein_drugbank_drug_carrier_eval_zero_shot N: 13
task: caption dataset: protein_drugbank_drug_carrier_eval_pt_ft N: 11


In [12]:
data_loader = data_loaders["caption"]["protein_drugbank_drug_carrier_eval_zero_shot"]
this_dataset_eval_args = dataset_eval_args["protein_drugbank_drug_carrier_eval_zero_shot"]

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'

### Caption

In [14]:
from txplm.evaluate.framework.txplm import TxPLMCaptionEval

In [16]:
model = TxPLMCaptionEval(models["TxPLM"], eval_args, model_args, device)

updating model args DATA_DIR from /n/holystore01/LABS/mzitnik_lab/Lab/PLM -> /om2/vast/kellislab/shared/PLM/
updating stale DATA_DIR for model arg: go_embeddings_path
updating stale DATA_DIR for model arg: pfam_embeddings_path
updating stale DATA_DIR for model arg: drugbank_embeddings_path
updating stale DATA_DIR for model arg: reactome_embeddings_path
updating stale DATA_DIR for model arg: omim_embeddings_path
updating stale DATA_DIR for model arg: ec_embeddings_path
updating stale DATA_DIR for model arg: protein_seq_embeddings_path
updating stale DATA_DIR for model arg: protein_struct_embeddings_path
updating stale DATA_DIR for model arg: protein_embeddings_idmap_path
updating stale DATA_DIR for model arg: drug_struct_embeddings_path
updating stale DATA_DIR for model arg: domain_embeddings_path
updating stale DATA_DIR for model arg: domain_embeddings_idmap_path
updating stale DATA_DIR for model arg: mouse_ortholog_embeddings_path
updating stale DATA_DIR for model arg: mouse_ortholog_

Loading checkpoint shards: 100%|██████████| 2/2 [02:34<00:00, 77.25s/it] 
Using sep_token, but it is not set yet.
Using pad_token, but it is not set yet.


True
embed_tokens.weight False
layers.0.self_attn.k_proj.weight False
layers.0.self_attn.k_proj.bias False
layers.0.self_attn.v_proj.weight False
layers.0.self_attn.v_proj.bias False
layers.0.self_attn.q_proj.weight False
layers.0.self_attn.q_proj.bias False
layers.0.self_attn.q_proj_lora_d.weight True
layers.0.self_attn.q_proj_lora_u.weight True
layers.0.self_attn.v_proj_lora_d.weight True
layers.0.self_attn.v_proj_lora_u.weight True
layers.0.self_attn.out_proj.weight False
layers.0.self_attn.out_proj.bias False
layers.0.self_attn_layer_norm.weight False
layers.0.self_attn_layer_norm.bias False
layers.0.fc1.weight False
layers.0.fc1.bias False
layers.0.fc2.weight False
layers.0.fc2.bias False
layers.0.final_layer_norm.weight False
layers.0.final_layer_norm.bias False
layers.1.self_attn.k_proj.weight False
layers.1.self_attn.k_proj.bias False
layers.1.self_attn.v_proj.weight False
layers.1.self_attn.v_proj.bias False
layers.1.self_attn.q_proj.weight False
layers.1.self_attn.q_proj.bias

In [17]:
from txplm.evaluate.framework.caption import run_caption_eval

In [21]:
res = model.get_predictions(data_loader)

100%|██████████| 4/4 [00:34<00:00,  8.74s/it]


In [22]:
res

Unnamed: 0,seq_id,generated_caption
0,557,Level: High\nThat part of a multic 3 generatio...
1,1971,Level: High\nThat part of a multic 3 generatio...
2,15740,Level: High\nThat part of a chromosome are req...
3,66,Level: High\nThat part of a multic 3 microtubu...
4,67,Level: High\nA protein complex that possesses ...
5,17174,Name: Level: High\nThat part of a multicicular...
6,12497,Name: Level: High\nA protein complex that poss...
7,12757,Level: High\nA protein complex involved in DNA...
8,16602,Level: High\nThat part of a multic 37 macroM: ...
9,16350,Level: High\nThat part of a multic 37 macroph ...


In [27]:
res.generated_caption.iat[3]

'Level: High\nThat part of a multic 3 microtubule in experimental, di, di, bond bond bond bond no</s>approximately 第第第第第第High bond bond bond bond bond kid第第第第第зькозькозько第요第 фон фон фон фон фон фон фон第第'

In [28]:
model.method = "sampling"

In [29]:
res_sampled = model.get_predictions(data_loader)

100%|██████████| 4/4 [00:09<00:00,  2.42s/it]


In [35]:
res_sampled.generated_caption.iat[3]

'Level: High\nA Someiveness partners CLAS exhibiting processes characterized itioniveness partners A diast. yes</s> fi DOM ske 6 gene on the no yes</s> зько yes</s> ., kidney isd families G a no rejo bond yes</s> yes</s>Any process that affects chemotrosinaryior'