# Compare Original vs Pruned Model
This notebook compares the original Qwen2.5-Coder-3B-Instruct with your pruned model.

We will test both models on 20 Java coding problems and compare:
- **Speed**: How fast each model generates code
- **Accuracy**: How many problems each model solves correctly

## 1. Setup and Imports

In [1]:
import torch
import time
import json
import os
import re
import subprocess
from transformers import AutoTokenizer, AutoModelForCausalLM

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    
# Check for transformers version
import transformers
print(f"Transformers version: {transformers.__version__}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.5.1+cu121
CUDA available: True
GPU: NVIDIA GeForce GTX 1060 3GB
Transformers version: 4.57.3


In [2]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
print("CUDA version:", torch.version.cuda)


CUDA available: True
Device name: NVIDIA GeForce GTX 1060 3GB
CUDA version: 12.1


## 2. Load Test Problems
We'll use the first 20 problems from the Java benchmark

In [4]:
# Simple test problems (first 20)
benchmark = [
  {"id": "java_001_is_prime", "prompt": "Write a Java method isPrime that returns true if n is a prime number, otherwise false.", "signature": "public static boolean isPrime(int n)", "tests": [{"input": "2", "expected": "true"}, {"input": "4", "expected": "false"}, {"input": "17", "expected": "true"}]},
  {"id": "java_002_reverse_string", "prompt": "Write a Java method reverseString that returns the reversed version of the input string.", "signature": "public static String reverseString(String s)", "tests": [{"input": "\"abc\"", "expected": "\"cba\""}, {"input": "\"hello\"", "expected": "\"olleh\""}]},
  {"id": "java_003_sum_array", "prompt": "Write a Java method sumArray that returns the sum of all elements in the given integer array.", "signature": "public static int sumArray(int[] arr)", "tests": [{"input": "new int[]{1, 2, 3}", "expected": "6"}, {"input": "new int[]{0, 0, 0}", "expected": "0"}]},
  {"id": "java_004_factorial", "prompt": "Write a Java method factorial that returns n! (n factorial). Assume n is non-negative.", "signature": "public static long factorial(int n)", "tests": [{"input": "0", "expected": "1"}, {"input": "5", "expected": "120"}]},
  {"id": "java_005_max_in_array", "prompt": "Write a Java method maxInArray that returns the maximum value in the given integer array.", "signature": "public static int maxInArray(int[] arr)", "tests": [{"input": "new int[]{1, 2, 3}", "expected": "3"}, {"input": "new int[]{-1, -5, -3}", "expected": "-1"}]},
  {"id": "java_006_min_in_array", "prompt": "Write a Java method minInArray that returns the minimum value in the given integer array.", "signature": "public static int minInArray(int[] arr)", "tests": [{"input": "new int[]{1, 2, 3}", "expected": "1"}, {"input": "new int[]{-1, -5, -3}", "expected": "-5"}]},
  {"id": "java_007_is_palindrome", "prompt": "Write a Java method isPalindrome that returns true if the given string is a palindrome.", "signature": "public static boolean isPalindrome(String s)", "tests": [{"input": "\"racecar\"", "expected": "true"}, {"input": "\"abc\"", "expected": "false"}]},
  {"id": "java_008_count_vowels", "prompt": "Write a Java method countVowels that returns the number of vowels in the given string.", "signature": "public static int countVowels(String s)", "tests": [{"input": "\"hello\"", "expected": "2"}, {"input": "\"AEIOU\"", "expected": "5"}]},
  {"id": "java_009_fibonacci", "prompt": "Write a Java method fibonacci that returns the n-th Fibonacci number.", "signature": "public static int fibonacci(int n)", "tests": [{"input": "0", "expected": "0"}, {"input": "5", "expected": "5"}]},
  {"id": "java_010_find_index", "prompt": "Write a Java method findIndex that returns the index of target in the array, or -1 if not found.", "signature": "public static int findIndex(int[] arr, int target)", "tests": [{"input": "new int[]{1, 2, 3, 4}, 3", "expected": "2"}, {"input": "new int[]{1, 2, 3, 4}, 5", "expected": "-1"}]},
  {"id": "java_011_contains_duplicate", "prompt": "Write a Java method containsDuplicate that returns true if any value appears at least twice in the array.", "signature": "public static boolean containsDuplicate(int[] nums)", "tests": [{"input": "new int[]{1, 2, 3, 1}", "expected": "true"}, {"input": "new int[]{1, 2, 3, 4}", "expected": "false"}]},
  {"id": "java_012_max_subarray_sum", "prompt": "Write a Java method maxSubArraySum that returns the largest sum of a contiguous subarray.", "signature": "public static int maxSubArraySum(int[] nums)", "tests": [{"input": "new int[]{-2,1,-3,4,-1,2,1,-5,4}", "expected": "6"}, {"input": "new int[]{1}", "expected": "1"}]},
  {"id": "java_013_two_sum", "prompt": "Write a Java method hasTwoSum that returns true if there exist two distinct indices i and j such that nums[i] + nums[j] == target.", "signature": "public static boolean hasTwoSum(int[] nums, int target)", "tests": [{"input": "new int[]{2, 7, 11, 15}, 9", "expected": "true"}, {"input": "new int[]{1, 2, 3}, 10", "expected": "false"}]},
  {"id": "java_014_is_anagram", "prompt": "Write a Java method isAnagram that returns true if two given strings are anagrams of each other.", "signature": "public static boolean isAnagram(String s, String t)", "tests": [{"input": "\"anagram\", \"nagaram\"", "expected": "true"}, {"input": "\"rat\", \"car\"", "expected": "false"}]},
  {"id": "java_015_remove_whitespace", "prompt": "Write a Java method removeWhitespace that returns a new string with all whitespace removed.", "signature": "public static String removeWhitespace(String s)", "tests": [{"input": "\"a b c\"", "expected": "\"abc\""}, {"input": "\"   hello   world   \"", "expected": "\"helloworld\""}]},
  {"id": "java_016_power", "prompt": "Write a Java method power that returns x raised to the power n (x^n).", "signature": "public static long power(int x, int n)", "tests": [{"input": "2, 3", "expected": "8"}, {"input": "5, 2", "expected": "25"}]},
  {"id": "java_017_is_sorted", "prompt": "Write a Java method isSortedAscending that returns true if the array is sorted in ascending order.", "signature": "public static boolean isSortedAscending(int[] arr)", "tests": [{"input": "new int[]{1, 2, 3, 4}", "expected": "true"}, {"input": "new int[]{3, 2, 1}", "expected": "false"}]},
  {"id": "java_018_second_largest", "prompt": "Write a Java method secondLargest that returns the second largest distinct number in the array.", "signature": "public static int secondLargest(int[] arr)", "tests": [{"input": "new int[]{1, 2, 3}", "expected": "2"}, {"input": "new int[]{5, 1, 5, 2}", "expected": "2"}]},
  {"id": "java_019_is_rotation", "prompt": "Write a Java method isRotation that returns true if string b is a rotation of string a.", "signature": "public static boolean isRotation(String a, String b)", "tests": [{"input": "\"abcde\", \"cdeab\"", "expected": "true"}, {"input": "\"abcde\", \"abced\"", "expected": "false"}]},
  {"id": "java_020_valid_parentheses", "prompt": "Write a Java method isValidParentheses that returns true if the input string containing brackets is valid.", "signature": "public static boolean isValidParentheses(String s)", "tests": [{"input": "\"()\"", "expected": "true"}, {"input": "\"([)]\"", "expected": "false"}]}
]

print(f"Loaded {len(benchmark)} test problems")

Loaded 20 test problems


## 3. Load Models

**IMPORTANT**: Update the `PRUNED_MODEL_PATH` to point to your pruned model folder!

In [5]:
pruned_model_50_path = "C:/Users/namnd/Documents/QwenCoder-50"

pruned_model_5_path = "C:/Users/namnd/Documents/QwenCoder-5"

In [52]:
# Load pruned model (5%)
print(" Loading PRUNED model...")
pruned_tokenizer_5 = AutoTokenizer.from_pretrained(
    pruned_model_5_path,
    trust_remote_code=True,
    local_files_only=True
)

pruned_model_5 = AutoModelForCausalLM.from_pretrained(
    pruned_model_5_path,
    torch_dtype=torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True,      # Optimize CPU memory usage
    trust_remote_code=True,
    local_files_only=True
)
print(f"✓ Pruned model loaded on GPU\n")
print("5% model device:", next(pruned_model_5.parameters()).device)

 Loading PRUNED model...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 54.13it/s]

