In [None]:
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
from peft import PeftModel

import torch
from radon.metrics import mi_visit
from tqdm.notebook import tqdm

from datasets import load_from_disk, Dataset

import signal

from copy import deepcopy

In [None]:
MODEL_ID = "google/gemma-3-1b-it"
MODEL_CACHE_DIRECTORY = "./llm_models_cache"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    cache_dir=MODEL_CACHE_DIRECTORY,
    device_map="cuda",
    trust_remote_code=True
)

# Enable gradient checkpointing to trade compute for memory
model.gradient_checkpointing_enable()

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    cache_dir=MODEL_CACHE_DIRECTORY,
    trust_remote_code=True
)

custom_dataset = load_from_disk("./mbpp_preprocessed_dataset")

In [5]:
def generate_prompt(prompt: str) -> list[dict]:
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are a code generation model. Your task is to generate code snippets based on user prompts. You are to only write the full code for the required function and stop after returning the function output."},]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt},]
        },
    ]
    return messages

In [6]:
# Evaluate the baseline model on the test set with chat-templated prompts
def evaluate_model(model, dataset, tokenizer, max_length=1024):
    results = []

    for item in tqdm(dataset):
        results.append(
            {
                "task_id": item['task_id'],
                "generated_sequences": []
            }
        )

        input_prompt = item['text']
        messages = generate_prompt(input_prompt)
        
        inputs = tokenizer.apply_chat_template(
            messages,
            add_special_tokens = True,
            tokenize = True,
            return_dict = True,
            return_tensors = "pt",
            add_generation_prompt = True
        ).to("cuda")

        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_length,
                do_sample=True,
                top_p=0.9, # Adjust for more/less diversity
                temperature=0.6, # Adjust for more/less diversity
                num_return_sequences=10,
                pad_token_id=tokenizer.eos_token_id
            )

        for i in range(outputs.shape[0]):
            decoded_output = tokenizer.decode(outputs[i], skip_special_tokens=True)
            model_output = decoded_output.split("model\n")[-1].strip() # Extract the code part only
            results[-1]["generated_sequences"].append(model_output)

    return results

In [None]:
peft_model = PeftModel.from_pretrained(
    model,
    "./gemma-3-1b-dpo-250-epochs",
    device_map="cuda"
    )

total_params = sum(p.numel() for p in peft_model.parameters())
print(f"Total parameters after PEFT: {total_params}")

Total parameters after PEFT: 1012931712


In [8]:
# Freeze the lora model
for param in peft_model.parameters():
    param.requires_grad = False

In [9]:
# Define a custom exception for timeouts
class TimeoutException(Exception):
    pass

# Define the handler function for the signal
def timeout_handler(signum, frame):
    raise TimeoutException("Function call timed out")

# Set the signal handler for SIGALRM
signal.signal(signal.SIGALRM, timeout_handler)

<Handlers.SIG_DFL: 0>

In [10]:
def test_one_example(example):
    temp_namespace = {}

    try:
        exec(example["code"], temp_namespace)
    except Exception as e:
        print(f"Error executing code for example id {example['task_id']}: {e}")
        return [False, ["Code execution error"]]

    # print(f"Running tests for example id: {example['task_id']}")

    failed_tests = []

    all_tests_passed = True
    for i, test in enumerate(example['test_list']):
        try:
            signal.alarm(5)  # Set the alarm for 5 seconds
            exec(test, temp_namespace)
            signal.alarm(0)  # Disable the alarm
            # print(f"Test {i+1} passed.")
        except AssertionError as e:
            all_tests_passed = False
            print(f"Test {i+1} failed.")
            print(f"  -> AssertionError: {e}")
            failed_tests.append(i+1)
        except TimeoutException as e:
            all_tests_passed = False
            print(f"Test {i+1} failed.")
            print(f"  -> Timeout: {e}")
            failed_tests.append(i+1)
        except Exception as e:
            all_tests_passed = False
            print(f"Test {i+1} failed.")
            print(f"  -> Exception: {e}")
            failed_tests.append(i+1)
        finally:
            signal.alarm(0)  # Disable the alarm always
    
    del temp_namespace

    return [all_tests_passed, failed_tests] # Empty list if all tests passed

In [11]:
custom_dataset

Dataset({
    features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],
    num_rows: 372
})

In [12]:
custom_dataset_dict = {}
for example in custom_dataset:
    custom_dataset_dict[example["task_id"]] = example

In [14]:
new_dpo_dataset = {} # This dictionary is in form: {"task_id": {"prompt": str, "valid_solutions": List[str]}}
for item in custom_dataset:
    new_dpo_dataset[item['task_id']] = {
        "valid_solutions": []
    } 

In [None]:
# # Perform evaluation with the DPO finetuned model
# dpo_results = evaluate_model(peft_model, test_dataset, tokenizer)

