In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import gc
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    get_scheduler,
    BitsAndBytesConfig
)
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

# Explicitly import bitsandbytes correctly
try:
    import bitsandbytes as bnb
except ImportError:
    raise ImportError()

def clear_gpu_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    print(f"Memory cleared. GPU memory: {torch.cuda.memory_allocated() // 1024**2} MB")

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    if torch.cuda.is_available():
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"BitsAndBytes Version: {bnb.__version__}")
        print(f"Initial GPU memory: {torch.cuda.memory_allocated() // 1024**2} MB")

    # Define model identifiers from the Hugging Face model hub
    teacher_model_id = "meta-llama/llama-3.2-3b"  # Replace with actual identifier
    student_model_id = "meta-llama/llama-3.2-1b"  # Replace with actual identifier

    # Create 8-bit quantization configuration
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_threshold=6.0,
        llm_int8_has_fp16_weight=False,
        bnb_4bit_compute_dtype=torch.float16
    )

    # Load the tokenizer first
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load teacher model with BitsAndBytesConfig in evaluation mode
    print("Loading teacher model in 8-bit...")
    teacher_model = AutoModelForCausalLM.from_pretrained(
        teacher_model_id,
        quantization_config=quantization_config,
        device_map="auto"
    )
    teacher_model.eval()  # Set to evaluation mode
    print(f"Teacher model loaded. GPU memory: {torch.cuda.memory_allocated() // 1024**2} MB")

    # Load student model with BitsAndBytesConfig in training mode
    print("Loading student model in 8-bit...")
    student_model = AutoModelForCausalLM.from_pretrained(
        student_model_id,
        quantization_config=quantization_config,
        device_map="auto"
    )
    student_model.train()  # Set to training mode
    print(f"Student model loaded. GPU memory: {torch.cuda.memory_allocated() // 1024**2} MB")

    # Load a sample dataset (using WikiText-2 as an example)
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:5%]")  # Use smaller portion

    # Tokenize the dataset; adjust max_length as needed
    def tokenize_function(example):
        return tokenizer(example["text"], truncation=True, max_length=256)

    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

    # Use a data collator for causal language modeling
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Create a DataLoader with smaller batch size
    batch_size = 2  # Adjust based on available memory
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=data_collator
    )

    # Define optimizer using BitsAndBytes 8-bit optimizer
    print("Setting up 8-bit optimizer...")
    optimizer = bnb.optim.AdamW8bit(
        student_model.parameters(),
        lr=5e-5,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0.01
    )

    num_epochs = 2
    total_steps = num_epochs * len(dataloader)

    # Set up scheduler
    scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=total_steps // 10,
        num_training_steps=total_steps
    )

    # Temperature for softening logits during distillation
    temperature = 2.0
    kl_loss_fct = nn.KLDivLoss(reduction="batchmean")

    # Training loop for distillation
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")

        for step, batch in enumerate(progress_bar):
            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits / temperature
                del teacher_outputs
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits / temperature
            del student_outputs

            loss = kl_loss_fct(
                F.log_softmax(student_logits, dim=-1),
                F.softmax(teacher_logits, dim=-1)
            ) * (temperature ** 2)

            del teacher_logits
            del student_logits

            loss.backward()
            optimizer.step()
            scheduler.step()

            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")
        print(f"GPU memory after epoch: {torch.cuda.memory_allocated() // 1024**2} MB")

    print("Saving 8-bit model...")

    student_model.save_pretrained("distilled_llama_1b_8bit", safe_serialization=True)
    tokenizer.save_pretrained("distilled_llama_1b_8bit")

    with open("distilled_llama_1b_8bit/quantization_config.json", "w") as f:
        f.write(quantization_config.to_json_string())

    print("Distilled model saved as 'distilled_llama_1b_8bit'.")

    del teacher_model
    del student_model
    clear_gpu_memory()
    print("Process completed successfully.")

if __name__ == "__main__":
    main()

Using device: cuda
CUDA Version: 12.4
BitsAndBytes Version: 0.45.3
Initial GPU memory: 0 MB
Loading tokenizer...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading teacher model in 8-bit...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Teacher model loaded. GPU memory: 3443 MB
Loading student model in 8-bit...
Student model loaded. GPU memory: 4875 MB


