In [0]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch
from torch import nn
import autograd
import autograd.numpy as np
import matplotlib.pylab as plt
import torch.nn.functional as F

from torchvision import datasets, models, transforms

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from scipy.integrate import solve_ivp
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import Ridge, LinearRegression
from sklearn.datasets import load_boston
from sklearn.metrics import r2_score
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.neural_network import MLPRegressor
from scipy.linalg import expm
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler

## Load the data

In [2]:
torch.manual_seed(0)

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

print('Torch version: {}'.format(torch.__version__))
print('Device: {}'.format(device))

train_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=True, download=True,
          transform=transforms.ToTensor()),
    batch_size=10000, shuffle=True, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=False, transform=transforms.ToTensor()),
    batch_size=10000, shuffle=True, pin_memory=True)

Torch version: 1.4.0
Device: cuda:0


## Model

In [0]:
n_classes, D = 10, 28

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        hidden_size = 256

        self.net = nn.Sequential(
          nn.Linear(D ** 2, hidden_size),
          nn.ReLU(),
          nn.Linear(hidden_size, n_classes)
        )

    def forward(self, x):
        x = x.view(-1, D ** 2)
        return self.net(x)

In [0]:
model = Net()
model = model.to(device)

In [5]:
param_size = 0
for p in model.parameters():
    param_size += p.nelement()
param_size

203530

## Matrix utils

In [6]:
def matrix_exp(M, device, n_iter=30):
  M = M.to(device)
  
  n = M.size()[0]
  norm = torch.sqrt((M ** 2).sum())
  steps = 0
  while norm > 1e-8:
    M /= 2.
    norm /= 2.
    steps += 1
  
  
  series_sum = torch.eye(n, dtype=torch.float64).to(device)
  prod = M.to(device)
  for i in range(1, n_iter):
    series_sum = (series_sum + prod)
    prod = torch.matmul(prod, M) / i

  exp = series_sum
  for _ in range(steps):
    exp = torch.matmul(exp, exp)
  return exp

a = 30 / 180 * np.pi
M = torch.tensor([[0, -1], [1, 0]], dtype=torch.float64) * a
matrix_exp(M, device)

tensor([[ 0.8660, -0.5000],
        [ 0.5000,  0.8660]], device='cuda:0', dtype=torch.float64)

In [7]:
# compute M^-1 * (exp(M) - E) 
def compute_exp_term(M, device, n_iter=30):
  with torch.no_grad():
    M = M.to(device)
    
    n = M.size()[0]
    norm = torch.sqrt((M ** 2).sum())
    steps = 0
    while norm > 1e-8:
      M /= 2.
      norm /= 2.
      steps += 1
    
    series_sum = torch.zeros([n, n], dtype=torch.float64).to(device)
    prod = torch.eye(n, dtype=torch.float64).to(device)
    
    # series_sum: E + M / 2 + M^2 / 6 + ...
    for i in range(1, n_iter):
      series_sum = (series_sum + prod)
      prod = torch.matmul(prod, M) / (i + 1)

    # (exp 0) (exp 0) = (exp^2           0)  
    # (sum E) (sum E) = (sum * exp + sum E)
    exp = torch.matmul(M, series_sum) + torch.eye(n).to(device)
    for step in range(steps):
      series_sum = (torch.matmul(series_sum, exp) + series_sum) / 2.
      exp = torch.matmul(exp, exp)
    
    return series_sum

a = 30 / 180 * np.pi
M = torch.tensor([[0, -1], [1, 0]], dtype=torch.float64) * a
torch.matmul(M.to(device), compute_exp_term(M, device)) + torch.eye(2).to(device)

tensor([[ 0.8660, -0.5000],
        [ 0.5000,  0.8660]], device='cuda:0', dtype=torch.float64)

In [8]:
_, (X, y) = next(enumerate(train_loader))
X = X.to(device)
y_one_hot = torch.eye(n_classes)[y] * 2 - 1
y = y_one_hot.to(device)
y.size()

torch.Size([10000, 10])

## Predictor (via closed form solution)

