### Toy Example 4 Class

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

def logit_standardization(logits: torch.Tensor) -> torch.Tensor:
    mean = torch.mean(logits, dim=-1, keepdim=True)
    std = torch.std(logits, dim=-1, keepdim=True)
    return (logits - mean) / (std + 1e-8)

def calculate_kl(student_logits, teacher_logits, temperature=1.0, standardize=False):
    if standardize:
        student_logits_p = logit_standardization(student_logits)
        teacher_logits_p = logit_standardization(teacher_logits)
    else:
        student_logits_p = student_logits
        teacher_logits_p = teacher_logits
    teacher_probs = F.softmax(teacher_logits_p / temperature, dim=-1)
    student_log_probs = F.log_softmax(student_logits_p / temperature, dim=-1)
    kl_div = F.kl_div(student_log_probs, teacher_probs, reduction='sum')
    return max(kl_div.item(), 0.0)
    return kl_div.item()


CLASSES = ['Cat', 'Dog', 'Bird', 'Frog']

# logits from the paper's Figure 2
teacher_logits = torch.tensor([[1.0, 4.0, 3.0, 1.0]])
student1_logits = torch.tensor([[2.0, 2.8, 3.0, 1.0]]) # Magnitude-close, Rank-wrong
student2_logits = torch.tensor([[0.1, 0.4, 0.3, 0.1]]) # Magnitude-far, Rank-correct

teacher_pred = torch.argmax(teacher_logits, dim=-1).item()
s1_pred = torch.argmax(student1_logits, dim=-1).item()
s2_pred = torch.argmax(student2_logits, dim=-1).item()

print(f"Teacher Prediction: '{CLASSES[teacher_pred]}' (Correctly 'Dog')")
print(f"Student 1 Prediction: '{CLASSES[s1_pred]}' (Incorrectly 'Bird')")
print(f"Student 2 Prediction: '{CLASSES[s2_pred]}' (Correctly 'Dog')")
print("\n--- Before Standardization (Normal KD) ---")

# KL divergence without standardization
kl_s1_before = calculate_kl(student1_logits, teacher_logits)
kl_s2_before = calculate_kl(student2_logits, teacher_logits)

print(f"KL Div(Teacher || Student 1): {kl_s1_before:.4f}")
print(f"KL Div(Teacher || Student 2): {kl_s2_before:.4f}")
print(f"Student 1 has a LOWER loss ({kl_s1_before:.4f} < {kl_s2_before:.4f}),")

print("\n--- Analysis After Logit Standardization ---")

# KL divergence WITH standardization
kl_s1_after = calculate_kl(student1_logits, teacher_logits, standardize=True)
kl_s2_after = calculate_kl(student2_logits, teacher_logits, standardize=True)

print(f"KL Div(Teacher || Student 1): {kl_s1_after:.4f}")
print(f"KL Div(Teacher || Student 2): {kl_s2_after:.4f}")
print(f"\n Student 2 has a lower loss ({kl_s2_after:.4f} << {kl_s1_after:.4f}).")

### Toy Example 10 Class

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

def logit_standardization(logits: torch.Tensor) -> torch.Tensor:
    mean = torch.mean(logits, dim=-1, keepdim=True)
    std = torch.std(logits, dim=-1, keepdim=True)
    return (logits - mean) / (std + 1e-8)

def calculate_kl(student_logits, teacher_logits, temperature=1.0, standardize=False):
    student_logits = student_logits.float()
    teacher_logits = teacher_logits.float()

    if standardize:
        student_logits_p = logit_standardization(student_logits)
        teacher_logits_p = logit_standardization(teacher_logits)
    else:
        student_logits_p = student_logits
        teacher_logits_p = teacher_logits

    teacher_probs = F.softmax(teacher_logits_p / temperature, dim=-1)
    student_log_probs = F.log_softmax(student_logits_p / temperature, dim=-1)
    kl_div = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
    return kl_div.item()

# Teacher's top two logits (at indices 2 and 4) are close.
teacher_logits_10 = torch.tensor([[1.0, 2.0, 8.0, 3.0, 7.5, 1.5, -2.0, 0.5, -1.0, 2.5]])
# Correct class is index 2.

# Student 1: Swaps top two logits (indices 2 and 4). Magnitudes are very close, rank is wrong.
student1_logits_10 = torch.tensor([[1.0, 2.0, 7.5, 3.0, 8.0, 1.5, -2.0, 0.5, -1.0, 2.5]])

# Student 2: Scales down teacher logits. Magnitudes are far, rank is correct.
student2_logits_10 = teacher_logits_10 / 5.0