Map:   0%|          | 0/1836 [00:00<?, ? examples/s]

Setting up 8-bit optimizer...


Epoch 1:   0%|          | 1/918 [00:00<14:00,  1.09it/s, loss=168]

Memory cleared. GPU memory: 5974 MB


Epoch 1:   2%|▏         | 21/918 [00:06<05:22,  2.78it/s, loss=18.8]

Memory cleared. GPU memory: 5917 MB


Epoch 1:   4%|▍         | 41/918 [00:12<05:38,  2.59it/s, loss=461]

Memory cleared. GPU memory: 6090 MB


Epoch 1:   7%|▋         | 61/918 [00:18<05:20,  2.67it/s, loss=226]

Memory cleared. GPU memory: 6007 MB


Epoch 1:   9%|▉         | 81/918 [00:24<05:04,  2.74it/s, loss=124]

Memory cleared. GPU memory: 5993 MB


Epoch 1:  11%|█         | 101/918 [00:30<05:06,  2.66it/s, loss=89.8]

Memory cleared. GPU memory: 5982 MB


Epoch 1:  13%|█▎        | 121/918 [00:36<04:54,  2.71it/s, loss=320]

Memory cleared. GPU memory: 6140 MB


Epoch 1:  15%|█▌        | 141/918 [00:41<04:40,  2.77it/s, loss=10.3]

Memory cleared. GPU memory: 5917 MB


Epoch 1:  18%|█▊        | 161/918 [00:47<04:26,  2.84it/s, loss=251]

Memory cleared. GPU memory: 6136 MB


Epoch 1:  20%|█▉        | 181/918 [00:53<04:20,  2.83it/s, loss=152]

Memory cleared. GPU memory: 6084 MB


Epoch 1:  22%|██▏       | 201/918 [00:58<04:17,  2.78it/s, loss=98.6]

Memory cleared. GPU memory: 6053 MB


Epoch 1:  24%|██▍       | 221/918 [01:04<03:58,  2.92it/s, loss=0.139]

Memory cleared. GPU memory: 5912 MB


Epoch 1:  26%|██▋       | 241/918 [01:09<03:53,  2.90it/s, loss=157]

Memory cleared. GPU memory: 6104 MB


Epoch 1:  28%|██▊       | 261/918 [01:15<03:41,  2.96it/s, loss=0.131]

Memory cleared. GPU memory: 5910 MB


Epoch 1:  31%|███       | 281/918 [01:21<03:36,  2.94it/s, loss=0.125]

Memory cleared. GPU memory: 5911 MB


Epoch 1:  33%|███▎      | 301/918 [01:26<03:34,  2.87it/s, loss=3.38]

Memory cleared. GPU memory: 5917 MB


Epoch 1:  35%|███▍      | 321/918 [01:32<03:36,  2.75it/s, loss=165]

Memory cleared. GPU memory: 6084 MB


Epoch 1:  37%|███▋      | 341/918 [01:38<03:22,  2.86it/s, loss=77.3]

Memory cleared. GPU memory: 5985 MB


Epoch 1:  39%|███▉      | 361/918 [01:43<03:17,  2.82it/s, loss=11]

Memory cleared. GPU memory: 5922 MB


Epoch 1:  42%|████▏     | 381/918 [01:49<03:15,  2.75it/s, loss=190]

Memory cleared. GPU memory: 6138 MB


Epoch 1:  44%|████▎     | 401/918 [01:55<02:55,  2.95it/s, loss=0.109]

Memory cleared. GPU memory: 5909 MB


Epoch 1:  46%|████▌     | 421/918 [02:00<02:51,  2.90it/s, loss=0.103]

Memory cleared. GPU memory: 5910 MB


Epoch 1:  48%|████▊     | 441/918 [02:06<02:56,  2.70it/s, loss=128]

Memory cleared. GPU memory: 6106 MB


Epoch 1:  50%|█████     | 461/918 [02:12<02:36,  2.92it/s, loss=5.65]

Memory cleared. GPU memory: 5921 MB


Epoch 1:  52%|█████▏    | 481/918 [02:17<02:36,  2.80it/s, loss=73.1]

Memory cleared. GPU memory: 6015 MB


