In [1]:
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, TaskType




In [2]:
DATASET_NAME = "pubmed_qa"
DATASET_SUBSET = "pqa_labeled"
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
OUTPUT_DIR = "./llama-3.2-1B-pubmedqa"
MAX_LENGTH = 512
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 1e-5
NUM_EPOCHS = 5
SAVE_STEPS = 1000
HF_TOKEN = "MY ACCESS TOKEN"

In [3]:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
tokenizer.pad_token = tokenizer.eos_token

Loading tokenizer...


In [4]:
print("Loading dataset...")
dataset = load_dataset(DATASET_NAME, DATASET_SUBSET)
print("Dataset loaded.")

Loading dataset...
Dataset loaded.


In [5]:
print(dataset)


DatasetDict({
    train: Dataset({
        features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
        num_rows: 1000
    })
})


In [6]:
import pandas as pd
df = pd.DataFrame(dataset['train'])
print(df.head())
print(df.columns)

      pubid                                           question  \
0  21645374  Do mitochondria play a role in remodelling lac...   
1  16418930  Landolt C and snellen e acuity: differences in...   
2   9488747  Syncope during bathing in infants, a pediatric...   
3  17208539  Are the long-term results of the transanal pul...   
4  10808977  Can tailored interventions increase mammograph...   

                                             context  \
0  {'contexts': ['Programmed cell death (PCD) is ...   
1  {'contexts': ['Assessment of visual acuity dep...   
2  {'contexts': ['Apparent life-threatening event...   
3  {'contexts': ['The transanal endorectal pull-t...   
4  {'contexts': ['Telephone counseling and tailor...   

                                         long_answer final_decision  
0  Results depicted mitochondrial dynamics in viv...            yes  
1  Using the charts described, there was only a s...             no  
2  "Aquagenic maladies" could be a pediatric form...    

In [5]:
def explore_dataset(ds):
    print(f"Dataset size: {len(ds['train'])}")
    print(f"Dataset columns: {ds['train'].column_names}")
    print(f"Sample entry: {ds['train'][0]}")
    return ds

In [6]:
dataset = explore_dataset(dataset)

Dataset size: 1000
Dataset columns: ['pubid', 'question', 'context', 'long_answer', 'final_decision']
Sample entry: {'pubid': 21645374, 'question': 'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?', 'context': {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.', 'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occu

In [7]:
def preprocess_pubmed_qa(examples):
    formatted_prompts = []
    for i in range(len(examples["pubid"])):
        question = examples["question"][i]
        context = examples["context"][i]["contexts"][0] if examples["context"][i]["contexts"] else ""
        
        label = examples["final_decision"][i]
        
        if label == "yes":
            answer = f"Yes. Based on the medical literature: {context}"
        elif label == "no":
            answer = f"No. Based on the medical literature: {context}"
        else:  # maybe
            answer = f"The evidence is inconclusive. Based on the medical literature: {context}"
        
        instruction = f"Answer the following medical question based on evidence from medical literature: {question}"
        
        formatted_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful medical assistant that provides evidence-based answers.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n{instruction}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n{answer}<|eot_id|>"
        formatted_prompts.append(formatted_prompt)
    
    return {"formatted_text": formatted_prompts}

In [8]:
processed_dataset = dataset.map(
    preprocess_pubmed_qa,
    batched=True,
    remove_columns=dataset["train"].column_names  # Remove original columns
)
print("Dataset processed.")
print(f"Sample processed entry: {processed_dataset['train'][0]}")

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset processed.
Sample processed entry: {'formatted_text': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful medical assistant that provides evidence-based answers.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\nAnswer the following medical question based on evidence from medical literature: Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\nYes. Based on the medical literature: Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals

In [9]:
processed_dataset = processed_dataset["train"].train_test_split(test_size=0.1)
train_dataset = processed_dataset["train"]
eval_dataset = processed_dataset["test"]

In [10]:
def tokenize_function(examples):
    return tokenizer(
        examples["formatted_text"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    )

In [11]:
print("Tokenizing dataset...")
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["formatted_text"])
tokenized_eval = eval_dataset.map(tokenize_function, batched=True, remove_columns=["formatted_text"])

Tokenizing dataset...


Map:   0%|          | 0/900 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [12]:
tokenized_train.set_format("torch")
tokenized_eval.set_format("torch")

In [15]:
print("Setting up model...")
# Load base model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    token=HF_TOKEN,
    device_map="auto"
)

Setting up model...


In [16]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]
)

In [17]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()  # Shows what percentage of parameters will be trained

trainable params: 1,703,936 || all params: 1,237,518,336 || trainable%: 0.1377


In [19]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    evaluation_strategy="steps",
    eval_steps=SAVE_STEPS,
    eval_strategy="steps",
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    report_to="tensorboard",
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)


In [20]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # We're doing causal language modeling, not masked
)

In [21]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    data_collator=data_collator,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [22]:
print("Starting training...")
trainer.train()


Starting training...


Step,Training Loss,Validation Loss


TrainOutput(global_step=140, training_loss=3.856314195905413, metrics={'train_runtime': 14636.408, 'train_samples_per_second': 0.307, 'train_steps_per_second': 0.01, 'total_flos': 1.3081021469687808e+16, 'train_loss': 3.856314195905413, 'epoch': 4.8533333333333335})

In [23]:
print("Saving model...")
model.save_pretrained(f"{OUTPUT_DIR}/final")
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")

Saving model...


('./llama-3.2-1B-pubmedqa/final\\tokenizer_config.json',
 './llama-3.2-1B-pubmedqa/final\\special_tokens_map.json',
 './llama-3.2-1B-pubmedqa/final\\tokenizer.json')

In [24]:
print("Evaluating model...")
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")
print("Fine-tuning complete!")

Evaluating model...


Evaluation results: {'eval_loss': 3.56388783454895, 'eval_runtime': 34.1247, 'eval_samples_per_second': 2.93, 'eval_steps_per_second': 0.733, 'epoch': 4.8533333333333335}
Fine-tuning complete!


In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model_path = "./llama-3.2-1B-pubmedqa/final"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)

model = model.to(device)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    print(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}")

def ask_question(question, max_length=512):
    prompt = f"Question: {question}\nAnswer:"
    
    inputs = tokenizer(
        prompt, 
        return_tensors="pt", 
        padding=True,
        truncation=True,
        max_length=model.config.max_position_embeddings
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],  # Pass attention mask
            max_length=max_length,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id  # Explicitly pass pad_token_id
        )
    
    outputs = outputs.cpu()
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = response.split("Answer:")[1].strip() if "Answer:" in response else response

    return answer

question = "What are the effects of statins on cardiovascular health?"
answer = ask_question(question)
print(answer)

Using device: cuda
Statins have been shown to lower cholesterol levels, reduce the risk of heart attack, stroke, and heart failure, and improve overall cardiovascular health.

Key points:

*   Statins work by inhibiting the enzyme HMG-CoA reductase, which is involved in the production of cholesterol in the liver.
*   By reducing cholesterol levels, statins can help to lower the risk of heart attack, stroke, and heart failure.
*   Statins can also help to improve blood lipid profiles, which can reduce the risk of cardiovascular disease.
*   In addition to lowering cholesterol, statins can help to improve blood vessel function and reduce inflammation in the blood vessels.
