In [1]:
pip install transformers datasets peft bitsandbytes accelerate torch tqdm pandas numpy


Collecting bitsandbytes
  Downloading bitsandbytes-0.48.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manyl

In [2]:
pip install -q --no-deps xformers trl peft accelerate bitsandbytes


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m423.1/423.1 kB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[?25hNote: you may need to restart the kernel to use updated packages.


In [3]:
import os
import json
import torch
import warnings
import pandas as pd
import numpy as np
from datetime import datetime
from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
warnings.filterwarnings('ignore')



2025-10-29 10:08:44.278697: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761732524.521113      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761732524.588842      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
os.environ["HF_TOKEN"] = user_secrets.get_secret("huggingface")

os.environ["HF_USERNAME"] = "megrisdal"

# Step 1: Download and explore data

In [5]:
dataset = load_dataset('lavita/medical-qa-datasets', 'all-processed')
print(f"  Total examples: {len(dataset['train']):,}")

README.md: 0.00B [00:00, ?B/s]

all-processed/train-00000-of-00001-a77e2(…):   0%|          | 0.00/155M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/239357 [00:00<?, ? examples/s]

  Total examples: 239,357


In [6]:
print("Dataset Structure:")
print(f"   Splits: {list(dataset.keys())}")
print(f"   Features: {list(dataset['train'].features.keys())}")
print(f"   Total examples: {len(dataset['train']):,}")

Dataset Structure:
   Splits: ['train']
   Features: ['instruction', 'input', 'output', '__index_level_0__']
   Total examples: 239,357


In [7]:
# Show sample examples
print("Sample Examples:")
for i in range(3):
    example = dataset['train'][i]
    instruction = example.get('instruction', '')
    input_text = example.get('input', '')
    output_text = example.get('output', '')
    
    print(f"\n--- Example {i+1} ---")
    print(f"Instruction: {instruction[:100]}...")
    print(f"Input: {input_text[:100]}...")
    print(f"Output: {output_text[:150]}...")


Sample Examples:

--- Example 1 ---
Instruction: If you are a doctor, please answer the medical questions based on the patient's description....
Input: hi. im a home health aide and i have a client with scoliosis in the back and kidney disease. her fee...
Output: hi, thanks for contacting chatbot. swelling in the legs and feet can come from many causes, one of them being general circulation or ineffectiveness o...

--- Example 2 ---
Instruction: Please summerize the given abstract to a title...
Input: RATIONALE: The COVID-19 pandemic struck an immunologically naïve, globally interconnected population...
Output: Hydroxychloroquine vs. Azithromycin for Hospitalized Patients with COVID-19 (HAHPS): Results of a Randomized, Active Comparator Trial...

--- Example 3 ---
Instruction: Please summerize the given abstract to a title...
Input: Objectives: To investigate the experience of playing the harmonica for individuals with COPD. Method...
Output: Playing the harmonica with chronic obstruct

In [8]:
# Statistics
print("Dataset Statistics:")

instructions = [ex.get('instruction', '') for ex in dataset['train']]
inputs = [ex.get('input', '') for ex in dataset['train']]
outputs = [ex.get('output', '') for ex in dataset['train']]

avg_instruction_len = np.mean([len(s.split()) for s in instructions[:1000]])
avg_input_len = np.mean([len(s.split()) for s in inputs[:1000]])
avg_output_len = np.mean([len(s.split()) for s in outputs[:1000]])

print(f"   Average instruction length: {avg_instruction_len:.1f} words")
print(f"   Average input length: {avg_input_len:.1f} words")
print(f"   Average output length: {avg_output_len:.1f} words")

Dataset Statistics:
   Average instruction length: 11.0 words
   Average input length: 98.9 words
   Average output length: 65.8 words


# Step 2 : Clean and prepare data

In [9]:
def clean_text(text):
    if not text or pd.isna(text):
        return ""
    
    text = str(text).strip()
    # Remove excessive whitespace
    text = ' '.join(text.split())
    return text

In [10]:
def process_example(example):
    instruction = clean_text(example.get('instruction', ''))
    input_text = clean_text(example.get('input', ''))
    output_text = clean_text(example.get('output', ''))
    
    # Combine instruction and input as question
    if input_text:
        question = f"{instruction} {input_text}"
    else:
        question = instruction
    
    question = clean_text(question)
    answer = clean_text(output_text)
    
    # Quality filters
    if len(question) < 10 or len(answer) < 10:
        return None
    if len(question.split()) > 500 or len(answer.split()) > 500:
        return None
    
    return {'question': question, 'answer': answer}

In [11]:
# Process examples with a cap
processed_data = []
max_samples=50000

for i, example in enumerate(tqdm(dataset['train'], desc="Processing")):
    if len(processed_data) >= max_samples:
        break
    result = process_example(example)
    if result:
        processed_data.append(result)

print(f" Processed {len(processed_data):,} examples (max {max_samples:,})")
print(f" Filtered out: {len(dataset['train']) - len(processed_data):,} examples")


Processing:  22%|██▏       | 51485/239357 [00:04<00:16, 11467.80it/s]

 Processed 50,000 examples (max 50,000)
 Filtered out: 189,357 examples





In [12]:

train_size=0.7
val_size=0.15
test_size=0.15

# Shuffle data
np.random.seed(42)
indices = np.random.permutation(len(processed_data))

# Calculate split points
train_end = int(len(indices) * train_size)
val_end = train_end + int(len(indices) * val_size)

# Split indices
train_indices = indices[:train_end]
val_indices = indices[train_end:val_end]
test_indices = indices[val_end:]

# Create splits
train_data = [processed_data[i] for i in train_indices]
val_data = [processed_data[i] for i in val_indices]
test_data = [processed_data[i] for i in test_indices]

print(f"Split sizes:")
print(f"   Train: {len(train_data):,} ({train_size*100:.0f}%)")
print(f"   Validation: {len(val_data):,} ({val_size*100:.0f}%)")
print(f"   Test: {len(test_data):,} ({test_size*100:.0f}%)")



Split sizes:
   Train: 35,000 (70%)
   Validation: 7,500 (15%)
   Test: 7,500 (15%)


In [13]:
# Convert to HuggingFace Dataset format
def dict_format(data_list):
    return {
        'question': [item['question'] for item in data_list],
        'answer': [item['answer'] for item in data_list]
    }

dataset = DatasetDict({
    'train': Dataset.from_dict(dict_format(train_data)),
    'validation': Dataset.from_dict(dict_format(val_data)),
    'test': Dataset.from_dict(dict_format(test_data))
})

In [14]:
save_path="./prepared_medical_qa"
    
os.makedirs(save_path, exist_ok=True)

# Save in HuggingFace format
dataset.save_to_disk(save_path)

# Also save as JSON for inspection
json_path = os.path.join(save_path, "samples.json")
samples = {
    'train_samples': [dataset['train'][i] for i in range(min(5, len(dataset['train'])))],
    'val_samples': [dataset['validation'][i] for i in range(min(5, len(dataset['validation'])))],
    'statistics': {
        'train_size': len(dataset['train']),
        'val_size': len(dataset['validation']),
        'test_size': len(dataset['test']),
        'created_at': datetime.now().isoformat()
    }
}

with open(json_path, 'w', encoding='utf-8') as f:
    json.dump(samples, f, indent=2, ensure_ascii=False)

Saving the dataset (0/1 shards):   0%|          | 0/35000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7500 [00:00<?, ? examples/s]

# Step 3 : setup the model

In [15]:

model_name="google/gemma-3-1b-it"
    
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True
)