# Generate sequences with the DPO finetuned model, then iteratively test them with test cases
# If the model manages to generate at least 2 correct solutions, we add the prompt and the valid outputs to the new pre-preprocessed dataset
current_dataset = custom_dataset.to_list()

total_len = len(current_dataset)

for i in range(5): # Iteratively improve the dataset over many iterations with simple self-feedback.
    generated_dataset = evaluate_model(peft_model, current_dataset, tokenizer, max_length=1024)
    current_dataset = [] # Reset for the next iteration
    for j, task in enumerate(generated_dataset):
        print(f"Evaluating output {i + 1} for task {task['task_id']}")
        for generated_solution in task['generated_sequences']:
            test_result = test_one_example({
                "task_id": task['task_id'],
                "code": generated_solution.replace("```python", "").replace("```", ""),
                "test_list": custom_dataset_dict[task['task_id']]['test_list']
            })

            if test_result[0]:
                new_dpo_dataset[task["task_id"]]["valid_solutions"].append(generated_solution)
            else:
                failed_generated_solutions = generated_solution
            
        if len(new_dpo_dataset[task["task_id"]]["valid_solutions"]) < 2:
            current_dataset.append(deepcopy(custom_dataset_dict[task['task_id']]))
            current_dataset[-1]["text"] += f"\nThe previous solution (provided below) did not pass all test cases. Please try again.\nFailed solution:\n{failed_generated_solutions}\n"

    print(f"{total_len - len(current_dataset)} examples so far after iteration {i+1}")


  0%|          | 0/372 [00:00<?, ?it/s]

Evaluating output 1 for task 602
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Evaluating output 1 for task 603
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> Timeout: Function call timed out
Test 3 failed.
  -> Timeout: Function call timed out
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.

  0%|          | 0/189 [00:00<?, ?it/s]

Evaluating output 2 for task 602
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Evaluating output 2 for task 603
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 fa

  0%|          | 0/189 [00:00<?, ?it/s]

Evaluating output 3 for task 602
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Evaluating output 3 for task 603
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 failed.
  -> AssertionError: 
Test 2 failed.
  -> AssertionError: 
Test 3 failed.
  -> AssertionError: 
Test 1 fa

In [None]:
# Add only those tasks which have at least 2 valid solutions to the final dataset
final_new_dpo_dataset = {
}

for key, value in new_dpo_dataset.items():
    if len(value["valid_solutions"]) >= 2:
        final_new_dpo_dataset[key] = value

In [None]:
# Filter the original dataset to only include those in final_new_dpo_dataset
custom_dataset_dict = {}
for example in custom_dataset:
    custom_dataset_dict[example["task_id"]] = example

In [None]:
# Compute MI scores for each valid solution and identify the best and worst solutions
for key, value in final_new_dpo_dataset.items():
    mi_scores = []
    for solution in value["valid_solutions"]:
        mi_score = mi_visit(solution.replace("python```", "").replace("```", ""), True)
        mi_scores.append(mi_score)

    max_index = np.argmax(mi_scores)
    min_index = np.argmin(mi_scores)

    final_new_dpo_dataset[key]["mi_scores"] = mi_scores
    final_new_dpo_dataset[key]["best_solution"] = value["valid_solutions"][max_index]
    final_new_dpo_dataset[key]["worst_solution"] = value["valid_solutions"][min_index]

In [None]:
# Number of tasks with at least 2 valid solutions
len(final_new_dpo_dataset)

183


In [None]:
# Add only those tasks where best and worst solutions are different
filtered_same = {}
for key, value in final_new_dpo_dataset.items():
    if value["best_solution"] != value["worst_solution"]:
        filtered_same[key] = value

In [None]:
# Resulting number of tasks with different best and worst solutions
len(filtered_same)

84


In [None]:
def generate_prompt(prompt: str) -> list[dict]:
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are a code generation model. Your task is to generate code snippets based on user prompts. You are to only write the full code for the required function and stop after returning the function output."},]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt},]
        },
    ]
    return messages

In [None]:
# Construct the final local DPO dataset
dpo_dataset_1b = []
for key, value in filtered_same.items():
    prompt = generate_prompt(custom_dataset_dict[int(key)]["text"])
    chosen = value["best_solution"]
    rejected = value["worst_solution"]
    dpo_dataset_1b.append((prompt, chosen, rejected))

In [None]:
# Save specialised, local DPO dataset
columns = ["prompt", "chosen", "rejected"]
df = pd.DataFrame(dpo_dataset_1b, columns=columns)

dpo_dataset_1b_ds = Dataset.from_pandas(df)
dpo_dataset_1b_ds.save_to_disk("./further_dpo_dataset_gemma3_1b_250_epochs")