In [1]:
import os
import re
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import csv
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)

from transformers import (
    AutoTokenizer,
    AutoModel,
    Trainer,
    TrainingArguments,
    set_seed,
)

from cyvcf2 import VCF
from pyfaidx import Fasta

In [2]:
from peft import LoraConfig, get_peft_model
from transformers.trainer_callback import EarlyStoppingCallback, TrainerCallback 

In [3]:
# Reproducibility
SEED = 42
set_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)


In [4]:
DATA_PATH = './clinvar_data/'

In [5]:
REF_DIR = os.path.join(DATA_PATH,'refs/GRCh38_gencode')


In [6]:
DATA_DIR = os.path.join(DATA_PATH,"processed_data")
OUT_DIR = os.path.join(DATA_PATH,'clinvar_outs')
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)

In [7]:
CLINVAR_VCF_PATH = os.path.join(DATA_PATH, "clinvar_GRCh38.vcf.gz")
GENCODE_FASTA_GZ = os.path.join(REF_DIR,"GRCh38.primary_assembly.genome.fa.gz")
GENCODE_FASTA = os.path.join(REF_DIR,"GRCh38.primary_assembly.genome.fa")

# Model configurations

In [8]:
# Task / model
MODEL_NAME   = "InstaDeepAI/nucleotide-transformer-500m-human-ref"
KMER         = 6
MAX_TOKENS   = 1000                 # 999 k-mers + CLS
SEQ_LEN      = 1004                 # → 999 k-mers for k=6 (L - k + 1)
# FLANK        = (SEQ_LEN - 1) // 2   # 
LEFT_FLANK   = SEQ_LEN // 2         # 501 on each side when L=1004 
RIGHT_FLANK  = SEQ_LEN - LEFT_FLANK - 1 # 502 bases to the left of SNV


In [9]:
POS_LABEL    = "pathogenic"
NEG_LABEL    = "benign"
USE_LABELS   = [NEG_LABEL, POS_LABEL]  # label mapping: benign->0, pathogenic->1

In [10]:
TRAIN_EPOCHS = 5
BATCH_SIZE   = 32
LR           = 5e-4

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_BF16 = torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8
USE_FP16 = (not USE_BF16) and torch.cuda.is_available()

In [12]:
device

device(type='cuda')

# Parse clinvar
### Benign/Pathogenic (+ optional Uncertain sample)

In [13]:
vcf = VCF(CLINVAR_VCF_PATH)

# clnsig = var.INFO.get("CLNSIG")

In [14]:
def parse_clinvar(vcf_path: str, sample_uncertain_k: int = 100_000):
    """Parse ClinVar VCF and bucket SNVs by clinical significance.
       keep *Pathogenic* and *Benign* **only**.
       Exclude Likely_* (both likely_pathogenic/likely_benign), Uncertain, Conflicting, and mixed cases.
    """
    vcf = VCF(vcf_path)
    benign, pathogenic, uncertain = [], [], []

    def normalize_sig(val: str) -> list[str]:
        # split on common separators and normalize
        toks = re.split(r"[\|;,/]+", val)
        out = []
        for t in toks:
            t = t.strip().lower().replace(" ", "_").replace("-", "_")
            if t:
                out.append(t)
        return out

    for var in tqdm(vcf, desc="Parsing ClinVar VCF (strict)"):
        # SNVs only
        if len(var.REF) != 1 or var.ALT is None or len(var.ALT) != 1 or any(len(a) != 1 for a in var.ALT):
            continue
        clnsig = var.INFO.get("CLNSIG")
        if clnsig is None:
            continue
        toks = set(normalize_sig(str(clnsig)))  

        # core flags
        has_path = "pathogenic" in toks
        has_ben  = "benign" in toks
        has_lpath = "likely_pathogenic" in toks
        has_lben  = "likely_benign" in toks
        has_unc   = "uncertain_significance" in toks
        has_conf  = "conflicting_interpretations_of_pathogenicity" in toks

        # strict filters — accept only pure Pathogenic or pure Benign
        if has_conf or has_unc or has_lpath or has_lben:
            continue  # strict mode drops these entirely

        # Mixed strict classes also dropped
        if has_path and has_ben:
            continue

        if has_path and not has_ben:
            pathogenic.append(var)
            continue
        if has_ben and not has_path:
            benign.append(var)
            continue
        # else: nothing to do (drop)

    print(f"Collected (STRICT): benign={len(benign)}, pathogenic={len(pathogenic)} | dropped all other labels")
    # uncertain list kept empty in strict mode; return signature unchanged
    return benign, pathogenic, uncertain

In [15]:
benign_vars, pathogenic_vars, uncertain_vars = parse_clinvar(CLINVAR_VCF_PATH)


Parsing ClinVar VCF (strict): 3683953it [00:31, 116580.67it/s]

Collected (STRICT): benign=181304, pathogenic=81598 | dropped all other labels





#  Reference genome (GENCODE) & contig name harmonization

### ClinVar GRCh38 VCF may use **RefSeq** contig names like `NC_000001.11`, whereas **GENCODE** uses Ensembl-style names `1..22, X, Y, MT`.
### The function below generates a set of candidate contig names to try when fetching sequences.

In [16]:
ref_genome = Fasta(GENCODE_FASTA)

In [17]:
ref_genome

Fasta("./clinvar_data/refs/GRCh38_gencode/GRCh38.primary_assembly.genome.fa")

In [18]:
# ref_genome.keys()

In [19]:
# vcf.seqnames[:10]

In [20]:
# Map RefSeq NC_* accessions to Ensembl-like chrom labels when possible
NC_TO_ENSEMBL_SPECIAL = {
    23: "X",
    24: "Y",
    12920: "MT",  # NC_012920.1 → mitochondrion
}


