In [1]:
import os
import torch
import random
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms

from tqdm import tqdm
from plotly_utils import line
from dataclasses import dataclass
from torch.utils.data import DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
train_dataset = torch.load('/home/user/Capstone/Israel/LLM Course/dataset/filtered_train_dataset.pth')
test_dataset = torch.load('/home/user/Capstone/Israel/LLM Course/dataset/filtered_test_dataset.pth')

  train_dataset = torch.load('/home/user/Capstone/Israel/LLM Course/dataset/filtered_train_dataset.pth')
  test_dataset = torch.load('/home/user/Capstone/Israel/LLM Course/dataset/filtered_test_dataset.pth')


In [4]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [11]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class MiniResNet(nn.Module):
    def __init__(self, block, num_blocks):
        super(MiniResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0])
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.layer5 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.layer6 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, 4)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1) 
        x = self.fc(x)
        return x

# MinResNet instantiation
def MiniResNet34():
    return MiniResNet(BasicBlock, [3, 4, 6, 3])

In [6]:
model = MiniResNet34()
sum(p.numel() for p in model.parameters())

21279044

In [12]:
@dataclass
class ResNetCNNTrainingArgs():
    batch_size: int = 64
    epochs: int = 20
    lr: float = 1e-3
    max_steps_per_epoch: int = None 



class ResNetCNNTrainer:
    def __init__(self, args: ResNetCNNTrainingArgs, model: MiniResNet):
        self.model = model
        self.args = args
        self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
        self.criterion = nn.CrossEntropyLoss()
        self.loss_list = []

    def training_step(self, batch):
        '''
        Performs a single training step.
        '''
        images, labels = batch
        images, labels = images.to(device), labels.to(device)

        outputs = self.model(images)
        loss = self.criterion(outputs, labels)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        self.loss_list.append(loss.item())
        return loss.item()

    def validation_step(self, batch):
        '''
        Performs a single validation step.
        '''
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)

        outputs = self.model(images)

        _, predicted = torch.max(outputs, 1)
        
        correct = (predicted == labels).sum().item()
        total = labels.size(0)
        
        return correct, total

    def train(self):
        '''
        Trains the model for `self.args.epochs` epochs, and evaluates on validation set.
        '''
        self.args.max_steps_per_epoch = self.calculate_max_steps_per_epoch()
        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)
    
        for epoch in range(self.args.epochs):
            self.model.train()
            running_loss = 0.0
            for i, batch in enumerate(self.train_loader()):
                loss = self.training_step(batch)
                progress_bar.update(1)
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.5f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            self.model.eval()
            correct_predictions = 0
            total_predictions = 0
            with torch.no_grad(): 
                for batch in self.test_loader():
                    correct, total = self.validation_step(batch)
                    correct_predictions += correct
                    total_predictions += total
            accuracy = correct_predictions / total_predictions * 100
            print(f"Epoch [{epoch+1}/{self.args.epochs}], Validation Accuracy: {accuracy:.2f}%")

        line(
            self.loss_list,
            yaxis_range=[0, max(self.loss_list) + 0.1],
            x=torch.linspace(0, self.args.epochs, len(self.loss_list)),
            labels={"x": "Num epochs", "y": "Training Loss"},
            title="CNN Training Loss",
            width=700,
        )

    def train_loader(self) -> DataLoader:
        '''Returns train loader.'''
        return DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=2, pin_memory=True)

    def test_loader(self) -> DataLoader:
        '''Returns test loader.'''
        return DataLoader(test_dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=2, pin_memory=True)

    def calculate_max_steps_per_epoch(self):
        dataset_size = len(self.train_loader().dataset)
        return dataset_size // self.args.batch_size

In [None]:
args = ResNetCNNTrainingArgs(batch_size=64, epochs=30, lr=5e-4)
model = MiniResNet34().to(device)
trainer = ResNetCNNTrainer(args, model)

trainer.train()

Epoch 1, loss: 1.22235:   2%|▏         | 165/9360 [00:08<08:15, 18.57it/s]

In [10]:
model = trainer.model
os.makedirs('/home/user/Capstone/Israel/LLM Course/models', exist_ok=True)
torch.save(model.state_dict(), '/home/user/Capstone/Israel/LLM Course/models/animal_classifier.pth')

print("Model training complete and saved as 'animal_classifier.pth'.")

Model training complete and saved as 'animal_classifier.pth'.
