In [None]:
from datasets import load_from_disk
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np

import torch
import os
import gzip
import json
import fileinput # Used to uncomment the execution line
from typing import Iterable, Dict, List
from tqdm.notebook import tqdm

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel

In [None]:
MODEL_ID = "google/gemma-3-4b-it"
MODEL_CACHE_DIRECTORY = "./llm_models_cache"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    cache_dir=MODEL_CACHE_DIRECTORY,
    device_map="cuda",
    trust_remote_code=True
)

# Enable gradient checkpointing to trade compute for memory
model.gradient_checkpointing_enable()

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    cache_dir=MODEL_CACHE_DIRECTORY,
    trust_remote_code=True
)

test_dataset = load_from_disk("./mbpp_test_with_signatures")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# Check max token length of the chosen code
max_code_length = max([len(tokenizer.encode(item['code'])) for item in test_dataset])
print(f"Max code length: {max_code_length}")


Max code length: 512


In [6]:
def generate_prompt(prompt: str) -> list[dict]:
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are a code generation model. Your task is to generate code snippets based on user prompts. You are to only write the full code for the required function and stop after returning the function output."},]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt},]
        },
    ]
    return messages

In [7]:
# Evaluate the baseline model on the test set with chat-templated prompts
def evaluate_model(model, dataset, tokenizer, max_length=1024):
    results = []

    for item in tqdm(dataset):
        results.append(
            {
                "task_id": item['task_id'],
                "generated_sequences": []
            }
        )

        input_prompt = item['text']
        messages = generate_prompt(input_prompt)
        
        inputs = tokenizer.apply_chat_template(
            messages,
            add_special_tokens = True,
            tokenize = True,
            return_dict = True,
            return_tensors = "pt",
            add_generation_prompt = True
        ).to("cuda")

        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_length,
                do_sample=True,
                top_p=0.9, # Adjust for more/less diversity
                temperature=0.6, # Adjust for more/less diversity
                num_return_sequences=10,
                pad_token_id=tokenizer.eos_token_id
            )

        for i in range(outputs.shape[0]):
            decoded_output = tokenizer.decode(outputs[i], skip_special_tokens=True)
            model_output = decoded_output.split("model\n")[-1].strip() # Extract the code part only
            results[-1]["generated_sequences"].append(model_output)

    return results

In [8]:
# Check total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 4300079472


In [None]:
peft_model = PeftModel.from_pretrained(
    model,
    "./gemma-3-4b-dpo-250-epochs-further-250-epochs",
    device_map="cuda"
    )

total_params = sum(p.numel() for p in peft_model.parameters())
print(f"Total parameters after PEFT: {total_params}")

Total parameters after PEFT: 4332867952


In [10]:
# Freeze the lora model
for param in peft_model.parameters():
    param.requires_grad = False

In [None]:
# Perform evaluation with the DPO finetuned model
for i in range(1, 6):
    dpo_results = evaluate_model(peft_model, test_dataset, tokenizer)
    # Save the results to a JSON file
    with open(f"gemma_4b_dpo_250_further_250_model_output_run_{i}.json", "w") as f:
        json.dump(dpo_results, f)