In [1]:
import sys
sys.path.append("../src")

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

import sys
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

import matplotlib.pyplot as plt
from pytorch_impl.nns import FCN
from pytorch_impl.nns.utils import warm_up_batch_norm, to_one_hot
from pytorch_impl.estimators import MatrixExpEstimator
from pytorch_impl import ClassifierTraining
from pytorch_impl.matrix_exp import matrix_exp, compute_exp_term

In [4]:
torch.manual_seed(0)

num_classes = 10

if torch.cuda.is_available() and False:
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

print('Torch version: {}'.format(torch.__version__))
print('Device: {}'.format(device))

train_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=True, download=True,
          transform=transforms.ToTensor()),
    batch_size=32, shuffle=True, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=False, transform=transforms.ToTensor()),
    batch_size=32, shuffle=True, pin_memory=True)

device

Torch version: 1.3.1
Device: cpu


device(type='cpu')

In [5]:
a = 30 / 180 * np.pi
M = torch.tensor([[0, -1], [1, 0]]) * a
matrix_exp(M, device).cpu().numpy()

array([[ 0.8660254, -0.5      ],
       [ 0.5      ,  0.8660254]], dtype=float32)

In [6]:
a = 30 / 180 * np.pi
M = torch.tensor([[0, -1], [1, 0]]) * a
M_clone = M.clone().to(device)
(torch.matmul(M_clone, compute_exp_term(M, device)) + torch.eye(2)).numpy()

array([[ 0.8660254, -0.5      ],
       [ 0.5      ,  0.8660254]], dtype=float32)

In [7]:
D = 28

model     = FCN(10, D * D).to(device)
estimator = MatrixExpEstimator(model, num_classes, device, criterion=nn.CrossEntropyLoss(), learning_rate=1., momentum=0)

In [8]:
_, (X, y) = next(enumerate(train_loader))
X, y = X.to(device), y.to(device)

estimator.predict(X)
for _ in range(30):
    estimator.fit(X, y)
    print(estimator.predict(X)[:5])

scale 63.699283599853516
accuracy 0.12500, loss 0.07153
computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-0.1071, -0.2889, -0.2074, -0.1390, -0.1515,  1.7243, -0.1966, -0.2087,
         -0.1710, -0.1846],
        [-0.1132, -0.2649, -0.2294, -0.1477, -0.1856, -0.2429, -0.1495, -0.1472,
         -0.1377,  1.7560],
        [-0.0604, -0.2112, -0.2596, -0.0952, -0.1142, -0.2777, -0.0861,  1.9017,
         -0.1418, -0.1971],
        [-0.0709, -0.2087, -0.2746, -0.0894, -0.1210, -0.2803, -0.1338,  1.8052,
         -0.1802, -0.1358],
        [ 1.7957, -0.2107, -0.3130, -0.1301, -0.1405, -0.3264, -0.1410, -0.1736,
         -0.1139, -0.1480]])
scale 0.5782727003097534
accuracy 1.00000, loss 0.02503
computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-0.1083, -0.2899, -0.2085, -0.1402, -0.1526,  1.7346, -0.1977, -0.2098,
         -0.1722, -0.1858],
        [-0.1144, -0.2659, -0.2305, -0.1489, -0.1866, -0.2440, -0.1507, -0.1483,
         -0.1389,  1.7661],
      

computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-0.1083, -0.2899, -0.2085, -0.1402, -0.1526,  1.7347, -0.1977, -0.2098,
         -0.1722, -0.1858],
        [-0.1144, -0.2659, -0.2305, -0.1489, -0.1866, -0.2440, -0.1507, -0.1483,
         -0.1389,  1.7661],
        [-0.0617, -0.2122, -0.2605, -0.0963, -0.1153, -0.2786, -0.0873,  1.9115,
         -0.1429, -0.1982],
        [-0.0721, -0.2098, -0.2756, -0.0906, -0.1222, -0.2812, -0.1350,  1.8153,
         -0.1813, -0.1369],
        [ 1.8056, -0.2117, -0.3140, -0.1312, -0.1416, -0.3274, -0.1422, -0.1748,
         -0.1151, -0.1491]])
scale -7.132670134524233e-07
accuracy 1.00000, loss 0.02484
computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-0.1083, -0.2899, -0.2085, -0.1402, -0.1526,  1.7347, -0.1977, -0.2098,
         -0.1722, -0.1858],
        [-0.1144, -0.2659, -0.2305, -0.1489, -0.1866, -0.2440, -0.1507, -0.1483,
         -0.1389,  1.7661],
        [-0.0617, -0.2122, -0.2605, -0.0963, -0.1153, -0.2

computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-0.1083, -0.2899, -0.2085, -0.1402, -0.1526,  1.7347, -0.1977, -0.2098,
         -0.1722, -0.1858],
        [-0.1144, -0.2659, -0.2305, -0.1489, -0.1866, -0.2440, -0.1507, -0.1483,
         -0.1389,  1.7661],
        [-0.0617, -0.2122, -0.2605, -0.0963, -0.1153, -0.2786, -0.0873,  1.9115,
         -0.1429, -0.1982],
        [-0.0721, -0.2098, -0.2756, -0.0906, -0.1222, -0.2812, -0.1350,  1.8153,
         -0.1813, -0.1369],
        [ 1.8056, -0.2117, -0.3140, -0.1312, -0.1416, -0.3274, -0.1422, -0.1748,
         -0.1151, -0.1491]])
scale -7.132670134524233e-07
accuracy 1.00000, loss 0.02484
computing grads ... 0s
exponentiating kernel matrix ... 0s
tensor([[-0.1083, -0.2899, -0.2085, -0.1402, -0.1526,  1.7347, -0.1977, -0.2098,
         -0.1722, -0.1858],
        [-0.1144, -0.2659, -0.2305, -0.1489, -0.1866, -0.2440, -0.1507, -0.1483,
         -0.1389,  1.7661],
        [-0.0617, -0.2122, -0.2605, -0.0963, -0.1153, -0.2