In [1]:
import torch
import torch.nn as nn
# import torch.nn.functional as F
# from torch.ao.quantization import QuantStub, DeQuantStub
# import torch.optim as optim
# from torchvision import datasets, transforms
# from torch.utils.data import DataLoader
from collections import OrderedDict
import torch.quantization as tq

import helper

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device={device}")

Device=cuda


In [2]:
class AlexNetCIFAR10(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        self.quant = tq.QuantStub()
        self.dequant = tq.DeQuantStub()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            # nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            # nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            # nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            # nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

        self.classifier = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            # nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                if m.bias is not None:
                    if isinstance(m, nn.Conv2d) and m.out_channels in [192, 384, 256]:  # layers with bias=1
                        nn.init.constant_(m.bias, 1)
                    else:
                        nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        x = self.dequant(x)
        return x
    
    def load_model(self, path='alexnet_cifar10.pth',device='cpu'):
        state_dict = torch.load(path,map_location=device)

        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if k.startswith('module.'):
                k = k[len('module.'):]
            new_state_dict[k] = v

        self.load_state_dict(new_state_dict)
        self.to(device)
        self.eval()

        print(f"Model loaded from {path}")
        # print(self)

    def save_model(self, path='alexnet_cifar10.pth'):
        torch.save(self.state_dict(), path)
        print(f"Model saved to {path}")
    
    # def forward(self, x):
    #     x = self.quant(x)
    #     x = self.features(x)
    #     x = torch.flatten(x, 1)
    #     x = self.classifier(x)
    #     x = self.dequant(x)
    #     return x

    

In [3]:
train_loader, test_loader = helper.load_dataset()

Loading the CIFAR10 dataset
Loaded train data: 50000 total samples, 782 batches
Loaded test data: 10000 total samples, 157 batches


In [4]:
model_fp32 = AlexNetCIFAR10()
# model_fp32.load_model()

total_params = sum(p.numel() for p in model_fp32.parameters())
print(f"Total parameters: {total_params}")

trainable_params = sum(p.numel() for p in model_fp32.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params}")

Total parameters: 1046858
Trainable parameters: 1046858


In [None]:
fp32_metrics = helper.train_model(model=model_fp32,train_loader=train_loader,test_loader=test_loader,device=device,epochs=100)