In [1]:
import pandas as pd
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Resize, Normalize 
from torch.utils.data import DataLoader
from torch.optim import Adam

from data.transforms import flatten_transform, scale_tanh_range
from models.mixer import MLPMixer
from training.train import train_classifier
from viz.plot import plot_training_curves

In [9]:
data_transform = Compose([
    ToTensor(),
    scale_tanh_range(),
    Resize((64, 64)),
    # flatten_transform()
])

train_imgs = ImageFolder(
    root="imagenette2-160/train",
    transform=data_transform
)
val_imgs = ImageFolder(
    root="imagenette2-160/val",
    transform=data_transform
)

In [10]:
train_imgs[0]

(tensor([[[ 0.1595,  0.0251, -0.0744,  ...,  0.1543,  0.1607,  0.1848],
          [ 0.0412,  0.0158, -0.0642,  ...,  0.1534,  0.1525,  0.1534],
          [-0.0153,  0.2299,  0.2388,  ...,  0.1753,  0.1688,  0.1670],
          ...,
          [-0.0062, -0.0307, -0.0325,  ...,  0.0330,  0.0652, -0.0286],
          [-0.0067,  0.0254, -0.0492,  ...,  0.0293,  0.0009,  0.0326],
          [-0.0226, -0.0116, -0.0378,  ...,  0.0354, -0.0156,  0.0287]],
 
         [[ 0.1305, -0.0091, -0.0929,  ...,  0.2205,  0.2233,  0.2363],
          [ 0.0195, -0.0108, -0.0808,  ...,  0.2132,  0.2082,  0.2015],
          [-0.0260,  0.2142,  0.2237,  ...,  0.2220,  0.2137,  0.2099],
          ...,
          [ 0.1333,  0.0989,  0.0833,  ...,  0.0776,  0.1177,  0.0333],
          [ 0.1166,  0.1449,  0.0621,  ...,  0.0810,  0.0493,  0.0787],
          [ 0.0852,  0.0939,  0.0657,  ...,  0.0904,  0.0363,  0.0771]],
 
         [[ 0.0523, -0.0419, -0.1970,  ...,  0.1971,  0.1903,  0.1981],
          [-0.0736, -0.0707,

In [11]:
import torch
from torch import nn

class ImagenetteClassifier(nn.Module):
    def __init__(
        self, 
        input_size: int, 
        hidden_size: int, 
        num_classes: int
    ):
        super().__init__()
        self.model = nn.Sequential(
            nn.BatchNorm1d(input_size),
            nn.Linear(input_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, num_classes)
        )
    
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.model(inputs)

In [12]:
train_loader = DataLoader(
    train_imgs,
    batch_size=64,
    shuffle=True
)

val_loader = DataLoader(
    val_imgs,
    batch_size=64,
    shuffle=False
)

In [13]:
device = torch.device("mps")

In [14]:
classifier = MLPMixer(
    image_size=(3, 64, 64),
    patch_size=16,
    hidden_dim=128,
    n_classes=len(train_imgs.class_to_idx)
).to(device)

optimizer = Adam(classifier.parameters(), lr=5e-4)
loss_fn = nn.CrossEntropyLoss()

In [15]:
classifier, *metrics_history = train_classifier(
    classifier,
    optimizer,
    loss_fn,
    train_loader,
    val_loader,
    device,
    n_epochs=10,
    eval_steps=0.1
)

Training model: 100%|██████████| 1480/1480 [04:41<00:00,  5.25it/s, loss=1.19, eval_loss=1.58, eval_acc=0.479] 


In [None]:
plot_training_curves(*metrics_history, title="Imagenette Classifier Training (hidden_size=256)")