Epoch 1:  55%|█████▍    | 501/918 [02:23<02:27,  2.83it/s, loss=169]

Memory cleared. GPU memory: 6089 MB


Epoch 1:  57%|█████▋    | 521/918 [02:29<02:21,  2.81it/s, loss=4.06]

Memory cleared. GPU memory: 5916 MB


Epoch 1:  59%|█████▉    | 541/918 [02:34<02:10,  2.90it/s, loss=117]

Memory cleared. GPU memory: 6114 MB


Epoch 1:  61%|██████    | 561/918 [02:40<02:05,  2.85it/s, loss=0.092]

Memory cleared. GPU memory: 5912 MB


Epoch 1:  63%|██████▎   | 581/918 [02:46<01:58,  2.84it/s, loss=104]

Memory cleared. GPU memory: 6094 MB


Epoch 1:  65%|██████▌   | 601/918 [02:51<01:51,  2.85it/s, loss=6.54]

Memory cleared. GPU memory: 5926 MB


Epoch 1:  68%|██████▊   | 621/918 [02:57<01:47,  2.77it/s, loss=11.3]

Memory cleared. GPU memory: 5924 MB


Epoch 1:  70%|██████▉   | 641/918 [03:03<01:38,  2.80it/s, loss=6.65]

Memory cleared. GPU memory: 5920 MB


Epoch 1:  72%|███████▏  | 661/918 [03:09<01:34,  2.73it/s, loss=118]

Memory cleared. GPU memory: 6066 MB


Epoch 1:  74%|███████▍  | 681/918 [03:14<01:22,  2.89it/s, loss=2.44]

Memory cleared. GPU memory: 5915 MB


Epoch 1:  76%|███████▋  | 701/918 [03:20<01:20,  2.71it/s, loss=108]

Memory cleared. GPU memory: 6059 MB


Epoch 1:  79%|███████▊  | 721/918 [03:26<01:07,  2.91it/s, loss=0.0778]

Memory cleared. GPU memory: 5911 MB


Epoch 1:  81%|████████  | 741/918 [03:31<01:01,  2.88it/s, loss=10.1]

Memory cleared. GPU memory: 5926 MB


Epoch 1:  83%|████████▎ | 761/918 [03:37<00:52,  2.99it/s, loss=0.0767]

Memory cleared. GPU memory: 5911 MB


Epoch 1:  85%|████████▌ | 781/918 [03:43<00:47,  2.89it/s, loss=6.57]

Memory cleared. GPU memory: 5920 MB


Epoch 1:  87%|████████▋ | 801/918 [03:48<00:41,  2.83it/s, loss=170]

Memory cleared. GPU memory: 6137 MB


Epoch 1:  89%|████████▉ | 821/918 [03:54<00:32,  2.97it/s, loss=0.0756]

Memory cleared. GPU memory: 5910 MB


Epoch 1:  92%|█████████▏| 841/918 [04:00<00:26,  2.85it/s, loss=2.92]

Memory cleared. GPU memory: 5916 MB


Epoch 1:  94%|█████████▍| 861/918 [04:05<00:19,  2.86it/s, loss=112]

Memory cleared. GPU memory: 6081 MB


Epoch 1:  96%|█████████▌| 881/918 [04:11<00:13,  2.79it/s, loss=51.2]

Memory cleared. GPU memory: 6009 MB


Epoch 1:  98%|█████████▊| 901/918 [04:17<00:06,  2.81it/s, loss=4.47]

Memory cleared. GPU memory: 5916 MB


Epoch 1: 100%|██████████| 918/918 [04:21<00:00,  3.51it/s, loss=101]


Epoch 1 average loss: 88.7638
GPU memory after epoch: 6083 MB


Epoch 2:   0%|          | 1/918 [00:00<07:41,  1.99it/s, loss=0.0694]

Memory cleared. GPU memory: 5911 MB


Epoch 2:   2%|▏         | 21/918 [00:06<05:20,  2.80it/s, loss=74.3]

Memory cleared. GPU memory: 6018 MB


Epoch 2:   4%|▍         | 41/918 [00:12<05:16,  2.77it/s, loss=134]

Memory cleared. GPU memory: 6135 MB


