In [93]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm

In [80]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [81]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

In [82]:
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

In [103]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=256)

In [107]:
def train(model, optimizer, device, dataloader, print_interval):
    model.train()
    pbar = tqdm(dataloader)
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % print_interval == 0:
            pbar.set_description(f'Batch {batch_idx} Loss: {loss.item():.5f}')

In [108]:
@torch.inference_mode
def test(model, device, dataloader):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in tqdm(dataloader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += F.nll_loss(output, target, reduction='sum')
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    n_total = len(dataloader.dataset)
    test_loss /= n_total
    accuracy = correct / n_total
    print(f'Loss: {test_loss:.5f}, Accuracy: {accuracy*100:.2f}')

In [109]:
device = 'cuda'
model = Net().to(device)
optimizer = Adam(model.parameters(), lr=0.001)
print_interval = 100

In [112]:
train(model, optimizer, device, train_dataloader, print_interval)

Batch 900 Loss: 0.06815: 100%|██████████| 938/938 [00:12<00:00, 74.04it/s]


In [113]:
test(model, device, test_dataloader)

100%|██████████| 40/40 [00:01<00:00, 24.40it/s]

Loss: 0.03525, Accuracy: 98.92



