In [1]:
!pip install timm datasets transformers



In [2]:
# setup / imports
import torch
import torch.nn.functional as F
import timm
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, ViTModel

device = "cuda" if torch.cuda.is_available() else "cpu"

In [13]:
# load student model, with randomly initialized weights
student_model = timm.create_model("tiny_vit_5m_224.in1k", pretrained=False, num_classes=0).to(device)
# if we set num_classes to 0, we get the embeddings, should we do this and add a separate classification head for evaluation?

In [4]:
# load pretrained teacher model
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
teacher_model_name = "google/vit-base-patch16-224"
teacher_model = ViTModel.from_pretrained(teacher_model_name).to(device)
teacher_model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (d

In [14]:
print(student_model.num_features)
print(teacher_model.config.hidden_size)
proj = torch.nn.Linear(student_model.num_features, teacher_model.config.hidden_size).to(device)

320
768


In [71]:
# not using anymore
def get_teacher_embedding(pil_images):
    pil_images = [img.convert("RGB") for img in pil_images] # converts img to RGB, some images in dataset are in grayscale??

    pixel_values = torch.stack([
        processor(img, return_tensors="pt").pixel_values[0]
        for img in pil_images
    ]).to(device)

    with torch.no_grad():
        outputs = teacher_model(pixel_values)
    return outputs.last_hidden_state[:, 0, :]


def get_student_embedding(imgs):
    imgs = imgs.to(device)
    feats = student_model.forward_features(imgs)   # [B, C, H, W]
    pooled = feats.mean(dim=[2, 3])                # [B, C]
    return proj(pooled)                            # [B, 768]

In [17]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [55]:
from datasets import load_dataset
import torchvision.transforms as T
from torch.utils.data import DataLoader
import ast
import numpy as np
from PIL import Image
from io import BytesIO

# ds = load_dataset("zh-plus/tiny-imagenet")

ds = load_dataset("csv", data_files={
    "train": "/content/drive/MyDrive/6.7960_Final_Project/Teacher_Embeddings/teacher_training_data_49.csv",
    "val": "/content/drive/MyDrive/6.7960_Final_Project/Teacher_Embeddings/teacher_training_data_30.csv",
    "test":  "/content/drive/MyDrive/6.7960_Final_Project/Teacher_Embeddings/teacher_training_data_40.csv",
})

student_transform = T.Compose([
    T.Lambda(lambda img: img.convert("RGB")),   # <- FORCE 3 channels ?
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

# https://docs.pytorch.org/vision/0.8/models.html? for the normalization values

def apply_student_transforms(batch):
    batch["pixel_values"] = [student_transform(img) for img in batch["image"]]
    return batch

def collate_fn(batch):
    pixel_student = torch.stack([
        student_transform(
            Image.open(
                BytesIO(
                    eval(item["image"])["bytes"]   # OK here since source is your own CSV
                )
            ).convert("RGB")
        )
        for item in batch
    ])

    labels = torch.tensor([item["label"] for item in batch])

    teacher_embeddings = torch.tensor(
        [np.fromstring(item["teacher_embedding"].strip("[]"), sep=" ") for item in batch],
        dtype=torch.float32
    )

    return {
        "pixel_values_student": pixel_student,
        "teacher_embedding": teacher_embeddings,
        "labels": labels,
    }

train_loader = DataLoader(
    ds["train"],
    batch_size=64,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn,
)

test_loader = DataLoader(
    ds["test"],
    batch_size=64,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    ds["val"],
    batch_size=64,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn,
)

Generating train split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [61]:
for p in student_model.parameters():
    p.requires_grad = True
for p in proj.parameters():
    p.requires_grad = True

# training loop
optimizer = torch.optim.AdamW(
    list(student_model.parameters()) + list(proj.parameters()),
    lr=3e-4
)

# AdamW optimizer for now, can look into other options? also, can revisit hyperparameters, currently using defaults

def train_one_epoch(student_model, train_loader, optimizer):
    student_model.train()
    proj.train()

    total_loss = 0.0

    for batch in tqdm(train_loader):
        imgs_student = batch["pixel_values_student"].to(device)
        teacher_emb = batch["teacher_embedding"].to(device)

        optimizer.zero_grad()

        student_emb = get_student_embedding(imgs_student)

        loss = F.mse_loss(student_emb, teacher_emb)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def evaluate_student(student_model, val_loader, device="cuda"):
    student_model.eval()
    proj.eval()

    total_loss = 0.0
    n_batches = 0

    for batch in val_loader:
        imgs_student = batch["pixel_values_student"].to(device)
        teacher_emb = batch["teacher_embedding"].to(device)

        # student forward
        student_emb = get_student_embedding(imgs_student)

        loss = F.mse_loss(student_emb, teacher_emb)

        total_loss += loss.item()
        n_batches += 1

    return total_loss / n_batches

# --- TRAIN ---
best_val_loss = float("inf")

for epoch in range(1):
    train_loss = train_one_epoch(student_model, train_loader, optimizer)
    val_loss = evaluate_student(student_model, val_loader)

    print(f"Epoch {epoch+1:02d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save best checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            "student_state_dict": student_model.state_dict(),
            "proj_state_dict": proj.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }, "best_student.pth")
        print("  → Saved new best model!")


100%|██████████| 32/32 [00:23<00:00,  1.35it/s]


Epoch 01 | Train Loss: 0.2949 | Val Loss: 0.9151
  → Saved new best model!


In [62]:
checkpoint = torch.load("best_student.pth", map_location=device)

student_model.load_state_dict(checkpoint["student_state_dict"])
proj.load_state_dict(checkpoint["proj_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

student_model.to(device)
proj.to(device)

student_model.eval()
proj.eval()

print("Loaded best checkpoint!")

Loaded best checkpoint!


In [63]:
test_loss = evaluate_student(student_model, test_loader, device)
print(f"Final Test Loss: {test_loss:.4f}")

Final Test Loss: 0.9261


In [64]:
student_model.eval()
proj.eval()

for p in student_model.parameters():
    p.requires_grad = False

for p in proj.parameters():
    p.requires_grad = False

In [65]:
import torch.nn as nn

num_classes = 200  # Tiny-ImageNet = 200, change if needed
embedding_dim = 768  # this matches your proj output

classifier = nn.Linear(embedding_dim, num_classes).to(device)

In [66]:
def train_classifier(student_model, proj, classifier, data_loader, optimizer, device):
    student_model.eval()
    proj.eval()
    classifier.train()

    total_loss = 0
    total_correct = 0
    total_samples = 0

    for batch in data_loader:
        imgs = batch["pixel_values_student"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        with torch.no_grad():
            emb = get_student_embedding(imgs)  # [B, 768]

        logits = classifier(emb)

        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.numel()

    acc = total_correct / total_samples
    return total_loss / len(data_loader), acc


In [67]:
@torch.no_grad()
def eval_classifier(student_model, proj, classifier, data_loader, device):
    student_model.eval()
    proj.eval()
    classifier.eval()

    total_correct = 0
    total_samples = 0
    total_loss = 0

    for batch in data_loader:
        imgs = batch["pixel_values_student"].to(device)
        labels = batch["labels"].to(device)

        emb = get_student_embedding(imgs)
        logits = classifier(emb)
        loss = F.cross_entropy(logits, labels)

        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.numel()
        total_loss += loss.item()

    acc = total_correct / total_samples
    return total_loss / len(data_loader), acc


In [70]:
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

for epoch in range(1):
    train_loss, train_acc = train_classifier(student_model, proj, classifier, train_loader, classifier_optimizer, device)
    val_loss, val_acc = eval_classifier(student_model, proj, classifier, test_loader, device)

    print(f"Epoch {epoch+1} | Train Acc: {train_acc:.3f} | Val Acc: {val_acc:.3f}")


Epoch 1 | Train Acc: 0.945 | Val Acc: 0.000
