In [None]:
import torch
from torch.utils.data import Dataset

class toy_dataset(Dataset):
    def __init__(self,x,y):
        self.x = x
        self.y = y

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, i):
        return self.x[i], self.y[i]
    
class variable_mlp(torch.nn.Module):
    def __init__(self,layer_width,nonlin):
        super().__init__()
        self.layer_width = layer_width
        self.linear_layers = torch.nn.ModuleList([torch.nn.Linear(layer_width[i],layer_width[i+1], bias=False) for i in range(len(self.layer_width)-1)])
        self.lin_out = torch.nn.Linear(self.layer_width[-1],1, bias=False)
        if nonlin=='tanh':
            self.act = torch.nn.Tanh()
        elif nonlin=='relu':
            self.act = torch.nn.ReLU()

        for lin in self.linear_layers:
            torch.nn.init.normal_(lin.weight, 0, 1)
        torch.nn.init.normal_(self.lin_out.weight, 0, 1)

    # Return full output of nn
    def forward(self,x):
        for i, lin in enumerate(self.linear_layers):
            x = self.act(lin(x)) / (self.layer_width[i]**0.5)
        return self.lin_out(x)

In [2]:
# Train MLP on Gaussian Data

n = 1000
d = 5
d1 = 20
layer_widths = [d,d1]

X = torch.randn((n,d))
Y = torch.randn((n,1))

X_test = torch.randn((n,d))

net = variable_mlp(layer_width=layer_widths,nonlin='tanh')

optimizer = torch.optim.SGD(net.parameters(),lr=0.1,momentum=0.9)
loss_fn = torch.nn.MSELoss()
for i in range(5000):
    optimizer.zero_grad()
    pred = net(X)
    loss = loss_fn(pred,Y)
    loss.backward()
    optimizer.step()

print(f'nn loss : {loss.item():.4}')

nn loss : 0.8223


In [3]:
# Test NUQLS
from importlib import reload
import posteriors.nuqlsPosterior.nuqls as nqls; reload(nqls)

train = toy_dataset(X, Y)

nuqls_posterior = nqls.Nuqls(net, task='regression', full_dataset=False)
res = nuqls_posterior.train(train=train, 
                      train_bs=50, 
                      scale=0.1, 
                      S=10, 
                      epochs=100, 
                      lr=0.1, 
                      mu=0.9, 
                      verbose=True)




  warn_deprecated('make_functional', 'torch.func.functional_call')
100%|██████████| 100/100 [00:04<00:00, 22.46it/s, max_loss=0.884, resid_norm=0.00149, gpu_mem=0]

Posterior samples computed!



