In [1]:
import torch
from functorch import make_functional
from torch.func import functional_call, vmap, jacrev, jvp
    ## NUQLS
import posteriors.nuqls as nuqls
from importlib import reload
reload(nuqls)
from torch.utils.data import DataLoader, Dataset
import tqdm

class toy_dataset(Dataset):
    def __init__(self,x,y):
        self.x = x
        self.y = y

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, i):
        return self.x[i], self.y[i]
    
class variable_mlp(torch.nn.Module):
    def __init__(self,layer_width,nonlin):
        super().__init__()
        self.layer_width = layer_width
        self.linear_layers = torch.nn.ModuleList([torch.nn.Linear(layer_width[i],layer_width[i+1], bias=True) for i in range(len(self.layer_width)-1)])
        self.lin_out = torch.nn.Linear(self.layer_width[-1],1, bias=True)
        if nonlin=='tanh':
            self.act = torch.nn.Tanh()
        elif nonlin=='relu':
            self.act = torch.nn.ReLU()

        for lin in self.linear_layers:
            torch.nn.init.normal_(lin.weight, 0, 1)
        torch.nn.init.normal_(self.lin_out.weight, 0, 1)

    # Return full output of nn
    def forward(self,x):
        for i, lin in enumerate(self.linear_layers):
            x = self.act(lin(x)) / (self.layer_width[i]**0.5)
        return self.lin_out(x)

In [6]:
n = 100
d = 5
layer_widths = [d,20]
nonlin = 'relu'

In [27]:
X = torch.randn((n,d))
Y = torch.randn((n,1))

X_test = torch.randn((n,d))

net = variable_mlp(layer_width=layer_widths,nonlin=nonlin)

optimizer = torch.optim.SGD(net.parameters(),lr=0.1,momentum=0.9)
loss_fn = torch.nn.MSELoss()
for i in range(5000):
        optimizer.zero_grad()
        pred = net(X)
        loss = loss_fn(pred,Y)
        loss.backward()
        optimizer.step()
print(f'nn loss : {loss.item():.4}')

# Compute NTKGP
fnet, params = make_functional(net)
train_data = toy_dataset(X,Y)
test_data = toy_dataset(X_test,Y)

nuql = nuqls.small_regression_parallel_width(net, train=train_data, S = 10, epochs=100, lr=0.01, bs=n, bs_test=n, width=layer_widths[-1], init_scale=0.1)
loss,resid = nuql.train_linear(mu=0,weight_decay=0,my=0,sy=1,threshold=None,verbose=True, progress_bar=False)
nuql_test_preds = nuql.test_linear(test=test_data)
print(f'cuqls loss : {loss.item():.4}, resid: {resid.item():.4}')

nn loss : 0.1147

-----------------
Epoch 0 of 100
max l2 loss = 0.171140132178374
Residual of normal equation l2 = 66.43049008736743

-----------------
Epoch 10 of 100
max l2 loss = 0.15094746597356123
Residual of normal equation l2 = 36.23925029642465

-----------------
Epoch 20 of 100
max l2 loss = 0.13980772999655647
Residual of normal equation l2 = 20.333203677402793

-----------------
Epoch 30 of 100
max l2 loss = 0.13345945120752165
Residual of normal equation l2 = 11.854384749783945

-----------------
Epoch 40 of 100
max l2 loss = 0.12968237052121956
Residual of normal equation l2 = 7.257818476934101

-----------------
Epoch 50 of 100
max l2 loss = 0.12731239100211658
Residual of normal equation l2 = 4.706451255743707

-----------------
Epoch 60 of 100
max l2 loss = 0.1257335254787276
Residual of normal equation l2 = 3.244755265614347

-----------------
Epoch 70 of 100
max l2 loss = 0.12461548944378933
Residual of normal equation l2 = 2.373002238494821

-----------------
Epoch 

  fnet, params = make_functional(net)


In [29]:
nuql.theta

tensor([[-2.8533e+00, -2.8282e+00, -3.0499e+00, -2.8780e+00, -3.0864e+00,
         -2.7528e+00, -2.9497e+00, -2.8749e+00, -2.8151e+00, -2.7828e+00],
        [-3.0797e+00, -3.0206e+00, -2.8722e+00, -2.9797e+00, -2.9706e+00,
         -3.0470e+00, -3.1822e+00, -2.9610e+00, -3.0396e+00, -2.9097e+00],
        [-1.8688e+00, -1.8718e+00, -1.9400e+00, -1.8867e+00, -1.8595e+00,
         -2.1044e+00, -2.0652e+00, -2.0570e+00, -1.9594e+00, -1.8543e+00],
        [-2.1421e+00, -2.0711e+00, -2.1153e+00, -2.2527e+00, -2.2065e+00,
         -1.9830e+00, -2.0545e+00, -2.1593e+00, -2.1605e+00, -2.2915e+00],
        [-1.1083e-01,  1.4973e-02, -4.2185e-02,  1.9202e-02,  7.3365e-03,
          8.4283e-02,  1.1332e-01,  1.9734e-01, -4.6046e-02,  1.6546e-01],
        [ 3.4585e+00,  3.5403e+00,  3.6417e+00,  3.5730e+00,  3.7155e+00,
          3.4660e+00,  3.6503e+00,  3.5354e+00,  3.5289e+00,  3.5700e+00],
        [-1.0016e+00, -1.0440e+00, -1.1570e+00, -1.0084e+00, -9.9891e-01,
         -9.9917e-01, -8.8161e-0