In [None]:
!pip install -q transformers datasets peft bitsandbytes accelerate

In [None]:
import os
import numpy as np
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, EvalPrediction, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType, AutoPeftModelForCausalLM
import torch
import gc
import re

# Ensure CUDA is available
if not torch.cuda.is_available():
    raise SystemError("CUDA is not available. This script requires a CUDA-enabled GPU.")

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

# --- Hardcode your Hugging Face and Weights & Biases tokens ---
os.environ["HF_TOKEN"] = "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"
os.environ["WANDB_API_KEY"] = "63ee3d8971de42124013ebdb34cda9d303028b54"

# Initialize Weights & Biases (optional, but good for tracking)
import wandb
wandb.login(key=os.environ["WANDB_API_KEY"])

# --- 1. Data Loading ---
print("Loading mathematics dataset from davidheineman/deepmind-math-large...")
dataset = load_dataset("davidheineman/deepmind-math-large", split='train')
train_test_split = dataset.train_test_split(test_size=0.005, seed=123)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]
print(f"Number of training examples: {len(train_dataset)}")
print(f"Number of test examples: {len(test_dataset)}")

train = train_dataset.map(lambda ex: {"prompt": f"Question: {ex['question']}\\nAnswer:", "completion": " " + ex['answer']})
test = test_dataset.map(lambda ex: {"prompt": f"Question: {ex['question']}\\nAnswer:", "completion": " " + ex['answer']})

print("Example from training data:")
print(train[0])

# --- 2. Tokenization and Training ---
model_name = "deepseek-ai/deepseek-coder-1.3b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.environ["HF_TOKEN"])
tokenizer.pad_token = tokenizer.eos_token

def tokenize_fn(ex):
    tok = tokenizer(
        ex["prompt"] + ex["completion"],
        truncation=True,
        max_length=128,
        padding="max_length"
    )
    tok["labels"] = tok["input_ids"].copy()
    return tok

train_tok = train.map(tokenize_fn, batched=True, remove_columns=train.column_names)
test_tok = test.map(tokenize_fn, batched=True, remove_columns=test.column_names)

def compute_metrics(eval_preds: EvalPrediction):
    predictions, labels = eval_preds
    predictions = np.argmax(predictions, axis=-1)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]
    exact_match = sum(1 for pred, label in zip(decoded_preds, decoded_labels) if pred == label) / len(decoded_preds)
    return {"exact_match": exact_match}

# Define the LoRA configuration once
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# Define BitsAndBytes config once
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

# --- 3. Refactored Training Loop ---
output_dir = "./deepseek_mathpatch"
os.makedirs(output_dir, exist_ok=True)

# Find the latest checkpoint if it exists
def get_latest_checkpoint(output_dir):
    checkpoints = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")]
    if not checkpoints:
        return None
    checkpoint_numbers = [int(re.search(r'\d+', c).group()) for c in checkpoints]
    latest_checkpoint_number = max(checkpoint_numbers)
    return os.path.join(output_dir, f"checkpoint-{latest_checkpoint_number}")

latest_checkpoint = get_latest_checkpoint(output_dir)

if latest_checkpoint:
    print(f"Resuming training from latest checkpoint: {latest_checkpoint}")
    model = AutoPeftModelForCausalLM.from_pretrained(
        latest_checkpoint,
        torch_dtype=torch.float16,
        use_cache=False,
        token=os.environ["HF_TOKEN"]
    )
else:
    print("No checkpoint found. Starting training from scratch.")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
        token=os.environ["HF_TOKEN"],
        use_cache=False
    )
    model = get_peft_model(model, lora_config)

model.enable_input_require_grads()
model.print_trainable_parameters()

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    num_train_epochs=1,
    logging_steps=50,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=245,
    save_steps=245,  # Save every 250 steps
    save_total_limit=3, # Keep a few checkpoints
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="exact_match",
    greater_is_better=True,
    report_to="wandb",
    run_name="deepseek-math-finetune",
    gradient_checkpointing=True,
    # Set max_steps to the total number of steps you want to train for
    # The Trainer will handle the rest
    #max_steps=9478, 
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=test_tok,
    compute_metrics=compute_metrics,
)

# Start training. The Trainer will handle resuming from the last checkpoint automatically.
trainer.train(resume_from_checkpoint=latest_checkpoint)

print("Full training process completed.")

# Final Evaluation and Inference with the best model
print("\n--- Final Model Evaluation and Prediction ---")
latest_checkpoint = get_latest_checkpoint(output_dir)
if latest_checkpoint:
    final_model = AutoPeftModelForCausalLM.from_pretrained(
        latest_checkpoint,
        torch_dtype=torch.float16,
        token=os.environ["HF_TOKEN"],
        use_cache=True
    )
    final_trainer = Trainer(
        model=final_model,
        tokenizer=tokenizer,
        args=training_args,
    )

    eval_results = final_trainer.evaluate(eval_dataset=test_tok)
    print(f"Final Evaluation Results: {eval_results}")

    print("\n--- Testing Final Inference ---")
    test_question = "What is (15 + 27)?"
    inputs = tokenizer(f"Question: {test_question}\nAnswer:", return_tensors="pt").to(final_model.device)
    with torch.no_grad():
        outputs = final_model.generate(
            **inputs,
            max_new_tokens=10,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Input: {test_question}")
    print(f"Output: {response}")

# Finish the Weights & Biases run
wandb.finish()