In [21]:
def contig_candidates(chrom: str) -> List[str]:
    """Return plausible contig name candidates across naming conventions.
    - Accepts inputs like 'chr1', '1', 'NC_000001.11', 'chrX', 'X', 'MT', 'chrM'
    - Returns unique candidates in priority order.
    """
    cands = []
    # As-is
    cands.append(chrom)

    # Strip/add 'chr'
    if chrom.startswith("chr"):
        cands.append(chrom[3:])
    else:
        cands.append("chr" + chrom)

    # Map NC_0000XX.yy → 1..22/X/Y/MT
    m = re.match(r"NC_(\d{6})\.(\d+)", chrom)
    if m:
        num = int(m.group(1))
        if 1 <= num <= 22:
            cands += [str(num), f"chr{num}"]
        elif num in NC_TO_ENSEMBL_SPECIAL:
            val = NC_TO_ENSEMBL_SPECIAL[num]
            cands += [val, f"chr{val}"]

    # Support chrM/chrMT → MT
    if chrom in ("chrM", "chrMT"):
        cands += ["MT"]

    # Deduplicate preserving order
    out = []
    for x in cands:
        if x not in out:
            out.append(x)
    return out

# Build paired windows (ref/alt) centered on SNVs

In [22]:
def extract_centered_ref_alt(chrom: str, pos_1based: int, ref: str, alt: str) -> Optional[Tuple[str, str]]:
   # Use LEFT_FLANK / RIGHT_FLANK for SEQ_LEN=1004
    fasta_contig = None
    for name in contig_candidates(chrom):
        if name in ref_genome:
            fasta_contig = name
            break
    if fasta_contig is None:
        return None

    # Skip unresolvable scaffolds early (NT_ / NW_ are not in primary assembly)
    if chrom.startswith(("NT_", "NW_")) and fasta_contig not in ref_genome:  
        return None

    snv0 = pos_1based - 1
    start = snv0 - LEFT_FLANK       
    end   = snv0 + RIGHT_FLANK + 1  
    if start < 0 or end > len(ref_genome[fasta_contig]):
        return None

    ref_window = ref_genome[fasta_contig][start:end].seq.upper()
    if len(ref_window) != SEQ_LEN:  # exact length check for 1004
        return None

    center_idx = LEFT_FLANK         #SNV sits here
   

    alt_window = ref_window[:center_idx] + alt.upper() + ref_window[center_idx + 1:]
    return ref_window, alt_window


In [23]:
def variants_to_dataframe(variants, label: str, max_n: Optional[int] = None) -> pd.DataFrame:
    rows = []
    it = variants if max_n is None else variants[:max_n]

    # debug counters
    skipped_no_contig = skipped_bounds = skipped_len = skipped_other = 0  

    for v in tqdm(it, desc=f"Building rows: {label}"):
        chrom, pos, ref, alts = v.CHROM, v.POS, v.REF, v.ALT
        if alts is None or len(alts) != 1:
            continue
        alt = alts[0]

        pair = extract_centered_ref_alt(chrom, pos, ref, alt)  
        if pair is None:
            fasta_contig = next((n for n in contig_candidates(chrom) if n in ref_genome), None)
            if fasta_contig is None:
                skipped_no_contig += 1  
            else:
                snv0 = pos - 1
                start = snv0 - LEFT_FLANK
                end   = snv0 + RIGHT_FLANK + 1
                if start < 0 or end > len(ref_genome[fasta_contig]):
                    skipped_bounds += 1  
                else:
                    # could be length or masked mismatch; we already turned off strict central-base
                    skipped_len += 1  # closest bucket
            continue

        ref_seq, alt_seq = pair
        if len(ref_seq) != SEQ_LEN or len(alt_seq) != SEQ_LEN:
            skipped_len += 1  
            continue

        rows.append({
            "chrom": chrom,
            "pos": pos,
            "ref": ref,
            "alt": alt,
            "label": label,
            "ref_seq": ref_seq,
            "alt_seq": alt_seq,
        })

    # print debug summary
    print(f"[{label}] kept={len(rows)} "
          f"| no_contig={skipped_no_contig} bounds={skipped_bounds} len_mismatch={skipped_len} other={skipped_other}")  
    return pd.DataFrame(rows)

In [24]:

####Sanity Test
print("Has chr1?", "chr1" in ref_genome)     # True
print("Has chrM?", "chrM" in ref_genome)     # True

v = benign_vars[0]
print("Example variant:", v.CHROM, v.POS, v.REF, v.ALT)
print("Candidates:", contig_candidates(v.CHROM))
print("Extractable?", extract_centered_ref_alt(v.CHROM, v.POS, v.REF, v.ALT[0]) is not None)


Has chr1? True
Has chrM? True
Example variant: 1 930165 G ['A']
Candidates: ['1', 'chr1']
Extractable? True


In [25]:
# Build dataframes
print("Extracting benign windows ...")
df_benign     = variants_to_dataframe(benign_vars, NEG_LABEL)
print("Extracting pathogenic windows ...")
df_pathogenic = variants_to_dataframe(pathogenic_vars, POS_LABEL)

Extracting benign windows ...


Building rows: benign: 100%|██████████| 181304/181304 [00:03<00:00, 57969.22it/s]


[benign] kept=180412 | no_contig=892 bounds=0 len_mismatch=0 other=0
Extracting pathogenic windows ...


Building rows: pathogenic: 100%|██████████| 81598/81598 [00:01<00:00, 53038.06it/s]


[pathogenic] kept=81554 | no_contig=44 bounds=0 len_mismatch=0 other=0


In [26]:
from collections import Counter

def pick_fasta_contig(chrom: str):
    for name in contig_candidates(chrom):
        if name in ref_genome:
            return name
    return None

def summarize_unresolved(vars_list, label):
    bad = []
    for v in vars_list:
        if pick_fasta_contig(v.CHROM) is None:
            bad.append(v.CHROM)
    c = Counter(bad)
    print(f"[{label}] unresolved contigs = {sum(c.values())} across {len(c)} names")
    for name, n in c.most_common(15):
        print(f"  {name}: {n}")

