In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms

!pip install kornia
from kornia import augmentation as K
from kornia.augmentation import AugmentationSequential
from torch.utils.data import random_split

import numpy as np

import os
import time
from pathlib import Path
import pickle

from torch.utils.data import Subset
from sklearn.model_selection import train_test_split

!pip install kymatio
from kymatio.torch import Scattering2D

DEBUG = False
MODEL_NAME = "baseline_100_scat"
TRAIN_SIZE = 1000
VALIDATION_SIZE = 5000
env = 'kaggle' # 'kaggle' or 'colab'

if env == 'colab':
    from google.colab import drive
    drive.mount('/content/drive')
    base_dir = Path('/content/drive/MyDrive/dl_pj')    
elif env == 'kaggle':
    base_dir = Path('/kaggle/working/')

checkpoint_dir = base_dir / 'checkpoints'
checkpoint_dir.mkdir(parents=True, exist_ok=True)
training_stats_dir = base_dir / 'stats'
training_stats_dir.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Collecting kymatio
  Downloading kymatio-0.3.0-py3-none-any.whl.metadata (9.6 kB)
Collecting appdirs (from kymatio)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting configparser (from kymatio)
  Downloading configparser-7.2.0-py3-none-any.whl.metadata (5.5 kB)
Downloading kymatio-0.3.0-py3-none-any.whl (87 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.6/87.6 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)
Downloading configparser-7.2.0-py3-none-any.whl (17 kB)
Installing collected packages: appdirs, configparser, kymatio
Successfully installed appdirs-1.4.4 configparser-7.2.0 kymatio-0.3.0


In [2]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes,kernel_size=1,stride=stride,bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out



class ResNet(nn.Module):
    def __init__(self, scattering_output_channels, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.scat_channels = scattering_output_channels
        self.K = K
        self.deconv1 = nn.ConvTranspose2d(in_channels=self.scat_channels, out_channels=self.scat_channels, groups=self.scat_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(scattering_output_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layer = []
        for s in strides:
            layer.append(block(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layer)

    def forward(self, x):
        x = x.view(x.size(0), -1, 16, 16)
        x = self.deconv1(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet18(scattering_output_channels):
    return ResNet(scattering_output_channels, BasicBlock, [2, 2, 2, 2])


In [3]:
scattering = Scattering2D(J=1, shape=(32, 32), max_order=1).to(device)
scattering_output_channels = 27

In [4]:
def get_model_summary(model):
    num_params = sum(p.numel() for p in model.parameters())
    total_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
    size_mb = total_bytes / (1024 ** 2)
    return num_params, size_mb

total_params, model_size_mb = get_model_summary(ResNet18(scattering_output_channels))
print(f"Total Parameters: {total_params:,}")
print(f"Model Size: {model_size_mb:.2f} MB")

Total Parameters: 11,187,921
Model Size: 42.68 MB


In [5]:
def calculate_accuracy(model, dataloader, device):
    model.eval() # put in evaluation mode,  turn off Dropout, BatchNorm uses learned statistics
    total_correct = 0
    total_images = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            images = normalize(images)
            outputs = model(scattering(images))
            predictions = torch.argmax(outputs, dim=-1)
            total_images += labels.size(0)
            total_correct += (predictions == labels).sum().item()

    model_accuracy = total_correct / total_images * 100
    return model_accuracy


Split data set into train-validation-test.
We are using 80% train, 20% validation split

In [6]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)

# 80/20% split
train_val_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)


targets = np.array(train_val_set.targets)

indices = np.arange(len(targets))
train_indices, remaining_indices = train_test_split(
    indices,
    train_size=TRAIN_SIZE,
    stratify=targets,
    random_state=42
)

validation_indices, _ = train_test_split(
    remaining_indices,
    train_size=VALIDATION_SIZE,
    stratify=targets[remaining_indices],
    random_state=42
)

trainset = Subset(train_val_set, train_indices)
valset = Subset(train_val_set, validation_indices)
print(f"Original size: {len(train_val_set)}")
print(f"Train size: {len(trainset)}")
print(f"Val size: {len(valset)}")

from collections import Counter 
subset_labels = [targets[i] for i in train_indices]
print("Samples per class", Counter(subset_labels))
subset_labels = [targets[i] for i in validation_indices]
print("Samples per class", Counter(subset_labels))

testset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)


100%|██████████| 170M/170M [00:05<00:00, 29.0MB/s]


Original size: 50000
Train size: 1000
Val size: 5000
Samples per class Counter({np.int64(7): 100, np.int64(1): 100, np.int64(8): 100, np.int64(4): 100, np.int64(3): 100, np.int64(9): 100, np.int64(5): 100, np.int64(0): 100, np.int64(2): 100, np.int64(6): 100})
Samples per class Counter({np.int64(1): 500, np.int64(9): 500, np.int64(2): 500, np.int64(4): 500, np.int64(6): 500, np.int64(5): 500, np.int64(0): 500, np.int64(3): 500, np.int64(8): 500, np.int64(7): 500})


In [7]:
# Hyperparamters
batch_size = 128

lr = 0.1
momentum = 0.9
weight_decay = 5e-4

T_max = 200

n_epochs = 1 if DEBUG else 200

print_progress_every = 1
val_accuracy_storing_threshold = 40


In [8]:

mean = torch.tensor([0.4914, 0.4822, 0.4465]).to(device)
std = torch.tensor([0.2023, 0.1994, 0.2010]).to(device)
normalize = K.Normalize(mean=mean, std=std)
# define a sequence of augmentations
aug_list = AugmentationSequential(
    K.RandomHorizontalFlip(p=0.5),
    K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.2),
    K.RandomResizedCrop(size=(32,32), scale=(0.7, 1.0), p=0.5),
    normalize,
    same_on_batch=False
).to(device)


trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True)
model = ResNet18(scattering_output_channels).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)

