In [38]:
import torch
from torch import nn
from torch import utils
import torchvision
import os

In [39]:
def prepare_dataset(train=False):
    trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
    download = False
    if len(os.listdir('../Datasets/MNIST') ) == 0:
        download = True     
    mnist = torchvision.datasets.MNIST('../Datasets/MNIST',train=train,download=download,transform=trans)
    return mnist

In [60]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.conv_layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=2, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=1))
        self.attention1 = nn.MultiheadAttention(embed_dim=64, num_heads=8, dropout=0.2)
        self.fc1 = nn.Linear(7 * 7 * 64, 256)
        self.fc2 = nn.Linear(256, 10)
    def forward(self,x):
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = self.conv_layer3(out)
        _, _, _ = self.attention1(out, out, out )
        out = torch.flatten(out,start_dim=1)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

In [61]:
def train(model, dataset, n_epochs=13, lr=1e-4, batch_size=4):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
    loss_list= []
    acc_list = []
    accuracy_track = []
    correct = 0
    mnist_loader = utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)
    total_step = len(mnist_loader)
    for epoch in range(n_epochs):
        for i,(sample, labels) in enumerate(mnist_loader):
            outputs = model(sample)
            d = outputs.data
            _, predicted = torch.max(d,1)
            loss = criterion(outputs,labels)
            loss_list.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            correct = correct + (predicted == labels).sum().item()
            if (i+1)%100 == 0:
                acc = correct / (100*batch_size) * 100
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
                      .format(epoch + 1, n_epochs, i + 1, total_step, sum(loss_list)/len(loss_list), acc))
                correct = 0
                if acc <= 100:
                    accuracy_track.append(acc)   

In [62]:
def test(model, testset):
    print("Testing model ... ")
    correct = 0
    criterion = nn.CrossEntropyLoss()
    list_acc =[]
    loss_list = []
    mnist_loader = utils.data.DataLoader(testset, shuffle=False, num_workers=0)
    total_step = len(mnist_loader)
    for i,(sample, labels) in enumerate(mnist_loader):
        outputs = model(sample)
        #print(outputs)
        d = outputs.data
        _, predicted = torch.max(d,1)
        correct = correct + (predicted == labels).sum().item()
        loss = criterion(outputs,labels)
        loss_list.append(loss.item())
        if (i+1)%100 == 0:
            acc = correct
            list_acc.append(acc)
            print('Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
                  .format(i + 1, total_step, sum(loss_list)/len(loss_list), acc))
            correct = 0
            if acc <= 100:
                list_acc.append(acc)
    print("Testing done")
    print("Average accuracy is: "+str(sum(list_acc)/len(list_acc)))

In [63]:
def main():
    MODEL_STORE_PATH = '/home/konrad/Kaggle/MNIST'
    mnist_train = prepare_dataset(train=True)
    print("Dataset succesfully loaded!")
    model = ConvNet()
    train(model, mnist_train)
    try:
        torch.save(model, MODEL_STORE_PATH + '/model_2.pt')
    except Exception as exception:
        print(type(exception).__name__)
        print("Exception occured while saving !!!")
    mnist_test = prepare_dataset(train=False)
    test(model, mnist_test)

In [64]:
main()

Dataset succesfully loaded!


ValueError: too many values to unpack (expected 3)