# Set padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=None,
    trust_remote_code=True,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/899 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

In [16]:
# Setup LoRA
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


trainable params: 2,981,888 || all params: 1,002,867,840 || trainable%: 0.2973


# Step 4 : Format data for the model

In [17]:
def format_prompt_gemma(question, answer=None):
    
    if answer is None:
        # For inference
        prompt = f"""<start_of_turn>user
{question}<end_of_turn>
<start_of_turn>model
"""
    else:
        # For training
        prompt = f"""<start_of_turn>user
{question}<end_of_turn>
<start_of_turn>model
{answer}<end_of_turn>"""
    
    return prompt

In [18]:
def tokenize_function(examples):
    # Format prompts
    prompts = [
        format_prompt_gemma(q, a)
        for q, a in zip(examples['question'], examples['answer'])
    ]
    
    # Tokenize
    tokenized = tokenizer(
        prompts,
        truncation=True,
        max_length=512,
        padding=True,
        return_tensors=None
    )
    
    # Labels are same as input_ids for causal LM
    tokenized["labels"] = tokenized["input_ids"].copy()
    
    return tokenized


In [19]:
max_length=512
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset['train'].column_names,
    desc="Tokenizing",
    num_proc=4  # Parallel processing
)

Tokenizing (num_proc=4):   0%|          | 0/35000 [00:00<?, ? examples/s]

