In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from tqdm import tqdm

# Define the Model with SEBlock and Enhanced Architecture
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)

    def forward(self, x):
        b, c, _, _ = x.size()
        y = x.view(b, c, -1).mean(2)  # Global Average Pooling
        y = torch.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y))
        y = y.view(b, c, 1, 1)
        return x * y

class CIFAR10EnhancedModel(nn.Module):
    def __init__(self, num_classes=10):
        super(CIFAR10EnhancedModel, self).__init__()
        def dw_sep_conv(in_ch, out_ch, stride=1):
            return nn.Sequential(
                nn.Conv2d(in_ch, in_ch, 3, stride, 1, groups=in_ch, bias=False),
                nn.BatchNorm2d(in_ch),
                nn.GELU(),
                nn.Conv2d(in_ch, out_ch, 1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.GELU(),
                SEBlock(out_ch)
            )

        def inverted_residual(in_ch, out_ch, expansion, stride):
            mid_ch = in_ch * expansion
            return nn.Sequential(
                nn.Conv2d(in_ch, mid_ch, 1, bias=False),
                nn.BatchNorm2d(mid_ch),
                nn.GELU(),
                dw_sep_conv(mid_ch, out_ch, stride),
            ) if stride > 1 or in_ch != out_ch else nn.Sequential(
                nn.Conv2d(in_ch, mid_ch, 1, bias=False),
                nn.BatchNorm2d(mid_ch),
                nn.GELU(),
                dw_sep_conv(mid_ch, out_ch, stride),
            )

        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.GELU(),
        )

        self.features = nn.Sequential(
            dw_sep_conv(32, 64, 1),
            inverted_residual(64, 128, 4, 2),
            inverted_residual(128, 128, 4, 1),
            inverted_residual(128, 256, 4, 2),
            dw_sep_conv(256, 512, 2),
        )

        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.features(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
!sudo apt update
!sudo apt install nvidia-driver-470

[33m0% [Working][0m            Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Get:3 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:5 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:6 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:8 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Hit:9 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:10 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1,224 kB]
Get:11 https://r2u.stat.illinois.edu/ubuntu jammy/main amd64 Packages [2,619 kB]
Get:12 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:13 http://security.ubuntu.com/

In [2]:
#when using cpu and no new changes
import torch
import torchvision
import torchvision.transforms as transforms
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_model():
    # Data Augmentation and Preprocessing
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

    # Model, Loss, Optimizer, Scheduler
    model = CIFAR10EnhancedModel().to('cpu')
    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=3e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=50)

    # Training Loop
    for epoch in range(50):
        model.train()
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for inputs, targets in pbar:
            inputs, targets = inputs.to('cpu'), targets.to('cpu')

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(loss=running_loss / len(train_loader))

        scheduler.step()

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to('cpu'), targets.to('cpu')
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

        acc = 100.0 * correct / total
        print(f"Validation Accuracy after Epoch {epoch+1}: {acc:.2f}%")

    # Save Model
    torch.save(model.state_dict(), "cifar10_light_model.pth")
    print("Model training complete. Saved as 'cifar10_light_model.pth'")

if __name__ == "__main__":
    train_model()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:13<00:00, 12.6MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified




NameError: name 'CIFAR10LightModel' is not defined

In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_model():
    # Data Augmentation and Preprocessing
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    # Load CIFAR-10 Dataset
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

    # Model, Loss, Optimizer, Scheduler
    model = CIFAR10EnhancedModel().to('cuda')
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    scheduler = OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_loader), epochs=50)

    # Training Loop
    for epoch in range(50):
        model.train()
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for inputs, targets in pbar:
            inputs, targets = inputs.to('cuda'), targets.to('cuda')

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(loss=running_loss / len(train_loader))

        scheduler.step()

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to('cuda'), targets.to('cuda')
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

        acc = 100.0 * correct / total
        print(f"Validation Accuracy after Epoch {epoch+1}: {acc:.2f}%")

    # Save Model
    torch.save(model.state_dict(), "cifar10_enhanced_model.pth")
    print("Model training complete. Saved as 'cifar10_enhanced_model.pth'")

if __name__ == "__main__":
    train_model()

Files already downloaded and verified
Files already downloaded and verified


Epoch 1: 100%|██████████| 391/391 [01:15<00:00,  5.18it/s, loss=1.63]


Validation Accuracy after Epoch 1: 60.80%


Epoch 2: 100%|██████████| 391/391 [01:08<00:00,  5.74it/s, loss=1.31]


