In [None]:
!pip install -q "transformers>=4.44.0" "datasets>=2.20.0" "accelerate>=0.34.0" \
              "bitsandbytes>=0.43.1" "peft>=0.12.0" "trl>=0.9.6" "scikit-learn"

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m44.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m465.5/465.5 kB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import transformers
transformers.__version__

'4.44.0'

In [None]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)

In [None]:
dataset = load_dataset("araag2/MedNLI", "processed")
dataset

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

processed/train-00000-of-00001.parquet:   0%|          | 0.00/637k [00:00<?, ?B/s]

processed/dev-00000-of-00001.parquet:   0%|          | 0.00/84.2k [00:00<?, ?B/s]

processed/test-00000-of-00001.parquet:   0%|          | 0.00/83.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/11232 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/1395 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1422 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'Label', 'Premise', 'Hypothesis'],
        num_rows: 11232
    })
    dev: Dataset({
        features: ['id', 'Label', 'Premise', 'Hypothesis'],
        num_rows: 1395
    })
    test: Dataset({
        features: ['id', 'Label', 'Premise', 'Hypothesis'],
        num_rows: 1422
    })
})

In [None]:
train_ds = dataset["train"]
val_ds   = dataset["dev"]
test_ds  = dataset["test"]

In [None]:
train_ds[3]

{'id': '23eb9ba2-66c7-11e7-9ac1-f45c89b91419',
 'Label': 'entailment',
 'Premise': 'Nystagmus and twiching of R arm was noted.',
 'Hypothesis': ' The patient had abnormal neuro exam.'}

### Baseline Model

### Label Mapping

In [None]:
# A clean way to switch back and forth between text labels and numeric labels.

label_list = ["entailment", "neutral", "contradiction"]
label2id = {l: i for i, l in enumerate(label_list)} # from labels to numbers
id2label = {i: l for l, i in label2id.items()} # from numbers to labels

### Preprocessing

In [None]:
# Tokenizing the data and making it ready for the model

from transformers import AutoTokenizer # loads the correct tokenizer based on the model given

encoder_model_id = "emilyalsentzer/Bio_ClinicalBERT" # medical-domain BERT. 12 transformer layers, ~110 million parameters, WordPiece vocabulary
tokenizer = AutoTokenizer.from_pretrained(encoder_model_id, use_fast=True)

def preprocess_fn(example):
    enc = tokenizer(
        example["Premise"],
        example["Hypothesis"],
        truncation=True,
        max_length=256,
    )
    enc["labels"] = label2id[example["Label"]]
    return enc

tokenized_ds = dataset.map(preprocess_fn, batched=False)

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

Map:   0%|          | 0/11232 [00:00<?, ? examples/s]

Map:   0%|          | 0/1395 [00:00<?, ? examples/s]

Map:   0%|          | 0/1422 [00:00<?, ? examples/s]

In [None]:
from transformers import AutoModelForSequenceClassification, Trainer
from transformers import TrainingArguments as HFTrainingArguments
import numpy as np
from sklearn.metrics import accuracy_score, f1_score

# from_pretrained loads the model architecture (BioClinicalBERT), the pretrained weights & then adds a new classification head on top for the task
encoder_model = AutoModelForSequenceClassification.from_pretrained(
    encoder_model_id,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
)

# How the model will be evaluated
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average="macro")
    return {"accuracy": acc, "macro_f1": macro_f1}

