# AlexNet implementation

In [41]:
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm

In [35]:
class AlexNet(nn.Module):
  def __init__(self, in_channels=3, classes=10): #3 input channels because we use RGB, and we'll have 1000 output classes (this is given by the dataset we're using)
    super().__init__()
    # Initialize each layer
    self.c1 = nn.Conv2d(in_channels=in_channels, out_channels=96, kernel_size=11, stride=4, padding=0)
    self.c2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2)
    self.c3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)
    self.c4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
    self.c5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)
    self.fc1 = nn.Linear(in_features=6*6*256, out_features=4096)
    self.fc2 = nn.Linear(in_features=4096, out_features=4096)
    self.fc3 = nn.Linear(in_features=4096, out_features=classes)
    self.norm = nn.LocalResponseNorm(k=2, size=5, alpha=1e-4, beta=0.75)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p=0.5)

    # Initialize weights
    self.weight_init()

  def forward(self, x):
      x = self.maxpool(self.norm(self.relu(self.c1(x))))
      x = self.maxpool(self.norm(self.relu(self.c2(x))))
      x = self.relu(self.c3(x))
      x = self.relu(self.c4(x))
      x = self.maxpool(self.relu(self.c5(x)))
      x = torch.flatten(x,1)
      x = self.relu(self.fc1(self.dropout(x)))
      x = self.relu(self.fc2(self.dropout(x)))
      x = self.fc3(x)
      return x

  def weight_init(self):
    bias = [1,3,4,5,6,7]
    for i, layer in enumerate(self.modules()):
      if layer is nn.Conv2d or layer is nn.Linear:
        nn.init.normal_(mean=0, std=0.01)
        if i in bias:
          nn.init.constant_(layer, 1)
        else:
          nn.init.constant_(layer, 0)



### Training

In [47]:
# Load CIFAR-10 train dataset
data_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
cifar10_train_dataset = torchvision.datasets.CIFAR10(root='./data', download=True, train=True, transform=data_transforms)
train_dataloader = torch.utils.data.DataLoader(cifar10_train_dataset,
                                         batch_size = 32,
                                         shuffle=True)

# Load AlexNet
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VGG16_A().to(device)
model.train()

# Define Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

epoch_nums = 2 # small just for working demonstration
# Train the network
for epoch in range(epoch_nums):

  current_loss = 0.0
  for i, data in tqdm(enumerate(train_dataloader, start = 0), unit="batch", total=len(train_dataloader), desc=f"Epoch {epoch}"):
    # get inputs -> data is a list of [inputs, labels]
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)

    # reset gradients
    optimizer.zero_grad()

    # forward pass + backward + optimize
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # compute statistics
    current_loss += loss.item()
    if i % 100 == 99:    # print every 100 mini-batches
            print(f'current loss: {current_loss / 100:.3f}')
            current_loss = 0.0

print("Training done! Saving trained model to: './trained_model.pth'")
torch.save(model.state_dict(), './trained_model.pth')
print("Saved.")






Files already downloaded and verified


Epoch 0:   6%|▋         | 101/1563 [00:10<02:58,  8.19batch/s]

current loss: 2.303


Epoch 0:  13%|█▎        | 201/1563 [00:21<02:24,  9.43batch/s]

current loss: 2.300


Epoch 0:  19%|█▉        | 301/1563 [00:32<02:11,  9.58batch/s]

current loss: 2.185


Epoch 0:  26%|██▌       | 401/1563 [00:43<02:00,  9.67batch/s]

current loss: 2.031


Epoch 0:  32%|███▏      | 501/1563 [00:54<01:50,  9.63batch/s]

current loss: 1.875


Epoch 0:  38%|███▊      | 601/1563 [01:05<01:39,  9.67batch/s]

current loss: 1.772


Epoch 0:  45%|████▍     | 701/1563 [01:16<01:40,  8.59batch/s]

current loss: 1.704


Epoch 0:  51%|█████     | 801/1563 [01:26<01:33,  8.15batch/s]

