In [None]:
import torch

In [None]:
print(torch.cuda.is_available())

In [None]:
!nvidia-smi

In [None]:
from huggingface_hub import login
HF_TOKEN="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"
login(token=HF_TOKEN)

# Or via environment variable (safer for PSC scripts)
# export HF_TOKEN=your_token_here

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
# export HF_TOKEN="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"
# V100 optimized loading (float16)
model_id = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
teacher = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto",cache_dir="/ocean/projects/cis220031p/bkoduru/huggingface/hub")

# Map Llama tokens for "0"-"9" to their indices
digit_tokens = [tokenizer.encode(str(i), add_special_tokens=False)[-1] for i in range(10)]

def extract_logits(dataloader, limit=2000):
    all_logits = []
    all_labels = []
    
    for i, (img, label) in enumerate(tqdm(dataloader)):
        if i >= limit: break
        
        # Convert 28x28 image to space-separated string
        pixels = (img.view(-1) * 255).int().tolist()
        pixel_str = " ".join(map(str, pixels))
        prompt = f"The following is a handwritten digit image pixel list: {pixel_str}\nDigit:"
        
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
        with torch.no_grad():
            outputs = teacher(**inputs)
            # Take last token logits and filter for 0-9
            logits = outputs.logits[0, -1, digit_tokens] 
            all_logits.append(logits.cpu())
            all_labels.append(label)
            
    return torch.stack(all_logits), torch.stack(all_labels)

# Save for fast training later
# mnist_logits, labels = extract_logits(train_loader)
# torch.save({'logits': mnist_logits, 'labels': labels}, 'llama3_mnist_logits.pt')

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class StudentMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 10)
        )
    def forward(self, x):
        return self.net(x.view(x.size(0), -1))

def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):
    # 1. Hard Loss (Ground Truth)
    hard_loss = F.cross_entropy(student_logits, labels)
    
    # 2. Soft Loss (Teacher guidance)
    # KL(Teacher || Student)
    soft_targets = F.softmax(teacher_logits / T, dim=-1)
    soft_student = F.log_softmax(student_logits / T, dim=-1)
    soft_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') * (T**2)
    
    return alpha * hard_loss + (1 - alpha) * soft_loss

# Training Loop
def train(model, data, epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    teacher_logits = data['logits'].to("cuda")
    labels = data['labels'].to("cuda")
    mnist_imgs = ... # Original MNIST images

    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(mnist_imgs)
        loss = distillation_loss(outputs, teacher_logits, labels)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch} Loss: {loss.item()}")

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Assuming 'teacher' and 'tokenizer' are already loaded in your environment
# Map Llama tokens for "0"-"9" to their indices in the vocab
digit_tokens = [tokenizer.encode(str(i), add_special_tokens=False)[-1] for i in range(10)]

def precompute_logits(num_samples=2000):
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    loader = DataLoader(mnist_train, batch_size=1, shuffle=False)
    
    all_logits = []
    all_labels = []
    
    print(f"Extracting knowledge for {num_samples} samples...")
    for i, (img, label) in enumerate(tqdm(loader, total=num_samples)):
        if i >= num_samples: break
        
        # Flatten and stringify image
        pixels = (img.view(-1) * 255).int().tolist()
        pixel_str = " ".join(map(str, pixels))
        prompt = f"Pixels: {pixel_str}\nDigit:"
        
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
        with torch.no_grad():
            outputs = teacher(**inputs)
            # Filter logits for just the 10 digit tokens at the final position
            logits = outputs.logits[0, -1, digit_tokens]
            all_logits.append(logits.cpu())
            all_labels.append(label)
            
    torch.save({'logits': torch.stack(all_logits), 'labels': torch.stack(all_labels)}, 'mnist_kd_data.pt')
    print("Done! Knowledge saved to mnist_kd_data.pt")

precompute_logits()

In [None]:
import torch
import torch.nn.functional as F

# Load your extracted knowledge
data = torch.load('mnist_kd_data.pt')
t_logits = data['logits']
labels = data['labels'].view(-1)

# 1. Teacher Accuracy: Does it actually 'see' the digits?
t_preds = t_logits.argmax(dim=1)
t_acc = (t_preds == labels).float().mean()
print(f"Teacher Accuracy: {t_acc.item():.2%}")

# 2. Logit Entropy: Is there 'Dark Knowledge'?
# High entropy means the teacher sees similarities (e.g., a 7 looks like a 1).
t_probs = F.softmax(t_logits, dim=-1)
entropy = -torch.sum(t_probs * torch.log(t_probs + 1e-9), dim=1).mean()
print(f"Average Teacher Entropy: {entropy.item():.4f}")

In [None]:
import matplotlib.pyplot as plt
import torch.nn.functional as F

def analyze_teacher_failures(logits, labels, num_examples=5):
    probs = F.softmax(logits, dim=-1)
    preds = logits.argmax(dim=1)
    
    # Find indices where the teacher was WRONG
    wrong_indices = (preds != labels).nonzero(as_tuple=True)[0]
    
    for i in range(min(num_examples, len(wrong_indices))):
        idx = wrong_indices[i]
        plt.figure(figsize=(6, 2))
        plt.bar(range(10), probs[idx].numpy(), color='salmon')
        plt.title(f"True Label: {labels[idx].item()} | Teacher Predicted: {preds[idx].item()}")
        plt.xticks(range(10))
        plt.ylabel("Probability")
        plt.show()

analyze_teacher_failures(t_logits, labels)

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import torch

# Load data
data = torch.load('mnist_kd_data.pt')
logits = data['logits'].numpy()
labels = data['labels'].view(-1).numpy()

# Run t-SNE on the 10-dimensional logit space
tsne = TSNE(n_components=3, perplexity=6, random_state=42)
embeddings = tsne.fit_transform(logits)

# Plot
plt.figure(figsize=(10, 7))
scatter = plt.scatter(embeddings[:, 0], embeddings[:, 1], c=labels, cmap='tab10', alpha=0.4)
plt.colorbar(scatter)
plt.title("t-SNE Visualization of Llama 3.1 Logits for MNIST")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")
plt.show()

In [2]:
import torch

In [4]:
!nvidia-smi

Sun Feb 15 14:41:23 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  |   00000000:15:00.0 Off |                    0 |
| N/A   29C    P0             41W /  300W |       4MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                