summarize_unresolved(benign_vars, "benign")
summarize_unresolved(pathogenic_vars, "pathogenic")

[benign] unresolved contigs = 892 across 2 names
  MT: 891
  NT_187693.1: 1
[pathogenic] unresolved contigs = 44 across 1 names
  MT: 44


In [27]:
df_benign.head()

Unnamed: 0,chrom,pos,ref,alt,label,ref_seq,alt_seq
0,1,930165,G,A,benign,CCCCAGGCCACAGGCAGATCCCAGGAGACACGCAGGGGCCCTAAGA...,CCCCAGGCCACAGGCAGATCCCAGGAGACACGCAGGGGCCCTAAGA...
1,1,930204,G,A,benign,CCTAAGAAGGGAGCTGGGAATGAGGGGCCACACAAGCCCGGGACGG...,CCTAAGAAGGGAGCTGGGAATGAGGGGCCACACAAGCCCGGGACGG...
2,1,930285,G,A,benign,CCTGGAGTTGGCCAGGACCCTCTAGCATCCTCAAGGGCTGGGCCAA...,CCTGGAGTTGGCCAGGACCCTCTAGCATCCTCAAGGGCTGGGCCAA...
3,1,930314,C,T,benign,CTCAAGGGCTGGGCCAACCAGGCTGGCGTGGGGTGGGGCAGGGGAG...,CTCAAGGGCTGGGCCAACCAGGCTGGCGTGGGGTGGGGCAGGGGAG...
4,1,930325,C,T,benign,GGCCAACCAGGCTGGCGTGGGGTGGGGCAGGGGAGGGCTGAGCCAG...,GGCCAACCAGGCTGGCGTGGGGTGGGGCAGGGGAGGGCTGAGCCAG...


In [28]:
df_pathogenic.head()

Unnamed: 0,chrom,pos,ref,alt,label,ref_seq,alt_seq
0,1,943995,C,T,pathogenic,AGTGCGCAGCAGGGACTGGACTGTGCACCCCACCTTTTTTTTTTTT...,AGTGCGCAGCAGGGACTGGACTGTGCACCCCACCTTTTTTTTTTTT...
1,1,1014143,C,T,pathogenic,CCAGCCAAGGTCTCCCAGGGGTGCAGGGAGAGCGGAGCTGCTCAGA...,CCAGCCAAGGTCTCCCAGGGGTGCAGGGAGAGCGGAGCTGCTCAGA...
2,1,1022368,C,A,pathogenic,TGTGGAAGGCAGGCACCCCAAGCCAGGTGGGCCCCCTTCCCAAATT...,TGTGGAAGGCAGGCACCCCAAGCCAGGTGGGCCCCCTTCCCAAATT...
3,1,1041582,C,T,pathogenic,GGCTGGGAGGGGCCTGGGGGGCGGAGCGGGGCGGGAGCGGGGCGGG...,GGCTGGGAGGGGCCTGGGGGGCGGAGCGGGGCGGGAGCGGGGCGGG...
4,1,1041975,C,A,pathogenic,CCCAGACCCCTGTCAGGGCGCCCTCCCTGACCCGAGCCGCAGCTGC...,CCCAGACCCCTGTCAGGGCGCCCTCCCTGACCCGAGCCGCAGCTGC...


In [29]:
print("Benign rows:", len(df_benign))
print("Pathogenic rows:", len(df_pathogenic))


Benign rows: 180412
Pathogenic rows: 81554


# Save the file 

In [30]:
DATA_DIR

'./clinvar_data/processed_data'

In [31]:
# Save paired dataset
paired_csv = os.path.join(DATA_DIR, "clinvar_paired_1004bp_gencode.csv")
# pd.concat([df_benign, df_pathogenic, df_uncertain], ignore_index=True).to_csv(paired_csv, index=False)
pd.concat([df_benign, df_pathogenic], ignore_index=True).to_csv(paired_csv, index=False)
print(f"Saved paired dataset: {paired_csv}")

Saved paired dataset: ./clinvar_data/processed_data/clinvar_paired_1004bp_gencode.csv


# Train/Val/Test split & Tokenizer (k=6)

In [32]:
df_all = pd.concat([df_benign, df_pathogenic], ignore_index=True)


In [33]:
label2id = {lbl: i for i, lbl in enumerate(USE_LABELS)}
id2label = {i: lbl for lbl, i in label2id.items()}

df_all["label_id"] = df_all["label"].map(label2id).astype(int)
print(df_all["label"].value_counts())

label
benign        180412
pathogenic     81554
Name: count, dtype: int64


In [34]:
df_all.head()

Unnamed: 0,chrom,pos,ref,alt,label,ref_seq,alt_seq,label_id
0,1,930165,G,A,benign,CCCCAGGCCACAGGCAGATCCCAGGAGACACGCAGGGGCCCTAAGA...,CCCCAGGCCACAGGCAGATCCCAGGAGACACGCAGGGGCCCTAAGA...,0
1,1,930204,G,A,benign,CCTAAGAAGGGAGCTGGGAATGAGGGGCCACACAAGCCCGGGACGG...,CCTAAGAAGGGAGCTGGGAATGAGGGGCCACACAAGCCCGGGACGG...,0
2,1,930285,G,A,benign,CCTGGAGTTGGCCAGGACCCTCTAGCATCCTCAAGGGCTGGGCCAA...,CCTGGAGTTGGCCAGGACCCTCTAGCATCCTCAAGGGCTGGGCCAA...,0
3,1,930314,C,T,benign,CTCAAGGGCTGGGCCAACCAGGCTGGCGTGGGGTGGGGCAGGGGAG...,CTCAAGGGCTGGGCCAACCAGGCTGGCGTGGGGTGGGGCAGGGGAG...,0
4,1,930325,C,T,benign,GGCCAACCAGGCTGGCGTGGGGTGGGGCAGGGGAGGGCTGAGCCAG...,GGCCAACCAGGCTGGCGTGGGGTGGGGCAGGGGAGGGCTGAGCCAG...,0


