In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

Let's load the different bird species from the `class_names.npy` file and then the attributes from `attributes.npy` which has for every class 312 features that are explained by the file `attributes.txt`.

In [3]:
bird_classes = np.load("class_names.npy", allow_pickle=True).item()

In [4]:
attributes = np.load('attributes.npy')
attributes.shape

(200, 312)

In [5]:
with open("attributes.txt", "r") as f:
    attribute_names = [line.strip() for line in f.readlines()]

attribute_names[:5]

['1 has_bill_shape::curved_(up_or_down)',
 '2 has_bill_shape::dagger',
 '3 has_bill_shape::hooked',
 '4 has_bill_shape::needle',
 '5 has_bill_shape::hooked_seabird']

Unify the attributes files to map for every bird species they're features

In [6]:
class_attributes = {}

for class_id in range(attributes.shape[0]):
    class_attributes[class_id + 1] = {
        attribute_names[i]: attributes[class_id, i] for i in range(len(attribute_names))
    }

Create a data frame `birds_df` with the class_id and the 312 attrbiutes of each bird class. Then merge it with the class name of each bird.

In [7]:
birds_df = pd.DataFrame.from_dict(class_attributes, orient="index")
birds_df.index.name = "class_id"
birds_df.reset_index(inplace=True)
birds_df.head()

Unnamed: 0,class_id,1 has_bill_shape::curved_(up_or_down),2 has_bill_shape::dagger,3 has_bill_shape::hooked,4 has_bill_shape::needle,5 has_bill_shape::hooked_seabird,6 has_bill_shape::spatulate,7 has_bill_shape::all-purpose,8 has_bill_shape::cone,9 has_bill_shape::specialized,...,303 has_crown_color::pink,304 has_crown_color::orange,305 has_crown_color::black,306 has_crown_color::white,307 has_crown_color::red,308 has_crown_color::buff,309 has_wing_pattern::solid,310 has_wing_pattern::spotted,311 has_wing_pattern::striped,312 has_wing_pattern::multi-colored
0,1,0.010638,0.010638,0.007092,0.003546,0.138299,0.065603,0.0,0.005319,0.0,...,0.0,0.005439,0.005439,0.228446,0.0,0.0,0.18602,0.009186,0.025262,0.020669
1,2,0.0,0.011332,0.009444,0.0,0.202095,0.041552,0.01511,0.005666,0.0,...,0.006291,0.0,0.111144,0.008388,0.0,0.046135,0.202572,0.002665,0.021323,0.058639
2,3,0.0,0.0,0.007425,0.0,0.002475,0.0,0.0,0.074247,0.14602,...,0.0,0.0,0.190411,0.012555,0.0,0.010462,0.203609,0.0,0.008853,0.017705
3,4,0.0,0.0,0.003861,0.0,0.003861,0.013514,0.005792,0.07336,0.138998,...,0.004885,0.0,0.190531,0.0,0.0,0.0,0.15275,0.00684,0.036478,0.043317
4,5,0.0,0.035088,0.0,0.0,0.0,0.0,0.102458,0.070177,0.0,...,0.0,0.0,0.204036,0.002458,0.002458,0.0,0.03164,0.002751,0.015132,0.1582


In [8]:
classes = pd.DataFrame.from_dict(bird_classes, orient="index").reset_index()
classes.columns = ["class", "id"]
classes.head()

Unnamed: 0,class,id
0,001.Black_footed_Albatross,1
1,002.Laysan_Albatross,2
2,003.Sooty_Albatross,3
3,004.Groove_billed_Ani,4
4,005.Crested_Auklet,5


In [9]:
birds_df = birds_df.merge(classes, left_on="class_id", right_on="id")
birds_df = birds_df.drop(columns=["id"])

# Reorder columns to have class_id and class first
cols = ["class_id", "class"] + [c for c in birds_df.columns if c not in ["class_id", "class"]]
birds_df = birds_df[cols]
birds_df.head()