teacher_pred_10 = torch.argmax(teacher_logits_10, dim=-1).item()
s1_pred_10 = torch.argmax(student1_logits_10, dim=-1).item()
s2_pred_10 = torch.argmax(student2_logits_10, dim=-1).item()

print(f"Teacher Prediction: Correct class is index {teacher_pred_10}")
print(f"Student 1 Prediction: Incorrectly predicts index {s1_pred_10}")
print(f"Student 2 Prediction: Correctly predicts index {s2_pred_10}")

# 3. Analyze Before Standardization
print("\n--- Before Standardization (Normal KD) ---")

kl_s1_before_10 = calculate_kl(student1_logits_10, teacher_logits_10)
kl_s2_before_10 = calculate_kl(student2_logits_10, teacher_logits_10)

print(f"KL Div(Teacher || Student 1): {kl_s1_before_10:.4f}")
print(f"KL Div(Teacher || Student 2): {kl_s2_before_10:.4f}")

if kl_s1_before_10 < kl_s2_before_10:
    print(f"\n Student 1 has a LOWER loss ({kl_s1_before_10:.4f} < {kl_s2_before_10:.4f}),")
else:
    print("\n[Analysis Failed] The example did not produce the desired misleading signal.")

print("\n--- After Logit Standardization ---")

kl_s1_after_10 = calculate_kl(student1_logits_10, teacher_logits_10, standardize=True)
kl_s2_after_10 = calculate_kl(student2_logits_10, teacher_logits_10, standardize=True)

print(f"KL Div(Teacher || Student 1): {kl_s1_after_10:.4f}")
print(f"KL Div(Teacher || Student 2): {kl_s2_after_10:.4f}")
print(f"\n After standardization, Student 2 has the minimal loss ({kl_s2_after_10:.4f} << {kl_s1_after_10:.4f}).")

### Loading Models