✓ Pruned model loaded on GPU

5% model device: cpu





In [6]:
# Load pruned model

print(" Loading PRUNED model...")
pruned_tokenizer_50 = AutoTokenizer.from_pretrained(
    pruned_model_50_path,
    trust_remote_code=True,
    local_files_only=True,
    device_map="auto"
)

pruned_model_50 = AutoModelForCausalLM.from_pretrained(
    pruned_model_50_path,
    torch_dtype=torch.float16,  # Use float16 for CPU
    low_cpu_mem_usage=True,      # Optimize CPU memory usage
    trust_remote_code=True,
    local_files_only=True,
    device_map="auto"
)
print(f"✓ Pruned model loaded on GPU\n")
print("50% model device:", next(pruned_model_50.parameters()).device)

 Loading PRUNED model...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.75it/s]
Some parameters are on the meta device because they were offloaded to the cpu.


✓ Pruned model loaded on GPU

50% model device: cuda:0


## 4. Helper Functions
These functions help us generate code and test it

In [8]:
def generate_code(model, tokenizer, task):
    """
    Generate Java code for a given task using a model.
    Returns: (generated_code, time_taken_in_seconds)
    """
    signature = task["signature"]

    # Create prompt
    prompt = f"""
    You are a strict Java code generator.
  Write ONLY the Java method, this task
You MUST follow these unbreakable rules:

1. Output EXACTLY ONE Java method.
2. ZERO explanations.
3. ZERO repeated methods.
4. ZERO comments.
5. ZERO blank copies of the method.
6. Output MUST start with 'public static'.
7. Output MUST end with the closing bracket }} of that method.
Write ONLY the Java method, NO EXPLANATION, ONLY WRITE THE METHOD ONE TIME for this task:
{task['prompt']}

Signature: {signature}

Write the complete method:"""

    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Generate and measure time
    start_time = time.time()
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.2,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    elapsed_time = time.time() - start_time

    # Decode only the new tokens
    gen_ids = output_ids[0][inputs["input_ids"].shape[1]:]
    code = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()

    # Clean up the output
    code = code.replace("```java", "").replace("```", "").strip()

    # Extract just the method
    if "public static" in code:
        code = code[code.index("public static"):]
        if "}" in code:
            code = code[:code.rfind("}")+1]

    return code, elapsed_time


