# Train CNN Model 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.transforms import transforms

torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb45548d410>

In [2]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [3]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
    def forward(self, t):
        t = t
        t = F.max_pool2d(F.relu(self.conv1(t)), kernel_size=2, stride=2)
        t = F.max_pool2d(F.relu(self.conv2(t)), kernel_size=2, stride=2)
        t = F.relu(self.fc1(t.reshape(-1, 12*4*4)))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

In [4]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [5]:
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)

In [6]:
total_loss = 0
total_correct = 0

## Training With All Batches

In [8]:
for batch in train_loader:
    images, labels = batch
    
    preds = network(images)
    loss = F.cross_entropy(preds, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    total_loss += loss
    total_correct += get_num_correct(preds, labels)

print(
    "epoch", 0,
    "total_correct:", total_correct,
    "loss:", total_loss
)

epoch 0 total_correct: 98497 loss: tensor(572.7956, grad_fn=<AddBackward0>)


Accuracy

In [9]:
total_correct/len(train_set)

1.6416166666666667

## Training With Multople Epochs 

In [11]:
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)

for epoch in range(10):
    
    total_loss = 0
    total_correct = 0
    
    for batch in train_loader:
        images, labels = batch
        
        preds = network(images)
        loss = F.cross_entropy(preds, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss
        total_correct += get_num_correct(preds, labels)
    print( 
        "epoch", epoch,
        "total_correct:", total_correct,
        "accuracy:", total_correct/len(train_set),
        "loss:", total_loss.item()
    )

epoch 0 total_correct: 45951 accuracy: 0.76585 loss: 365.08197021484375
epoch 1 total_correct: 50834 accuracy: 0.8472333333333333 loss: 247.05551147460938
epoch 2 total_correct: 51800 accuracy: 0.8633333333333333 loss: 221.4763641357422
epoch 3 total_correct: 52168 accuracy: 0.8694666666666667 loss: 211.18914794921875
epoch 4 total_correct: 52375 accuracy: 0.8729166666666667 loss: 204.8790283203125
epoch 5 total_correct: 52475 accuracy: 0.8745833333333334 loss: 202.6754608154297
epoch 6 total_correct: 52698 accuracy: 0.8783 loss: 196.08602905273438
epoch 7 total_correct: 52722 accuracy: 0.8787 loss: 196.39512634277344
epoch 8 total_correct: 52887 accuracy: 0.88145 loss: 193.18121337890625
epoch 9 total_correct: 52946 accuracy: 0.8824333333333333 loss: 192.04408264160156


# Analyze The Results 

Locally Disabling PyTorch Gradient Tracking

In [12]:
@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for batch in loader:
        images, labels = batch
        
        preds = model(images)
        all_preds = torch.cat(
            (all_preds, preds),
            dim=0
        )
    return all_preds