<a href="https://colab.research.google.com/github/issam9/alexnet-implementation-pytorch/blob/main/AlexNet_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [98]:
import torch
import torch.nn as nn 
import torchvision 
import torchvision.transforms as tfms
from fastai.vision import *
import tqdm

In [85]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [87]:
# imagenette is a small version of imagenet it contains only 10 classes
path = untar_data(URLs.IMAGENETTE)
path.ls()

[PosixPath('/root/.fastai/data/imagenette2/val'),
 PosixPath('/root/.fastai/data/imagenette2/train')]

In [88]:
transform = tfms.Compose(
    [tfms.ToTensor(),
     tfms.Resize(227),
     tfms.CenterCrop(227),
     tfms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # imagenet stats

In [89]:
trainset = torchvision.datasets.ImageFolder(path/'train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

validset = torchvision.datasets.ImageFolder(path/'val', transform=transform)
validloader = torch.utils.data.DataLoader(validset, batch_size=64, shuffle=False)

In [97]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()
        # input should be of size (3 x 227 x 227)
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4), #(96 x 55 x 55)
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2), 
            nn.MaxPool2d(kernel_size=3, stride=2), #(96 x 27 x 27)
            nn.Conv2d(96, 256, 5, padding=2), # (256 x 27 x 27)
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(5, alpha=1e-4, beta=0.75, k=2),
            nn.MaxPool2d(3, stride=2), # (256 x 13 x 13)
            nn.Conv2d(256, 384, kernel_size=3, padding=1), #(384 x 13 x 13)
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1), #(384 x 13 x 13)
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1), # (256 x 13 x 13)
            nn.MaxPool2d(3, stride=2), #(256 x 6 x 6)
            nn.Flatten(),
        )

      

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5), 
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x:torch.Tensor):
      
        x = self.features(x)
        x = self.classifier(x)
        return x 

In [91]:
alexnet = AlexNet(10).to(device)

In [93]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(alexnet.parameters(), lr=0.001)

In [94]:
train_len = len(trainloader)
valid_len = len(validloader)

In [None]:
for epoch in range(20):
  print(f'training on epoch {epoch}...')
  alexnet.train()
  
  sum_loss = 0.0
  for imgs, labels in tqdm.tqdm(trainloader, position=0):

    imgs, labels = imgs.to(device), labels.to(device)

    #zero the gradients
    optimizer.zero_grad()

    output = alexnet(imgs)

    #calculate loss
    train_loss = criterion(output, labels)

    # backward propagation
    train_loss.backward()
    optimizer.step()

    sum_loss+=train_loss.item()
  
  total = 0
  correct = 0
  with torch.no_grad():
    _, preds = torch.max(output, dim=1)
    total+=labels.size(0)
    correct += torch.sum(preds == labels).item()
    print(f'\n epoch : {epoch+1} \t training loss : {sum_loss/train_len} \t training accuracy : {correct/total} ')


  alexnet.eval()
  total = 0
  correct = 0
  sum_loss = 0.0
  with torch.no_grad():
    for imgs, labels in tqdm.tqdm(validloader, position=0):
      imgs, labels = imgs.to(device), labels.to(device)
      output = alexnet(imgs)
      valid_loss = criterion(output, labels)

      sum_loss+=valid_loss.item()

      _, preds = torch.max(output, dim=1)
      total += labels.size(0)
      correct += torch.sum(preds == labels).item()

  print(f'\n epoch : {epoch+1} \t valid loss : {sum_loss/valid_len} \t valid accuracy : {correct/total} ')