In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,random_split
import torch.nn.functional as F

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        """
        Residual block with optional downsampling for the shortcut connection.

        Args:
        - in_channels (int): Number of input channels.
        - out_channels (int): Number of output channels.
        - stride (int): Stride for the convolutional layers (default: 1).
        - downsample (nn.Module): Downsampling layer to match dimensions (default: None).
        """
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out += identity  # Add the residual connection
        out = self.relu(out)

        return out


In [3]:
# ResNet-18 Definition
class ResNet18(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNet18, self).__init__()
        self.in_channels = 64

        # Initial convolutional layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Residual layers
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)

        # Fully connected layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, out_channels, blocks, stride):
        """
        Creates a ResNet layer by stacking residual blocks.

        Args:
        - out_channels (int): Number of output channels for the layer.
        - blocks (int): Number of residual blocks in the layer.
        - stride (int): Stride for the first block in the layer.

        Returns:
        - nn.Sequential: Stacked residual blocks.
        """
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

        layers = []
        layers.append(ResidualBlock(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
# Data preprocessing for FashionMNIST
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to fit models like AlexNet/ResNet
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3 channels for compatibility
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize to [-1, 1] range
])

# Load the Caltech101 dataset
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_set, val_set = random_split(train_dataset, [0.9, 0.1])
train_loader = DataLoader(train_set, batch_size=256, shuffle=True)
val_loader = DataLoader(val_set, batch_size=256, shuffle=False)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:03<00:00, 8.59MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 480kB/s]

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz





Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:00<00:00, 4.63MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 22.4MB/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [9]:
# Initialize model, loss function, and optimizer
model = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for idx,(images, labels) in enumerate(train_loader):#train
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if(idx % 5 == 0):
          print(f'batch:{idx} , batch loss:{loss.item()}\n')
        total_loss += loss.item()

    acc=0
    for idx,(images, labels) in enumerate(val_loader): #validation
        images, labels = images.to(device), labels.to(device)
        outputs = torch.argmax(model(images),dim=1)
        comp=torch.eq(labels,outputs).float().to('cpu')
        acc+=torch.sum(comp)/len(comp)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Total loss: {total_loss / len(train_loader):.4f} , Accuracy : {acc/(idx+1)} %')

print("Training complete!")

batch:0 , batch loss:0.7131824493408203

batch:5 , batch loss:0.4327884018421173

batch:10 , batch loss:0.6984161138534546

batch:15 , batch loss:0.6316965222358704

batch:20 , batch loss:0.3843723237514496

batch:25 , batch loss:0.6037749648094177

batch:30 , batch loss:0.7975050210952759

batch:35 , batch loss:0.5167608857154846

batch:40 , batch loss:0.458487331867218

batch:45 , batch loss:0.36329928040504456

batch:50 , batch loss:0.9189883470535278

batch:55 , batch loss:0.4188882112503052

batch:60 , batch loss:0.6051297187805176

batch:65 , batch loss:0.7752204537391663

batch:70 , batch loss:0.4102668762207031

batch:75 , batch loss:0.4014722406864166

batch:80 , batch loss:0.38441935181617737

batch:85 , batch loss:0.4529407024383545

batch:90 , batch loss:0.7015796899795532

batch:95 , batch loss:0.5108770132064819

batch:100 , batch loss:0.4466085135936737

batch:105 , batch loss:0.6917083859443665

batch:110 , batch loss:0.40941449999809265

batch:115 , batch loss:1.153517