Unnamed: 0,class_id,class,1 has_bill_shape::curved_(up_or_down),2 has_bill_shape::dagger,3 has_bill_shape::hooked,4 has_bill_shape::needle,5 has_bill_shape::hooked_seabird,6 has_bill_shape::spatulate,7 has_bill_shape::all-purpose,8 has_bill_shape::cone,...,303 has_crown_color::pink,304 has_crown_color::orange,305 has_crown_color::black,306 has_crown_color::white,307 has_crown_color::red,308 has_crown_color::buff,309 has_wing_pattern::solid,310 has_wing_pattern::spotted,311 has_wing_pattern::striped,312 has_wing_pattern::multi-colored
0,1,001.Black_footed_Albatross,0.010638,0.010638,0.007092,0.003546,0.138299,0.065603,0.0,0.005319,...,0.0,0.005439,0.005439,0.228446,0.0,0.0,0.18602,0.009186,0.025262,0.020669
1,2,002.Laysan_Albatross,0.0,0.011332,0.009444,0.0,0.202095,0.041552,0.01511,0.005666,...,0.006291,0.0,0.111144,0.008388,0.0,0.046135,0.202572,0.002665,0.021323,0.058639
2,3,003.Sooty_Albatross,0.0,0.0,0.007425,0.0,0.002475,0.0,0.0,0.074247,...,0.0,0.0,0.190411,0.012555,0.0,0.010462,0.203609,0.0,0.008853,0.017705
3,4,004.Groove_billed_Ani,0.0,0.0,0.003861,0.0,0.003861,0.013514,0.005792,0.07336,...,0.004885,0.0,0.190531,0.0,0.0,0.0,0.15275,0.00684,0.036478,0.043317
4,5,005.Crested_Auklet,0.0,0.035088,0.0,0.0,0.0,0.0,0.102458,0.070177,...,0.0,0.0,0.204036,0.002458,0.002458,0.0,0.03164,0.002751,0.015132,0.1582


In [10]:
images_df = pd.read_csv("train_images.csv")
images_df['image_path'] = '.' + images_df['image_path']
images_df.head()

Unnamed: 0,image_path,label
0,./train_images/1.jpg,1
1,./train_images/2.jpg,1
2,./train_images/3.jpg,1
3,./train_images/4.jpg,1
4,./train_images/5.jpg,1


### Load training metadata and create train/validation split

In this step, we load the `train_images.csv` file that contains the image paths and labels.  
Then we create a stratified train/validation split so that all 200 classes are represented proportionally in both sets.  
This split will be used to train the CNN on `train_images` and evaluate it on `val_images`.


In [11]:
train_images, val_images = train_test_split(
    images_df,
    test_size=0.2,
    stratify=images_df["label"],
    random_state=42
)

len(train_images), len(val_images)

(3140, 786)

In [12]:
train_images.head()

Unnamed: 0,image_path,label
1249,./train_images/1250.jpg,42
3882,./train_images/3883.jpg,193
686,./train_images/687.jpg,23
1452,./train_images/1453.jpg,49
2357,./train_images/2358.jpg,85


### Define Image Transformations for Training and Validation

Before training a CNN, all images need to be preprocessed in a consistent way.  
Here, we define two sets of transformations:

**Training transforms**
- **Resize to 224×224:** ResNet models expect fixed-size input.
- **Random horizontal flip:** A simple data augmentation step to help the model generalize.
- **Convert to tensor:** Converts the image to a PyTorch tensor with values in `[0,1]`.
- **Normalize with ImageNet statistics:** Since ResNet18 was pretrained on ImageNet, the same normalization must be applied for best performance.

**Validation transforms**
- Same as above but **without augmentation**, to ensure a stable and deterministic evaluation.

These transforms prepare raw images so they can be passed into the CNN during training and validation.


In [13]:
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T

# image transforms (basic baseline)
train_transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

val_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

### Create a custom PyTorch Dataset for bird images

Here we define a `BirdsDataset` class that:
- Reads the image path and label from the DataFrame rows.
- Loads each image with PIL.
- Applies the appropriate transform (train or validation).
- Converts labels from 1–200 to 0–199 so they work with `nn.CrossEntropyLoss`.

This Dataset will be used together with a DataLoader to efficiently feed batches to the CNN.

In [14]:
class BirdsDataset(Dataset):
    def __init__(self, df, transform=None, use_attributes=False):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.use_attributes = use_attributes

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row["image_path"]
        label = int(row["label"]) - 1

        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)

        return img, label

### Wrap Datasets in DataLoaders

Now we create `DataLoader` objects for the training and validation sets.  
DataLoaders handle:
- Shuffling (for training),
- Batching,
- Parallel loading of images (with `num_workers`).

These will be used directly in the training and evaluation loops.


In [15]:
from torch.utils.data import DataLoader

batch_size = 32

