In [1]:
import torch
from torch import nn
import sys
sys.path.insert(0, "..")
from helper import train, test
from data import get_dataloader
import torch.nn.functional as F
from time import time

In [2]:
device = "cuda" if torch.cuda.is_available() else "mps"

# LeNet-5
![LeNet-5](https://cdn.analyticsvidhya.com/wp-content/uploads/2021/03/Screenshot-from-2021-03-18-12-52-17.png)

*We are multiplying all the channels and number of neurons by 3, since the first LeNet was built on grayscale images and CIFAR10 is RGB.*

In [3]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6*3, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6*3, out_channels=16*3, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16*3, out_channels=120*3, kernel_size=5, stride=1),
            nn.Tanh()
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(in_features=120*3, out_features=84*3),
            nn.Tanh(),
            nn.Linear(in_features=84*3, out_features=10)
        )
        

    def forward(self, x):
        
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probs = F.softmax(logits, dim=1)
        return logits

In [5]:
# get the data
train_dl, test_dl = get_dataloader("cifar", batch_size=64)
# Training the model
model = LeNet5().to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

epochs = 10
start = time()
for t in range(epochs):
    print(f"Epoch {t+1}\n---")
    train(train_dl, model, loss_fn, optimizer, device)
    test(test_dl, model, loss_fn, device)
print(f"Total time taken: {(time()-start):>0.1f} seconds")
    


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100.0%


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Epoch 1
---
loss: 2.303998  [    0/50000]
loss: 1.818795  [ 6400/50000]
loss: 1.603588  [12800/50000]
loss: 1.666484  [19200/50000]
loss: 1.841048  [25600/50000]
loss: 1.821961  [32000/50000]
loss: 1.853914  [38400/50000]
loss: 1.601500  [44800/50000]
Train Error: 
 Accuracy: 37.2%
Test Error: 
 Accuracy: 41.8%, Avg loss: 1.642597 

Epoch 2
---
loss: 1.629886  [    0/50000]
loss: 1.487332  [ 6400/50000]
loss: 1.377378  [12800/50000]
loss: 1.478075  [19200/50000]
loss: 1.554031  [25600/50000]
loss: 1.682760  [32000/50000]
loss: 1.644889  [38400/50000]
loss: 1.487047  [44800/50000]
Train Error: 
 Accuracy: 45.3%
Test Error: 
 Accuracy: 48.4%, Avg loss: 1.448469 

Epoch 3
---
loss: 1.502871  [    0/50000]
loss: 1.358214  [ 6400/50000]
loss: 1.184022  [12800/50000]
loss: 1.434438  [19200/50000]
loss: 1.385777  [25600/50000]
loss: 1.511878  [32000/50000]
loss: 1.557158  [38400/50000]
loss: 1.372005  [44800/