In [11]:
!pip install dspy-ai==2.4.17 bitsandbytes accelerate transformers datasets



In [12]:
import torch
import gc
import dspy
import numpy as np
from datasets import load_dataset
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
from dspy.teleprompt import BootstrapFewShot
from dspy.evaluate import Evaluate

# ==========================================
# 1. AGGRESSIVE MEMORY CLEANUP
# ==========================================
# This fixes the "Zombie Process" OOM error
torch.cuda.empty_cache()
gc.collect()

# ==========================================
# 2. LOAD LIGHTWEIGHT MODEL (Llama-3.2-3B)
# ==========================================
# We switch to 3B (2.4GB VRAM) instead of 8B (5.7GB VRAM)
model_id = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
print(f"Loading lightweight model: {model_id}...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

# Load Model
hf_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    trust_remote_code=True,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Fix pad token to prevent infinite loops
tokenizer.pad_token = tokenizer.eos_token
hf_model.config.pad_token_id = tokenizer.eos_token_id

# ==========================================
# 3. DEFINE ROBUST WRAPPER
# ==========================================
class LocalLlamaWrapper(dspy.LM):
    def __init__(self, model, tokenizer):
        super().__init__("local-llama")
        self.model = model
        self.tokenizer = tokenizer

    def basic_request(self, prompt, **kwargs):
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=150,   # Keep it concise
                do_sample=False,      # Deterministic
                pad_token_id=self.tokenizer.eos_token_id,
                temperature=0.0
            )
        generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
        response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        return [response.strip()]

    def __call__(self, prompt=None, messages=None, **kwargs):
        if messages and not prompt:
            prompt = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        return self.basic_request(prompt or " ", **kwargs)

# Connect DSPy to the new light model
student_lm = LocalLlamaWrapper(hf_model, tokenizer)
dspy.settings.configure(lm=student_lm)
print("✅ Llama-3.2-3B Loaded Successfully.")

# ==========================================
# 4. DATA SETUP (SQUAD)
# ==========================================
print("Loading Data...")
dataset = load_dataset("squad", split="validation")

def convert_to_dspy(row):
    return dspy.Example(
        question=row['question'],
        context=row['context'],
        answer=row['answers']['text'][0]
    ).with_inputs('question', 'context')

# We use 10 train / 20 dev examples for speed
trainset = [convert_to_dspy(x) for x in dataset.select(range(0, 10))]
devset = [convert_to_dspy(x) for x in dataset.select(range(10, 30))]

class QASignature(dspy.Signature):
    """Answer questions based on the context. Keep answers concise and exact."""
    context = dspy.InputField(desc="facts")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="short exact answer")

class QAModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.prog = dspy.ChainOfThought(QASignature)

    def forward(self, question, context):
        return self.prog(question=question, context=context)

# ==========================================
# 5. BASELINE EVALUATION
# ==========================================
# Evaluate how the raw 3B model performs before optimization
print("\n--- Baseline Evaluation (Zero-Shot) ---")
# Metric: Exact Match logic
def exact_match_metric(gold, pred, trace=None):
    return gold.answer.lower() in pred.answer.lower()

evaluator = Evaluate(devset=devset, metric=exact_match_metric, num_threads=1, display_progress=True)
baseline_score = evaluator(QAModule())
print(f"Baseline Score: {baseline_score}%")

# ==========================================
# 6. THE IMPROVEMENT: COST-AWARE OPTIMIZATION
# ==========================================
# This metric penalizes the "rambling" common in small models
def cost_aware_metric(gold, pred, trace=None):
    # 1. Base Accuracy
    is_correct = gold.answer.lower() in pred.answer.lower()
    base_score = 1.0 if is_correct else 0.0

    # 2. Length Penalty (The Fix for 'Bulky' answers)
    length = len(pred.answer.split())
    penalty = np.exp(-0.1 * max(0, length - 5)) # Penalize if answer > 5 words

    return base_score * penalty

print("\n--- Starting Optimization (BootstrapFewShot) ---")
# BootstrapFewShot is lighter and more robust for 3B models than MIPRO
teleprompter = BootstrapFewShot(
    metric=cost_aware_metric,
    max_bootstrapped_demos=2, # Learn from 2 examples
    max_labeled_demos=2
)

print("Compiling...")
# This step teaches the 3B model to be "smart and concise"
optimized_program = teleprompter.compile(
    QAModule(),
    trainset=trainset
)

# ==========================================
# 7. FINAL EVALUATION
# ==========================================
print("\n--- Final Evaluation (Optimized) ---")
final_score = evaluator(optimized_program)

print(f"\n=== RESULTS ===")
print(f"Model: Llama-3.2-3B (Lightweight)")
print(f"Baseline: {baseline_score}%")
print(f"Optimized: {final_score}%")

# Show the optimized prompt that improved the result
print("\n--- Winning Prompt ---")
student_lm.inspect_history(n=1)

Loading lightweight model: unsloth/Llama-3.2-3B-Instruct-bnb-4bit...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

✅ Llama-3.2-3B Loaded Successfully.
Loading Data...

--- Baseline Evaluation (Zero-Shot) ---
Average Metric: 17 / 20  (85.0): 100%|██████████| 20/20 [01:39<00:00,  4.96s/it]
Baseline Score: 85.0%

--- Starting Optimization (BootstrapFewShot) ---
Compiling...


 20%|██        | 2/10 [00:12<00:51,  6.44s/it]


Bootstrapped 2 full traces after 3 examples in round 0.

--- Final Evaluation (Optimized) ---
Average Metric: 16 / 20  (80.0): 100%|██████████| 20/20 [02:00<00:00,  6.03s/it]

=== RESULTS ===
Model: Llama-3.2-3B (Lightweight)
Baseline: 85.0%
Optimized: 80.0%

--- Winning Prompt ---