In [35]:
df_all['chrom'].unique()

array(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',
       '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', 'X',
       'Y'], dtype=object)

In [36]:
df_all[df_all['chrom']=='X'].tail()

Unnamed: 0,chrom,pos,ref,alt,label,ref_seq,alt_seq,label_id
261933,X,155260942,G,T,pathogenic,TATAAGTGAGCTCAAATACCAAGGAACAGTTGTCCTGTCTGTAGGT...,TATAAGTGAGCTCAAATACCAAGGAACAGTTGTCCTGTCTGTAGGT...,1
261934,X,155264073,C,T,pathogenic,CATAGTAGATGATCACTGAATGTGTGCTGGTTGCAAGAAGGAATCA...,CATAGTAGATGATCACTGAATGTGTGCTGGTTGCAAGAAGGAATCA...,1
261935,X,155264101,C,T,pathogenic,GGTTGCAAGAAGGAATCAATCGTATCTATTTCATGTTAGAACATAT...,GGTTGCAAGAAGGAATCAATCGTATCTATTTCATGTTAGAACATAT...,1
261936,X,155264268,G,T,pathogenic,ACATAGGGACTGAGATTGATGAATTACAGATACAGAGATCAATTCC...,ACATAGGGACTGAGATTGATGAATTACAGATACAGAGATCAATTCC...,1
261937,X,155264287,A,G,pathogenic,TGAATTACAGATACAGAGATCAATTCCAGCTAGTTTGTTAGCCACC...,TGAATTACAGATACAGAGATCAATTCCAGCTAGTTTGTTAGCCACC...,1


In [37]:
def subsample(df_all,SUBSET_PER_CLASS = 40_000):
    # SUBSET_PER_CLASS = 20_000  # Just train this much sample for now
    df_sub = (
        df_all.groupby("label_id", group_keys=False)
              .apply(lambda x: x.sample(n=min(SUBSET_PER_CLASS, len(x)), random_state=SEED))
              .reset_index(drop=True)
    )
    print("Using subset per class:", SUBSET_PER_CLASS)
    print("Subset counts:", df_sub["label"].value_counts())
    return df_all


In [38]:
df_sub = subsample(df_all)

Using subset per class: 40000
Subset counts: label
benign        40000
pathogenic    40000
Name: count, dtype: int64


  .apply(lambda x: x.sample(n=min(SUBSET_PER_CLASS, len(x)), random_state=SEED))


In [39]:
df_sub['chrom'].value_counts()

chrom
1     21939
2     21189
17    16695
X     16137
11    14607
3     13820
12    13207
7     13090
5     12881
16    12399
19    12124
6     11980
9     11479
15     9761
10     9742
8      8837
4      8746
14     7341
22     5943
13     5821
20     5668
18     5097
21     3430
Y        33
Name: count, dtype: int64

In [40]:
def prepare_data(df_all, CHROM_SPLIT=True):
    if CHROM_SPLIT:
        ALL_CHROMS = sorted(df_all["chrom"].unique().tolist())
        print("Available CHROMs (VCF-style):", ALL_CHROMS)
    
        CHR_TEST = {"18", "21"}   #  holdout set
        CHR_VAL  = {"8"}          # small dev set
        CHR_TRAIN = set(ALL_CHROMS) - CHR_TEST - CHR_VAL

        def _by_chrom(df, chroms):
            return df[df["chrom"].isin(chroms)].reset_index(drop=True)

        train_df = _by_chrom(df_all, CHR_TRAIN)
        val_df   = _by_chrom(df_all, CHR_VAL)
        test_df  = _by_chrom(df_all, CHR_TEST)

        # sanity prints
        for name, d in [("train", train_df), ("val", val_df), ("test", test_df)]:
            uniq = sorted(d["chrom"].unique().tolist())
            counts = d["label"].value_counts().to_dict()
            print(f"{name}: n={len(d)} | chroms={uniq} | counts={counts}")

    else:
        # random stratified split
        train_df, temp_df = train_test_split(
            df_all, test_size=0.2, stratify=df_all["label_id"], random_state=SEED
        )
        val_df, test_df = train_test_split(
            temp_df, test_size=0.5, stratify=temp_df["label_id"], random_state=SEED
        )
        print("Splits:", len(train_df), len(val_df), len(test_df))
    return train_df, val_df,test_df

In [41]:
 train_df, val_df,test_df = prepare_data(df_sub)

Available CHROMs (VCF-style): ['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '3', '4', '5', '6', '7', '8', '9', 'X', 'Y']
train: n=244602 | chroms=['1', '10', '11', '12', '13', '14', '15', '16', '17', '19', '2', '20', '22', '3', '4', '5', '6', '7', '9', 'X', 'Y'] | counts={'benign': 167808, 'pathogenic': 76794}
val: n=8837 | chroms=['8'] | counts={'benign': 6153, 'pathogenic': 2684}
test: n=8527 | chroms=['18', '21'] | counts={'benign': 6451, 'pathogenic': 2076}


# Tokenizer

In [42]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    kmer=KMER,
)

In [43]:
tokenizer