training_args = HFTrainingArguments(
    output_dir="bio_clinicalbert_mednli",
    per_device_train_batch_size=16, # how many examples the model processes in one step before updating weights.
    per_device_eval_batch_size=32, # During evaluation, we can use a bigger batch size because - We are NOT storing gradients, It’s only forward pass, Faster evaluation
    num_train_epochs=3,
    learning_rate=2e-5,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    load_best_model_at_end=True, # gives the model checkpoint where the metric was best.
    metric_for_best_model="macro_f1", # Macro-F1 cares about all classes
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    report_to="none"
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
# The manager that controls and coordinates the whole training process. A built-in training loop for models in Hugging Face.

trainer = Trainer(
    model=encoder_model,
    args=training_args,
    train_dataset=tokenized_ds["train"], # training data after tokenization.
    eval_dataset=tokenized_ds["dev"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.5856,0.49276,0.805735,0.807062
2,0.4226,0.46733,0.82509,0.825972
3,0.2904,0.486456,0.835842,0.836075


TrainOutput(global_step=2106, training_loss=0.4783052267172398, metrics={'train_runtime': 107.2939, 'train_samples_per_second': 314.053, 'train_steps_per_second': 19.628, 'total_flos': 1645568870531040.0, 'train_loss': 0.4783052267172398, 'epoch': 3.0})

In [None]:
metrics = trainer.evaluate(tokenized_ds["test"])
metrics

{'eval_loss': 0.5323934555053711,
 'eval_accuracy': 0.8171589310829818,
 'eval_macro_f1': 0.8168678106706434,
 'eval_runtime': 0.8582,
 'eval_samples_per_second': 1656.943,
 'eval_steps_per_second': 52.435,
 'epoch': 3.0}

In [None]:
trainer.save_model("baseline_bioclinicalbert_mednli")
tokenizer.save_pretrained("baseline_bioclinicalbert_mednli")

('baseline_bioclinicalbert_mednli/tokenizer_config.json',
 'baseline_bioclinicalbert_mednli/special_tokens_map.json',
 'baseline_bioclinicalbert_mednli/vocab.txt',
 'baseline_bioclinicalbert_mednli/added_tokens.json',
 'baseline_bioclinicalbert_mednli/tokenizer.json')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
save_path = "/content/drive/MyDrive/baseline_bioclinicalbert_mednli"
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

('/content/drive/MyDrive/baseline_bioclinicalbert_mednli/tokenizer_config.json',
 '/content/drive/MyDrive/baseline_bioclinicalbert_mednli/special_tokens_map.json',
 '/content/drive/MyDrive/baseline_bioclinicalbert_mednli/vocab.txt',
 '/content/drive/MyDrive/baseline_bioclinicalbert_mednli/added_tokens.json',
 '/content/drive/MyDrive/baseline_bioclinicalbert_mednli/tokenizer.json')

In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_path = "/content/drive/MyDrive/baseline_bioclinicalbert_mednli"   # or your saved folder name

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [None]:
import torch
import numpy as np

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}

def classify_nli(premise, hypothesis):
    # Prepare input for the model
    inputs = tokenizer(
        premise,
        hypothesis,
        return_tensors="pt",
        truncation=True,
        max_length=512,
    )

    # Run model
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Get label
    pred_id = int(torch.argmax(logits, dim=-1))
    pred_label = id2label[pred_id]

    return pred_label

In [None]:
premise = "To investigate the efficacy of 6 weeks of daily low-dose oral prednisolone in improving pain , mobility , and systemic low-grade inflammation in the short term and whether the effect would be sustained at 12 weeks in older adults with moderate to severe knee osteoarthritis ( OA ) . A total of 125 patients with primary knee OA were randomized 1:1 ; 63 received 7.5 mg/day of prednisolone and 62 received placebo for 6 weeks . Outcome measures included pain reduction and improvement in function scores and systemic inflammation markers . Pain was assessed using the visual analog pain scale ( 0-100 mm ) . Secondary outcome measures included the Western Ontario and McMaster Universities Osteoarthritis Index scores , patient global assessment ( PGA ) of the severity of knee OA , and 6-min walk distance ( 6MWD ) . Serum levels of interleukin 1 ( IL-1 ) , IL-6 , tumor necrosis factor ( TNF ) - , and high-sensitivity C-reactive protein ( hsCRP ) were measured . There was a clinically relevant reduction in the intervention group compared to the placebo group for knee pain , physical function , PGA , and 6MWD at 6 weeks . The mean difference between treatment arms ( 95 % CI ) was 10.9 ( 4.8-18 .0 ) , p < 0.001 ; 9.5 ( 3.7-15 .4 ) , p < 0.05 ; 15.7 ( 5.3-26 .1 ) , p < 0.001 ; and 86.9 ( 29.8-144 .1 ) , p < 0.05 , respectively . Further , there was a clinically relevant reduction in the serum levels of IL-1 , IL-6 , TNF - , and hsCRP at 6 weeks in the intervention group when compared to the placebo group . These differences remained significant at 12 weeks . The Outcome Measures in Rheumatology Clinical Trials-Osteoarthritis Research Society International responder rate was 65 % in the intervention group and 34 % in the placebo group ( p < 0.05 ) . Low-dose oral prednisolone had both a short-term and a longer sustained effect resulting in less knee pain , better physical function , and attenuation of systemic inflammation in older patients with knee OA"
hypothesis = "These results suggest a potential pathway for primary care practices to implement systematic training programs in MI, thereby enhancing the overall effectiveness of weight management strategies. Looking ahead, scaling this intervention could transform childhood obesity treatment by fostering collaborative healthcare environments where parents receive comprehensive support, ultimately leading to healthier lifestyles and improved long-term health trajectories for children."
print(classify_nli(premise, hypothesis))

entailment


In [None]:
import torch.nn.functional as F

def classify_with_probs(premise, hypothesis):
    inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True)
    with torch.no_grad():
        logits = model(**inputs).logits

    probs = F.softmax(logits, dim=-1).squeeze().tolist()
    pred_id = int(np.argmax(probs))
    return id2label[pred_id], probs


In [None]:
premise = "Multiple sclerosis (MS) progresses through brain region-specific inflammation and degeneration, with poorly defined mechanisms. In individuals with MS, we identified increased expression of formyl peptide receptor 1 (FPR1) in central nervous system (CNS)-resident microglia and CNS-infiltrating macrophages. Blood amounts of N-formylated peptides, which are endogenous agonists of FPR1, correlated with disease progression in patients with MS. In MS mouse models, signaling through FPR1 promoted microglial mitochondrial dysfunction, causing axonal loss and apoptosis. FPR1-expressing microglia sustained the clonal expansion of myelin-reactive CD4+ T cells in the CNS. A CNS-penetrating small molecule FPR1 antagonist, T0080, mitigated autoimmune responses and axonal degeneration. Our study identifies FPR1 signaling as a potential mechanism for MS progression and suggests antagonizing FPR1 as a therapeutic approach."
hypothesis = "The paper reported that the small-molecule antagonist T0080 can fully reverse MS after a single oral dose in humans."
print(classify_with_probs(premise, hypothesis))

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


('entailment', [0.7074811458587646, 0.06522099673748016, 0.2272978574037552])


### BioMistral-7B + QLoRA on MedNLI

In [None]:
# AutoTokenizer → converts text → tokens → numbers the model understands.
# AutoModelForCausalLM → loads a decoder-only model for text generation.
# BitsAndBytesConfig → tells Transformers to load the model in 4-bit (QLoRA requirement).

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

llm_id = "BioMistral/BioMistral-7B" # load this model weights

# Enables QLoRA. If we tried to fine-tune all 7 billion parameters of BioMistral - It would not fit in Colab easily & Training would be slow and expensive.
# QLoRA solves this by loading the base model in 4-bit (very compressed, low memory).
# Freezing the original weights (we don’t change them).
# Adding small trainable adapter layers (LoRA weights) on top.

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
)