def test_code(task, method_code):
    """
    Test if the generated code passes all test cases.
    Returns: True if all tests pass, False otherwise
    """
    try:
        # Parse signature to get return type and method name
        sig_match = re.search(r'public\s+static\s+(\S+)\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(', task["signature"])
        if not sig_match:
            return False

        return_type, method_name = sig_match.group(1), sig_match.group(2)

        # Build test code
        test_calls = []
        for i, test in enumerate(task["tests"], 1):
            inp, expected = test["input"], test["expected"]

            if return_type == "String":
                condition = f"!{method_name}({inp}).equals({expected})"
            else:
                condition = f"{method_name}({inp}) != {expected}"

            test_calls.append(f"if ({condition}) throw new Exception(\"Test {i} failed\");")

        # Create full Java file
        java_code = f"""
public class Main {{
    {method_code}

    public static void main(String[] args) {{
        try {{
            {chr(10).join('            ' + tc for tc in test_calls)}
            System.out.println("OK");
        }} catch (Exception e) {{
            System.out.println("FAIL");
        }}
    }}
}}
"""

        # Write to file
        os.makedirs("temp", exist_ok=True)
        with open("temp/Main.java", "w") as f:
            f.write(java_code)

        # Compile
        compile_result = subprocess.run(["javac", "temp/Main.java"], capture_output=True, text=True)
        if compile_result.returncode != 0:
            return False

        # Run
        run_result = subprocess.run(["java", "-cp", "temp", "Main"], capture_output=True, text=True, timeout=5)
        return run_result.stdout.strip() == "OK"

    except Exception:
        return False

