In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.models as models
from torchvision.io import read_image
import torchvision.transforms.v2 as transforms
import lightning as L
import matplotlib.pyplot as plt

In [2]:
class Model(L.LightningModule):
    def __init__(self, num_classes=100):
        super().__init__()

        weights = models.EfficientNet_B2_Weights.IMAGENET1K_V1
        self.preprocess = weights.transforms()
        backbone = models.efficientnet_b2(weights=weights)
        num_filters = 1408
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)
        # self.feature_extractor.eval()
        # for param in self.feature_extractor.parameters():
        #     param.requires_grad = False

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.3, inplace=True),
            nn.Linear(num_filters, num_classes)
        )

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = self.feature_extractor(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        loss = F.cross_entropy(x, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = self.feature_extractor(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        loss = F.cross_entropy(x, y)
        self.log("val_loss", loss, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [3]:
model = Model()

In [4]:
class Food2kDataset(data.Dataset):
    def __init__(self, img_dir, transform=None):
        classes = os.listdir(img_dir)
        self.images = []
        for class_name in classes:
            for image in os.listdir(os.path.join(img_dir, class_name)):
                self.images.append((os.path.join(img_dir, class_name, image), int(class_name)))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = read_image(self.images[idx][0])
        label = self.images[idx][1]
        if self.transform:
            image = self.transform(image)
        return image, label

In [5]:
train_set = Food2kDataset('data/train', transforms.Compose([
                    transforms.ToTensor(),
                    model.preprocess
                ]))
val_set = Food2kDataset('data/val', transforms.Compose([
                    transforms.ToTensor(),
                    model.preprocess
                ]))
test_set = Food2kDataset('data/test', transforms.Compose([
                    transforms.ToTensor(),
                    model.preprocess
                ]))

train_loader = data.DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = data.DataLoader(val_set, batch_size=16)



In [None]:
trainer = L.Trainer(accelerator="gpu", devices=1)
trainer.fit(model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
C:\Users\Keanu Thakalath\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\loops\utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                | Params
----------------------------------------------------------
0 | preprocess        | ImageClassification | 0     
1 | feature_extractor | Sequential          | 7.7 M 
2 | classifier        | Sequential          | 140 K 
----------------------------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.368    Total estimated model params size (MB)


Sanity Checking: |                                                                               | 0/? [00:00<?, ?it/s]

C:\Users\Keanu Thakalath\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


                                                                                                                       

C:\Users\Keanu Thakalath\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Epoch 0:  55%|██████████████████████▉                   | 1119/2045 [05:29<04:32,  3.39it/s, v_num=2, train_loss=0.543]

In [None]:
model.cpu()
model.eval()

In [None]:
x, y = next(iter(val_loader))
pred = model.feature_extractor(x)
pred = torch.flatten(pred, start_dim=1)
pred = model.classifier(pred)

In [None]:
ex = 1
plt.imshow(x[ex][0])
print(f"Label: {y[ex].item()} Predicted: {pred[ex].argmax().item()}")