In [26]:
import sys, os, torch
sys.path.append("..")  # so we can import src.*

from src.config_utils import load_config
from src.train import build_trainer, compute_metrics
from src.data import ClinVarPairedDataset, build_or_load_paired_dataframe, prepare_splits
from src.model import build_tokenizer, SiameseNTClassifier


In [2]:
from safetensors.torch import load_file 

In [24]:
## CENTRAL EMBEDDING
config_central = '../config/clinvar_nt-middle.yaml'
model_central = './clinvar_outs_model2/embed_central/checkpoint-14079'

## MEAN EMBEDDING
config_mean = '../config/clinvar_nt-mean.yaml'
model_mean = './clinvar_outs_model2/embed_mean/checkpoint-18772'


##CLS EMBEDDING
config_cls = '../config/clinvar_nt.yaml'
model_cls = './clinvar_outs_model1_cls/embed_cls/checkpoint-56310'


In [4]:
cfg1 = load_config(config_central)

In [5]:
# 1) DataFrame with all variants
df_all = build_or_load_paired_dataframe(cfg1)

# 2) Splits + class weights
train_df, val_df, test_df, class_weights = prepare_splits(cfg1, df_all)
# 3) Tokenizer + datasets
tokenizer = build_tokenizer(cfg1)
max_tokens = cfg1["data"]["max_tokens"]

Loading existing paired CSV from ../../clinvar_data/processed_data/clinvar_paired_gencode.csv


  df = pd.read_csv(paired_csv)


In [6]:
test_ds  = ClinVarPairedDataset(test_df, tokenizer, max_tokens)


In [7]:
len(test_ds)

14299

## Make a function for evaluation of each models

In [8]:
def evaluate(config_path, checkpoint_dir, test_ds):
    """
    Load config, rebuild Trainer + test set, load checkpoint weights
    from either model.safetensors or pytorch_model.bin, and evaluate
    on the test split defined in the config.
    """
    # # 1. Build trainer + datasets from config
    cfg = load_config(config_path)
    trainer, _, run_dir = build_trainer(cfg)

    print("Trainer run_dir (from config):", run_dir)
    # print("Test dataset size:", len(test_ds))

    # 2. Figure out which weight file exists
    safetensors_path = os.path.join(checkpoint_dir, "model.safetensors")
    
    ckpt_path = safetensors_path
    print("Loading weights from:", ckpt_path, "(safetensors)")
    # Load on CPU; PyTorch will move to model's device in load_state_dict
    state_dict = load_file(ckpt_path)  

    # 3. see if there is any missing or unexpeted keys(debug)
    missing, unexpected = trainer.model.load_state_dict(state_dict, strict=False)
    print("Missing keys:", missing)
    print("Unexpected keys:", unexpected)

    # 4. Evaluate on test
    test_metrics = trainer.evaluate(test_ds)
    print("Test metrics:", test_metrics)
    return test_metrics

In [9]:
test_metrics_central = evaluate(config_central, model_central,test_ds)

Loading existing paired CSV from ../../clinvar_data/processed_data/clinvar_paired_gencode.csv


  df = pd.read_csv(paired_csv)


Subsampling to 40000 per class
label
benign        40000
pathogenic    40000
Name: count, dtype: int64


  .apply(lambda x: x.sample(
Some weights of EsmModel were not initialized from the model checkpoint at InstaDeepAI/nucleotide-transformer-500m-human-ref and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[LoRA] enabled with r=4, alpha=8, dropout=0.05
[LoRA] Target modules: attention.self.query, attention.self.key, attention.self.value, attention.output.dense
[LoRA] [LoRA] Using PEFT print_trainable_parameters():
trainable params: 983,040 || all params: 481,421,281 || trainable%: 0.2042


  trainer = Trainer(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Trainer run_dir (from config): ./clinvar_outs_model2/embed_central
Loading weights from: ./clinvar_outs_model2/embed_central/checkpoint-14079/model.safetensors (safetensors)
Missing keys: []
Unexpected keys: []


Test metrics: {'eval_loss': 0.5771441459655762, 'eval_model_preparation_time': 0.0149, 'eval_auroc': 0.7838132313221281, 'eval_auprc': 0.7726857069635815, 'eval_accuracy': 0.7044548569830058, 'eval_f1': 0.7168319485392656, 'eval_precision': 0.7021527960094514, 'eval_recall': 0.732137968792773, 'eval_runtime': 107.6625, 'eval_samples_per_second': 132.813, 'eval_steps_per_second': 8.304}


In [13]:
test_metrics_mean = evaluate(config_mean, model_mean,test_ds)

Loading existing paired CSV from ../../clinvar_data/processed_data/clinvar_paired_gencode.csv


  df = pd.read_csv(paired_csv)


Subsampling to 40000 per class
label
benign        40000
pathogenic    40000
Name: count, dtype: int64


  .apply(lambda x: x.sample(
Some weights of EsmModel were not initialized from the model checkpoint at InstaDeepAI/nucleotide-transformer-500m-human-ref and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[LoRA] enabled with r=4, alpha=8, dropout=0.05
[LoRA] Target modules: attention.self.query, attention.self.key, attention.self.value, attention.output.dense
[LoRA] [LoRA] Using PEFT print_trainable_parameters():
trainable params: 983,040 || all params: 481,421,281 || trainable%: 0.2042


  trainer = Trainer(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Trainer run_dir (from config): ./clinvar_outs_model2/embed_mean
Loading weights from: ./clinvar_outs_model2/embed_mean/checkpoint-18772/model.safetensors (safetensors)
Missing keys: []
Unexpected keys: []


Test metrics: {'eval_loss': 0.6880917549133301, 'eval_model_preparation_time': 0.0143, 'eval_auroc': 0.74986238242466, 'eval_auprc': 0.7324545794440942, 'eval_accuracy': 0.6700468564235261, 'eval_f1': 0.6548141644717589, 'eval_precision': 0.7033951587551085, 'eval_recall': 0.6125102655351765, 'eval_runtime': 106.9353, 'eval_samples_per_second': 133.716, 'eval_steps_per_second': 8.36}


In [25]:
test_metrics_cls = evaluate(config_cls, model_cls,test_ds)

Loading existing paired CSV from ../../clinvar_data/processed_data/clinvar_paired_gencode.csv


  df = pd.read_csv(paired_csv)


Subsampling to 40000 per class
label
benign        40000
pathogenic    40000
Name: count, dtype: int64


  .apply(lambda x: x.sample(
Some weights of EsmModel were not initialized from the model checkpoint at InstaDeepAI/nucleotide-transformer-500m-human-ref and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[LoRA] enabled with r=4, alpha=8, dropout=0.04
[LoRA] Target modules: attention.self.query, attention.self.key, attention.self.value, attention.output.dense
[LoRA] [LoRA] Using PEFT print_trainable_parameters():
trainable params: 983,040 || all params: 481,421,281 || trainable%: 0.2042


  trainer = Trainer(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Trainer run_dir (from config): ./clinvar_outs_model1_cls/embed_cls
Loading weights from: ./clinvar_outs_model1_cls/embed_cls/checkpoint-56310/model.safetensors (safetensors)
Missing keys: []
Unexpected keys: []


Test metrics: {'eval_loss': 0.693402111530304, 'eval_model_preparation_time': 0.0142, 'eval_auroc': 0.5, 'eval_auprc': 0.5109448213161759, 'eval_accuracy': 0.48905517868382403, 'eval_f1': 0.0, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_runtime': 314.2924, 'eval_samples_per_second': 45.496, 'eval_steps_per_second': 11.375}