Tokenizing (num_proc=4):   0%|          | 0/7500 [00:00<?, ? examples/s]

Tokenizing (num_proc=4):   0%|          | 0/7500 [00:00<?, ? examples/s]

#  Step 5 : train the model

In [20]:
output_dir="/kaggle/working//gemma_medical_qa"

os.makedirs(output_dir, exist_ok=True)

In [21]:
# Training arguments
training_args = TrainingArguments(
    # Output
    output_dir=output_dir,
    logging_dir=f"{output_dir}/logs",
    
    # Training hyperparameters
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,  
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_steps=100,
    
    # Optimization
    fp16=True,
    optim="paged_adamw_8bit",
    gradient_checkpointing=True,
    max_grad_norm=0.3,
    
    # Logging and evaluation
    logging_steps=50,
    eval_steps=250,
    save_steps=500,
    save_total_limit=3,
    eval_strategy="steps",
    
    # Best model
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    
    # Other
    report_to=[],
    seed=42,
    dataloader_num_workers=4,
    remove_unused_columns=False,
)

In [22]:
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False 
)

In [23]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    data_collator=data_collator,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [24]:
# Train
train_result = trainer.train()

It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss
250,2.6172,2.604192
500,2.5671,2.548115
750,2.5129,2.517644
1000,2.5065,2.497792
1250,2.5034,2.483223
1500,2.473,2.472858
1750,2.4854,2.465068
2000,2.4666,2.459803


In [25]:
# Save model
trainer.save_model()
tokenizer.save_pretrained(output_dir)

('/kaggle/working//gemma_medical_qa/tokenizer_config.json',
 '/kaggle/working//gemma_medical_qa/special_tokens_map.json',
 '/kaggle/working//gemma_medical_qa/chat_template.jinja',
 '/kaggle/working//gemma_medical_qa/tokenizer.model',
 '/kaggle/working//gemma_medical_qa/added_tokens.json',
 '/kaggle/working//gemma_medical_qa/tokenizer.json')

In [26]:
# Final evaluation
eval_results = trainer.evaluate()

print("Evaluation results:")
for key, value in eval_results.items():
    if isinstance(value, (int, float)):
        print(f"   {key}: {value:.4f}")


Evaluation results:
   eval_loss: 2.4598
   eval_runtime: 1390.1759
   eval_samples_per_second: 5.3950
   eval_steps_per_second: 1.3490
   epoch: 1.0000


In [27]:
# Save training info
training_info = {
    'model_name': 'google/gemma-2-1b-it',
    'dataset': 'lavita/medical-qa-datasets',
    'training_examples': len(tokenized_dataset['train']),
    'validation_examples': len(tokenized_dataset['validation']),
    'test_examples': len(tokenized_dataset['test']),
    'epochs': training_args.num_train_epochs,
    'learning_rate': training_args.learning_rate,
    'final_train_loss': float(train_result.training_loss),
    'final_eval_loss': float(eval_results.get('eval_loss', 0)),
    'training_time_seconds': train_result.metrics.get('train_runtime', 0),
    'timestamp': datetime.now().isoformat()
}

In [28]:
info_file = os.path.join(output_dir, 'training_info.json')
with open(info_file, 'w') as f:
    json.dump(training_info, f, indent=2)

print(info_file)

/kaggle/working//gemma_medical_qa/training_info.json