# Downloads the correct tokenizer for BioMistral.
llm_tokenizer = AutoTokenizer.from_pretrained(llm_id)

# For Mistral-style models, often there is no explicit pad token → use eos as pad. Training with batches requires a pad token.
if llm_tokenizer.pad_token is None:
    llm_tokenizer.pad_token = llm_tokenizer.eos_token

# Load the model in 4-bit
llm_model = AutoModelForCausalLM.from_pretrained(
    llm_id,
    quantization_config=quant_config,
    device_map="auto",
)

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/14.5G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [None]:
# BioMistral-7B is a text generator, not a classifier.
# So we need to convert every MedNLI example into a single text string that looks like a prompt:

def format_example(example):
    premise = example["Premise"]
    hypothesis = example["Hypothesis"]
    label = example["Label"]  # already a string like "entailment"
    text = (
        "### Instruction:\n"
        "You are a medical NLI classifier. Given a medical premise and hypothesis, "
        "answer with one of: entailment, contradiction, neutral.\n\n"
        f"### Premise:\n{premise}\n\n"
        f"### Hypothesis:\n{hypothesis}\n\n"
        f"### Answer:\n{label}"
    )
    return {"text": text}

train_llm_ds = dataset["train"].map(format_example)
val_llm_ds   = dataset["dev"].map(format_example)
test_llm_ds  = dataset["test"].map(format_example)

train_llm_ds[0]["text"][:500]


Map:   0%|          | 0/11232 [00:00<?, ? examples/s]

Map:   0%|          | 0/1395 [00:00<?, ? examples/s]

Map:   0%|          | 0/1422 [00:00<?, ? examples/s]

'### Instruction:\nYou are a medical NLI classifier. Given a medical premise and hypothesis, answer with one of: entailment, contradiction, neutral.\n\n### Premise:\nLabs were notable for Cr 1.7 (baseline 0.5 per old records) and lactate 2.4.\n\n### Hypothesis:\n Patient has elevated Cr\n\n### Answer:\nentailment'

#### PEFT (LoRA) config

In [None]:
# PEFT -  Parameter efficient fine tuning - a collection of techniques used to adapt large language models (LLMs) for specific tasks without modifying all of the model's parameters
# This step adds small LoRA layers on specific parts of the BioMistral architecture, and trains only those small layers — not the full 7B model.


from peft import LoraConfig, TaskType

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"], # These are the most expressive parts of the model. Adding LoRA to these modules lets the LLM learn new tasks efficiently.
    bias="none",
    task_type=TaskType.CAUSAL_LM, # This tells PEFT - We are training a generative model (left-to-right). Apply the LoRA adapters in the correct format.
)


#### SFTTrainer with QLoRA

In [None]:
from transformers import TrainingArguments
from trl import SFTTrainer

llm_training_args = TrainingArguments(
    output_dir="biomistral_mednli_qlora",
    per_device_train_batch_size=2,   # tune based on your GPU
    gradient_accumulation_steps=8,   # effective batch size = 16
    num_train_epochs=3, # More than 3 may overfit (MedNLI is small)
    learning_rate=2e-4, # standard for QLoRA.
    warmup_ratio=0.03, # Warmup = gradually increasing learning rate for stability. So for the first 3% of training steps - LR starts at 0, Slowly increases to full LR, Prevents unstable early training
    logging_strategy="steps",
    logging_steps=50, # Print training logs (loss, lr, etc.) every 50 training steps.
    eval_strategy="epoch", # Evaluate ONLY at the end of each epoch.
    save_strategy="epoch", # Save model checkpoint at the end of each epoch.
    load_best_model_at_end=True, # automatically reload the best checkpoint, based on the validation loss or metric.
    bf16=torch.cuda.is_available(),
    lr_scheduler_type="cosine", # The learning rate follows a cosine curve - starts high, decreases smoothly
)

def formatting_func(example):
    return example["text"]

# "Use this LLM (BioMistral), this data, and this LoRA config to train"
sft_trainer = SFTTrainer(
    model=llm_model,
    args=llm_training_args,
    train_dataset=train_llm_ds,
    eval_dataset=val_llm_ds,
    processing_class=llm_tokenizer,
    peft_config=peft_config,
    formatting_func=formatting_func,
)

sft_trainer.train()

Applying formatting function to train dataset:   0%|          | 0/11232 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/11232 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/11232 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/11232 [00:00<?, ? examples/s]