Epoch 2:   7%|▋         | 61/918 [00:17<05:04,  2.82it/s, loss=66.9]

Memory cleared. GPU memory: 6064 MB


Epoch 2:   9%|▉         | 81/918 [00:23<05:01,  2.78it/s, loss=186]

Memory cleared. GPU memory: 6134 MB


Epoch 2:  11%|█         | 101/918 [00:28<04:50,  2.81it/s, loss=168]

Memory cleared. GPU memory: 6137 MB


Epoch 2:  13%|█▎        | 121/918 [00:34<04:46,  2.78it/s, loss=69]

Memory cleared. GPU memory: 6047 MB


Epoch 2:  15%|█▌        | 141/918 [00:40<04:26,  2.92it/s, loss=6.28]

Memory cleared. GPU memory: 5921 MB


Epoch 2:  18%|█▊        | 161/918 [00:46<04:34,  2.75it/s, loss=63.5]

Memory cleared. GPU memory: 5990 MB


Epoch 2:  20%|█▉        | 181/918 [00:51<04:10,  2.94it/s, loss=5.25]

Memory cleared. GPU memory: 5919 MB


Epoch 2:  22%|██▏       | 201/918 [00:57<04:08,  2.88it/s, loss=107]

Memory cleared. GPU memory: 6055 MB


Epoch 2:  24%|██▍       | 221/918 [01:03<04:03,  2.86it/s, loss=158]

Memory cleared. GPU memory: 6095 MB


Epoch 2:  26%|██▋       | 241/918 [01:08<03:57,  2.85it/s, loss=63.8]

Memory cleared. GPU memory: 6044 MB


Epoch 2:  28%|██▊       | 261/918 [01:14<03:56,  2.78it/s, loss=170]

Memory cleared. GPU memory: 6121 MB


Epoch 2:  31%|███       | 281/918 [01:20<03:48,  2.78it/s, loss=133]

Memory cleared. GPU memory: 6092 MB


Epoch 2:  33%|███▎      | 301/918 [01:25<03:40,  2.80it/s, loss=59.4]

Memory cleared. GPU memory: 6043 MB


Epoch 2:  35%|███▍      | 321/918 [01:31<03:29,  2.84it/s, loss=90.8]

Memory cleared. GPU memory: 6069 MB


Epoch 2:  37%|███▋      | 341/918 [01:37<03:29,  2.75it/s, loss=107]

Memory cleared. GPU memory: 6076 MB


Epoch 2:  39%|███▉      | 361/918 [01:42<03:12,  2.89it/s, loss=0.0615]

Memory cleared. GPU memory: 5910 MB


Epoch 2:  42%|████▏     | 381/918 [01:48<03:08,  2.84it/s, loss=2.73]

Memory cleared. GPU memory: 5917 MB


Epoch 2:  44%|████▎     | 401/918 [01:54<03:01,  2.86it/s, loss=128]

Memory cleared. GPU memory: 6060 MB


Epoch 2:  46%|████▌     | 421/918 [01:59<02:49,  2.93it/s, loss=68.9]

Memory cleared. GPU memory: 6002 MB


Epoch 2:  48%|████▊     | 441/918 [02:05<02:48,  2.83it/s, loss=9.2]

Memory cleared. GPU memory: 5928 MB


Epoch 2:  50%|█████     | 461/918 [02:11<02:42,  2.81it/s, loss=32.8]

Memory cleared. GPU memory: 5982 MB


Epoch 2:  52%|█████▏    | 481/918 [02:16<02:36,  2.79it/s, loss=69.8]

Memory cleared. GPU memory: 6042 MB


Epoch 2:  55%|█████▍    | 501/918 [02:22<02:23,  2.90it/s, loss=74.4]

Memory cleared. GPU memory: 6051 MB


Epoch 2:  57%|█████▋    | 521/918 [02:28<02:19,  2.84it/s, loss=1.89]

Memory cleared. GPU memory: 5917 MB


Epoch 2:  59%|█████▉    | 541/918 [02:33<02:07,  2.95it/s, loss=0.058]

Memory cleared. GPU memory: 5910 MB


Epoch 2:  61%|██████    | 561/918 [02:39<02:06,  2.81it/s, loss=90.1]