EsmTokenizer(name_or_path='InstaDeepAI/nucleotide-transformer-500m-human-ref', vocab_size=4107, model_max_length=1000, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '<unk>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<cls>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [44]:
enc_test = tokenizer("ACGTN" * 250, padding="max_length", truncation=True, max_length=MAX_TOKENS)
print("Tokenizer test tokens:", len(enc_test["input_ids"]))


Tokenizer test tokens: 1000


#  Paired Dataset class

In [45]:
class ClinVarPairedDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer, max_tokens: int = 1000):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_tokens = max_tokens

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        row = self.df.iloc[idx]
        ref_seq = row["ref_seq"]
        alt_seq = row["alt_seq"]
        y = int(row["label_id"])

        ref_enc = self.tokenizer(
            ref_seq,
            padding="max_length",
            truncation=True,
            max_length=self.max_tokens,
            return_tensors="pt",
        )
        alt_enc = self.tokenizer(
            alt_seq,
            padding="max_length",
            truncation=True,
            max_length=self.max_tokens,
            return_tensors="pt",
        )
        item = {
            "ref_input_ids": ref_enc["input_ids"].squeeze(0),
            "ref_attention_mask": ref_enc["attention_mask"].squeeze(0),
            "alt_input_ids": alt_enc["input_ids"].squeeze(0),
            "alt_attention_mask": alt_enc["attention_mask"].squeeze(0),
            "labels": torch.tensor(y, dtype=torch.long),
        }
        return item

In [46]:
from dataclasses import dataclass
from typing import Optional

# collate ref/alt pairs into batches
@dataclass
class DataCollatorSiamese:
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        batch = {}
        for key in ["ref_input_ids", "ref_attention_mask",
                    "alt_input_ids", "alt_attention_mask",
                    "labels"]:
            batch[key] = torch.stack([f[key] for f in features])
        return batch


In [47]:
train_ds = ClinVarPairedDataset(train_df, tokenizer, MAX_TOKENS)
val_ds   = ClinVarPairedDataset(val_df, tokenizer, MAX_TOKENS)
test_ds  = ClinVarPairedDataset(test_df, tokenizer, MAX_TOKENS)


In [48]:
# train_ds[0]

In [49]:
# ex = train_ds[0]
# for k, v in ex.items():
#     print(k, tuple(v.shape), v.dtype)
#     if "input_ids" in k:
#         print(" first 5 ids:", v[:5].tolist())
#     if "attention_mask" in k:
#         print(" sum(mask):", int(v.sum().item()))
# print("label id:", int(ex["labels"]))


 # Siamese NT model (shared encoder + fusion head)

In [50]:
counts = train_df["label_id"].value_counts().sort_index().to_numpy()
weights = (counts.sum() / (len(counts) * counts)).astype(np.float32)
print("Class counts:", counts, "-> weights:", weights)


Class counts: [167808  76794] -> weights: [0.7288151 1.5925853]


In [51]:
class SiameseNTClassifier(nn.Module):
    def __init__(self, base_model_name: str, num_labels: int, class_weights: Optional[np.ndarray] = None):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name, trust_remote_code=True)
        hidden = self.encoder.config.hidden_size if hasattr(self.encoder, "config") else 768
        fuse_dim = hidden * 4  # [ref, alt, alt-ref, ref*alt]
        self.classifier = nn.Sequential(
            nn.Linear(fuse_dim, hidden),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(hidden, num_labels),
        )
        if class_weights is not None:
            self.register_buffer("class_weights", torch.tensor(class_weights, dtype=torch.float))
        else:
            self.class_weights = None

    @staticmethod
    def masked_mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
        summed = (last_hidden_state * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-6)
        return summed / denom

    def forward(
        self,
        ref_input_ids=None,
        ref_attention_mask=None,
        alt_input_ids=None,
        alt_attention_mask=None,
        labels=None,
        **kwargs,
    ):
        
        kwargs.pop("num_items_in_batch", None)   
        ref_out = self.encoder(input_ids=ref_input_ids, attention_mask=ref_attention_mask)  # ref
        alt_out = self.encoder(input_ids=alt_input_ids, attention_mask=alt_attention_mask)  # alt

        ref_last = ref_out.last_hidden_state
        alt_last = alt_out.last_hidden_state

        ref_repr = self.masked_mean_pool(ref_last, ref_attention_mask)
        alt_repr = self.masked_mean_pool(alt_last, alt_attention_mask)

        diff = alt_repr - ref_repr
        prod = alt_repr * ref_repr
        feat = torch.cat([ref_repr, alt_repr, diff, prod], dim=-1)

        logits = self.classifier(feat)
        loss = None
        if labels is not None:
            if hasattr(self, "class_weights") and self.class_weights is not None:
                loss_fn = nn.CrossEntropyLoss(weight=self.class_weights)
            else:
                loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
        return {"loss": loss, "logits": logits}

In [52]:
model = SiameseNTClassifier(MODEL_NAME, num_labels=2, class_weights=weights)


Some weights of EsmModel were not initialized from the model checkpoint at InstaDeepAI/nucleotide-transformer-500m-human-ref and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [53]:
model

SiameseNTClassifier(
  (encoder): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(4105, 1280, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1002, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-23): 24 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((1280,), eps=1e-12, elementwise_affine=True)
          )
     

# Using LORA

In [54]:
# Where to use LORA
tm = [ 
    "attention.self.query",
    "attention.self.key", 
    "attention.self.value",
    "attention.output.dense"
]

In [55]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules=tm
)

In [56]:
model.encoder = get_peft_model(model.encoder, lora_config) 

In [57]:
 model.encoder.print_trainable_parameters() 

trainable params: 1,966,080 || all params: 482,404,321 || trainable%: 0.4076


In [58]:
def _summarize_params(m):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    non_trainable = total - trainable
    print(f"Total params: {total:,}")
    print(f"Trainable params: {trainable:,} ({trainable/total:.2%})")
    print(f"Non-trainable params: {non_trainable:,}")

In [59]:
_summarize_params(model)

Total params: 488,961,763
Trainable params: 8,523,522 (1.74%)
Non-trainable params: 480,438,241


In [60]:
model = model.to(device) #### TRANSFER TO GPU


In [61]:
# # Speed/memory-friendly flags (A100/H100 benefit; V100 ignores TF32)
# torch.backends.cuda.matmul.allow_tf32 = True
# try:
#     torch.set_float32_matmul_precision("high")
# except Exception:
#     pass

