In [None]:
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

In [None]:
def make_synthetic_dataset(n: int = 10000, rng_seed: int = 12435):
    rng = np.random.RandomState(rng_seed)
    
    x = rng.uniform(low=-1.0, high=+1.0, size=(n, 3))
    
    # Simple function for the mean of the target
    y_zero_noise = x[:, 0] - x[:, 1]
    
    # Also simple function for the stddev. Ranges from 1 to 3
    sigma = 2.0 + x[:, 2]
    eps = sigma * rng.normal(loc=0.0, scale=1.0)
    
    y_with_noise = y_zero_noise + eps
    
    return x, y_zero_noise

In [None]:
mu_model = nn.Linear(in_features=3, out_features=1, bias=True)
# FIXME - need to ensure this is non-negative!
sigma_model = nn.Linear(in_features=3, out_features=1, bias=True)

In [None]:
x_data, y_data = make_synthetic_dataset()
dataset = TensorDataset(torch.Tensor(x_data), torch.Tensor(y_data))
dataloader = DataLoader(dataset, shuffle=True, batch_size=128)

In [None]:
optimizer = torch.optim.SGD(params=(list(mu_model.parameters()) + list(sigma_model.parameters())),
                            lr=1e-2)

In [None]:
n_epochs = 10

for epoch in tqdm(range(n_epochs)):
    # Train
    mu_model.train()
    sigma_model.train()
    
    total_loss = 0.0
    total_n = 0
    
    for x, y in iter(dataloader):
        y_hat = mu_model.forward(x).squeeze(dim=1)
        sigma_hat = sigma_model.forward(x).squeeze(dim=1)

        loss = torch.log(sigma_hat) + 0.5 * (sigma_hat**(-2)) * (y - y_hat)**2
        
        loss = loss.mean(dim=0)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_n += x.size(0)
    
    print(f'Epoch: {epoch}\tLoss: {loss}')