In [None]:
!pip install wandb

In [None]:
!wandb login API-KEY

In [16]:
import argparse
import os
import os.path

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader, random_split

import wandb
from shared_methods import all_labels
from transformers import ViTImageProcessor
from huggingface_pretrained import get_vision_transformer

# Check if we can use Cuda

In [17]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

#device = "cpu" # uncomment if you want to use "cpu", currently cpu is faster than cuda (maybe because the NN is very little)
print(f"Using {device} device")

Using cuda device


# Initialize wandb

In [18]:
config = argparse.Namespace()
config.learning_rate = 0.01
config.epochs = 2
config.batch_size = 32

# Creating a custom Dataset Class

In [26]:
def get_label_for_image_path(character):
    return all_labels.index(character)  # this returns an integer


class SimpsonsImageDataset(Dataset):
    def __init__(self, root_dir, _transform):
        self.root_dir = root_dir
        self.transform = _transform
        self.samples = []

        for character in os.listdir(self.root_dir):
            char_dir = os.path.join(self.root_dir, character)
            if os.path.isdir(char_dir):
                for filename in os.listdir(char_dir):
                    if filename.endswith(".jpg"):
                        img_path = os.path.join(char_dir, filename)
                        label = get_label_for_image_path(character)

                        # we just store the image path here, not the tensors
                        # we create the tensor when calling __getitem__, that's why
                        # the Dataset is created very fast
                        self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")

        # here we transform the image to a tensor using this function:
        # processor(images=image, return_tensors="pt")
        _inputs = self.transform(image)

        return _inputs, label

# Creating the Dataset

In [27]:
# initialize the image processor, which is based on a pre-trained model,
# it handles tasks like resizing the image, normalizing it and converting it to a tensor
processor  = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

# transform method which creates a tensor from an image
transform = lambda image: processor (image, return_tensors="pt")

# Initialize custom dataset and dataloader
simpsons_dataset = SimpsonsImageDataset(root_dir="data/train_vid", _transform=transform)
dataloader = DataLoader(simpsons_dataset, batch_size=32, shuffle=True)

# Splitting into train and test set

In [28]:
total_size = len(simpsons_dataset)
train_size = int(0.8 * total_size)
validation_size = total_size - train_size

train_dataset, validation_dataset = random_split(simpsons_dataset, [train_size, validation_size])

# Create the dataloader for training and validation

In [29]:
# this function is used to customize the way batches of data are collided (combined)
# before they are passed to the model for training
def custom_collate_fn(batch):
    # batch is a list of tuples --> each tuple contains an image tensor and its corresponding label

    # stack all the tensors from this batch into a single tensor and remove any singleton dimensions at index 1
    # of size 1. If there is no dimension at index 1 of size 1, squeeze(1) does nothing
    # It's often included as a precautionary measure to handle cases where a singleton dimension might be introduced during data preprocessing or loading.
    pixel_values = torch.stack([item[0]['pixel_values'] for item in batch]).squeeze(1)

    # create a tensor from the label
    _labels = torch.tensor([item[1] for item in batch], dtype=torch.long)
    return pixel_values, _labels

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=custom_collate_fn)
validation_loader = DataLoader(validation_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=custom_collate_fn)

print(f"I am using {len(train_dataset)} images for training and {len(validation_dataset)} images for validation")

I am using 32 images for training and 8 images for validation


# Neural Network Architecture, loss function and optimizer

In [30]:
# create a complete CNN
model = get_vision_transformer()
config.model = model.__class__
print(model)

model.to(device)

# loss function
loss_function = nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
# adding a scheduler to reduce the learning_rate as soon as the validation loss stops decreasing.
# this is to try to prevent overfitting of the model
scheduler = ReduceLROnPlateau(optimizer, 'min')  # 'min' means reducing the LR when the metric stops decreasing

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

# Train the Network

In [33]:
wandb.init(project="kaggle-simpsons", config=vars(config))

In [34]:
wandb.watch(model)

best_val_loss = float('inf')

for epoch in range(1, config.epochs + 1):
    # Training
    model.train()
    train_loss = 0.0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        if batch_idx % 50 == 0:
            print(f"batch {batch_idx} from {len(train_loader)} ...")

        labels = labels.to(dtype=torch.long)

        optimizer.zero_grad()

        # the output of the Huggingface model is a dictionary that contains an entry 'logits'
        # logits are the raw scores output by the model, which can be converted to probabilities using a softmax function
        output = model(inputs.to(device))['logits']

        loss = loss_function(output, labels.to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in validation_loader:
            labels = labels.to(dtype=torch.long)
            output = model(inputs.to(device))['logits']

            loss = loss_function(output, labels.to(device))
            val_loss += loss.item()

    # step the scheduler - adjust the learning rate if validation loss stops decresing
    scheduler.step(val_loss)

    print(
        f"Epoch {epoch}, Train Loss: {train_loss / len(train_loader)}, Validation Loss: {val_loss / len(validation_loader)}")
    wandb.log({'epoch': epoch, 'training loss': train_loss, 'validation loss': val_loss,
               'adjusted learning rate': optimizer.param_groups[0]['lr']})

    # Save model if validation loss has decreased
    if val_loss < best_val_loss:
        torch.save(model.state_dict(), "best_model.pth")
        best_val_loss = val_loss


batch 0 from 1 ...
Epoch 1, Train Loss: 3.422445297241211, Validation Loss: 3.265557289123535
batch 0 from 1 ...
Epoch 2, Train Loss: 3.269341230392456, Validation Loss: 3.117297410964966


In [35]:
  wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
adjusted learning rate,▁▁
epoch,▁█
training loss,█▁
validation loss,█▁

0,1
adjusted learning rate,0.01
epoch,2.0
training loss,3.26934
validation loss,3.1173
