In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchmetrics import Accuracy

from typing import Any, Literal
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path = os.path.dirname(os.path.realpath('__file__'))
train_path = os.path.join(path, 'inaturalist_12K/train/')
test_path = os.path.join(path, 'inaturalist_12K/val/')

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
])
dataset = ImageFolder(train_path, transform=transform)

train_size = int(0.7 * len(dataset))
val_size = (len(dataset) - train_size) * 2 // 3
tu_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, tune_dataset = random_split(dataset, [train_size, val_size, tu_size])

test_dataset = ImageFolder(test_path, transform=transform)


In [4]:
len(train_dataset), len(val_dataset), len(test_dataset), len(tune_dataset)

(6999, 2000, 2000, 1000)

In [5]:
dl = DataLoader(train_dataset, batch_size=32, shuffle=True)
for x, y in dl:
    print(x.shape, y.shape)
    break

torch.Size([32, 3, 224, 224]) torch.Size([32])


In [11]:
class ConvolutionBlock(pl.LightningModule):
    def __init__(
            self, in_channels: int, out_channels: int, kernel_size: int,
            stride: int, padding: int, batch_norm: bool=True,
            activation: Literal['relu', 'gelu', 'silu', 'mish']="relu"
            ):
        super(ConvolutionBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        if batch_norm:
            self.bn = nn.BatchNorm2d(out_channels)
        match activation:
            case "relu":
                self.activation = nn.ReLU()
            case "gelu":
                self.activation = nn.GELU()
            case "silu":
                self.activation = nn.SiLU()
            case "mish":
                self.activation = nn.Mish()
            case _:
                self.activation = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        x = self.conv(x)
        if self.bn:          
            x = self.bn(x)
        x = self.activation(x)
        x = self.maxpool(x)  
        return x
    
class CNNBase(pl.LightningModule):
    def __init__(
            self, in_channels: int, out_channels: int,
            kernel_size: int, stride: int, padding: int,
            batch_norm: bool=True, activation: Literal['relu', 'gelu', 'silu', 'mish']="relu",
            kernel_strategy: Literal['same', 'double', 'half'] = 'same'
            ):
        super(CNNBase, self).__init__()
        if kernel_strategy == 'same':
            coeff = 1
        elif kernel_strategy == 'double':
            coeff = 2
        elif kernel_strategy == 'half':
            coeff = 0.5
        for i in range(1, 6):
            setattr(self, f"conv{i}", ConvolutionBlock(in_channels, out_channels, kernel_size, stride, padding, batch_norm, activation))
            in_channels = out_channels
            out_channels *= coeff
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

class ClassifierHead(pl.LightningModule):
    def __init__(self, num_classes: int, in_size: int, hidden_size: int, dropout: float=0.0, activation: Literal['relu', 'gelu', 'silu', 'mish']='relu') -> None:
        super(ClassifierHead, self).__init__()
        self.fc1 = nn.Linear(in_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(dropout)
        match activation:
            case "relu":
                self.activation = nn.ReLU()
            case "gelu":
                self.activation = nn.GELU()
            case "silu":
                self.activation = nn.SiLU()
            case "mish":
                self.activation = nn.Mish()
            case _:
                self.activation = nn.ReLU()
        self.o_activation = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.activation(self.dropout(self.fc1(x)))
        x = self.o_activation(self.fc2(x))
        return x
    


In [12]:
class NeuralNetwork(pl.LightningModule):
    def __init__(
            self, in_channels: int, out_channels: int,  # Convolutional Layers
            kernel_size: int, stride: int, padding: int,
            batch_norm: bool=True, activation: Literal['relu', 'gelu', 'silu', 'mish']="relu",
            kernel_strategy: Literal['same', 'double', 'half'] = 'same',
            dropout: float=0.0, num_classes: int=10, hidden_size: int=64   # Fully-Connected Layers
            ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.cnn = CNNBase(in_channels, out_channels, kernel_size, stride, padding, batch_norm, activation, kernel_strategy)
        in_size = self.get_in_size()
        self.classifier = ClassifierHead(num_classes, in_size, hidden_size, dropout, activation)
        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)
    
    def get_in_size(self):
        x = torch.randn(1, 3, 224, 224)
        x = self.cnn(x)
        return x.numel()
    
    def forward(self, x):
        x = self.cnn(x)
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx) -> Any:
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx) -> Any:
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def configure_optimizers(self) -> None:
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    def train_dataloader(self) -> DataLoader:
        return DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    
    def test_dataloader(self) -> DataLoader:
        return DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    def val_dataloader(self) -> DataLoader:
        return DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [13]:
model = NeuralNetwork(3, 32, 3, 1, 1, batch_norm=True, activation='gelu', kernel_strategy='double', dropout=0.0, num_classes=10, hidden_size=64)
logger = pl.loggers.WandbLogger(project='iNaturalist', name='cnn-test')
callbacks = [pl.callbacks.ModelCheckpoint(monitor='val_acc', mode='max', save_top_k=1, save_last=True)]
trainer = pl.Trainer(max_epochs=10, logger=logger, callbacks=callbacks)
# trainer = pl.Trainer(max_epochs=10, fast_dev_run=True)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
trainer.fit(model)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params
--------------------------------------------------
0 | cnn        | CNNBase            | 1.6 M 
1 | classifier | ClassifierHead     | 1.6 M 
2 | accuracy   | MulticlassAccuracy | 0     
--------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.708    Total estimated model params size (MB)


                                                                           