### set-up

In [13]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm   
from torch import optim 
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=64, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=64, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))

batch_size = 64

num_batches_per_epoch = len(train_loader.dataset) // batch_size

num_of_epochs = 50000 // num_batches_per_epoch

print("Number of batches per epoch: ", num_batches_per_epoch)
print("Number of epochs: ", num_of_epochs)

cuda
Files already downloaded and verified
Files already downloaded and verified
> Size of training dataset 50000
> Size of test dataset 10000
Number of classes:  100
Number of batches per epoch:  781
Number of epochs:  64


### Model

In [14]:
class EnhancedLightweightAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        # Channel attention
        self.channel_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 16, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 16, channels, 1),
            nn.Sigmoid()
        )
        # Spatial attention
        self.spatial_attn = nn.Sequential(
            nn.Conv2d(channels, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        channel_out = self.channel_attn(x) * x
        spatial_out = self.spatial_attn(channel_out) * channel_out
        return spatial_out

class QualityEnhancementBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.attention = EnhancedLightweightAttention(out_channels)
        self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        
    def forward(self, x):
        identity = self.residual_conv(x)
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.attention(out)
        return out + identity

class ImageQualityNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Initial feature extraction
        self.init_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Encoder path
        self.enc1 = QualityEnhancementBlock(64, 128)
        self.enc2 = QualityEnhancementBlock(128, 256)
        self.enc3 = QualityEnhancementBlock(256, 512)
        
        # Bridge
        self.bridge = QualityEnhancementBlock(512, 512)
        
        # Decoder path with skip connections
        self.dec3 = QualityEnhancementBlock(1024, 256)  # 512 + 512 input channels
        self.dec2 = QualityEnhancementBlock(512, 128)   # 256 + 256 input channels
        self.dec1 = QualityEnhancementBlock(256, 64)    # 128 + 128 input channels
        
        # Final quality refinement
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, kernel_size=1)
        )
        
        # Residual connection for the entire network
        self.global_residual = nn.Identity()
        
    def forward(self, x):
        # Store input for global residual
        input_img = x
        
        # Initial features
        x1 = self.init_conv(x)
        
        # Encoder path
        e1 = self.enc1(x1)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        
        # Bridge
        bridge = self.bridge(e3)
        
        # Decoder path with skip connections
        d3 = self.dec3(torch.cat([bridge, e3], dim=1))
        d2 = self.dec2(torch.cat([d3, e2], dim=1))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))
        
        # Final refinement
        out = self.final_conv(d1)
        
        # Global residual connection
        return out + self.global_residual(input_img)
    
def show_images(images, labels, predictions=None, class_names=None):
    num_images = min(25, len(images))  # Show up to 25 images in a 5x5 grid
    plt.figure(figsize=(15, 15))
    for i in range(num_images):
        plt.subplot(5, 5, i + 1)
        img = images[i].cpu().numpy().transpose((1, 2, 0))
        # Denormalize the images
        mean = np.array([0.5071, 0.4867, 0.4408])
        std = np.array([0.2675, 0.2565, 0.2761])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        
        plt.imshow(img)
        if predictions is not None and class_names is not None:
            title = f'True: {class_names[labels[i]]}\nPred: {class_names[predictions[i]]}'
        elif class_names is not None:
            title = f'True: {class_names[labels[i]]}'
        else:
            title = f'Class: {labels[i]}'
        plt.title(title, fontsize=8)
        plt.axis('off')
    plt.tight_layout()
    plt.show()
    
R = ImageQualityNet().to(device)

print("Number of parameters in model: ", sum(p.numel() for p in R.parameters()))

Number of parameters in model:  13970366


### Training

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(R.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_of_epochs)

R = R.to(device)

for epoch in range(num_of_epochs):
    R.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = R(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_of_epochs}], Step [{i+1}/{len(train_loader)}], '
                    f'Loss: {running_loss/100:.3f}, '
                    f'Acc: {100.*correct/total:.2f}%')
            running_loss = 0.0
    
    scheduler.step()

    # Display images and predictions at the end of each epoch
    if (epoch + 1) % 1 == 0:  # Change this number to control visualization frequency
        R.eval()
        with torch.no_grad():
            images, labels = next(iter(train_loader))
            images, labels = images[:9].to(device), labels[:9].to(device)
            outputs = R(images)
            _, predictions = outputs.max(1)
            show_images(images, labels.cpu(), predictions.cpu(), class_names)
        R.train()

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of size: : [64]