In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from PIL import Image
from io import BytesIO
import numpy as np
from tqdm import tqdm

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
from google.colab import drive
drive.mount('/content/drive')

# ----------------------------
# 1. LOAD DATASET FROM CSV
# ----------------------------

ds = load_dataset("csv", data_files={
    "train": "/content/drive/MyDrive/6.7960_Final_Project/Teacher_Embeddings/teacher_training_data_49.csv",
    "test":  "/content/drive/MyDrive/6.7960_Final_Project/Teacher_Embeddings/teacher_training_data_49.csv",
})


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [7]:
# ----------------------------
# 2. IMAGE TRANSFORM (to load images)
# ----------------------------

import torchvision.transforms as T
student_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

# ----------------------------
# 3. COLLATE FUNCTION
#    - decodes images
#    - loads TEACHER embeddings from CSV
# ----------------------------

def collate_fn(batch):

    pixel_values = torch.stack([
        student_transform(
            Image.open(BytesIO(eval(item["image"])["bytes"])).convert("RGB")
        )
        for item in batch
    ])

    labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)

    teacher_embeddings = torch.tensor(
        [np.fromstring(item["teacher_embedding"].strip("[]"), sep=" ") for item in batch],
        dtype=torch.float32
    )

    return {
        "teacher_embedding": teacher_embeddings,
        "labels": labels,
        "pixel_values": pixel_values,  # unused
    }

# ----------------------------
# 4. DATALOADERS
# ----------------------------

train_loader = DataLoader(
    ds["train"], batch_size=64, shuffle=True, collate_fn=collate_fn
)

test_loader = DataLoader(
    ds["test"], batch_size=64, shuffle=False, collate_fn=collate_fn
)

# ----------------------------
# 5. BUILD TEACHER CLASSIFIER
# ----------------------------

embedding_dim = 768     # teacher embedding dim
num_classes = 200       # Tiny-ImageNet num classes

classifier = nn.Linear(embedding_dim, num_classes).to(device)

optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

In [8]:
# ----------------------------
# 6. TRAIN FUNCTION
# ----------------------------

def train_teacher_classifier(classifier, loader, optimizer, device):
    classifier.train()
    total_loss, total_correct, total_samples = 0, 0, 0

    for batch in tqdm(loader):
        teacher_emb = batch["teacher_embedding"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        logits = classifier(teacher_emb)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    acc = total_correct / total_samples
    return total_loss / len(loader), acc

# ----------------------------
# 7. EVAL FUNCTION
# ----------------------------

@torch.no_grad()
def eval_teacher_classifier(classifier, loader, device):
    classifier.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    for batch in loader:
        teacher_emb = batch["teacher_embedding"].to(device)
        labels = batch["labels"].to(device)

        logits = classifier(teacher_emb)
        loss = F.cross_entropy(logits, labels)

        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    acc = total_correct / total_samples
    return total_loss / len(loader), acc


In [9]:
# ----------------------------
# 8. RUN TRAINING
# ----------------------------

num_epochs = 5
for epoch in range(num_epochs):
    train_loss, train_acc = train_teacher_classifier(
        classifier, train_loader, optimizer, device
    )
    val_loss, val_acc = eval_teacher_classifier(
        classifier, test_loader, device
    )

    print(f"Epoch {epoch+1} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

100%|██████████| 32/32 [00:05<00:00,  5.51it/s]


Epoch 1 | Train Acc: 0.7800 | Val Acc: 0.0000


100%|██████████| 32/32 [00:04<00:00,  6.56it/s]


Epoch 2 | Train Acc: 0.9485 | Val Acc: 0.0000


100%|██████████| 32/32 [00:04<00:00,  6.93it/s]


Epoch 3 | Train Acc: 0.9570 | Val Acc: 0.0000


100%|██████████| 32/32 [00:05<00:00,  5.89it/s]


Epoch 4 | Train Acc: 0.9655 | Val Acc: 0.0000


100%|██████████| 32/32 [00:13<00:00,  2.35it/s]


Epoch 5 | Train Acc: 0.9710 | Val Acc: 0.0000


In [10]:
from sklearn.metrics import f1_score, recall_score, roc_auc_score
import numpy as np

@torch.no_grad()
def evaluate_teacher_metrics(classifier, loader, device):
    classifier.eval()

    all_labels = []
    all_preds = []
    all_probs = []

    for batch in loader:
        teacher_emb = batch["teacher_embedding"].to(device)
        labels = batch["labels"].to(device)

        logits = classifier(teacher_emb)
        probs = torch.softmax(logits, dim=1)

        preds = logits.argmax(dim=1)

        all_labels.append(labels.cpu().numpy())
        all_preds.append(preds.cpu().numpy())
        all_probs.append(probs.cpu().numpy())

    all_labels = np.concatenate(all_labels)
    all_preds  = np.concatenate(all_preds)
    all_probs  = np.concatenate(all_probs)

    f1 = f1_score(all_labels, all_preds, average="macro")

    recall = recall_score(all_labels, all_preds, average="macro")

    try:
        auc = roc_auc_score(all_labels, all_probs, multi_class="ovr")
    except ValueError:
        auc = float("nan")  # happens if a class is missing in predictions

    return {
        "f1_macro": f1,
        "recall_macro": recall,
        "auc_macro": auc,
    }


In [11]:
teacher_train_metrics = evaluate_teacher_metrics(
    classifier,
    train_loader,
    device
)

print("Teacher Train Metrics:")
for k, v in teacher_train_metrics.items():
    print(f"{k}: {v:.4f}")

Teacher Train Metrics:
f1_macro: 0.9790
recall_macro: 0.9790
auc_macro: nan