train_dataset = BirdsDataset(train_images, transform=train_transform)
val_dataset   = BirdsDataset(val_images,   transform=val_transform)

batch_size = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

### Define Device (GPU or CPU)

In [16]:
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


## Define two models: Simple CNN and ResNet18

### Building a simple CNN from scratch (baseline CNN)

Before comparing with pretrained models like ResNet18, it's useful to build a classic convolutional neural network from scratch.  
This gives a "true baseline" — a model that only learns from the bird training images, without any prior ImageNet knowledge.

The custom CNN below contains:
- Three convolutional blocks (Conv → BatchNorm → ReLU → MaxPool)
- A flatten layer
- Two fully-connected layers
- A final output layer with 200 logits (one per bird species)

This model is lightweight, easy to understand, and suitable for verifying that the training loop and data pipeline work correctly.


In [19]:
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=200):
        super(SimpleCNN, self).__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)    # 224 → 112
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)    # 112 → 56
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)    # 56 → 28
        )

        # 28×28 feature map with 128 channels → flatten
        self.fc1 = nn.Linear(128 * 28 * 28, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)

        x = torch.flatten(x, 1)  # flatten all except batch
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [21]:
CNN_model = SimpleCNN(num_classes=200).to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
CNN_optimizer = torch.optim.Adam(CNN_model.parameters(), lr=3e-4)

### Define a baseline CNN model (ResNet18)

As a strong baseline, we use a pretrained `ResNet18` from `torchvision.models`:
- We load ImageNet-pretrained weights.
- We replace the final fully-connected layer so it outputs 200 logits (one per bird class).
- The rest of the network acts as a feature extractor.

This gives a solid starting point for accuracy without heavy custom architecture work.


In [22]:
# Load pretrained ResNet18
weights = models.ResNet18_Weights.IMAGENET1K_V1
ResNet_model = models.resnet18(weights=weights)

# Replace the final layer to match 200 classes
num_features = ResNet_model.fc.in_features
ResNet_model.fc = nn.Linear(num_features, 200)

ResNet_model = ResNet_model.to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
ResNet_optimizer = torch.optim.Adam(ResNet_model.parameters(), lr=1e-4)

### Define training and validation loops

Here we implement two functions:

- `train_one_epoch`: runs one epoch over the training set, updates weights, and tracks loss and accuracy.
- `evaluate`: runs one full pass over the validation set without gradient updates, and reports loss and accuracy.

These utilities keep the main training loop clean and readable, and allow easy reuse later.

In [23]:
from tqdm import tqdm

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for imgs, labels in tqdm(loader, desc="Train", leave=False):
        imgs = imgs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Val", leave=False):
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


### Train the CNN baseline model and monitor accuracy

We run the training for a few epochs.  
For each epoch, we log:
- Training loss and accuracy
- Validation loss and accuracy

We also keep track of the best validation accuracy and save the model weights whenever a new best score is reached.  
This gives me a first baseline performance for the bird classification task.


In [26]:
num_epochs = 5
best_val_acc = 0.0