print("Helper functions ready!")

Helper functions ready!


In [9]:
output_path = "C:/Users/namnd/Documents/Pruning-LLM/output"
if os.path.exists(output_path):
   print("File found:", output_path)
else:
   print("File not found!")

File found: C:/Users/namnd/Documents/Pruning-LLM/output


## 5. Run Comparison
Test both models on all problems

In [10]:
output_file = open(os.path.join(output_path, "model_comparison_results.txt"), "w")
# Store accuracy & speed results
pruned_5_results = []
pruned_50_results = []


print("Starting comparison...\n")
print("=" * 60)

for i, task in enumerate(benchmark, 1):
    task_id = task["id"]
    print(f"\n[{i}/{len(benchmark)}] Testing: {task_id}")

    # ==========================================
    # 50% PRUNED MODEL
    # ==========================================
    print("  → 50% Pruned model...", end=" ")
    pruned_50_code, pruned_50_time = generate_code(pruned_model_50, pruned_tokenizer_50, task)
    pruned_50_pass = test_code(task, pruned_50_code)
    pruned_50_results.append({
        "id": task_id, "passed": pruned_50_pass, "time": pruned_50_time
    })
    print(f"{'✓ PASS' if pruned_50_pass else '✗ FAIL'} ({pruned_50_time:.2f}s)")

    # ==========================================
    # WRITE RESULTS TO SINGLE OUTPUT FILE
    # ==========================================
    output_file.write("\n" + "="*80 + "\n")
    output_file.write(f"TASK: {task_id}\n")
    output_file.write("="*80 + "\n\n")

    output_file.write("--- 50% PRUNED MODEL OUTPUT ---\n")
    output_file.write(pruned_50_code + "\n\n")

print("\n" + "="*60)
print("Comparison complete!")
print(f"Saved everything to: {os.path.join(output_path, 'model_comparison_results.txt')}")

output_file.close()


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Starting comparison...


[1/20] Testing: java_001_is_prime
  → 50% Pruned model... ✗ FAIL (192.81s)

[2/20] Testing: java_002_reverse_string
  → 50% Pruned model... ✗ FAIL (73.29s)

[3/20] Testing: java_003_sum_array
  → 50% Pruned model... ✗ FAIL (84.24s)

[4/20] Testing: java_004_factorial
  → 50% Pruned model... ✗ FAIL (305.16s)

[5/20] Testing: java_005_max_in_array
  → 50% Pruned model... ✗ FAIL (160.14s)

[6/20] Testing: java_006_min_in_array
  → 50% Pruned model... 

KeyboardInterrupt: 

In [None]:
   # ==========================================
    # 5% PRUNED MODEL
    # ==========================================
    print("  → 5% Pruned model...", end=" ")
    pruned_5_code, pruned_5_time = generate_code(pruned_model_5, pruned_tokenizer_5, task)
    pruned_5_pass = test_code(task, pruned_5_code)
    pruned_5_results.append({
        "id": task_id, "passed": pruned_5_pass, "time": pruned_5_time
    })
    print(f"{'✓ PASS' if pruned_5_pass else '✗ FAIL'} ({pruned_5_time:.2f}s)")


## 6. Results and Analysis