Applying formatting function to eval dataset:   0%|          | 0/1395 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/1395 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/1395 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/1395 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.
  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mij132[0m ([33mij132-rutgers-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,0.4052,0.855999,0.559162,1135894.0,0.819453
2,0.1961,1.011632,0.401387,2271788.0,0.818359
3,0.1332,1.119829,0.353559,3407682.0,0.817625


TrainOutput(global_step=2106, training_loss=0.3359882453454752, metrics={'train_runtime': 7510.2209, 'train_samples_per_second': 4.487, 'train_steps_per_second': 0.28, 'total_flos': 1.6283456081107354e+17, 'train_loss': 0.3359882453454752, 'epoch': 3.0})

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
save_path = "/content/drive/MyDrive/BioMistral_MedNLI_LoRA"

# Create directory if not exists
import os
os.makedirs(save_path, exist_ok=True)

# Save LoRA adapters + tokenizer
sft_trainer.model.save_pretrained(save_path)
llm_tokenizer.save_pretrained(save_path)

print("Saved to:", save_path)

Saved to: /content/drive/MyDrive/BioMistral_MedNLI_LoRA


To load the model later

In [None]:
from google.colab import drive
drive.mount('/content/drive')

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base_model_id = "BioMistral/BioMistral-7B"
lora_path = "/content/drive/MyDrive/BioMistral_MedNLI_LoRA"

tokenizer = AutoTokenizer.from_pretrained(lora_path)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)

model = PeftModel.from_pretrained(base_model, lora_path)
model.eval()

print("Model loaded successfully!")


#### Simple evaluation helper

In [None]:
import re
from tqdm import tqdm
from datasets import Dataset
import torch

def predict_label(texts, max_new_tokens=10, max_length=4096):
    llm_model.eval()
    preds = []

    for t in tqdm(texts):
        # Allow long premises + hypotheses
        inputs = llm_tokenizer(
            t,
            return_tensors="pt",
            truncation=True,
            max_length=max_length
        ).to(llm_model.device)

        with torch.no_grad():
            out = llm_model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=llm_tokenizer.eos_token_id,
                eos_token_id=llm_tokenizer.eos_token_id
            )

        # Only decode the newly generated tokens (not the input prompt)
        gen = llm_tokenizer.decode(
            out[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True
        )

        # Normalize
        gen = gen.strip().lower()
        gen = re.sub(r"[^a-z]", " ", gen)

        # Map generated text to label
        if "entailment" in gen:
            preds.append("entailment")
        elif "contradiction" in gen:
            preds.append("contradiction")
        elif "neutral" in gen:
            preds.append("neutral")
        else:
            # default fallback
            preds.append("neutral")

    return preds


In [None]:
premise = "Multiple sclerosis (MS) progresses through brain region-specific inflammation and degeneration, with poorly defined mechanisms. In individuals with MS, we identified increased expression of formyl peptide receptor 1 (FPR1) in central nervous system (CNS)-resident microglia and CNS-infiltrating macrophages. Blood amounts of N-formylated peptides, which are endogenous agonists of FPR1, correlated with disease progression in patients with MS. In MS mouse models, signaling through FPR1 promoted microglial mitochondrial dysfunction, causing axonal loss and apoptosis. FPR1-expressing microglia sustained the clonal expansion of myelin-reactive CD4+ T cells in the CNS. A CNS-penetrating small molecule FPR1 antagonist, T0080, mitigated autoimmune responses and axonal degeneration. Our study identifies FPR1 signaling as a potential mechanism for MS progression and suggests antagonizing FPR1 as a therapeutic approach."

hypothesis = "The paper reported that the small-molecule antagonist T0080 can fully reverse MS after a single oral dose in humans."

text = (
    "### Instruction:\n"
    "You are a medical NLI classifier. Given a medical premise and hypothesis, "
    "answer with one of: entailment, contradiction, neutral.\n\n"
    f"### Premise:\n{premise}\n\n"
    f"### Hypothesis:\n{hypothesis}\n\n"
    "### Answer:\n"
)

pred = predict_label([text])[0]
print(pred)


  0%|          | 0/1 [00:00<?, ?it/s]


RuntimeError: expected scalar type Float but found Half

## BioMistral-7B Few Shot NLI

In [None]:
!pip install -q "transformers>=4.44.0" "accelerate>=0.34.0"

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

model_id = "BioMistral/BioMistral-7B"

tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# A100 has plenty of VRAM; run in bfloat16 on GPU
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",   # put model on GPU
)
model.eval()

Device: cuda


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


pytorch_model.bin:   0%|          | 0.00/14.5G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): MistralRMSNorm((4096,)

In [None]:
import re
from typing import List, Dict, Tuple

def build_nli_prompt(
    premise: str,
    hypothesis: str,
    few_shot_examples: List[Dict[str, str]] = None,
) -> str:
    """
    few_shot_examples: list of dicts with keys: premise, hypothesis, label
                       label ∈ {"entailment", "contradiction", "neutral"}
    """

    system_instruction = (
        "You are a careful medical NLI classifier.\n"
        "Given a medical premise and a hypothesis, decide whether the hypothesis is:\n"
        "- entailment (definitely supported by the premise)\n"
        "- contradiction (definitely false given the premise)\n"
        "- neutral (not clearly supported or contradicted by the premise).\n"
        "Always answer with exactly one word: entailment, contradiction, or neutral.\n"
    )

    shots_str = ""
    if few_shot_examples:
        for i, ex in enumerate(few_shot_examples, start=1):
            shots_str += (
                f"\n### Example {i}\n"
                f"Premise: {ex['premise']}\n"
                f"Hypothesis: {ex['hypothesis']}\n"
                f"Answer: {ex['label']}\n"
            )

    query_str = (
        "\n### Now classify the following pair.\n"
        f"Premise: {premise}\n"
        f"Hypothesis: {hypothesis}\n"
        "Answer:"
    )

    return system_instruction + shots_str + query_str

In [None]:
def biomistral_nli(
    premise: str,
    hypothesis: str,
    few_shot_examples: List[Dict[str, str]] = None,
    max_new_tokens: int = 8,
) -> Tuple[str, str]:
    """
    Returns:
      (predicted_label, raw_generated_text)
    """

    prompt = build_nli_prompt(premise, hypothesis, few_shot_examples)

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=4096,
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # greedy decoding
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    # Only look at the newly generated tokens
    gen_text = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True,
    ).strip()

    # Normalize for robust matching
    norm = gen_text.lower()
    norm = re.sub(r"[^a-z]", " ", norm)
    norm = norm.strip()

    # Simple label parsing
    if "entailment" in norm:
        label = "entailment"
    elif "contradiction" in norm:
        label = "contradiction"
    elif "neutral" in norm:
        label = "neutral"
    else:
        # fallback: try first word, else default to neutral
        first = norm.split()[0] if norm else ""
        if first.startswith("entail"):
            label = "entailment"
        elif first.startswith("contra"):
            label = "contradiction"
        elif first.startswith("neutral"):
            label = "neutral"
        else:
            label = "neutral"

    return label, gen_text