for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}/{num_epochs}")
    
    train_loss, train_acc = train_one_epoch(CNN_model, train_loader, CNN_optimizer, criterion, device)
    val_loss, val_acc = evaluate(CNN_model, val_loader, criterion, device)

    print(f"  Train  | loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    print(f"  Val    | loss: {val_loss:.4f}, acc: {val_acc:.4f}")

    # Save best ResNet_model
    if val_acc > best_val_acc:  
        best_val_acc = val_acc
        torch.save(CNN_model.state_dict(), "best_CNN_baseline.pt")
        print(f"New best model saved with val_acc = {best_val_acc:.4f}")


Epoch 1/5


                                                                                

  Train  | loss: 6.6179, acc: 0.0073
  Val    | loss: 5.2867, acc: 0.0051
New best model saved with val_acc = 0.0051
Epoch 2/5


                                                                                

  Train  | loss: 5.2842, acc: 0.0070
  Val    | loss: 5.2782, acc: 0.0115
New best model saved with val_acc = 0.0115
Epoch 3/5


                                                                                

  Train  | loss: 5.2738, acc: 0.0076
  Val    | loss: 5.2830, acc: 0.0153
New best model saved with val_acc = 0.0153
Epoch 4/5


                                                                                

  Train  | loss: 5.2472, acc: 0.0143
  Val    | loss: 5.2639, acc: 0.0115
Epoch 5/5


                                                                                

  Train  | loss: 5.2037, acc: 0.0169
  Val    | loss: 5.2339, acc: 0.0216
New best model saved with val_acc = 0.0216


### Train the ResNet18 baseline model and monitor accuracy


In [None]:
num_epochs = 5
best_val_acc = 0.0

for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}/{num_epochs}")
    
    train_loss, train_acc = train_one_epoch(ResNet_model, train_loader, ResNet_optimizer, criterion, device)
    val_loss, val_acc = evaluate(ResNet_model, val_loader, criterion, device)

    print(f"  Train  | loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    print(f"  Val    | loss: {val_loss:.4f}, acc: {val_acc:.4f}")

    # Save best ResNet_model
    if val_acc > best_val_acc:  
        best_val_acc = val_acc
        torch.save(ResNet_model.state_dict(), "best_resnet18_baseline.pt")
        print(f"New best model saved with val_acc = {best_val_acc:.4f}")


Epoch 1/5


                                                      

  Train  | loss: 4.6993, acc: 0.1127
  Val    | loss: 3.7666, acc: 0.2761
New best model saved with val_acc = 0.2761
Epoch 2/5


                                                      

  Train  | loss: 3.1974, acc: 0.4427
  Val    | loss: 2.9975, acc: 0.4262
New best model saved with val_acc = 0.4262
Epoch 3/5


                                                      

  Train  | loss: 2.3340, acc: 0.6366
  Val    | loss: 2.6871, acc: 0.4758
New best model saved with val_acc = 0.4758
Epoch 4/5


                                                      

  Train  | loss: 1.6966, acc: 0.7707
  Val    | loss: 2.2902, acc: 0.5229
New best model saved with val_acc = 0.5229
Epoch 5/5


                                                      

  Train  | loss: 1.2163, acc: 0.8774
  Val    | loss: 2.1153, acc: 0.5394
New best model saved with val_acc = 0.5394




### Re-load the ResNet18 model (No need to run)

In [None]:
ResNet_model = models.resnet18(weights=None)  # initialize architecture
num_features = ResNet_model.fc.in_features
ResNet_model.fc = nn.Linear(num_features, 200)

ResNet_model.load_state_dict(torch.load("best_resnet18_baseline.pt", map_location=device))
ResNet_model = ResNet_model.to(device)
ResNet_model.eval()

## Load `Falconsai/nsfw_image_detection` and adapt it for 200 bird classes

The `Falconsai/nsfw_image_detection` model is a ViT-based image classifier originally trained for 2 classes
(`normal` vs `nsfw`). I reuse the pretrained backbone and:

1. Load the model and its image processor from Hugging Face.
2. Replace the final classification layer (`classifier`) so that it outputs 200 logits (one per bird class).
3. Update the config metadata (`num_labels`, `id2label`, `label2id`) for consistency.

This gives me a strong transformer-based model specialized for my 200 bird classes.


In [17]:
## !pip install transformers

from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
import torch.nn as nn


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model_name = "Falconsai/nsfw_image_detection"

# Image processor: handles resize, normalize, etc. for ViT
processor = AutoImageProcessor.from_pretrained(model_name)

# Load ViT-based image classification model
vit_model = AutoModelForImageClassification.from_pretrained(model_name)

print("Original num_labels:", vit_model.config.num_labels)

# Replace classifier head to output 200 classes
num_features = vit_model.classifier.in_features
vit_model.classifier = nn.Linear(num_features, 200)

# Update config info
vit_model.config.num_labels = 200
vit_model.num_labels = 200
vit_model.config.id2label = {i: f"class_{i+1}" for i in range(200)}
vit_model.config.label2id = {v: k for k, v in vit_model.config.id2label.items()}

vit_model = vit_model.to(device)
print("Adapted num_labels:", vit_model.config.num_labels)


Using device: cpu


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Original num_labels: 2
Adapted num_labels: 200


### Create a Dataset that uses the ViT image processor

For the ViT model, I no longer use the `torchvision` transforms.
Instead, I use the Hugging Face `AutoImageProcessor`, which:

- Resizes the image to the correct resolution (224×224 for ViT)
- Converts it to a tensor
- Applies the exact normalization used during pretraining

I define a `BirdsDatasetViT` class that:
- Takes the same `train_df` / `val_df` as before (with `image_path` and `label`)
- Loads each image with PIL
- Runs the image through the processor to get `pixel_values`
- Returns `(pixel_values, label)` where labels are 0–199


In [20]:
from torch.utils.data import Dataset
from PIL import Image
import os

class BirdsDatasetViT(Dataset):
    def __init__(self, df, processor, base_dir=".", label_col="label", path_col="image_path"):
        self.df = df.reset_index(drop=True)
        self.processor = processor
        self.base_dir = base_dir
        self.label_col = label_col
        self.path_col = path_col

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        raw_path = str(row[self.path_col])
        # Fix leading "/" → make it relative
        rel_path = raw_path.lstrip("/")
        img_path = os.path.join(self.base_dir, rel_path)

        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        img = Image.open(img_path).convert("RGB")

        # Use HF processor to get ViT-ready pixel_values
        inputs = self.processor(images=img, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)  # (3, H, W)

        label = int(row[self.label_col]) - 1  # 1–200 → 0–199

        return pixel_values, label


### Create DataLoaders for the ViT-based model

Now I wrap the `BirdsDatasetViT` in PyTorch DataLoaders.
On macOS, using `num_workers=0` avoids multiprocess issues while debugging.
These loaders will feed ViT-ready `pixel_values` and labels into the training loop.


In [21]:
from torch.utils.data import DataLoader

# Adjust this path: the folder that contains `train_images/`
# If your notebook is already in the project root, "." is fine.
base_dir = "."  

batch_size = 32

train_dataset_vit = BirdsDatasetViT(train_images, processor=processor, base_dir=base_dir)
val_dataset_vit   = BirdsDatasetViT(val_images,   processor=processor, base_dir=base_dir)

train_loader_vit = DataLoader(
    train_dataset_vit,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

val_loader_vit = DataLoader(
    val_dataset_vit,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)

### Training and validation loops for the ViT model

The training logic is the same as before, but the forward pass changes slightly:

- For ResNet: `outputs = model(images)`
- For ViT (Hugging Face): `outputs = vit_model(pixel_values=images)`

From `outputs`, I use `outputs.logits` and compute cross-entropy loss as usual.
The rest of the loop (accuracy computation, backprop, logging) is unchanged.


In [22]:
from tqdm import tqdm

def train_one_epoch_vit(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for pixel_values, labels in tqdm(loader, desc="Train (ViT)", leave=False):
        pixel_values = pixel_values.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values)
        logits = outputs.logits

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * pixel_values.size(0)
        _, preds = logits.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


def evaluate_vit(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for pixel_values, labels in tqdm(loader, desc="Val (ViT)", leave=False):
            pixel_values = pixel_values.to(device)
            labels = labels.to(device)

            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits

            loss = criterion(logits, labels)

            total_loss += loss.item() * pixel_values.size(0)
            _, preds = logits.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


### Fine-tune the Falconsai ViT as a "ceiling" model

Now I fine-tune the adapted ViT model on the bird dataset.
This model:

- Starts from a powerful Vision Transformer backbone
- Has a new classification head for 200 bird classes
- Is expected to perform at least as well as ResNet18, and possibly better,
  giving me an approximate performance "ceiling" for this assignment.

I reuse the same training hyperparameters as a starting point and monitor train/validation accuracy.


In [None]:
criterion_vit = nn.CrossEntropyLoss()
optimizer_vit = torch.optim.Adam(vit_model.parameters(), lr=1e-4)

num_epochs_vit = 5
best_val_acc_vit = 0.0

for epoch in range(1, num_epochs_vit + 1):
    print(f"Epoch {epoch}/{num_epochs_vit} (ViT)")
    
    train_loss, train_acc = train_one_epoch_vit(vit_model, train_loader_vit, optimizer_vit, criterion_vit, device)
    val_loss, val_acc = evaluate_vit(vit_model, val_loader_vit, criterion_vit, device)

    print(f"  Train (ViT) | loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    print(f"  Val   (ViT) | loss: {val_loss:.4f}, acc: {val_acc:.4f}")

    if val_acc > best_val_acc_vit:
        best_val_acc_vit = val_acc
        torch.save(vit_model.state_dict(), "vit_nsfw_birds_state_dict.pt")
        print(f" New best ViT model saved with val_acc = {best_val_acc_vit:.4f}")


Epoch 1/5 (ViT)


Train (ViT):   1%|▎                            | 1/99 [01:20<2:11:58, 80.80s/it]