Memory cleared. GPU memory: 6075 MB


Epoch 2:  63%|██████▎   | 581/918 [02:44<01:58,  2.84it/s, loss=120]

Memory cleared. GPU memory: 6086 MB


Epoch 2:  65%|██████▌   | 601/918 [02:50<01:47,  2.96it/s, loss=0.0601]

Memory cleared. GPU memory: 5910 MB


Epoch 2:  68%|██████▊   | 621/918 [02:56<01:42,  2.88it/s, loss=79.6]

Memory cleared. GPU memory: 6098 MB


Epoch 2:  70%|██████▉   | 641/918 [03:02<01:39,  2.80it/s, loss=31]

Memory cleared. GPU memory: 5972 MB


Epoch 2:  72%|███████▏  | 661/918 [03:07<01:31,  2.80it/s, loss=84.4]

Memory cleared. GPU memory: 6073 MB


Epoch 2:  74%|███████▍  | 681/918 [03:13<01:21,  2.89it/s, loss=7.54]

Memory cleared. GPU memory: 5922 MB


Epoch 2:  76%|███████▋  | 701/918 [03:19<01:18,  2.77it/s, loss=84.6]

Memory cleared. GPU memory: 6050 MB


Epoch 2:  79%|███████▊  | 721/918 [03:24<01:08,  2.86it/s, loss=70.6]

Memory cleared. GPU memory: 6065 MB


Epoch 2:  81%|████████  | 741/918 [03:30<01:00,  2.92it/s, loss=66.4]

Memory cleared. GPU memory: 6044 MB


Epoch 2:  83%|████████▎ | 761/918 [03:36<00:56,  2.79it/s, loss=13.7]

Memory cleared. GPU memory: 5930 MB


Epoch 2:  85%|████████▌ | 781/918 [03:41<00:46,  2.97it/s, loss=0.0614]

Memory cleared. GPU memory: 5910 MB


Epoch 2:  87%|████████▋ | 801/918 [03:47<00:41,  2.80it/s, loss=105]

Memory cleared. GPU memory: 6106 MB


Epoch 2:  89%|████████▉ | 821/918 [03:52<00:34,  2.81it/s, loss=102]

Memory cleared. GPU memory: 6102 MB


Epoch 2:  92%|█████████▏| 841/918 [03:58<00:27,  2.83it/s, loss=39.1]

Memory cleared. GPU memory: 5979 MB


Epoch 2:  94%|█████████▍| 861/918 [04:04<00:19,  2.97it/s, loss=0.0616]

Memory cleared. GPU memory: 5911 MB


Epoch 2:  96%|█████████▌| 881/918 [04:09<00:13,  2.82it/s, loss=101]

Memory cleared. GPU memory: 6091 MB


Epoch 2:  98%|█████████▊| 901/918 [04:15<00:06,  2.80it/s, loss=125]

Memory cleared. GPU memory: 6097 MB


Epoch 2: 100%|██████████| 918/918 [04:20<00:00,  3.53it/s, loss=1.77]


Epoch 2 average loss: 61.8027
GPU memory after epoch: 5916 MB
Saving 8-bit model...
Distilled model saved as 'distilled_llama_1b_8bit'.
Memory cleared. GPU memory: 2470 MB
Process completed successfully.


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
)
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

def generate_text(model, tokenizer, prompt, max_length=256, temperature=0.7, num_return_sequences=1):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    outputs = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        num_return_sequences=num_return_sequences,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

