In [4]:
import os
import torch
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from pytorch_lightning.callbacks import ModelCheckpoint


## Load and Pre-process Data

In [5]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
        ])

    def prepare_data(self):
        # Download CIFAR-10 dataset
        CIFAR10(root="data", train=True, download=True)
        CIFAR10(root="data", train=False, download=True)

    def setup(self, stage=None):
        # Split dataset into train and validation sets
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(root="data", train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(root="data", train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)


## Define the CNN Architecture

In [9]:
class CIFAR10Classifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


## Train the Model

In [12]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='my_model',
    filename='cifar10-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

data_module = CIFAR10DataModule()
model = CIFAR10Classifier()
trainer = Trainer(max_epochs=2, callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=data_module)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


/Users/krishan/Documents/GitHub/Image-Classification-with-Convolutional-Neural-Networks/mle/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/krishan/Documents/GitHub/Image-Classification-with-Convolutional-Neural-Networks/my_model exists and is not empty.

  | Name    | Type      | Params
--------------------------------------
0 | conv1   | Conv2d    | 896   
1 | conv2   | Conv2d    | 18.5 K
2 | conv3   | Conv2d    | 73.9 K
3 | pool    | MaxPool2d | 0     
4 | fc1     | Linear    | 1.0 M 
5 | fc2     | Linear    | 5.1 K 
6 | dropout | Dropout   | 0     
--------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.590     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/krishan/Documents/GitHub/Image-Classification-with-Convolutional-Neural-Networks/mle/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/Users/krishan/Documents/GitHub/Image-Classification-with-Convolutional-Neural-Networks/mle/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 1407/1407 [00:12<00:00, 112.84it/s, v_num=2]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/157 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/157 [00:00<?, ?it/s]
Validation DataLoader 0:   1%|          | 1/157 [00:00<00:00, 257.52it/s]
Validation DataLoader 0:   1%|▏         | 2/157 [00:00<00:00, 226.91it/s]
Validation DataLoader 0:   2%|▏         | 3/157 [00:00<00:00, 202.55it/s]
Validation DataLoader 0:   3%|▎         | 4/157 [00:00<00:00, 200.90it/s]
Validation DataLoader 0:   3%|▎         | 5/157 [00:00<00:00, 201.93it/s]
Validation DataLoader 0:   4%|▍         | 6/157 [00:00<00:00, 200.99it/s]
Validation DataLoader 0:   4%|▍         | 7/157 [00:00<00:00, 200.32it/s]
Validation DataLoader 0:   5%|▌         | 8/157 [00:00<00:00, 198.68it/s]
Validation DataLoader 0:   6%|▌         | 9/157 [00:00<00:00, 199.22it/s]
Validation DataLoader 0:   6%|▋         | 10/157 [00:00<00:00, 199.48it/s]
Validation DataLoader 0: 

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 1407/1407 [00:13<00:00, 102.27it/s, v_num=2]


## Evaluate the Model


In [14]:
# Evaluate the model
trainer.test(datamodule=data_module)



Files already downloaded and verified
Files already downloaded and verified


Restoring states from the checkpoint path at /Users/krishan/Documents/GitHub/Image-Classification-with-Convolutional-Neural-Networks/my_model/cifar10-epoch=01-val_loss=0.93.ckpt
Loaded model weights from the checkpoint at /Users/krishan/Documents/GitHub/Image-Classification-with-Convolutional-Neural-Networks/my_model/cifar10-epoch=01-val_loss=0.93.ckpt
/Users/krishan/Documents/GitHub/Image-Classification-with-Convolutional-Neural-Networks/mle/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 313/313 [00:01<00:00, 232.07it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.9320631623268127
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.9320631623268127}]