In [1]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim import SGD

In [7]:
class Model(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.act2 = nn.ReLU()
        self.fc1 = nn.Linear(9216, 120)
        self.act3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.act4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)

    def forward(self, input_data):
        input_data = self.conv1(input_data)
        input_data = self.act1(input_data)
        input_data = self.conv2(input_data)
        input_data = self.act2(input_data)
        input_data = input_data.reshape(input_data.size(0), -1)
        input_data = self.fc1(input_data)
        input_data = self.act3(input_data)
        input_data = self.fc2(input_data)
        input_data = self.act4(input_data)
        return self.fc3(input_data)

# class Model(nn.Module):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
#         self.conv2 = nn.Conv2d(6, 16, 5)
#         self.fc1 = nn.Linear(9216, 120)
#         self.fc2 = nn.Linear(120, 84)
#         self.fc3 = nn.Linear(84, 10)
#
#     def forward(self, input_data):
#         input_data = torch.square(self.conv1(input_data))
#         input_data = torch.square(self.conv2(input_data))
#         input_data = input_data.reshape(input_data.size(0), -1)
#         input_data = torch.square(self.fc1(input_data))
#         input_data = torch.square(self.fc2(input_data))
#         return self.fc3(input_data)

In [8]:
batch_size = 128
num_of_epochs = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

target_transform = transforms.Lambda(lambda y: torch.zeros(10, dtype=torch.float)
                                     .scatter_(0, torch.tensor(y), 1))

# load data
train_dataset = MNIST('./data', train=True, transform=transform, target_transform=target_transform,
                      download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = MNIST('./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

last_batch_idx = int(len(train_dataset) / batch_size)
if len(train_dataset) % batch_size != 0:
    last_batch_idx = last_batch_idx + 1

model = Model().to(device)
optimizer = SGD(model.parameters(), lr=0.1)
criterion = nn.MSELoss()
running_loss = []
running_acc = []
running_curr_loss = []
for epoch in range(num_of_epochs):
    curr_loss = torch.zeros(1).to(device)
    for idx, (data, label) in enumerate(train_loader):
        model.train()
        # data = data.reshape(data.size(0), -1).to(device)
        data = data.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(label, out)
        loss.backward()
        optimizer.step()
        curr_loss += loss * 10
        running_curr_loss.append(loss.item() * 10)

        with torch.no_grad():
            if idx == 0 or (idx + 1) % 10 == 0 or (idx + 1) == last_batch_idx:
                curr_acc = 0
                model.eval()
                if idx == 0:
                    running_loss.append(curr_loss.item())
                elif (idx + 1) == last_batch_idx:
                    running_loss.append((curr_loss / ((idx + 1) % 10)).item())
                else:
                    running_loss.append((curr_loss / 10).item())
                test_total = 0
                for test_data, test_label in test_loader:
                    test_data = test_data.to(device)
                    test_label = test_label.to(device)
                    test_out = model(test_data)
                    pred_label = torch.argmax(test_out, dim=1)
                    curr_acc = curr_acc + torch.count_nonzero(pred_label == test_label)
                    test_total = test_total + test_data.size(0)
                running_acc.append(curr_acc / test_total)
                if idx == 0 or (idx + 1) % 10 == 0:
                    print('epoch: {}, loss: {}, acc: {}'.format(epoch, running_loss[-1], running_acc[-1]))
                curr_loss = torch.zeros(1).to(device)

epoch: 0, loss: 0.9613454937934875, acc: 0.09600000083446503
epoch: 0, loss: 0.8504909873008728, acc: 0.11099999397993088
epoch: 0, loss: 0.8939313292503357, acc: 0.2425999939441681
epoch: 0, loss: 0.8693649172782898, acc: 0.5686999559402466
epoch: 0, loss: 0.8418002128601074, acc: 0.6753000020980835
epoch: 0, loss: 0.8045352101325989, acc: 0.7166000008583069
epoch: 0, loss: 0.7514105439186096, acc: 0.7691999673843384
epoch: 0, loss: 0.6886343955993652, acc: 0.7831000089645386
epoch: 0, loss: 0.6162659525871277, acc: 0.7364000082015991
epoch: 0, loss: 0.5748128890991211, acc: 0.7584999799728394
epoch: 0, loss: 0.5603702664375305, acc: 0.8073999881744385
epoch: 0, loss: 0.5052667260169983, acc: 0.7872999906539917
epoch: 0, loss: 0.5294490456581116, acc: 0.8208000063896179
epoch: 0, loss: 0.4763886630535126, acc: 0.793999969959259
epoch: 0, loss: 0.4682949185371399, acc: 0.8070999979972839
epoch: 0, loss: 0.4460977613925934, acc: 0.7804999947547913
epoch: 0, loss: 0.4448614716529846, acc

KeyboardInterrupt: 