In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f'Using device: {device}')

Using device: cuda


# Dataset 
* From PyTorch built-in datasets

In [2]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

In [3]:
train_dataset = datasets.MNIST(
    root='./data', 
    train=True, 
    download=True,               
    transform=transform
)

test_dataset = datasets.MNIST(
    root='./data', 
    train=False,
    transform=transform
)

In [4]:
train_dataset.classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

In [5]:
NUM_CLASSES = len(train_dataset.classes)
NUM_CLASSES

10

In [6]:
ex_img, ex_target = train_dataset[0] # img, seg_mask

print(ex_img.shape)
print(ex_target) 

torch.Size([1, 28, 28])
5


# Dataloader

In [7]:
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=6,
    pin_memory=False,
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=6,
    pin_memory=False,
)

In [8]:
ex_img_batch, ex_target_batch = next(iter(train_loader))
print(ex_img_batch.shape)
print(ex_target_batch.shape)

torch.Size([32, 1, 28, 28])
torch.Size([32])


# Model

### Flatten Operation

In [9]:
x_r = torch.randn(32, 1, 28, 28)
x_r.shape

torch.Size([32, 1, 28, 28])

In [10]:
bz = x_r.shape[0]
bz

32

In [11]:
#x_r_r = x_r.reshape(bz, 1*28*28)
x_r_r = x_r.reshape(bz, -1)
x_r_r.shape

torch.Size([32, 784])

In [12]:
class FlattenLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        batch_size = x.shape[0]
        return x.reshape(batch_size, -1)

In [13]:
class FlatNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = FlattenLayer()
        self.linear1 = nn.Linear(1*28*28, 512) # 784 -> 512 
        self.linear2 = nn.Linear(512, 10) # 512 -> 10 

        self.activation_fn = nn.ReLU()

    def forward(self, x):
        x_flat = self.flatten(x)
        x_linear1 = self.linear1(x_flat)
        x_linear1_act = self.activation_fn(x_linear1)
        class_logits = self.linear2(x_linear1_act)
        return class_logits

### Better implementation

### Dummy Input for Dimentional Testing

In [14]:
model = FlatNet()

In [15]:
dummy_input = torch.randn(1, 1, 28, 28)

In [16]:
dummy_preds = model(dummy_input)
dummy_preds.shape

torch.Size([1, 10])

## Print Model Parametrs

In [17]:
for p in model.parameters():
    print(p.shape)

torch.Size([512, 784])
torch.Size([512])
torch.Size([10, 512])
torch.Size([10])


## Pretty Print with Names

In [18]:
for n, p in model.named_parameters():
    print(f'name: {n} and parameter data: {p.shape}')

name: linear1.weight and parameter data: torch.Size([512, 784])
name: linear1.bias and parameter data: torch.Size([512])
name: linear2.weight and parameter data: torch.Size([10, 512])
name: linear2.bias and parameter data: torch.Size([10])


# Optimizer & Loss

In [19]:
model = FlatNet().to(device)

optimizer = torch.optim.SGD(
    model.parameters(), 
    lr=0.02
)

criterion = nn.CrossEntropyLoss()

# Training

In [20]:
LOG_INTERVAL = 1000

In [21]:
def train(model, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (img, target) in enumerate(train_loader):
        img, target = img.to(device), target.to(device)

        # Zero gradients, perform a backward pass, and update the weights.
        # In PyTorch, gradients are accumulated, you need to reset gradients in each loop
        optimizer.zero_grad()

        # Forward pass
        preds = model(img)
        loss = criterion(preds, target)
        
        # Compute gradients
        loss.backward()
        # Update gradients
        optimizer.step()
        
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(img), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

# Testing

### Function Decorator

In [22]:
@torch.no_grad()
def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    
    for img, target in test_loader:
        img, target = img.to(device), target.to(device)
        
        preds = model(img)
        test_loss += criterion(preds, target)
        
        pred_max = preds.argmax(dim=1, keepdim=True)  # get the index of the max probable class
        correct += pred_max.eq(target.view_as(pred_max)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

### Start Training
* Training consists of two steps: forward and backward propagation
* In forward propagation, we input the data into the model and measure the error (with loss function)
* In backward propagation, we adjust the internal paramters of the model so that model makes better predictions next time
* One complete cycle of the dataset is called "epoch" (one loop cycle of all data)

In [23]:
NUM_EPOCHS = 10

for epoch in range(1, NUM_EPOCHS+1):
    train(model, train_loader, optimizer, criterion, epoch)
    test(model, test_loader, criterion)


Test set: Average loss: 0.0083, Accuracy: 9212/10000 (92%)


Test set: Average loss: 0.0068, Accuracy: 9357/10000 (94%)


Test set: Average loss: 0.0049, Accuracy: 9548/10000 (95%)


Test set: Average loss: 0.0042, Accuracy: 9615/10000 (96%)


Test set: Average loss: 0.0037, Accuracy: 9654/10000 (97%)


Test set: Average loss: 0.0034, Accuracy: 9678/10000 (97%)


Test set: Average loss: 0.0031, Accuracy: 9707/10000 (97%)


Test set: Average loss: 0.0029, Accuracy: 9718/10000 (97%)


Test set: Average loss: 0.0028, Accuracy: 9731/10000 (97%)


Test set: Average loss: 0.0026, Accuracy: 9761/10000 (98%)



# Save/Load Model

In [24]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
},
    'flatnet_checkpoint.pt'
)