In [1]:
import random
import time
from itertools import product, islice
from collections import defaultdict
from tqdm.auto import tqdm

import numpy as np
import matplotlib.pyplot as plt

from scipy import interpolate, optimize

import torch
import torch.nn as nn
import torchvision

from precisionml.optimizers import ConjugateGradients

In [2]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
torch.set_default_dtype(torch.float64)
device

'cuda:0'

### define losses

In [3]:
rmse_loss_fn_torch = lambda x, y: torch.sqrt(torch.mean(torch.pow(x-y, 2)))
qmqe_loss_fn_torch = lambda x, y: torch.pow(torch.mean(torch.pow(x-y, 4)), 1/4)
smse_loss_fn_torch = lambda x, y: torch.pow(torch.mean(torch.pow(x-y, 6)), 1/6)
mse_loss_fn_torch = nn.MSELoss()

def lp_norm(p):
    def loss(x, y):
        return torch.pow(torch.mean(torch.pow(torch.abs(x-y), p)), 1/p)
    return loss

def dl_loss(epsilon):
    def loss(x, y):
        return torch.mean(0.5 * torch.log2(1 + torch.pow((x-y) / epsilon, 2)))
    return loss

In [4]:
def loss(param_vector, lenghts, shapes, 
             mlp, loss_fn, x, y, device=device):
    l = 0
    for i, param in enumerate(mlp.parameters()):
        param_data = param_vector[l:l+lenghts[i]]
        l += lenghts[i]
        param_data_shaped = param_data.reshape(shapes[i])
        param.data = torch.tensor(param_data_shaped).to(device)
    return loss_fn(mlp(x.to(device)), y).detach().cpu().numpy()

def gradient(param_vector, lenghts, shapes, 
             mlp, loss_fn, x, y, device=device):
    l = 0
    for i, param in enumerate(mlp.parameters()):
        param_data = param_vector[l:l+lenghts[i]]
        l += lenghts[i]
        param_data_shaped = param_data.reshape(shapes[i])
        param.data = torch.tensor(param_data_shaped).to(device)
    loss_fn(mlp(x.to(device)), y).backward()
    grads = []
    for param in mlp.parameters():
        grads.append(param.grad.detach().clone().cpu().numpy().flatten())
        param.grad = None
    return np.concatenate(grads)

In [5]:
n = 300
x_line = y_line = np.linspace(-1, 1, n)
X, Y = np.meshgrid(x_line, y_line)
Z = np.exp(X + Y)
x_2d = np.stack([X, Y], axis=2)

x = torch.from_numpy(x_2d.reshape((n*n, 2))).to(device)
y = torch.from_numpy(Z.reshape((n*n, 1))).to(device)

n = 257
x_line = y_line = np.linspace(-1, 1, n)
X, Y = np.meshgrid(x_line, y_line)
Z = np.exp(X + Y)
x_2d = np.stack([X, Y], axis=2)

x_test = torch.from_numpy(x_2d.reshape((n*n, 2))).to(device)
y_test = torch.from_numpy(Z.reshape((n*n, 1))).to(device)

### train first network

In [7]:
width = 10
mlp_N = nn.Sequential(
    nn.Linear(2, width),
    nn.ReLU(),
    nn.Linear(width, width),
    nn.ReLU(),
    nn.Linear(width, 1)
).to(device)

params = []
shapes = []
lenghts = []
for param in mlp_N.parameters():
    param_np = param.data.detach().clone().cpu().numpy()
    shapes.append(param_np.shape)
    param_np_flat = param_np.flatten()
    lenghts.append(len(param_np_flat))
    params.append(param_np_flat)

param_vector = np.concatenate(params)
N = len(param_vector)

result = optimize.minimize(loss,
                           param_vector, 
                           args=(lenghts, shapes, mlp_N, rmse_loss_fn_torch, x, y, device),
                           jac=gradient,
                           method='BFGS',
                           options={
                               'disp': True,
                               'gtol': 1e-18,
                               'maxiter': 25000,
#                                'finite_diff_rel_step': 1e-15
                           },
                        )
    
l = 0
for i, param in enumerate(mlp_N.parameters()):
    param_data = result.x[l:l+lenghts[i]]
    l += lenghts[i]
    param_data_shaped = param_data.reshape(shapes[i])
    param.data = torch.tensor(param_data_shaped).to(device)

rmse_N = rmse_loss_fn_torch(mlp_N(x_test), y_test).item()
rmse_N


         Current function value: 0.004184
         Iterations: 283
         Function evaluations: 393
         Gradient evaluations: 381


0.004199931166830488

### train second network

In [8]:

width = 30
mlp_M = nn.Sequential(
    nn.Linear(2, width),
    nn.ReLU(),
    nn.Linear(width, width),
    nn.ReLU(),
    nn.Linear(width, 1)
).to(device)

params = []
shapes = []
lenghts = []
for param in mlp_M.parameters():
    param_np = param.data.detach().clone().cpu().numpy()
    shapes.append(param_np.shape)
    param_np_flat = param_np.flatten()
    lenghts.append(len(param_np_flat))
    params.append(param_np_flat)

param_vector = np.concatenate(params)
M = len(param_vector)

result = optimize.minimize(loss,
                           param_vector, 
                           args=(lenghts, shapes, mlp_M, rmse_loss_fn_torch, x, y, device),
                           jac=gradient,
                           method='BFGS',
                           options={
                               'disp': True,
                               'gtol': 1e-18,
                               'maxiter': 25000,
#                                'finite_diff_rel_step': 1e-15
                           },
                        )

l = 0
for i, param in enumerate(mlp_M.parameters()):
    param_data = result.x[l:l+lenghts[i]]
    l += lenghts[i]
    param_data_shaped = param_data.reshape(shapes[i])
    param.data = torch.tensor(param_data_shaped).to(device)

rmse_M = rmse_loss_fn_torch(mlp_M(x_test), y_test).item()
rmse_M


         Current function value: 0.000581
         Iterations: 2904
         Function evaluations: 3108
         Gradient evaluations: 3097


0.0005884489711409657

In [9]:
a = - np.log(rmse_N / rmse_M) / np.log(N / M)

In [10]:
def mlp_R(x):
    numerator = mlp_N(x) - (mlp_M(x) * np.power(N/M, -a))
    denominator = 1 - np.power(N/M, -a)
    return numerator / denominator

In [11]:
rmse_loss_fn_torch(mlp_R(x_test), y_test).item()

0.0009555428767046097