# Setup

In [1]:
import pandas as pd
from numpy import genfromtxt
from PIL import Image
import torch
import numpy as np
from copy import deepcopy as copy
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

In [2]:
class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 28, 3)
            self.pool1 = nn.MaxPool2d((2, 2))
            self.fc1 = nn.Linear(4732, 10)
        
        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = self.pool1(x)
            x = x.flatten(start_dim=1)
            x = self.fc1(x)
            return F.softmax(x, dim=1)

In [3]:
device = torch.device('cuda:0')
net = Net().to(device)
net.load_state_dict(torch.load('model1'))
net.eval()

Net(
  (conv1): Conv2d(1, 28, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=4732, out_features=10, bias=True)
)

# Relics

In [7]:
test = genfromtxt('../data/digit-recognizer/test.csv', delimiter=',')[1:]

In [8]:
np.save('testing_data.npy', test)

# Submission Code

In [4]:
test = np.load('testing_data.npy', allow_pickle=True)

In [5]:
BATCH_SIZE = 10
predictions = []
for i in range(0, len(test), BATCH_SIZE):
    batch_X = torch.Tensor(test[i:i+BATCH_SIZE]).view(-1, 1, 28, 28).to(device)
    with torch.no_grad():
        output = net(batch_X)
        softmax = torch.exp(output).cpu()
        prob = list(softmax.numpy())
        predictions.append(np.argmax(prob, axis=1))

In [6]:
flat_list = []
for sublist in predictions:
    for item in sublist:
        flat_list.append(item)

In [28]:
y = np.array(flat_list.copy())
results = pd.Series(y ,name="Label")
submission = pd.concat([pd.Series(range(1,28001),name = "ImageId"), results],axis = 1)
submission.to_csv("submission1.csv",index=False)

In [20]:
print(results)

0        0
1        0
2        0
3        0
4        0
        ..
27995    0
27996    0
27997    0
27998    0
27999    0
Name: Label, Length: 28000, dtype: int64