In [None]:
few_shots = [
    {
        "premise": "Nystagmus and twiching of R arm was noted.",
        "hypothesis": "The patient had abnormal neuro exam.",
        "label": "entailment",
    },
    {
        "premise": "The patient denied any cough, dysuria, headache, photophobia, stiff neck, or diarrhea.",
        "hypothesis": "The patient complains of painful urination.",
        "label": "contradiction",
    },
    {
        "premise": "History of TIA [**5-/3025**] with left hemi[** Location **] that resolved.",
        "hypothesis": "Patient has abnormal brain MRI.",
        "label": "neutral",
    },
]

premise = """She started taking ibuprofen for it at [**First Name8 (NamePattern2) **] [**Last Name (un) 5416**] dose."""

hypothesis = """The patient is not in pain."""

label, raw = biomistral_nli(premise, hypothesis, few_shot_examples=few_shots)
print("Few-shot prediction:", label)
print("Raw model output:", repr(raw))

Few-shot prediction: contradiction
Raw model output: 'contradiction'


In [None]:
from tqdm.auto import tqdm
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
)
import numpy as np
import pandas as pd

labels_order = ["entailment", "contradiction", "neutral"]

def evaluate_mednli_biomistral(test_ds, few_shots):
    y_true = []
    y_pred = []
    raw_outputs = []

    for ex in tqdm(test_ds, desc="Evaluating BioMistral on MedNLI test"):
        premise = ex["Premise"]
        hypothesis = ex["Hypothesis"]
        gold = ex["Label"]  # already string like 'entailment'

        pred_label, raw = biomistral_nli(
            premise,
            hypothesis,
            few_shot_examples=few_shots,
            max_new_tokens=8,
        )

        y_true.append(gold)
        y_pred.append(pred_label)
        raw_outputs.append(raw)

    # Accuracy
    acc = accuracy_score(y_true, y_pred)

    # Micro / macro / weighted precision, recall, F1
    p_micro, r_micro, f1_micro, _ = precision_recall_fscore_support(
        y_true, y_pred, labels=labels_order, average="micro"
    )
    p_macro, r_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, labels=labels_order, average="macro"
    )
    p_weighted, r_weighted, f1_weighted, _ = precision_recall_fscore_support(
        y_true, y_pred, labels=labels_order, average="weighted"
    )

    # Per-class metrics
    report = classification_report(
        y_true,
        y_pred,
        labels=labels_order,
        target_names=labels_order,
        digits=4,
    )

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=labels_order)
    cm_df = pd.DataFrame(cm, index=[f"true_{l}" for l in labels_order],
                            columns=[f"pred_{l}" for l in labels_order])

    metrics = {
        "accuracy": acc,
        "precision_micro": p_micro,
        "recall_micro": r_micro,
        "f1_micro": f1_micro,
        "precision_macro": p_macro,
        "recall_macro": r_macro,
        "f1_macro": f1_macro,
        "precision_weighted": p_weighted,
        "recall_weighted": r_weighted,
        "f1_weighted": f1_weighted,
        "classification_report": report,
        "confusion_matrix": cm_df,
        "y_true": y_true,
        "y_pred": y_pred,
        "raw_generations": raw_outputs,
    }

    return metrics

metrics = evaluate_mednli_biomistral(test_ds, few_shots=few_shots)

print("Accuracy:", metrics["accuracy"])
print("Micro  F1:", metrics["f1_micro"])
print("Macro  F1:", metrics["f1_macro"])
print("Weighted F1:", metrics["f1_weighted"])
print("\nPer-class classification report:\n")
print(metrics["classification_report"])
print("\nConfusion matrix:\n")
print(metrics["confusion_matrix"])