# # Enable input grads + non-reentrant GC (better with PyTorch ≥2.0)
# if hasattr(model, "enable_input_require_grads"):
#     model.enable_input_require_grads()
# if hasattr(model.encoder, "gradient_checkpointing_enable"):
#     try:
#         model.encoder.gradient_checkpointing_enable(
#             gradient_checkpointing_kwargs={"use_reentrant": False}
#         )
#     except TypeError:
#         model.encoder.gradient_checkpointing_enable()

# # Disable KV cache during training to save memory
# if hasattr(model.encoder, "config"):
#     model.encoder.config.use_cache = False


In [62]:


# speed/memory-friendly flags
torch.backends.cuda.matmul.allow_tf32 = True  #  speedup on V100, no accuracy change
torch.set_float32_matmul_precision("high")    #  PyTorch >= 2.0

#enable gradient checkpointing to trade compute for memory
if hasattr(model.encoder, "gradient_checkpointing_enable"):
    model.encoder.gradient_checkpointing_enable()  # 
# some HF models read this flag to disable key/value caching during training
if hasattr(model.encoder, "config"):
    model.encoder.config.use_cache = False  # 

# Metrics

In [63]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()
    preds = probs.argmax(axis=1)

    pos_id = 1  # pathogenic
    pos_probs = probs[:, pos_id]

    out = {}
    try:
        out["auroc"] = roc_auc_score(labels, pos_probs)
    except ValueError:
        out["auroc"] = float("nan")
    try:
        out["auprc"] = average_precision_score(labels, pos_probs)
    except ValueError:
        out["auprc"] = float("nan")

    out["accuracy"]  = accuracy_score(labels, preds)
    out["f1"]        = f1_score(labels, preds, zero_division=0)
    out["precision"] = precision_score(labels, preds, zero_division=0)
    out["recall"]    = recall_score(labels, preds, zero_division=0)
    return out


# Training Arguments(Huggingface)

In [64]:
args = TrainingArguments(
    output_dir=OUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    learning_rate=LR,
    num_train_epochs=TRAIN_EPOCHS,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,             # keep only the 2 most recent/best checkpoints
    save_safetensors=True,          # use .safetensors when possible
    logging_strategy="steps",       # log by step
    logging_steps=100,
    logging_first_step=True,        # also log at step 1
    load_best_model_at_end=True,
    metric_for_best_model="auroc",
    greater_is_better=True,
    
    #data loading optimizations
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=2,
    
    fp16=USE_FP16,
    bf16=USE_BF16,
    report_to=["tensorboard"],      # write TB event files
    logging_dir=os.path.join(OUT_DIR, "runs"),  # TB log dir
)

# Loggging

In [65]:
from transformers.utils import logging as hf_logging  
hf_logging.set_verbosity_info()                       
hf_logging.enable_default_handler()                   
hf_logging.enable_explicit_format()                   


In [66]:
class CSVLoggerCallback(TrainerCallback):  
    """ Write clean CSV logs at each logging/eval step.
    - Adds a 'split' column: 'train' or 'eval'
    - Skips noisy keys ('epoch' duplicate, runtime counters)
    - Rounds epoch to 6 decimals; filters NaN/Inf values
    """
    def __init__(self, csv_path: str):
        self.csv_path = csv_path
        os.makedirs(os.path.dirname(csv_path), exist_ok=True)
        if not os.path.exists(csv_path):
            with open(csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["step", "epoch", "split", "metric", "value"])  

    @staticmethod
    def _should_skip_key(k: str) -> bool:
        skip = {
            "epoch", "total_flos", "train_runtime", "train_samples_per_second",
            "train_steps_per_second", "total_loss"
        }
        return (k in skip) or k.startswith("_")

    @staticmethod
    def _is_finite_number(v) -> bool:
        return isinstance(v, (int, float)) and math.isfinite(v)

    def _write_rows(self, step: int, epoch: float | str, split: str, logs: dict):
        with open(self.csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            ep = "" if epoch == "" else round(float(epoch), 6)
            for k, v in logs.items():
                if self._should_skip_key(k):
                    continue
                if not self._is_finite_number(v):
                    continue
                # Strip eval_ prefix for eval metrics
                name = k[5:] if split == "eval" and k.startswith("eval_") else k
                writer.writerow([step, ep, split, name, v])

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        step = int(state.global_step)
        epoch = state.epoch if state.epoch is not None else ""
        # Detect eval vs train by presence of any eval_* keys
        is_eval = any(k.startswith("eval_") for k in logs.keys())
        split = "eval" if is_eval else "train"
        self._write_rows(step, epoch, split, logs)

    # Also capture metrics coming via on_evaluate explicitly (belt & suspenders)
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if not metrics:
            return
        step = int(state.global_step)
        epoch = state.epoch if state.epoch is not None else ""
        self._write_rows(step, epoch, "eval", metrics)

In [67]:
OUT_DIR


'./clinvar_data/clinvar_outs'

In [68]:
log_csv_path = os.path.join(OUT_DIR, "logs", "train_log.csv")  # log path


In [69]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=DataCollatorSiamese(),   
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=1e-4),
        CSVLoggerCallback(log_csv_path),  #  write CSV logs
    ],
    
)

  trainer = Trainer(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
[INFO|trainer.py:748] 2025-10-19 22:55:39,635 >> Using auto half precision backend


In [70]:
train_result = trainer.train()
print(train_result)

[INFO|trainer.py:2414] 2025-10-19 22:55:40,175 >> ***** Running training *****
[INFO|trainer.py:2415] 2025-10-19 22:55:40,176 >>   Num examples = 244,602
[INFO|trainer.py:2416] 2025-10-19 22:55:40,177 >>   Num Epochs = 5
[INFO|trainer.py:2417] 2025-10-19 22:55:40,177 >>   Instantaneous batch size per device = 32
[INFO|trainer.py:2420] 2025-10-19 22:55:40,178 >>   Total train batch size (w. parallel, distributed & accumulation) = 32
[INFO|trainer.py:2421] 2025-10-19 22:55:40,178 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2422] 2025-10-19 22:55:40,179 >>   Total optimization steps = 38,220
[INFO|trainer.py:2423] 2025-10-19 22:55:40,183 >>   Number of trainable parameters = 8,523,522


