# Defining an Example Model

In the next section, we define a simple 2-layer sparse DGP model for a regression task. We’ll be using this model to demonstrate the usage of the library.

In [121]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import sparse_dgp as gp
from sparse_dgp.layers.linear import LinearReparameterization
from sparse_dgp.layers.activations import TMGP

## Defining a 2-layer DTMGP Model

First, we define a 2-layer DTMGP model with a single output dimension. The model consists of two layers, each with level-3 sparse grid design.

In [122]:
# Define a 2-layer DTMGP model for regression
class SparseDGP_grid(nn.Module):
    def __init__(self, input_dim, output_dim, design_class, kernel):
        super(SparseDGP_grid, self).__init__()
        
        # 1st layer of DGP: input:[n, input_dim] size tensor, output:[n, w1] size tensor
        self.tmk1 = TMGP(in_features=input_dim, n_level=3, design_class=design_class, kernel=kernel)
        self.fc1 = LinearReparameterization(
            in_features=self.tmk1.out_features, 
            out_features=8, 
            prior_mean=0.0, 
            prior_variance=1.0, 
            posterior_mu_init=0.0, 
            posterior_rho_init=-3.0, 
            bias=True,
        )

        # 2nd layer of DGP: input:[n, w1] size tensor, output:[n, output_dim] size tensor
        self.tmk2 = TMGP(in_features=8, n_level=3, design_class=design_class, kernel=kernel)
        self.fc2 = LinearReparameterization(
            in_features=self.tmk2.out_features, 
            out_features=output_dim, 
            prior_mean=0.0, 
            prior_variance=1.0, 
            posterior_mu_init=0.0, 
            posterior_rho_init=-3.0, 
            bias=True,
        )

    def forward(self, x):
        kl_sum = 0

        x = self.tmk1(x)
        x, kl = self.fc1(x)
        kl_sum += kl

        x = self.tmk2(x)
        x, kl = self.fc2(x)
        kl_sum += kl

        return torch.squeeze(x), kl_sum

## Preparing the Data

We set up the training data for this example. We'll be using 1000 regularly spaced points in the range [0, 1] as input data. The output data is generated by a function that takes the input data and adds Gaussian noise to get the training labels.

In [123]:
train_x = torch.linspace(0, 1, 300)
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.1

class SineDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    
dataset = SineDataset(train_x, train_y)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

## Initializing the Model and the Optimizer

In [124]:
from sparse_dgp.utils.sparse_activation.design_class import HyperbolicCrossDesign
from sparse_dgp.kernels.laplace_kernel import LaplaceProductKernel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using: ", device)

model = SparseDGP_grid(input_dim=1, 
                       output_dim=1, 
                       design_class=HyperbolicCrossDesign, 
                       kernel=LaplaceProductKernel(1.),
                       ).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

Using:  cuda


## Training the Model

In [125]:
for epoch in range(50):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output_ = []
        kl_ = []
        for mc_run in range(1):
            output, kl = model(data)
            output_.append(output)
            kl_.append(kl)
        output = torch.mean(torch.stack(output_), dim=0)
        kl = torch.mean(torch.stack(kl_), dim=0)
        nll_loss = F.mse_loss(output, target)
        # ELBO loss
        loss = nll_loss + (kl / 8)
        loss.backward()
        optimizer.step()
        
    print(f"Epoch: {epoch}, Loss: {loss.item()}")

Epoch: 0, Loss: 1.2616976499557495
Epoch: 1, Loss: 1.100917935371399
Epoch: 2, Loss: 1.0582685470581055
Epoch: 3, Loss: 0.950246274471283
Epoch: 4, Loss: 0.8577843308448792
Epoch: 5, Loss: 0.8147437572479248
Epoch: 6, Loss: 0.6811959147453308
Epoch: 7, Loss: 1.0003430843353271
Epoch: 8, Loss: 0.723312497138977
Epoch: 9, Loss: 0.7210695147514343
Epoch: 10, Loss: 0.923382043838501
Epoch: 11, Loss: 0.6909069418907166
Epoch: 12, Loss: 0.6228176951408386
Epoch: 13, Loss: 0.7014390230178833
Epoch: 14, Loss: 0.6350182294845581
Epoch: 15, Loss: 0.5633167028427124
Epoch: 16, Loss: 0.6716229915618896
Epoch: 17, Loss: 0.6238070130348206
Epoch: 18, Loss: 0.5351642966270447
Epoch: 19, Loss: 0.5005326271057129
Epoch: 20, Loss: 0.7072726488113403
Epoch: 21, Loss: 0.6460894346237183
Epoch: 22, Loss: 0.5426713228225708
Epoch: 23, Loss: 0.49523478746414185
Epoch: 24, Loss: 0.42837417125701904
Epoch: 25, Loss: 1.3395843505859375
Epoch: 26, Loss: 1.6528364419937134
Epoch: 27, Loss: 0.40847164392471313
Epo