Evaluating BioMistral on MedNLI test:   0%|          | 0/1422 [00:00<?, ?it/s]

Accuracy: 0.5513361462728551
Micro  F1: 0.5513361462728551
Macro  F1: 0.5215872176893334
Weighted F1: 0.5215872176893335

Per-class classification report:

               precision    recall  f1-score   support

   entailment     0.7295    0.3186    0.4435       474
contradiction     0.5211    0.9367    0.6697       474
      neutral     0.5207    0.3987    0.4516       474

     accuracy                         0.5513      1422
    macro avg     0.5904    0.5513    0.5216      1422
 weighted avg     0.5904    0.5513    0.5216      1422


Confusion matrix:

                    pred_entailment  pred_contradiction  pred_neutral
true_entailment                 151                 178           145
true_contradiction                1                 444            29
true_neutral                     55                 230           189


## BioMistral-7B - Hyperparameter tuned

In [None]:
# ============================================================
# 0. Install dependencies
# ============================================================
!pip install -q "transformers>=4.44.0" "datasets>=2.20.0" "accelerate>=0.34.0" \
              "bitsandbytes>=0.43.1" "peft>=0.12.0" "trl>=0.9.6" "scikit-learn"


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m465.5/465.5 kB[0m [31m37.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# ============================================================
# 1. Imports and basic setup
# ============================================================
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, TaskType
from trl import SFTTrainer

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Optional: small speed-up flags for A100
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True



Device: cuda


  self.setter(val)


In [None]:
# ============================================================
# 2. Load MedNLI dataset
#    (araag2/MedNLI, "processed" config)
# ============================================================
dataset = load_dataset("araag2/MedNLI", "processed")
train_ds = dataset["train"]
val_ds   = dataset["dev"]
test_ds  = dataset["test"]

print(dataset)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

processed/train-00000-of-00001.parquet:   0%|          | 0.00/637k [00:00<?, ?B/s]

processed/dev-00000-of-00001.parquet:   0%|          | 0.00/84.2k [00:00<?, ?B/s]

processed/test-00000-of-00001.parquet:   0%|          | 0.00/83.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/11232 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/1395 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1422 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'Label', 'Premise', 'Hypothesis'],
        num_rows: 11232
    })
    dev: Dataset({
        features: ['id', 'Label', 'Premise', 'Hypothesis'],
        num_rows: 1395
    })
    test: Dataset({
        features: ['id', 'Label', 'Premise', 'Hypothesis'],
        num_rows: 1422
    })
})


In [None]:
# ============================================================
# 3. Load BioMistral-7B in 4-bit with consistent fp16 dtype
# ============================================================
llm_id = "BioMistral/BioMistral-7B"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,  # <- compute in fp16
)

llm_tokenizer = AutoTokenizer.from_pretrained(llm_id)
if llm_tokenizer.pad_token is None:
    llm_tokenizer.pad_token = llm_tokenizer.eos_token

llm_model = AutoModelForCausalLM.from_pretrained(
    llm_id,
    quantization_config=quant_config,
    device_map="auto",
    torch_dtype=torch.float16,           # <- model weights in fp16
)

# Optional: slightly safer for training with gradient checkpointing
llm_model.config.use_cache = False

print("Model loaded.")



tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


pytorch_model.bin:   0%|          | 0.00/14.5G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Model loaded.


In [None]:
# ============================================================
# 4. Format MedNLI examples as instruction-tuned text
#    (same style you used earlier)
# ============================================================
def format_example(example):
    premise = example["Premise"]
    hypothesis = example["Hypothesis"]
    label = example["Label"]  # "entailment", "contradiction", "neutral"

    text = (
        "### Instruction:\n"
        "You are a careful medical natural language inference (NLI) classifier.\n"
        "You will be given a Premise (trusted medical information) and a "
        "Hypothesis (a claim about the patient or disease).\n"
        "Using ONLY the information in the Premise, decide which of the "
        "following labels is most appropriate:\n"
        "- entailment: the hypothesis must be true if the premise is true.\n"
        "- contradiction: the hypothesis must be false if the premise is true.\n"
        "- neutral: based on the premise alone, the hypothesis cannot be "
        "confidently confirmed or denied (it could be true or could be false).\n\n"
        "Always answer with exactly one word: entailment, contradiction, or neutral.\n\n"
        f"### Premise:\n{premise}\n\n"
        f"### Hypothesis:\n{hypothesis}\n\n"
        f"### Answer:\n{label}"
    )
    return {"text": text}

train_llm_ds = train_ds.map(format_example)
val_llm_ds   = val_ds.map(format_example)
test_llm_ds  = test_ds.map(format_example)

print("Example formatted sample:\n")
print(train_llm_ds[0]["text"][:500])

def formatting_func(example):
    # SFTTrainer expects a string or list of strings
    return example["text"]



Map:   0%|          | 0/11232 [00:00<?, ? examples/s]

Map:   0%|          | 0/1395 [00:00<?, ? examples/s]

Map:   0%|          | 0/1422 [00:00<?, ? examples/s]

Example formatted sample:

### Instruction:
You are a careful medical natural language inference (NLI) classifier.
You will be given a Premise (trusted medical information) and a Hypothesis (a claim about the patient or disease).
Using ONLY the information in the Premise, decide which of the following labels is most appropriate:
- entailment: the hypothesis must be true if the premise is true.
- contradiction: the hypothesis must be false if the premise is true.
- neutral: based on the premise alone, the hypothesis cannot


In [None]:
# ============================================================
# 5. Gentle, narrow LoRA configuration (Strategy B)
#    - smaller rank
#    - fewer target modules (q_proj, v_proj only)
#    - more dropout
# ============================================================
peft_config = LoraConfig(
    r=8,                        # was 16
    lora_alpha=16,              # was 32
    lora_dropout=0.1,           # more regularization than 0.05
    target_modules=["q_proj", "v_proj"],  # narrower than all proj + MLP
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)


In [None]:
# ============================================================
# 6. TrainingArguments: softer training for small MedNLI
#    - 1 epoch
#    - smaller LR
#    - weight decay, grad clipping
#    - fp16-only (no bf16) to avoid dtype mishaps
# ============================================================
output_dir = "biomistral_mednli_qlora_gentle"

llm_training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,        # effective batch size 16
    num_train_epochs=1,                   # gentler: 1 epoch on 11k examples
    learning_rate=5e-5,                   # smaller LR than 2e-4
    warmup_ratio=0.1,                     # more gradual warmup

    weight_decay=0.05,                    # regularization on LoRA params
    max_grad_norm=1.0,                    # gradient clipping

    logging_strategy="steps",
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    fp16=True,
    bf16=False,

    lr_scheduler_type="cosine",
    report_to="none",                     # or ["tensorboard"] if you want
)

# ============================================================
# 7. SFTTrainer setup
# ============================================================
sft_trainer = SFTTrainer(
    model=llm_model,
    args=llm_training_args,
    train_dataset=train_llm_ds,
    eval_dataset=val_llm_ds,
    processing_class=llm_tokenizer,      # preferred arg name
    peft_config=peft_config,
    formatting_func=formatting_func,
)


Applying formatting function to train dataset:   0%|          | 0/11232 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/11232 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/11232 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/11232 [00:00<?, ? examples/s]

Applying formatting function to eval dataset:   0%|          | 0/1395 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/1395 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/1395 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/1395 [00:00<?, ? examples/s]

In [None]:

# ============================================================
# 8. Train
# ============================================================
sft_trainer.train()

# After training, the Trainer will already have reloaded the best
# checkpoint according to eval_loss (because load_best_model_at_end=True).

# ============================================================
# 9. Save LoRA adapters + tokenizer (optional, e.g. to Drive)
# ============================================================
# If you want to save to Google Drive, mount it first:
# from google.colab import drive
# drive.mount('/content/drive')
# save_path = "/content/drive/MyDrive/BioMistral_MedNLI_LoRA_gentle"

save_path = "./BioMistral_MedNLI_LoRA_gentle"
os.makedirs(save_path, exist_ok=True)

sft_trainer.model.save_pretrained(save_path)
llm_tokenizer.save_pretrained(save_path)

print("Saved model and tokenizer to:", save_path)


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.


Epoch,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,0.3799,0.397116,0.405332,2483734.0,0.911627


Saved model and tokenizer to: ./BioMistral_MedNLI_LoRA_gentle


In [None]:
from google.colab import drive
import os

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Choose where to save in your Drive
#    Change this path if you want a different folder/name
save_path = "/content/drive/MyDrive/BioMistral_MedNLI_LoRA_gentle"

# 3. Create the directory if it doesn't exist
os.makedirs(save_path, exist_ok=True)

# 4. Save the LoRA-adapted model + tokenizer
sft_trainer.model.save_pretrained(save_path)
llm_tokenizer.save_pretrained(save_path)

print("Saved fine-tuned model and tokenizer to:", save_path)


Mounted at /content/drive
Saved fine-tuned model and tokenizer to: /content/drive/MyDrive/BioMistral_MedNLI_LoRA_gentle


In [None]:
def build_inference_prompt(premise, hypothesis):
    return (
        "### Instruction:\n"
        "You are a careful medical natural language inference (NLI) classifier.\n"
        "You will be given a Premise (trusted medical information) and a "
        "Hypothesis (a claim about the patient or disease).\n"
        "Using ONLY the information in the Premise, decide which of the "
        "following labels is most appropriate:\n"
        "- entailment: the hypothesis must be true if the premise is true.\n"
        "- contradiction: the hypothesis must be false if the premise is true.\n"
        "- neutral: based on the premise alone, the hypothesis cannot be "
        "confidently confirmed or denied (it could be true or could be false).\n\n"
        "Always answer with exactly one word: entailment, contradiction, or neutral.\n\n"
        f"### Premise:\n{premise}\n\n"
        f"### Hypothesis:\n{hypothesis}\n\n"
        "### Answer:\n"
    )


In [None]:
import re
import torch
from tqdm.auto import tqdm

llm_model.eval()
model_device = next(llm_model.parameters()).device
print("Model device:", model_device)