Validation Accuracy after Epoch 2: 68.70%


Epoch 3: 100%|██████████| 391/391 [01:07<00:00,  5.80it/s, loss=1.17]


Validation Accuracy after Epoch 3: 73.74%


Epoch 4: 100%|██████████| 391/391 [01:07<00:00,  5.82it/s, loss=1.09]


Validation Accuracy after Epoch 4: 78.38%


Epoch 5: 100%|██████████| 391/391 [01:06<00:00,  5.87it/s, loss=1.03]


Validation Accuracy after Epoch 5: 79.20%


Epoch 6: 100%|██████████| 391/391 [01:06<00:00,  5.88it/s, loss=0.992]


Validation Accuracy after Epoch 6: 81.02%


Epoch 7: 100%|██████████| 391/391 [01:06<00:00,  5.87it/s, loss=0.957]


Validation Accuracy after Epoch 7: 81.24%


Epoch 8: 100%|██████████| 391/391 [01:06<00:00,  5.85it/s, loss=0.931]


Validation Accuracy after Epoch 8: 83.24%


Epoch 9: 100%|██████████| 391/391 [01:07<00:00,  5.81it/s, loss=0.916]


Validation Accuracy after Epoch 9: 84.55%


Epoch 10: 100%|██████████| 391/391 [01:07<00:00,  5.79it/s, loss=0.894]


Validation Accuracy after Epoch 10: 84.58%


Epoch 11: 100%|██████████| 391/391 [01:07<00:00,  5.78it/s, loss=0.878]


Validation Accuracy after Epoch 11: 85.64%


Epoch 12: 100%|██████████| 391/391 [01:08<00:00,  5.75it/s, loss=0.863]


Validation Accuracy after Epoch 12: 86.10%


Epoch 13: 100%|██████████| 391/391 [01:08<00:00,  5.74it/s, loss=0.849]


Validation Accuracy after Epoch 13: 85.88%


Epoch 14: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s, loss=0.836]


Validation Accuracy after Epoch 14: 87.38%


Epoch 15: 100%|██████████| 391/391 [01:08<00:00,  5.69it/s, loss=0.823]


Validation Accuracy after Epoch 15: 86.61%


Epoch 16: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s, loss=0.813]


Validation Accuracy after Epoch 16: 87.24%


Epoch 17: 100%|██████████| 391/391 [01:11<00:00,  5.46it/s, loss=0.8]


Validation Accuracy after Epoch 17: 87.07%


Epoch 18: 100%|██████████| 391/391 [01:10<00:00,  5.53it/s, loss=0.793]


Validation Accuracy after Epoch 18: 87.61%


Epoch 19: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s, loss=0.785]


Validation Accuracy after Epoch 19: 87.78%


Epoch 20: 100%|██████████| 391/391 [01:07<00:00,  5.82it/s, loss=0.779]


Validation Accuracy after Epoch 20: 88.27%


Epoch 21: 100%|██████████| 391/391 [01:07<00:00,  5.83it/s, loss=0.771]


Validation Accuracy after Epoch 21: 88.09%


Epoch 22: 100%|██████████| 391/391 [01:07<00:00,  5.81it/s, loss=0.768]


Validation Accuracy after Epoch 22: 88.71%


Epoch 23: 100%|██████████| 391/391 [01:07<00:00,  5.83it/s, loss=0.758]


Validation Accuracy after Epoch 23: 88.66%


Epoch 24: 100%|██████████| 391/391 [01:07<00:00,  5.78it/s, loss=0.753]


Validation Accuracy after Epoch 24: 88.71%


Epoch 25: 100%|██████████| 391/391 [01:06<00:00,  5.86it/s, loss=0.749]


Validation Accuracy after Epoch 25: 89.15%


Epoch 26: 100%|██████████| 391/391 [01:06<00:00,  5.87it/s, loss=0.739]


Validation Accuracy after Epoch 26: 89.08%


Epoch 27: 100%|██████████| 391/391 [01:06<00:00,  5.86it/s, loss=0.735]


Validation Accuracy after Epoch 27: 89.03%


Epoch 28: 100%|██████████| 391/391 [01:06<00:00,  5.86it/s, loss=0.74]


Validation Accuracy after Epoch 28: 88.65%


Epoch 29: 100%|██████████| 391/391 [01:07<00:00,  5.79it/s, loss=0.723]


Validation Accuracy after Epoch 29: 88.74%


Epoch 30: 100%|██████████| 391/391 [01:07<00:00,  5.81it/s, loss=0.723]