current loss: 1.653


Epoch 0:  58%|█████▊    | 901/1563 [01:37<01:11,  9.26batch/s]

current loss: 1.591


Epoch 0:  64%|██████▍   | 1001/1563 [01:48<00:59,  9.41batch/s]

current loss: 1.540


Epoch 0:  70%|███████   | 1101/1563 [01:59<00:49,  9.41batch/s]

current loss: 1.515


Epoch 0:  77%|███████▋  | 1201/1563 [02:10<00:37,  9.60batch/s]

current loss: 1.474


Epoch 0:  83%|████████▎ | 1301/1563 [02:21<00:27,  9.59batch/s]

current loss: 1.411


Epoch 0:  90%|████████▉ | 1401/1563 [02:32<00:19,  8.33batch/s]

current loss: 1.381


Epoch 0:  96%|█████████▌| 1501/1563 [02:43<00:06,  9.57batch/s]

current loss: 1.354


Epoch 0: 100%|██████████| 1563/1563 [02:49<00:00,  9.20batch/s]
Epoch 1:   6%|▋         | 101/1563 [00:11<02:35,  9.40batch/s]

current loss: 1.304


Epoch 1:  13%|█▎        | 201/1563 [00:22<02:44,  8.27batch/s]

current loss: 1.255


Epoch 1:  19%|█▉        | 301/1563 [00:33<02:31,  8.33batch/s]

current loss: 1.185


Epoch 1:  26%|██▌       | 401/1563 [00:43<02:05,  9.22batch/s]

current loss: 1.152


Epoch 1:  32%|███▏      | 501/1563 [00:54<01:50,  9.64batch/s]

current loss: 1.166


Epoch 1:  38%|███▊      | 601/1563 [01:05<01:42,  9.38batch/s]

current loss: 1.136


Epoch 1:  45%|████▍     | 701/1563 [01:16<01:29,  9.59batch/s]

current loss: 1.118


Epoch 1:  51%|█████     | 801/1563 [01:27<01:20,  9.48batch/s]

current loss: 1.103


Epoch 1:  58%|█████▊    | 901/1563 [01:38<01:21,  8.08batch/s]

current loss: 1.061


Epoch 1:  64%|██████▍   | 1001/1563 [01:49<00:59,  9.39batch/s]

current loss: 1.047


Epoch 1:  70%|███████   | 1101/1563 [02:00<00:49,  9.41batch/s]

current loss: 1.013


Epoch 1:  77%|███████▋  | 1201/1563 [02:11<00:38,  9.51batch/s]

current loss: 0.996


Epoch 1:  83%|████████▎ | 1301/1563 [02:22<00:28,  9.34batch/s]

current loss: 0.970


Epoch 1:  90%|████████▉ | 1401/1563 [02:33<00:17,  9.31batch/s]

current loss: 0.958


Epoch 1:  96%|█████████▌| 1501/1563 [02:44<00:07,  8.86batch/s]

current loss: 0.971


Epoch 1: 100%|██████████| 1563/1563 [02:50<00:00,  9.14batch/s]


Training done! Saving trained model to: './trained_model'
Saved.


### Testing

In [50]:
# Load CIFAR-10 test dataset
data_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((227,227)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
cifar10_test_dataset = torchvision.datasets.CIFAR10(root='./data', download=True, train=False, transform=data_transforms)
test_dataloader = torch.utils.data.DataLoader(cifar10_test_dataset,
                                         batch_size = 32,
                                         shuffle=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AlexNet().to(device)
model.load_state_dict(torch.load("./trained_model.pth"))

correct = 0
total = 0
with torch.no_grad():
  for data in tqdm(test_dataloader):
    images, labels = data
    images, labels = images.to(device), labels.to(device)

    outputs = model(images)

    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f'\nAccuracy of AlexNet on CIFAR10: {100 * correct // total} %')

Files already downloaded and verified


100%|██████████| 313/313 [00:24<00:00, 12.67it/s]


Accuracy of AlexNet on CIFAR10: 68 %