In [6]:
import torch
from model import get_model
from data_utils import get_cifar100_loaders
from train_eval import evaluate
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_loader, test_loader = get_cifar100_loaders(batch_size=128)
teacher = get_model('vgg16_bn', pretrained=True, num_classes=100).to(device)
teacher.eval()
student_kd = get_model('vgg11_bn', pretrained=False, num_classes=100).to(device)
student_kd.load_state_dict(torch.load('distilled_student.pth', map_location=device))
student_kd.eval()
student_indep = get_model('vgg11_bn', pretrained=False, num_classes=100).to(device)
student_indep.load_state_dict(torch.load('vgg11_bn_finetuned.pth', map_location=device))
student_indep.eval()


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(ke

### Analysis

In [7]:
import torch
import torch.nn.functional as F
from model import get_model
from data_utils import get_cifar100_loaders
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_loader, test_loader = get_cifar100_loaders(batch_size=256, num_workers=2)
num_examples_to_inspect = 5
T = 3.0
@torch.no_grad()
def collect_logits(model, loader, max_batches=None):
    model.eval()
    logits_list = []
    labels_list = []
    for i, (x, y) in enumerate(loader):
        x = x.to(device)
        logits = model(x)               # logits (pre-softmax)
        logits_list.append(logits.cpu())
        labels_list.append(y)
        if max_batches and i+1 >= max_batches:
            break
    return torch.cat(logits_list, dim=0), torch.cat(labels_list, dim=0)

print("Collecting train logits ...")
train_teacher_logits, train_labels = collect_logits(teacher, train_loader)
train_indep_logits, _ = collect_logits(student_indep, train_loader)
train_kd_logits, _    = collect_logits(student_kd, train_loader)
print("Train logits collected:", train_teacher_logits.shape)

print("Collecting test logits...")
test_teacher_logits, test_labels = collect_logits(teacher, test_loader)
test_indep_logits, _ = collect_logits(student_indep, test_loader)
test_kd_logits, _    = collect_logits(student_kd, test_loader)
print("Test logits collected:", test_teacher_logits.shape)

def per_sample_kl(teacher_logits, student_logits, T=1.0):
    # teacher_probs: p_t, student_log_probs: log p_s
    p_t = F.softmax(teacher_logits / T, dim=1)            # [N, C]
    log_p_s = F.log_softmax(student_logits / T, dim=1)    # [N, C]
    log_p_t = torch.log(p_t + 1e-12)
    kl = (p_t * (log_p_t - log_p_s)).mean(dim=1)  # [N]
    return kl

# Logit standardization
def standardize_logits(logits):
    mean = logits.mean(dim=1, keepdim=True)
    std = logits.std(dim=1, keepdim=True)
    return (logits - mean) / (std + 1e-8)

def analyze_split(teacher_logits, indep_logits, kd_logits, labels, split_name="split"):
    n = labels.size(0)
    with torch.no_grad():
        # preds
        teacher_pred = teacher_logits.argmax(dim=1)
        indep_pred = indep_logits.argmax(dim=1)
        kd_pred = kd_logits.argmax(dim=1)

        # per-sample KL
        kl_indep = per_sample_kl(teacher_logits, indep_logits, T=T)
        kl_kd    = per_sample_kl(teacher_logits, kd_logits, T=T)

        # KD lower KL but KD disagrees with teacher, indep agrees with teacher OR indep matches gold
        indep_matches = (indep_pred == teacher_pred) | (indep_pred == labels)
        kd_disagrees_teacher = (kd_pred != teacher_pred)

        conflict_mask = (kl_kd < kl_indep) & kd_disagrees_teacher & indep_matches

        conflict_indices = torch.nonzero(conflict_mask).squeeze(1).cpu().tolist()
        num_conflicts = len(conflict_indices)

        t_std = standardize_logits(teacher_logits)
        indep_std = standardize_logits(indep_logits)
        kd_std = standardize_logits(kd_logits)

        kl_indep_std = per_sample_kl(t_std, indep_std, T=T)
        kl_kd_std    = per_sample_kl(t_std, kd_std, T=T)
        corrected = 0
        corrected_examples = []
        for idx in conflict_indices:
            if kl_indep_std[idx] < kl_kd_std[idx]:
                corrected += 1
                if len(corrected_examples) < num_examples_to_inspect:
                    corrected_examples.append(idx)

        # Build summary and some examples
        fraction = num_conflicts / float(n)
        corrected_frac = corrected / float(num_conflicts) if num_conflicts > 0 else 0.0

        print(f"\n=== {split_name.upper()} SUMMARY ===")
        print(f"Total samples inspected: {n}")
        print(f"Conflicts found (KD lower KL but KD disagrees with teacher & Indep matches teacher/gold): {num_conflicts} ({fraction*100:.3f}%)")
        print(f"After logit-standardization, conflicts corrected: {corrected}/{num_conflicts} ({corrected_frac*100:.2f}%)")

        # Show a few example cases (before and after KLs)
        if num_conflicts > 0:
            print(f"\nShowing up to {num_examples_to_inspect} example conflict indices (index, label, teacher_pred, indep_pred, kd_pred):")
            for i, idx in enumerate(conflict_indices[:num_examples_to_inspect]):
                print(f"#{i+1}: idx={idx}, label={labels[idx].item()}, teacher={teacher_pred[idx].item()}, "
                      f"indep={indep_pred[idx].item()}, kd={kd_pred[idx].item()}")
                print(f"   KL_indep_before={kl_indep[idx].item():.6f}, KL_kd_before={kl_kd[idx].item():.6f}")
                print(f"   KL_indep_std = {kl_indep_std[idx].item():.6f}, KL_kd_std = {kl_kd_std[idx].item():.6f}")
        else:
            print("No conflicts found in this split.")

        return {
            'n': n,
            'num_conflicts': num_conflicts,
            'fraction': fraction,
            'corrected': corrected,
            'corrected_frac': corrected_frac,
            'example_indices': conflict_indices[:num_examples_to_inspect]
        }

train_report = analyze_split(train_teacher_logits, train_indep_logits, train_kd_logits, train_labels, split_name="train")
test_report  = analyze_split(test_teacher_logits, test_indep_logits, test_kd_logits, test_labels, split_name="test")


Collecting train logits ...
Train logits collected: torch.Size([50000, 100])
Collecting test logits...
Test logits collected: torch.Size([10000, 100])

=== TRAIN SUMMARY ===
Total samples inspected: 50000
Conflicts found (KD lower KL but KD disagrees with teacher & Indep matches teacher/gold): 6 (0.012%)
After logit-standardization, conflicts corrected: 2/6 (33.33%)

Showing up to 5 example conflict indices (index, label, teacher_pred, indep_pred, kd_pred):
#1: idx=12000, label=22, teacher=22, indep=22, kd=61
   KL_indep_before=0.011890, KL_kd_before=0.011188
   KL_indep_std = 0.000608, KL_kd_std = 0.000449
#2: idx=16372, label=65, teacher=65, indep=65, kd=64
   KL_indep_before=0.018644, KL_kd_before=0.017165
   KL_indep_std = 0.000622, KL_kd_std = 0.000612
#3: idx=17589, label=72, teacher=72, indep=72, kd=32
   KL_indep_before=0.014524, KL_kd_before=0.013811
   KL_indep_std = 0.000406, KL_kd_std = 0.000548
#4: idx=29120, label=35, teacher=35, indep=35, kd=98
   KL_indep_before=0.01714