In [17]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch import autograd
import torch.nn.functional as F

images = np.load("images.npy")
labels = np.load("labels.npy")

flat_images = np.zeros((len(images),len(images[0])*len(images[0][0])))
for i in range(0, len(images)-1):
    flat_images[i] = images[i].flatten()

HEIGHT = 26
WIDTH = 26
NUM_CLASSES = 5
D_H = 100
NUM_OPT_STEPS = 2000

In [18]:
class TwoSimpleConvNN(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 8, kernel_size=3)
        self.conv2 = torch.nn.Conv2d(8, 16 ,kernel_size=3, stride=2)
        self.final_conv = torch.nn.Conv2d(16, 5, kernel_size=1)
        
    def forward(self, x):
        x = x.view(-1, HEIGHT, WIDTH).unsqueeze(1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        n, c, h, w = x.size()
        x = F.avg_pool2d(x, kernel_size=[h, w])
        x = self.final_conv(x).view(-1, NUM_CLASSES)
        return x

In [19]:
model = TwoSimpleConvNN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [20]:
def train(batch_size):
    i = np.random.choice(flat_images.shape[0], size=batch_size, replace=False)
    x = torch.from_numpy(flat_images[i].astype(np.float32))
    y = torch.from_numpy(labels[i].astype(np.int))
    y_hat_ = model(x)
    loss = F.cross_entropy(y_hat_, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

In [21]:
def accuracy(y, y_hat):
    return (y == y_hat).astype(np.float).mean()

In [22]:
def approx_train_accuracy():
    i = np.random.choice(flat_images.shape[0], size=1000, replace=False)
    x = torch.from_numpy(flat_images[i].astype(np.float32))
    y = torch.from_numpy(labels[i].astype(np.int))
    y_hat = torch.from_numpy(labels[i].astype(np.int))
    for spot, val in enumerate(x):
        y_hat[spot] = torch.max(model(val), 1)[1]
    return accuracy(y.numpy(), y_hat.numpy())

In [None]:
def val_accuracy():
    x = torch.from_numpy(flat_images.astype(np.float32))
    y = torch.from_numpy(labels.astype(np.int))
    y_hat = torch.from_numpy(labels.astype(np.int))
    for spot, val in enumerate(x):
        print(torch.max(model(val), 1))
        y_hat[spot] = torch.max(model(val), 1)[1]
    return accuracy(y.numpy(), y_hat.numpy())

In [None]:
train_accs, val_accs, steps = [], [], []
for i in range(NUM_OPT_STEPS):
    train(10)
    if i % 100 == 0:
        train_accs.append(approx_train_accuracy())
        val_accs.append(val_accuracy())
        steps.append(i)
        print("%6d %5.2f %5.2f" % (i, train_accs[-1], val_accs[-1]))

     0  0.20  0.20
   100  0.46  0.44
   200  0.54  0.57
   300  0.56  0.57
   400  0.59  0.62
   500  0.62  0.60
   600  0.66  0.67
   700  0.63  0.62
   800  0.64  0.63
   900  0.74  0.73
  1000  0.71  0.73


In [None]:
plt.plot(steps, train_accs)
plt.plot(steps, val_accs)

In [None]:
val_accuracy()