In [4]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
input_size = 784
epochs = 2
learning_rate = 0.01
batch_size = 64

In [28]:
class Normalize:
    def __call__(self, samples):
        print(torch.max(samples, 1))
        inputs = samples
        inputs = inputs / 255.0
        print(torch.max(inputs, 1))
        return inputs
        
    
transform = transforms.Compose([
    transforms.ToTensor(),
    # Normalize()
])
train_datasets = torchvision.datasets.MNIST(root='./datasets/', train=True, transform=transform)
test_datasets = torchvision.datasets.MNIST(root='./datasets/', train=False, transform=transform)

train_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True)

In [79]:
class CNN(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple) -> None:
        super(CNN, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=16,
                        padding=0, kernel_size=kernel_size,  stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.cnn2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32,
                        padding=0, kernel_size=kernel_size,  stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.fc = nn.Linear(32*5*5, out_channels)
        
    def forward(self, x: torch.Tensor):
        x = self.cnn1(x)
        x = self.cnn2(x)
        # flatten the output to batch size, 32x5x5
        x = x.reshape((x.shape[0], -1))
        x = self.fc(x)
        
        return x
        

In [66]:
examples = iter(train_loader)
samples, labels = examples._next_data()

In [67]:
samples.shape

torch.Size([64, 1, 28, 28])

In [80]:
    
model = CNN(in_channels=1, out_channels=num_classes, kernel_size=(3, 3))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

test = model(samples)

In [81]:
# training loop
total_steps = len(train_loader)
for epoch in range(epochs):
    for idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # backward
        optimizer.zero_grad()
        loss.backward()
        
        # update weight
        optimizer.step()
        
        
        if (idx+1) % 100 == 0:
            print(f"epoch {epoch + 1} / {epochs}, step {idx+1} / {total_steps}, loss = {loss:.4f}")
            

epoch 1 / 2, step 100 / 938, loss = 0.2328
epoch 1 / 2, step 200 / 938, loss = 0.0893
epoch 1 / 2, step 300 / 938, loss = 0.0323
epoch 1 / 2, step 400 / 938, loss = 0.1085
epoch 1 / 2, step 500 / 938, loss = 0.1349
epoch 1 / 2, step 600 / 938, loss = 0.0525
epoch 1 / 2, step 700 / 938, loss = 0.0700
epoch 1 / 2, step 800 / 938, loss = 0.2278
epoch 1 / 2, step 900 / 938, loss = 0.1083
epoch 2 / 2, step 100 / 938, loss = 0.0272
epoch 2 / 2, step 200 / 938, loss = 0.0248
epoch 2 / 2, step 300 / 938, loss = 0.0426
epoch 2 / 2, step 400 / 938, loss = 0.0212
epoch 2 / 2, step 500 / 938, loss = 0.0326
epoch 2 / 2, step 600 / 938, loss = 0.0480
epoch 2 / 2, step 700 / 938, loss = 0.0030
epoch 2 / 2, step 800 / 938, loss = 0.0564
epoch 2 / 2, step 900 / 938, loss = 0.0200


In [82]:
with torch.no_grad():
    n_corrects = 0
    n_samples = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        
        _, predictions = torch.max(outputs, 1)
        n_samples += labels.shape[0] 
        n_corrects += (predictions == labels).sum().item()
    
    acc = n_corrects / n_samples * 100.0
    print(f'accuracy = {acc:.2f}')

accuracy = 98.35