stats = {
    'total_training_time': 0,
    'loss': [],
    'time_per_epoch': [],
    'total_time_per_epoch': [],
    'val_accuracy': [],
    'max_val_accuracy': 0,
    'allocated_memory': [], # Memory currently used by Tensors
    'reserved_memory': [], # Memory held by the PyTorch caching allocator
}

start_time = time.time()
for epoch in range(n_epochs):
    model.train()
    iteration_losses = []
    epoch_start_time = time.time()
    for inputs, targets in trainloader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        inputs = aug_list(inputs)

        outputs = model(scattering(inputs))

        optimizer.zero_grad()
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        iteration_losses.append(loss.item())

    scheduler.step()
    epoch_end_time = time.time()

    model.eval()
    val_accuracy = calculate_accuracy(model, valloader, device)

    # Track stats
    if (epoch % 1) == 0:
        stats['loss'].append(
            np.mean(iteration_losses)
        )
        stats['val_accuracy'].append(
            val_accuracy
        )
        stats['allocated_memory'].append(torch.cuda.memory_allocated())
        stats['reserved_memory'].append(torch.cuda.memory_reserved())
        stats['time_per_epoch'].append(epoch_end_time - epoch_start_time)
        stats['total_time_per_epoch'].append(time.time() - start_time)

    # Store best model
    if (val_accuracy > stats['max_val_accuracy']):
        if (val_accuracy > val_accuracy_storing_threshold):
            stats['max_val_accuracy'] = val_accuracy
            print('==> Saving model ...')
            state = {
                'net': model.state_dict(),
                'epoch': epoch,
                'acc':val_accuracy
            }
            save_path = checkpoint_dir / f"{MODEL_NAME}_max_acc.pth"
            torch.save(state, save_path)

    if DEBUG:
        print('==> Saving model ... DEBUG')
        state = {
            'net': model.state_dict(),
            'epoch': epoch,
            'acc':val_accuracy
        }
        save_path = checkpoint_dir / f"{MODEL_NAME}_max_acc.pth"
        torch.save(state, save_path)
        
    # Print progress
    if (epoch % print_progress_every) == 0:
        print(f"Epoch {epoch} Loss {stats['loss'][-1]:.3f} Val Acc {stats['val_accuracy'][-1]:.3f}")



Epoch 0 Loss 5.005 Val Acc 10.000
Epoch 1 Loss 3.791 Val Acc 10.000
Epoch 2 Loss 3.121 Val Acc 9.820
Epoch 3 Loss 2.719 Val Acc 9.260
Epoch 4 Loss 2.421 Val Acc 11.660
Epoch 5 Loss 2.197 Val Acc 15.540
Epoch 6 Loss 2.111 Val Acc 17.660
Epoch 7 Loss 2.036 Val Acc 22.160
Epoch 8 Loss 2.043 Val Acc 23.560
Epoch 9 Loss 1.975 Val Acc 25.820
Epoch 10 Loss 1.931 Val Acc 26.460
Epoch 11 Loss 1.911 Val Acc 24.640
Epoch 12 Loss 1.831 Val Acc 27.720
Epoch 13 Loss 1.798 Val Acc 31.040
Epoch 14 Loss 1.813 Val Acc 27.780
Epoch 15 Loss 1.749 Val Acc 32.160
Epoch 16 Loss 1.738 Val Acc 32.540
Epoch 17 Loss 1.683 Val Acc 34.100
Epoch 18 Loss 1.649 Val Acc 30.680
Epoch 19 Loss 1.621 Val Acc 34.500
Epoch 20 Loss 1.629 Val Acc 32.180
Epoch 21 Loss 1.542 Val Acc 31.720
Epoch 22 Loss 1.550 Val Acc 33.660
Epoch 23 Loss 1.491 Val Acc 36.760
Epoch 24 Loss 1.458 Val Acc 33.960
Epoch 25 Loss 1.355 Val Acc 36.900
Epoch 26 Loss 1.353 Val Acc 34.400
Epoch 27 Loss 1.411 Val Acc 31.680
Epoch 28 Loss 1.317 Val Acc 34.3

In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = ResNet18(scattering_output_channels).to(device)
checkpoint = torch.load(checkpoint_dir / f"{MODEL_NAME}_max_acc.pth", map_location=device)
model.load_state_dict(checkpoint['net'])
model.eval()
total_params, model_size_mb = get_model_summary(ResNet18(scattering_output_channels))
stats['total_params'] = total_params
stats['model_size_mb'] = model_size_mb
stats['train_acc'] = calculate_accuracy(model, trainloader, device)
stats['val_acc'] = calculate_accuracy(model, valloader, device)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)
stats['test_acc'] = calculate_accuracy(model, testloader, device)

with open(training_stats_dir / f'{MODEL_NAME}_stats.pkl', 'wb') as file:
    pickle.dump(stats, file)

print(f'Final test accuracy is: {stats['test_acc']}')

Final test accuracy is: 49.85