In [9]:
class Estimator:
  def __init__(self, model, device, learning_rate):
    self.model = model
    self.device = device
    self.lr = learning_rate
    self.ws = None
    self.step = 2000

  def one_grad(self, pred_elem):
    self.model.zero_grad()
    pred_elem.backward(retain_graph=True)
    grads = []
    for param in self.model.parameters():
        cur_grad = param.grad
        grads.append(cur_grad.view(-1))
    grad = torch.cat(grads).view(-1).detach()
    return grad

  def grads_x(self, x):
    pred = self.model.forward(x)
    return torch.stack([self.one_grad(elem) for elem in pred[0]])

  def grads(self, X):
    return torch.cat([self.grads_x(x) for x in X])

  def grads_logit(self, X, logit_i):
    pred = self.model.forward(X)[:,logit_i]
    return torch.stack([self.one_grad(elem) for elem in pred]).detach()

  def compute_Theta_0(self, X, logit_i):
    n = X.size()[0]
    
    Theta_0 = torch.zeros([n,n]).double().to(self.device)
    for li in range(0, n, self.step):
      ri = min(li + self.step, n)
      grads_i = self.grads_logit(X[li:ri], logit_i).double()

      for lj in range(0, n, self.step):
        rj = min(lj + self.step, n)
        grads_j = self.grads_logit(X[lj:rj], logit_i).double()

        Theta_0[li:ri, lj:rj] = torch.matmul(grads_i, grads_j.T)
        del grads_j
      del grads_i
    return Theta_0

  def fit(self, X, y):
    n = len(X)
    n_classes = y.size()[1]

    ws = []
    for i in range(n_classes):
      w = None

      print(f"computing grads {i}/{n_classes} ...")
      Theta_0 = self.compute_Theta_0(X, i)
      
      with torch.no_grad():
        f0 = self.model.forward(X)[:,i]
        cur_y = y[:,i]
        n = len(X)
        print(f"exponentiating kernel matrix ...")
        exp_term = - self.lr * compute_exp_term(- self.lr * Theta_0, self.device)
        del Theta_0
        right_vector = torch.mv(exp_term, (f0 - cur_y).double())
      
      for l in range(0, n, self.step):
        r = min(l + self.step, n)
        grads = self.grads_logit(X[l:r], i).double()
        cur_w = torch.mv(grads.T.double(), right_vector[l:r]).detach()
        w = cur_w if (w is None) else (w + cur_w)
      ws.append(w)
    self.ws = torch.stack(ws)
  
  def predict(self, X):
    if self.ws is None:
      return self.model.forward(X)
    
    def predict_one(x):
      with torch.no_grad():
        f0 = self.model.forward(x)[0]
      return f0.double() + (self.grads_x(x).double() * self.ws).sum(dim=1)
    return torch.stack([predict_one(x) for x in X])

print("fitting the model ...")
estimator = Estimator(model, device, learning_rate=1e5)
estimator.fit(X, y)

print(estimator.predict(X).size())
((estimator.predict(X) - y) ** 2).mean()

fitting the model ...
computing grads 0/10 ...
exponentiating kernel matrix ...
computing grads 1/10 ...
exponentiating kernel matrix ...
computing grads 2/10 ...
exponentiating kernel matrix ...
computing grads 3/10 ...
exponentiating kernel matrix ...
computing grads 4/10 ...
exponentiating kernel matrix ...
computing grads 5/10 ...
exponentiating kernel matrix ...
computing grads 6/10 ...
exponentiating kernel matrix ...
computing grads 7/10 ...
exponentiating kernel matrix ...
computing grads 8/10 ...
exponentiating kernel matrix ...
computing grads 9/10 ...
exponentiating kernel matrix ...
torch.Size([10000, 10])


tensor(2.5947e-06, device='cuda:0', dtype=torch.float64)

## Accuracy

In [12]:
_, (X, y) = next(enumerate(train_loader))
X = X.to(device)
y_one_hot = torch.eye(n_classes)[y] * 2 - 1
y = y_one_hot.to(device)

(torch.argmax(estimator.predict(X), dim=1) == torch.argmax(y, dim=1)).double().mean()

tensor(0.8729, device='cuda:0', dtype=torch.float64)