In [1]:
import sys
sys.path.append("../src")

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

import sys
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

import matplotlib.pyplot as plt
from pytorch_impl.nns import FCN
from pytorch_impl.nns.utils import warm_up_batch_norm, to_one_hot
from pytorch_impl.estimators import MatrixExpEstimator
from pytorch_impl import ClassifierTraining
from pytorch_impl.matrix_exp import matrix_exp, compute_exp_term

In [4]:
torch.manual_seed(0)

num_classes = 10

if torch.cuda.is_available() and False:
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

print('Torch version: {}'.format(torch.__version__))
print('Device: {}'.format(device))

train_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=True, download=True,
          transform=transforms.ToTensor()),
    batch_size=32, shuffle=True, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=False, transform=transforms.ToTensor()),
    batch_size=32, shuffle=True, pin_memory=True)

device

Torch version: 1.3.1
Device: cpu


device(type='cpu')

In [5]:
a = 30 / 180 * np.pi
M = torch.tensor([[0, -1], [1, 0]]) * a
matrix_exp(M, device).cpu().numpy()

array([[ 0.8660254, -0.5      ],
       [ 0.5      ,  0.8660254]], dtype=float32)

In [6]:
a = 30 / 180 * np.pi
M = torch.tensor([[0, -1], [1, 0]]) * a
M_clone = M.clone().to(device)
(torch.matmul(M_clone, compute_exp_term(M, device)) + torch.eye(2)).numpy()

array([[ 0.8660254, -0.5      ],
       [ 0.5      ,  0.8660254]], dtype=float32)

In [7]:
D = 28

model     = FCN(1, D * D).to(device)
estimator = MatrixExpEstimator(model, num_classes, device, learning_rate=1., momentum=0.75)

In [8]:
_, (X, y) = next(enumerate(train_loader))
X, y = X.to(device), y.to(device)

estimator.predict(X)
for _ in range(30):
    estimator.fit(X, y)
    print(estimator.predict(X)[:5])

accuracy 0.06250, loss 5.00000
computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-1.0057, -0.9994, -0.9934, -1.0053, -1.0007, -0.9994, -1.0000, -0.9996,
          0.9975, -1.0011],
        [-0.9829, -0.9973, -1.0049, -1.0070, -0.9966, -0.9967,  0.9929, -0.9958,
         -0.9967, -0.9981],
        [-1.0140, -1.0039, -0.9873, -0.9902, -1.0042, -0.9982, -1.0024,  0.9662,
         -1.0063, -0.9936],
        [-1.0081, -1.0035, -0.9940, -0.9878, -1.0037,  0.9603, -0.9991, -0.9706,
         -1.0010, -1.0048],
        [-0.9999, -0.9981, -1.0006, -0.9998, -0.9977, -0.9969,  1.0022, -0.9995,
         -0.9983, -0.9972]])
accuracy 1.00000, loss 0.00222
computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-1.7551, -1.7494, -1.7438, -1.7546, -1.7506, -1.7496, -1.7500, -1.7498,
          1.7480, -1.7509],
        [-1.7358, -1.7478, -1.7541, -1.7562, -1.7472, -1.7473,  1.7446, -1.7464,
         -1.7473, -1.7484],
        [-1.7631, -1.7532, -1.7383, -1.7417, -1.7537, -1

accuracy 1.00000, loss 0.08785
computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-1.1342, -1.1318, -1.1288, -1.1337, -1.1326, -1.1329, -1.1326, -1.1335,
          1.1331, -1.1323],
        [-1.1347, -1.1329, -1.1326, -1.1344, -1.1329, -1.1331,  1.1361, -1.1323,
         -1.1335, -1.1327],
        [-1.1407, -1.1319, -1.1267, -1.1324, -1.1338, -1.1363, -1.1320,  1.1428,
         -1.1322, -1.1379],
        [-1.1385, -1.1333, -1.1280, -1.1297, -1.1338,  1.1461, -1.1340, -1.1522,
         -1.1338, -1.1324],
        [-1.1297, -1.1335, -1.1359, -1.1327, -1.1331, -1.1332,  1.1317, -1.1315,
         -1.1329, -1.1336]])
accuracy 1.00000, loss 0.08787
computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-0.9935, -0.9986, -1.0026, -0.9932, -0.9976, -0.9992, -0.9982, -0.9995,
          1.0010, -0.9970],
        [-1.0170, -1.0013, -0.9934, -0.9916, -1.0019, -1.0018,  1.0074, -1.0026,
         -1.0022, -1.0003],
        [-0.9885, -0.9936, -1.0069, -1.0087, -0.9945, -1

computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-0.9805, -0.9820, -0.9831, -0.9802, -0.9817, -0.9822, -0.9818, -0.9823,
          0.9824, -0.9816],
        [-0.9868, -0.9827, -0.9806, -0.9795, -0.9828, -0.9827,  0.9839, -0.9832,
         -0.9828, -0.9824],
        [-0.9785, -0.9804, -0.9840, -0.9850, -0.9805, -0.9827, -0.9811,  0.9918,
         -0.9799, -0.9836],
        [-0.9805, -0.9805, -0.9823, -0.9859, -0.9807,  0.9940, -0.9820, -0.9908,
         -0.9816, -0.9803],
        [-0.9815, -0.9824, -0.9821, -0.9818, -0.9824, -0.9827,  0.9812, -0.9818,
         -0.9823, -0.9826]])
accuracy 1.00000, loss 0.00160
computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-0.9817, -0.9828, -0.9840, -0.9815, -0.9824, -0.9825, -0.9824, -0.9823,
          0.9824, -0.9825],
        [-0.9830, -0.9826, -0.9822, -0.9809, -0.9826, -0.9825,  0.9816, -0.9830,
         -0.9823, -0.9825],
        [-0.9787, -0.9823, -0.9847, -0.9836, -0.9816, -0.9813, -0.9825,  0.9811,
      