Epoch,Training Loss,Validation Loss,Auroc,Auprc,Accuracy,F1,Precision,Recall
1,0.6639,0.687445,0.676441,0.462623,0.525744,0.513974,0.373127,0.825633
2,0.6522,0.634084,0.709385,0.531401,0.636189,0.528247,0.43573,0.670641
3,0.6333,0.693286,0.710216,0.523027,0.546452,0.525118,0.38499,0.825633
4,0.6332,0.661546,0.714985,0.541971,0.603599,0.531497,0.414563,0.740313
5,0.6127,0.615235,0.716591,0.54297,0.669797,0.53535,0.467464,0.626304


[INFO|trainer.py:4307] 2025-10-20 01:37:29,416 >> 
***** Running Evaluation *****
[INFO|trainer.py:4309] 2025-10-20 01:37:29,417 >>   Num examples = 8837
[INFO|trainer.py:4312] 2025-10-20 01:37:29,418 >>   Batch size = 32
[INFO|trainer.py:3984] 2025-10-20 01:43:05,374 >> Saving model checkpoint to ./clinvar_data/clinvar_outs/checkpoint-7644
[INFO|trainer.py:3998] 2025-10-20 01:43:05,385 >> Trainer.model is not a `PreTrainedModel`, only saving its state dict.
[INFO|tokenization_utils_base.py:2510] 2025-10-20 01:43:08,421 >> tokenizer config file saved in ./clinvar_data/clinvar_outs/checkpoint-7644/tokenizer_config.json
[INFO|tokenization_utils_base.py:2519] 2025-10-20 01:43:08,425 >> Special tokens file saved in ./clinvar_data/clinvar_outs/checkpoint-7644/special_tokens_map.json
[INFO|trainer.py:4307] 2025-10-20 04:25:00,675 >> 
***** Running Evaluation *****
[INFO|trainer.py:4309] 2025-10-20 04:25:00,677 >>   Num examples = 8837
[INFO|trainer.py:4312] 2025-10-20 04:25:00,677 >>   Batch

TrainOutput(global_step=38220, training_loss=0.6425233478760731, metrics={'train_runtime': 50260.4506, 'train_samples_per_second': 24.333, 'train_steps_per_second': 0.76, 'total_flos': 0.0, 'train_loss': 0.6425233478760731, 'epoch': 5.0})


In [None]:
train_result2 = trainer.train()
print(train_result2)

[INFO|trainer.py:2414] 2025-10-20 14:53:48,347 >> ***** Running training *****
[INFO|trainer.py:2415] 2025-10-20 14:53:48,348 >>   Num examples = 244,602
[INFO|trainer.py:2416] 2025-10-20 14:53:48,349 >>   Num Epochs = 5
[INFO|trainer.py:2417] 2025-10-20 14:53:48,350 >>   Instantaneous batch size per device = 32
[INFO|trainer.py:2420] 2025-10-20 14:53:48,351 >>   Total train batch size (w. parallel, distributed & accumulation) = 32
[INFO|trainer.py:2421] 2025-10-20 14:53:48,351 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2422] 2025-10-20 14:53:48,352 >>   Total optimization steps = 38,220
[INFO|trainer.py:2423] 2025-10-20 14:53:48,357 >>   Number of trainable parameters = 8,523,522


Epoch,Training Loss,Validation Loss


# Evaluate and save

In [71]:
print("Validation:", trainer.evaluate(eval_dataset=val_ds))
print("Test:", trainer.evaluate(eval_dataset=test_ds))

#  determine best checkpoint dir (fallback to OUT_DIR if none)
best_dir = trainer.state.best_model_checkpoint or OUT_DIR  # 
print(f" Best checkpoint: {best_dir}")  #

# save full trainer model snapshot into best_dir (will include classifier; PEFT may store adapters)
trainer.save_model(best_dir)  

#  save tokenizer alongside best checkpoint
tokenizer.save_pretrained(best_dir)  

# if using PEFT/LoRA, also save adapters explicitly
try:
    adapter_dir = os.path.join(best_dir, "lora_adapter")
    os.makedirs(adapter_dir, exist_ok=True)
    model.encoder.save_pretrained(adapter_dir)  # PEFT: writes adapter config + weights
    # Save classifier head separately for redundancy
    torch.save(model.classifier.state_dict(), os.path.join(best_dir, "classifier.pt"))
    print(f" Saved LoRA adapter to: {adapter_dir}")
except Exception as e:
    print(f" PEFT adapter save skipped/failed: {e}")

print(f" Best model & tokenizer saved to: {best_dir}")  

[INFO|trainer.py:4307] 2025-10-20 12:53:20,664 >> 
***** Running Evaluation *****
[INFO|trainer.py:4309] 2025-10-20 12:53:20,665 >>   Num examples = 8837
[INFO|trainer.py:4312] 2025-10-20 12:53:20,666 >>   Batch size = 32


[INFO|trainer.py:4307] 2025-10-20 12:58:56,735 >> 
***** Running Evaluation *****
[INFO|trainer.py:4309] 2025-10-20 12:58:56,737 >>   Num examples = 8527
[INFO|trainer.py:4312] 2025-10-20 12:58:56,737 >>   Batch size = 32


Validation: {'eval_loss': 0.6152347326278687, 'eval_auroc': 0.716590576658836, 'eval_auprc': 0.5429701473194328, 'eval_accuracy': 0.6697974425710083, 'eval_f1': 0.5353503184713376, 'eval_precision': 0.4674638487208009, 'eval_recall': 0.6263040238450075, 'eval_runtime': 336.0447, 'eval_samples_per_second': 26.297, 'eval_steps_per_second': 0.824, 'epoch': 5.0}