def evaluate_perplexity(model, tokenizer, dataset, batch_size=4):
    model.eval()

    # Tokenize the dataset
    def tokenize_function(examples):
        return tokenizer(examples["text"], truncation=True, max_length=512)
    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

    # Use a collator that pads batches (no masked language modeling needed)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

    total_loss = 0.0
    total_tokens = 0
    for batch in tqdm(dataloader, desc="Evaluating Perplexity"):
        batch = {k: v.to(model.device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        # outputs.loss is the average loss per token; multiply by total tokens in batch
        loss = outputs.loss.item() * batch["input_ids"].numel()
        total_loss += loss
        total_tokens += batch["input_ids"].numel()
    avg_loss = total_loss / total_tokens
    ppl = np.exp(avg_loss)
    return ppl

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    print("Using device:", device)

    teacher_model_id = "meta-llama/llama-3.2-3b"  # 8-bit teacher model
    student_model_id = "meta-llama/llama-3.2-1b"    # 8-bit student model
    distilled_model_path = "distilled_llama_1b_8bit"       # 8-bit distilled model saved locally

    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_threshold=6.0,
        llm_int8_has_fp16_weight=False,
        bnb_4bit_compute_dtype=torch.float16
    )

    tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    teacher_model = AutoModelForCausalLM.from_pretrained(
        teacher_model_id,
        device_map="auto",
        quantization_config=quantization_config
    )
    teacher_model.eval()

    student_model = AutoModelForCausalLM.from_pretrained(
        student_model_id,
        device_map="auto",
        quantization_config=quantization_config
    )
    student_model.eval() 

    distilled_model = AutoModelForCausalLM.from_pretrained(
        distilled_model_path,
        device_map="auto",
        quantization_config=quantization_config
    )
    distilled_model.eval()

    print("\n--- Code Generation Evaluation ---")
    code_prompts = [
        "def add(a, b):\n    \"\"\"Return the sum of a and b.\"\"\"\n    ",
        "def factorial(n):\n    \"\"\"Return the factorial of n.\"\"\"\n    ",
        "def is_prime(n):\n    \"\"\"Return True if n is a prime number, False otherwise.\"\"\"\n    "
    ]

    for prompt in code_prompts:
        print("\nPrompt:")
        print(prompt)
        teacher_output = generate_text(teacher_model, tokenizer, prompt)[0]
        student_output = generate_text(student_model, tokenizer, prompt)[0]
        distilled_output = generate_text(distilled_model, tokenizer, prompt)[0]

        print("\nTeacher Model Output:")
        print(teacher_output)
        print("\nStudent Model Output:")
        print(student_output)
        print("\nDistilled Model Output:")
        print(distilled_output)

    print("\n--- Perplexity Evaluation on WikiText-2 ---")
    wikitext_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")

    teacher_ppl = evaluate_perplexity(teacher_model, tokenizer, wikitext_dataset, batch_size=4)
    student_ppl = evaluate_perplexity(student_model, tokenizer, wikitext_dataset, batch_size=4)
    distilled_ppl = evaluate_perplexity(distilled_model, tokenizer, wikitext_dataset, batch_size=4)

    print("\nPerplexity Scores:")
    print(f"Teacher Model Perplexity: {teacher_ppl:.2f}")
    print(f"Student Model Perplexity: {student_ppl:.2f}")
    print(f"Distilled Model Perplexity: {distilled_ppl:.2f}")

if __name__ == "__main__":
    main()

Using device: cuda


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



--- Code Generation Evaluation ---

Prompt:
def add(a, b):
    """Return the sum of a and b."""
    

Teacher Model Output:
def add(a, b):
    """Return the sum of a and b."""
    

Student Model Output:
def add(a, b):
    """Return the sum of a and b."""
    

Distilled Model Output:
def add(a, b):
    """Return the sum of a and b."""
    

Prompt:
def factorial(n):
    """Return the factorial of n."""
    

Teacher Model Output:
def factorial(n):
    """Return the factorial of n."""
    

Student Model Output:
def factorial(n):
    """Return the factorial of n."""
    

Distilled Model Output:
def factorial(n):
    """Return the factorial of n."""
    

Prompt:
def is_prime(n):
    """Return True if n is a prime number, False otherwise."""
    

Teacher Model Output:
def is_prime(n):
    """Return True if n is a prime number, False otherwise."""
    

Student Model Output:
def is_prime(n):
    """Return True if n is a prime number, False otherwise."""
    

Distilled Model Output:
d

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

Evaluating Perplexity: 100%|██████████| 940/940 [02:14<00:00,  6.96it/s]


Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

Evaluating Perplexity: 100%|██████████| 940/940 [01:18<00:00, 12.03it/s]
Evaluating Perplexity: 100%|██████████| 940/940 [01:18<00:00, 12.02it/s]


Perplexity Scores:
Teacher Model Perplexity: 16.57
Student Model Perplexity: 21.07
Distilled Model Perplexity: 20.32