Validation Accuracy after Epoch 30: 89.66%


Epoch 31: 100%|██████████| 391/391 [01:08<00:00,  5.75it/s, loss=0.72]


Validation Accuracy after Epoch 31: 89.34%


Epoch 32: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s, loss=0.716]


Validation Accuracy after Epoch 32: 89.17%


Epoch 33: 100%|██████████| 391/391 [01:08<00:00,  5.72it/s, loss=0.711]


Validation Accuracy after Epoch 33: 89.22%


Epoch 34: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s, loss=0.708]


Validation Accuracy after Epoch 34: 89.18%


Epoch 35: 100%|██████████| 391/391 [01:08<00:00,  5.70it/s, loss=0.704]


Validation Accuracy after Epoch 35: 90.28%


Epoch 36: 100%|██████████| 391/391 [01:08<00:00,  5.71it/s, loss=0.701]


Validation Accuracy after Epoch 36: 89.93%


Epoch 37: 100%|██████████| 391/391 [01:08<00:00,  5.68it/s, loss=0.698]


Validation Accuracy after Epoch 37: 89.21%


Epoch 38: 100%|██████████| 391/391 [01:08<00:00,  5.71it/s, loss=0.69]


Validation Accuracy after Epoch 38: 89.24%


Epoch 39: 100%|██████████| 391/391 [01:08<00:00,  5.67it/s, loss=0.693]


Validation Accuracy after Epoch 39: 90.24%


Epoch 40: 100%|██████████| 391/391 [01:09<00:00,  5.65it/s, loss=0.683]


Validation Accuracy after Epoch 40: 90.16%


Epoch 41: 100%|██████████| 391/391 [01:09<00:00,  5.63it/s, loss=0.686]


Validation Accuracy after Epoch 41: 89.49%


Epoch 42: 100%|██████████| 391/391 [01:09<00:00,  5.63it/s, loss=0.679]


Validation Accuracy after Epoch 42: 89.97%


Epoch 43: 100%|██████████| 391/391 [01:09<00:00,  5.64it/s, loss=0.679]


Validation Accuracy after Epoch 43: 89.75%


Epoch 44: 100%|██████████| 391/391 [01:08<00:00,  5.70it/s, loss=0.678]


Validation Accuracy after Epoch 44: 89.79%


Epoch 45: 100%|██████████| 391/391 [01:09<00:00,  5.66it/s, loss=0.673]


Validation Accuracy after Epoch 45: 90.68%


Epoch 46: 100%|██████████| 391/391 [01:11<00:00,  5.46it/s, loss=0.675]


Validation Accuracy after Epoch 46: 90.06%


Epoch 47: 100%|██████████| 391/391 [01:10<00:00,  5.58it/s, loss=0.667]


Validation Accuracy after Epoch 47: 89.93%


Epoch 48: 100%|██████████| 391/391 [01:08<00:00,  5.72it/s, loss=0.664]


Validation Accuracy after Epoch 48: 90.19%


Epoch 49: 100%|██████████| 391/391 [01:07<00:00,  5.79it/s, loss=0.668]


Validation Accuracy after Epoch 49: 90.61%


Epoch 50: 100%|██████████| 391/391 [01:07<00:00,  5.82it/s, loss=0.661]


Validation Accuracy after Epoch 50: 89.63%
Model training complete. Saved as 'cifar10_enhanced_model.pth'


In [6]:
from fvcore.nn import FlopCountAnalysis
model = CIFAR10EnhancedModel()
inputs = torch.randn(1, 3, 32, 32)
flops = FlopCountAnalysis(model, inputs)
print(f"FLOPs: {flops.total() / 1e6:.2f}M")



FLOPs: 96.31M


In [7]:
torch.save(model.state_dict(), "cifar10_model.pth")
import os
print(f"Model Size: {os.path.getsize('cifar10_model.pth') / 1e6:.2f} MB")


Model Size: 2.41 MB


In [9]:
model =  CIFAR10EnhancedModel() # Initialize your model

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {total_params}")


Total Trainable Parameters: 584718


In [10]:
pip install torchsummary




In [5]:
!pip install fvcore

Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/50.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6 (from fvcore)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting iopath>=0.1.7 (from fvcore)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting portalocker (from iopath>=0.1.7->fvcore)
  Downloading portalocker-3.0.0-py3-none-any.whl.metadata (8.5 kB)
Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Downloading portalocker-3.0.0-py3-none-any.whl (19 kB)
Building wheels for collected packages: fvcore, iop