In [1]:
!pip install timm datasets transformers



In [2]:
# setup / imports
import torch
import torch.nn.functional as F
import timm
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, ViTModel

device = "cuda" if torch.cuda.is_available() else "cpu"

In [13]:
# load student model, with randomly initialized weights
student_model = timm.create_model("tiny_vit_5m_224.in1k", pretrained=False, num_classes=0).to(device)
# if we set num_classes to 0, we get the embeddings, should we do this and add a separate classification head for evaluation?

In [4]:
# load pretrained teacher model
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
teacher_model_name = "google/vit-base-patch16-224"
teacher_model = ViTModel.from_pretrained(teacher_model_name).to(device)
teacher_model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (d

In [14]:
print(student_model.num_features)
print(teacher_model.config.hidden_size)
proj = torch.nn.Linear(student_model.num_features, teacher_model.config.hidden_size).to(device)

320
768


In [15]:
def get_teacher_embedding(pil_images):
    pil_images = [img.convert("RGB") for img in pil_images] # converts img to RGB, some images in dataset are in grayscale??

    pixel_values = torch.stack([
        processor(img, return_tensors="pt").pixel_values[0]
        for img in pil_images
    ]).to(device)

    with torch.no_grad():
        outputs = teacher_model(pixel_values)
    return outputs.last_hidden_state[:, 0, :]

def get_student_embedding(imgs):
    imgs = imgs.to(device)
    feats = student_model.forward_features(imgs)   # [B, C, H, W]
    pooled = feats.mean(dim=[2, 3])                # [B, C]
    return proj(pooled)                            # [B, 768]

In [17]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [33]:
from datasets import load_dataset
import torchvision.transforms as T
from torch.utils.data import DataLoader
import ast
import numpy as np

# ds = load_dataset("zh-plus/tiny-imagenet")

ds = load_dataset("csv", data_files={
    "train": "/content/drive/MyDrive/6.7960_Final_Project/Teacher_Embeddings/teacher_training_data_49.csv",
    "test": "/content/drive/MyDrive/6.7960_Final_Project/Teacher_Embeddings/teacher_training_data_49.csv",
})

student_transform = T.Compose([
    T.Lambda(lambda img: img.convert("RGB")),   # <- FORCE 3 channels ?
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

# https://docs.pytorch.org/vision/0.8/models.html? for the normalization values

def apply_student_transforms(batch):
    batch["pixel_values"] = [student_transform(img) for img in batch["image"]]
    return batch

from PIL import Image
from io import BytesIO


def collate_fn(batch):
    pixel_student = torch.stack([
        student_transform(
            Image.open(
                BytesIO(
                    eval(item["image"])["bytes"]   # OK here since source is your own CSV
                )
            ).convert("RGB")
        )
        for item in batch
    ])

    labels = torch.tensor([item["label"] for item in batch])

    teacher_embeddings = torch.tensor(
        [np.fromstring(item["teacher_embedding"].strip("[]"), sep=" ") for item in batch],
        dtype=torch.float32
    )

    return {
        "pixel_values_student": pixel_student,
        "teacher_embedding": teacher_embeddings,
        "labels": labels,
    }

train_loader = DataLoader(
    ds["train"],
    batch_size=64,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn,
)

test_loader = DataLoader(
    ds["test"],
    batch_size=64,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn,
)



In [None]:
# training loop
optimizer = torch.optim.AdamW(
    list(student_model.parameters()) + list(proj.parameters()),
    lr=3e-4
)

# AdamW optimizer for now, can look into other options? also, can revisit hyperparameters, currently using defaults

def train_one_epoch(student_model, train_loader, optimizer):
    student_model.train()
    proj.train()

    total_loss = 0.0

    for batch in tqdm(train_loader):
        imgs_student = batch["pixel_values_student"].to(device)
        teacher_emb = batch["teacher_embedding"].to(device)

        optimizer.zero_grad()

        student_emb = get_student_embedding(imgs_student)

        loss = F.mse_loss(student_emb, teacher_emb)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


# --- TRAIN ---
for epoch in range(30):
    loss = train_one_epoch(student_model, train_loader, optimizer)
    print("Epoch loss:", loss)

100%|██████████| 32/32 [00:16<00:00,  1.92it/s]


Epoch loss: 0.6332217752933502


100%|██████████| 32/32 [00:16<00:00,  1.90it/s]


Epoch loss: 0.6070549413561821


100%|██████████| 32/32 [00:16<00:00,  1.92it/s]


Epoch loss: 0.5786515884101391


100%|██████████| 32/32 [00:16<00:00,  1.95it/s]


Epoch loss: 0.5517709013074636


100%|██████████| 32/32 [00:16<00:00,  1.91it/s]


Epoch loss: 0.5312258694320917


100%|██████████| 32/32 [00:16<00:00,  1.94it/s]


Epoch loss: 0.5183593621477485


100%|██████████| 32/32 [00:16<00:00,  1.91it/s]


Epoch loss: 0.5062798112630844


100%|██████████| 32/32 [00:16<00:00,  1.91it/s]


Epoch loss: 0.4913893034681678


100%|██████████| 32/32 [00:16<00:00,  1.89it/s]


Epoch loss: 0.48164631612598896


100%|██████████| 32/32 [00:16<00:00,  1.94it/s]


Epoch loss: 0.470749469473958


100%|██████████| 32/32 [00:16<00:00,  1.95it/s]


Epoch loss: 0.4614019254222512


100%|██████████| 32/32 [00:16<00:00,  1.91it/s]


Epoch loss: 0.4557664766907692


100%|██████████| 32/32 [00:16<00:00,  1.93it/s]


Epoch loss: 0.4477397706359625


100%|██████████| 32/32 [00:16<00:00,  1.93it/s]


Epoch loss: 0.44077474903315306


100%|██████████| 32/32 [00:16<00:00,  1.95it/s]


Epoch loss: 0.4311323370784521


 50%|█████     | 16/32 [00:08<00:08,  1.95it/s]

In [None]:
next(iter(train_loader))

{'pixel_values_student': tensor([[[[ 0.0398,  0.0398, -0.0116,  ..., -1.3987, -1.4329, -1.4329],
           [ 0.0398,  0.0398, -0.0116,  ..., -1.3987, -1.4329, -1.4329],
           [ 0.0227,  0.0227, -0.0287,  ..., -1.3987, -1.4329, -1.4329],
           ...,
           [-1.3473, -1.3473, -1.2617,  ...,  0.0056, -0.0972, -0.0972],
           [-1.5528, -1.5528, -1.4158,  ..., -0.0287, -0.1314, -0.1314],
           [-1.5528, -1.5528, -1.4158,  ..., -0.0287, -0.1314, -0.1314]],
 
          [[ 0.4503,  0.4503,  0.3978,  ..., -1.0903, -1.1253, -1.1253],
           [ 0.4503,  0.4503,  0.3978,  ..., -1.0903, -1.1253, -1.1253],
           [ 0.4328,  0.4328,  0.3803,  ..., -1.0903, -1.1253, -1.1253],
           ...,
           [-1.1604, -1.1604, -1.0728,  ..., -0.1625, -0.2675, -0.2675],
           [-1.3704, -1.3704, -1.2304,  ..., -0.1975, -0.3025, -0.3025],
           [-1.3704, -1.3704, -1.2304,  ..., -0.1975, -0.3025, -0.3025]],
 
          [[ 0.6705,  0.6705,  0.6182,  ..., -0.4450, -0.4798,

In [34]:
@torch.no_grad()
def evaluate(student_model, proj, data_loader, device="cuda"):
    student_model.eval()
    proj.eval()

    total_loss = 0.0
    n_batches = 0

    for batch in data_loader:
        imgs_student = batch["pixel_values_student"].to(device)
        teacher_emb = batch["teacher_embedding"].to(device)

        student_emb = get_student_embedding(imgs_student)

        loss = F.mse_loss(student_emb, teacher_emb)

        total_loss += loss.item()
        n_batches += 1

    return total_loss / n_batches

In [35]:
val_loss = evaluate(
    student_model=student_model,
    proj=proj,
    data_loader=test_loader,   # or train_loader if you want train eval
    device=device
)

print("Validation MSE:", val_loss)

Validation MSE: 0.7857959140092134