[INFO|trainer.py:3984] 2025-10-20 13:04:21,052 >> Saving model checkpoint to ./clinvar_data/clinvar_outs/checkpoint-38220
[INFO|trainer.py:3998] 2025-10-20 13:04:21,066 >> Trainer.model is not a `PreTrainedModel`, only saving its state dict.


Test: {'eval_loss': 0.6093655824661255, 'eval_auroc': 0.69308297558981, 'eval_auprc': 0.4528098351004526, 'eval_accuracy': 0.6735076814823502, 'eval_f1': 0.4583657587548638, 'eval_precision': 0.38446475195822455, 'eval_recall': 0.5674373795761078, 'eval_runtime': 324.302, 'eval_samples_per_second': 26.293, 'eval_steps_per_second': 0.823, 'epoch': 5.0}
 Best checkpoint: ./clinvar_data/clinvar_outs/checkpoint-38220


[INFO|tokenization_utils_base.py:2510] 2025-10-20 13:04:25,617 >> tokenizer config file saved in ./clinvar_data/clinvar_outs/checkpoint-38220/tokenizer_config.json
[INFO|tokenization_utils_base.py:2519] 2025-10-20 13:04:25,621 >> Special tokens file saved in ./clinvar_data/clinvar_outs/checkpoint-38220/special_tokens_map.json
[INFO|tokenization_utils_base.py:2510] 2025-10-20 13:04:25,632 >> tokenizer config file saved in ./clinvar_data/clinvar_outs/checkpoint-38220/tokenizer_config.json
[INFO|tokenization_utils_base.py:2519] 2025-10-20 13:04:25,635 >> Special tokens file saved in ./clinvar_data/clinvar_outs/checkpoint-38220/special_tokens_map.json
[INFO|configuration_utils.py:693] 2025-10-20 13:04:25,824 >> loading configuration file config.json from cache at /home/nkhat001/.cache/huggingface/hub/models--InstaDeepAI--nucleotide-transformer-500m-human-ref/snapshots/f87b5d7233295242e79c951873d290f4cf992045/config.json
[INFO|configuration_utils.py:765] 2025-10-20 13:04:25,827 >> Model con

 Saved LoRA adapter to: ./clinvar_data/clinvar_outs/checkpoint-38220/lora_adapter
 Best model & tokenizer saved to: ./clinvar_data/clinvar_outs/checkpoint-38220


# Reload helper (Best checkpoint)

In [72]:
def load_best_siamese(checkpoint_dir: str) -> SiameseNTClassifier:
    base = SiameseNTClassifier(MODEL_NAME, num_labels=2, class_weights=None)
    # Load LoRA adapter if present
    adapter_path = os.path.join(checkpoint_dir, "lora_adapter")
    if os.path.isdir(adapter_path):
        from peft import PeftModel
        base.encoder = PeftModel.from_pretrained(base.encoder, adapter_path)
    # Load the whole-state dict fallback if available (from trainer.save_model)
    state_dict_path_bin = os.path.join(checkpoint_dir, "pytorch_model.bin")
    state_dict_path_sft = os.path.join(checkpoint_dir, "model.safetensors")
    if os.path.exists(state_dict_path_sft):
        from safetensors.torch import load_file
        sd = load_file(state_dict_path_sft)
        base.load_state_dict(sd, strict=False)
    elif os.path.exists(state_dict_path_bin):
        sd = torch.load(state_dict_path_bin, map_location="cpu")
        base.load_state_dict(sd, strict=False)
    else:
        # load classifier if we saved it
        clf_path = os.path.join(checkpoint_dir, "classifier.pt")
        if os.path.exists(clf_path):
            base.classifier.load_state_dict(torch.load(clf_path, map_location="cpu"))
    return base

# Inference helper (paired ref/alt)

In [73]:
@torch.no_grad()
def predict_ref_alt(ref_seqs: List[str], alt_seqs: List[str]) -> List[Dict[str, float]]:
    assert len(ref_seqs) == len(alt_seqs)
    ref_enc = tokenizer(ref_seqs, padding=True, truncation=True, max_length=MAX_TOKENS, return_tensors="pt")
    alt_enc = tokenizer(alt_seqs, padding=True, truncation=True, max_length=MAX_TOKENS, return_tensors="pt")
    ref_enc = {k: v.to(device) for k, v in ref_enc.items()}
    alt_enc = {k: v.to(device) for k, v in alt_enc.items()}
    out = model(
        ref_input_ids=ref_enc["input_ids"],
        ref_attention_mask=ref_enc["attention_mask"],
        alt_input_ids=alt_enc["input_ids"],
        alt_attention_mask=alt_enc["attention_mask"],
    )
    probs = torch.softmax(out["logits"], dim=-1).cpu().numpy()
    return [{id2label[i]: float(p) for i, p in enumerate(row)} for row in probs]


In [74]:
preds = predict_ref_alt(test_df["ref_seq"].head(10).tolist(), test_df["alt_seq"].head(10).tolist())

In [75]:
preds

[{'benign': 0.41584742069244385, 'pathogenic': 0.5841525197029114},
 {'benign': 0.5528414845466614, 'pathogenic': 0.4471585154533386},
 {'benign': 0.5188356637954712, 'pathogenic': 0.4811643064022064},
 {'benign': 0.6111499667167664, 'pathogenic': 0.38885006308555603},
 {'benign': 0.4409506916999817, 'pathogenic': 0.5590493679046631},
 {'benign': 0.6813028454780579, 'pathogenic': 0.31869715452194214},
 {'benign': 0.6361271142959595, 'pathogenic': 0.36387285590171814},
 {'benign': 0.35309356451034546, 'pathogenic': 0.6469064354896545},
 {'benign': 0.6740504503250122, 'pathogenic': 0.3259495198726654},
 {'benign': 0.45314082503318787, 'pathogenic': 0.5468591451644897}]