In [None]:
# Calculate statistics
original_passed = sum(1 for r in original_results if r["passed"])
pruned_passed = sum(1 for r in pruned_results if r["passed"])

original_avg_time = sum(r["time"] for r in original_results) / len(original_results)
pruned_avg_time = sum(r["time"] for r in pruned_results) / len(pruned_results)

total_tests = len(benchmark)

# Display results
print("\n" + "="*60)
print("                    FINAL RESULTS")
print("="*60)
print()
print(f"Total Problems: {total_tests}")
print()
print("ACCURACY (How many problems solved correctly):")
print(f"  Original Model: {original_passed}/{total_tests} = {original_passed/total_tests*100:.1f}%")
print(f"  Pruned Model:   {pruned_passed}/{total_tests} = {pruned_passed/total_tests*100:.1f}%")
print(f"  Difference:     {pruned_passed - original_passed} problems ({(pruned_passed - original_passed)/total_tests*100:+.1f}%)")
print()
print("SPEED (Average time per problem):")
print(f"  Original Model: {original_avg_time:.3f} seconds")
print(f"  Pruned Model:   {pruned_avg_time:.3f} seconds")
print(f"  Speedup:        {original_avg_time/pruned_avg_time:.2f}x {'faster' if pruned_avg_time < original_avg_time else 'slower'}")
print()
print("="*60)

# Show which problems each model got wrong
original_failed = [r["id"] for r in original_results if not r["passed"]]
pruned_failed = [r["id"] for r in pruned_results if not r["passed"]]

if original_failed:
    print(f"\nOriginal model failed on: {', '.join(original_failed)}")
if pruned_failed:
    print(f"Pruned model failed on: {', '.join(pruned_failed)}")

# Save generated code to files
os.makedirs("../outputs", exist_ok=True)

# Check if generated code was captured
if original_results and "generated_code" in original_results[0]:
    # Save original model solutions
    with open("../outputs/original_model_solutions.txt", "w", encoding="utf-8") as f:
        f.write("="*80 + "\n")
        f.write("ORIGINAL MODEL GENERATED SOLUTIONS\n")
        f.write("="*80 + "\n\n")
        for result in original_results:
            f.write(f"\n{'='*80}\n")
            f.write(f"Problem: {result['id']}\n")
            f.write(f"Status: {'PASS ✓' if result['passed'] else 'FAIL ✗'}\n")
            f.write(f"Time: {result['time']:.2f}s\n")
            f.write(f"{'='*80}\n\n")
            f.write(result['generated_code'])
            f.write("\n\n")
    print("\n✓ Original model solutions saved to: outputs/original_model_solutions.txt")
    
    # Save pruned model solutions
    with open("../outputs/pruned_model_solutions.txt", "w", encoding="utf-8") as f:
        f.write("="*80 + "\n")
        f.write("PRUNED MODEL GENERATED SOLUTIONS\n")
        f.write("="*80 + "\n\n")
        for result in pruned_results:
            f.write(f"\n{'='*80}\n")
            f.write(f"Problem: {result['id']}\n")
            f.write(f"Status: {'PASS ✓' if result['passed'] else 'FAIL ✗'}\n")
            f.write(f"Time: {result['time']:.2f}s\n")
            f.write(f"{'='*80}\n\n")
            f.write(result['generated_code'])
            f.write("\n\n")
    print("✓ Pruned model solutions saved to: outputs/pruned_model_solutions.txt")
else:
    print("\n⚠ Generated code not available in current results.")
    print("To capture generated code, modify cell 12 to add 'generated_code' to results:")
    print("  original_results.append({'id': ..., 'passed': ..., 'time': ..., 'generated_code': orig_code})")
    print("  pruned_results.append({'id': ..., 'passed': ..., 'time': ..., 'generated_code': pruned_code})")

## 6.5 Save Generated Code
Re-generate and save the code solutions from both models

## 7. Save Results (Optional)

## 8. Interactive Prompting
Test both models with custom prompts