def predict_single_label(prompt, max_new_tokens=8, max_length=4096):
    inputs = llm_tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_length,
    ).to(model_device)

    with torch.no_grad():
        if model_device.type == "cuda":
            # Let autocast handle float16/float32 mixing like Trainer does
            with torch.cuda.amp.autocast():
                out = llm_model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=llm_tokenizer.eos_token_id,
                    eos_token_id=llm_tokenizer.eos_token_id,
                )
        else:
            # CPU fallback
            out = llm_model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=llm_tokenizer.eos_token_id,
                eos_token_id=llm_tokenizer.eos_token_id,
            )

    gen = llm_tokenizer.decode(
        out[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True,
    ).strip().lower()

    norm = re.sub(r"[^a-z]", " ", gen).strip()

    if "entailment" in norm:
        return "entailment"
    elif "contradiction" in norm:
        return "contradiction"
    elif "neutral" in norm:
        return "neutral"
    else:
        first = norm.split()[0] if norm else ""
        if first.startswith("entail"):
            return "entailment"
        elif first.startswith("contra"):
            return "contradiction"
        elif first.startswith("neutral"):
            return "neutral"
        else:
            return "neutral"


Model device: cuda:0


In [None]:
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
)
import pandas as pd

labels_order = ["entailment", "contradiction", "neutral"]

def evaluate_split(ds, split_name="test"):
    y_true = []
    y_pred = []

    for ex in tqdm(ds, desc=f"Evaluating {split_name}"):
        premise = ex["Premise"]
        hypothesis = ex["Hypothesis"]
        gold = ex["Label"]  # expected to be string: "entailment"/"contradiction"/"neutral"

        prompt = build_inference_prompt(premise, hypothesis)
        pred_label = predict_single_label(prompt)

        y_true.append(gold)
        y_pred.append(pred_label)

    # Accuracy
    acc = accuracy_score(y_true, y_pred)

    # Micro / macro / weighted precision, recall, F1
    p_micro, r_micro, f1_micro, _ = precision_recall_fscore_support(
        y_true, y_pred, labels=labels_order, average="micro"
    )
    p_macro, r_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, labels=labels_order, average="macro"
    )
    p_weighted, r_weighted, f1_weighted, _ = precision_recall_fscore_support(
        y_true, y_pred, labels=labels_order, average="weighted"
    )

    # Per-class breakdown
    report = classification_report(
        y_true,
        y_pred,
        labels=labels_order,
        target_names=labels_order,
        digits=4,
    )

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=labels_order)
    cm_df = pd.DataFrame(
        cm,
        index=[f"true_{l}" for l in labels_order],
        columns=[f"pred_{l}" for l in labels_order],
    )

    metrics = {
        "split": split_name,
        "accuracy": acc,
        "precision_micro": p_micro,
        "recall_micro": r_micro,
        "f1_micro": f1_micro,
        "precision_macro": p_macro,
        "recall_macro": r_macro,
        "f1_macro": f1_macro,
        "precision_weighted": p_weighted,
        "recall_weighted": r_weighted,
        "f1_weighted": f1_weighted,
        "classification_report": report,
        "confusion_matrix": cm_df,
        "y_true": y_true,
        "y_pred": y_pred,
    }

    print(f"\n=== {split_name.upper()} RESULTS ===")
    print("Accuracy:", acc)
    print("Micro  F1:", f1_micro)
    print("Macro  F1:", f1_macro)
    print("Weighted F1:", f1_weighted)
    print("\nPer-class classification report:\n")
    print(report)
    print("\nConfusion matrix:\n")
    print(cm_df)

    return metrics


In [None]:
# Evaluate on test set
test_metrics = evaluate_split(test_ds, split_name="test")

Evaluating test:   0%|          | 0/1422 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():



=== TEST RESULTS ===
Accuracy: 0.8227848101265823
Micro  F1: 0.8227848101265823
Macro  F1: 0.8224692967108802
Weighted F1: 0.8224692967108802

Per-class classification report:

               precision    recall  f1-score   support

   entailment     0.8337    0.7722    0.8018       474
contradiction     0.8816    0.9114    0.8963       474
      neutral     0.7546    0.7848    0.7694       474

     accuracy                         0.8228      1422
    macro avg     0.8233    0.8228    0.8225      1422
 weighted avg     0.8233    0.8228    0.8225      1422


Confusion matrix:

                    pred_entailment  pred_contradiction  pred_neutral
true_entailment                 366                  23            85
true_contradiction                6                 432            36
true_neutral                     67                  35           372


In [None]:
# ============================================================
# 11. Small sanity check on a custom example
# ============================================================
premise = (
    "Multiple sclerosis (MS) progresses through brain region-specific inflammation "
    "and degeneration. In individuals with MS, we identified increased expression "
    "of FPR1 in CNS-resident microglia and CNS-infiltrating macrophages. "
    "A CNS-penetrating small molecule FPR1 antagonist, T0080, mitigated autoimmune "
    "responses and axonal degeneration in MS mouse models."
)

hypothesis = "The paper reported that the small-molecule antagonist T0080 can fully reverse MS after a single oral dose in humans."

test_prompt = (
    "### Instruction:\n"
    "You are a medical NLI classifier. Given a medical premise and hypothesis, "
    "answer with one of: entailment, contradiction, neutral.\n\n"
    f"### Premise:\n{premise}\n\n"
    f"### Hypothesis:\n{hypothesis}\n\n"
    "### Answer:\n"
)

pred = predict_label([test_prompt])[0]
print("Predicted label:", pred)


NameError: name 'predict_label' is not defined