In [1]:
# Install the latest compatible versions of the required libraries
!pip install transformers datasets torch bitsandbytes accelerate -U -q

print("Libraries installed/updated.")
print("‼️ IMPORTANT: Please restart the runtime now (Runtime -> Restart session) before proceeding.")

Libraries installed/updated.
‼️ IMPORTANT: Please restart the runtime now (Runtime -> Restart session) before proceeding.


In [2]:
# Function definitions
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import time
import os
import gc

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def measure_vram(load_model_func):
    if DEVICE == "cpu":
        # No VRAM to measure on CPU
        model = load_model_func()
        return model, 0

    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(DEVICE)
    vram_before = torch.cuda.max_memory_allocated(DEVICE)
    model = load_model_func()
    torch.cuda.synchronize(DEVICE)
    vram_after = torch.cuda.max_memory_allocated(DEVICE)
    vram_used = (vram_after - vram_before) / 1e6  # Convert to MB
    return model, vram_used

def load_quantized_model():
    quantization_config = BitsAndBytesConfig(load_in_4bit=True)
    m = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto"
    )
    return m


def get_vram_usage():
    if DEVICE == "cpu":
        return 0
    return torch.cuda.memory_allocated() / 1e6

def clear_vram():
    if DEVICE == "cpu":
        return
    gc.collect()
    torch.cuda.empty_cache()


# --- Data Storage ---
model_data = []

print(f"Running on device: {DEVICE}")
if DEVICE == "cpu":
    print("WARNING: This notebook is optimized for GPU. CPU execution will be slow and VRAM will not be measured.")

Running on device: cpu


In [3]:
# Load models
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"

# Load BF16 (Baseline) Model
print(f"Loading: BF16 (Baseline) for {MODEL_NAME}")

def load_bf16_model():
    m = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map=DEVICE
    )
    return m

model_bf16, vram_bf16 = measure_vram(load_bf16_model)
model_data.append({
    "name": "BF16 (Baseline)",
    "model": model_bf16,
    "vram_mb": vram_bf16
})
print(f"Model loaded. Peak VRAM used by this model: {vram_bf16:.2f} MB")

# Load INT8 (8-bit) Model
print(f"Loading: INT8 (8-bit) for {MODEL_NAME}")

def load_int8_model():
    config = BitsAndBytesConfig(load_in_8bit=True)
    m = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto"
    )
    return m

model_int8, vram_int8 = measure_vram(load_int8_model)
model_data.append({
    "name": "INT8 (8-bit)",
    "model": model_int8,
    "vram_mb": vram_int8
})
print(f"Model loaded. Peak VRAM used by this model: {vram_int8:.2f} MB")

# Load NF4 (4-bit) Model
print(f"Loading: NF4 (4-bit) for {MODEL_NAME}")

def load_nf4_model():
    config = BitsAndBytesConfig(load_in_4bit=True)
    m = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto"
    )
    return m

model_nf4, vram_nf4 = measure_vram(load_nf4_model)
model_data.append({
    "name": "NF4 (4-bit)",
    "model": model_nf4,
    "vram_mb": vram_nf4
})
print(f"Model loaded. Peak VRAM used by this model: {vram_nf4:.2f} MB\n")
print(f"All models loaded.")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
print("Tokenizer loaded.")

Loading: BF16 (Baseline) for Qwen/Qwen2-1.5B-Instruct


config.json:   0%|          | 0.00/660 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


ValueError: Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` requires `accelerate`. You can install it with `pip install accelerate`

In [None]:
# --- Generate text ---
import random

print("Generating Text...\n")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

SEED = int.from_bytes(os.urandom(4), byteorder="big")
PROMPT = "In a galaxy far, far away," # Try different prompts

for data in model_data:
    print(f"Generating with: {MODEL_NAME} {data['name']}...")
    torch.manual_seed(SEED)  # Same seed for each model for deterministic comparison

    inputs = tokenizer(PROMPT, return_tensors="pt").to(DEVICE)

    start_time = time.time()
    outputs = data["model"].generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7, # Experiment with different temperatures
        top_p=0.9,
        top_k=50,
        pad_token_id=tokenizer.eos_token_id # Suppress warning
    )
    end_time = time.time()

    data["inference_time_s"] = end_time - start_time
    data["generated_text"] = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Completed in {data['inference_time_s']:.2f} seconds.\n")

print("All text generations complete.")

In [None]:
# Print the output and the comparison

import textwrap # For the output to be readable on a screen

print("="*80)
print(" " * 25 + "COMPARISON: GENERATED TEXT")
print("="*80 + "\n")

for data in model_data:
    print(f"### MODEL: {MODEL_NAME}: {data['name']} ###")
    print("-" * 50)
    wrapped_text = textwrap.fill(data["generated_text"], width=80)
    print(wrapped_text)

    print("\n" + "="*80 + "\n")


print("\n" + "="*80)
print(" " * 24 + "COMPARISON: RESOURCE USAGE")
print("="*80 + "\n")

# --- Header ---
print(f"{'Model Precision':<20} | {'VRAM Usage (MB)':<20} | {'Inference Time (s)':<22} |")
print(f"{'':-<20} | {'':-<20} | {'':-<22} |")

# --- Table Rows ---
for data in model_data:
    name = data['name']
    vram = f"{data['vram_mb']:.2f}"
    inf_time = f"{data['inference_time_s']:.2f}"
    print(f"{name:<20} | {vram:<20} | {inf_time:<22} |")

print("\n" + "="*80 + "\n")