**AlexNet Architecture**

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,random_split
import torch.nn.functional as F

Build the network

In [3]:
# Define the AlexNet model
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2),  # Input: 1 channel (grayscale)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # Output: 64 channels
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # Output: 192 channels
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),   # Output: 256 channels
            nn.Dropout()
        )
        self.classifier = nn.Sequential(
            nn.Linear(256 * 3 * 3, 4096),  # Adjust input size accordingly
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10)  # 10 output classes for MNIST
        )

    def forward(self, x):
        x = self.features(x)
        # print("features",x.shape)
        x = x.view(x.size(0), -1)  # Flatten
        # x= nn.Flatten(x)
        # print("after flatten",x.shape)
        x = self.classifier(x)
        return x


In [4]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

load data

In [5]:
# Data loading and preprocessing
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to fit AlexNet input size
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_set, val_set = random_split(train_dataset, [0.9, 0.1])
train_loader = DataLoader(dataset=train_set, batch_size=2048, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=2048, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:12<00:00, 806kB/s] 


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 18.9MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.27MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 8.85MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [6]:
# Initialize model, loss function, and optimizer
model = AlexNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [10]:
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for idx,(images, labels) in enumerate(train_loader):#train
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if(idx % 5 == 0):
          print(f'batch:{idx} , batch loss:{loss.item()}\n')
        total_loss += loss.item()

    acc=0
    for idx,(images, labels) in enumerate(val_loader): #validation
        images, labels = images.to(device), labels.to(device)
        outputs = torch.argmax(model(images),dim=1)
        comp=torch.eq(labels,outputs).float().to('cpu')
        acc+=torch.sum(comp)/len(comp)
        # torch.cat([acc,torch.sum(comp)/len(comp)])

    # print(f'batch : {idx} , accuracy : {acc/(idx+1):.4f} %')
    print(f'Epoch [{epoch + 1}/{num_epochs}], Total loss: {total_loss / len(train_loader):.4f} , Accuracy : {acc/(idx+1)} %')

print("Training complete!")

batch:0 , batch loss:0.021643318235874176

batch:5 , batch loss:0.021922262385487556

batch:10 , batch loss:0.01290686335414648

batch:15 , batch loss:0.02091917209327221

batch:20 , batch loss:0.012219661846756935

batch:25 , batch loss:0.026842720806598663

Epoch [1/5], Total loss: 0.0244 , Accuracy : 0.9877614974975586 %
batch:0 , batch loss:0.013479269109666348

batch:5 , batch loss:0.021226637065410614

batch:10 , batch loss:0.028376800939440727

batch:15 , batch loss:0.016954727470874786

batch:20 , batch loss:0.0236023161560297

batch:25 , batch loss:0.023316070437431335

Epoch [2/5], Total loss: 0.0242 , Accuracy : 0.9871719479560852 %
batch:0 , batch loss:0.02218485251069069

batch:5 , batch loss:0.03618291765451431

batch:10 , batch loss:0.014803973957896233

batch:15 , batch loss:0.0226740725338459

batch:20 , batch loss:0.024812743067741394

batch:25 , batch loss:0.018595093861222267

Epoch [3/5], Total loss: 0.0198 , Accuracy : 0.9878230094909668 %
batch:0 , batch loss:0.0