In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torchvision.transforms import transforms

torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fd03a684210>

## Prepare The Data.

In [31]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

## Build The Model

In [5]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
    
    def forward(self, t):
        t = t
        t = F.max_pool2d(F.relu(self.conv1(t)), kernel_size=2, stride=2)
        t = F.max_pool2d(F.relu(self.conv2(t)), kernel_size=2, stride=2)
        t = F.relu(self.fc1(t.reshape(-1, 12*4*4)))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

## Train The Model 

### Creating an instance of our Network class.

In [6]:
network = Network()
network

Network(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=192, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=60, bias=True)
  (out): Linear(in_features=60, out_features=10, bias=True)
)

### Get a batch from dataloader.

In [32]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
batch = next(iter(train_loader))
images, labels = batch

In [33]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

### Calculating the Loss: Using F.cross_entropy

In [41]:
preds = network(images)
loss = F.cross_entropy(preds, labels)
loss.item()

2.2859325408935547

In [42]:
get_num_correct(preds, labels)

15

### Calculating The Gradients: Using loss.backward()

In [43]:
loss.backward()

In [44]:
network.conv1.weight.grad.shape

torch.Size([6, 1, 5, 5])

### Updating The Weights: Using optimizer.step()

In [22]:
optimizer = optim.Adam(network.parameters(), lr=0.01)
optimizer.step()

In [24]:
preds = network(images)
loss.item()

2.312589168548584

In [26]:
loss = F.cross_entropy(preds, labels)
loss.item()

2.2898976802825928

In [27]:
get_num_correct(preds, labels)

10

## A Single Batch

In [30]:
network = Network()

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)

batch = next(iter(train_loader))
images, labels = batch

preds = network(images)
loss = F.cross_entropy(preds, labels)

loss.backward()
optimizer.step()

print('loss1:', loss.item())
preds = network(images)
loss = F.cross_entropy(preds, labels)
print('loss2:', loss.item())

loss1: 2.296959638595581
loss2: 2.2859325408935547
