In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
import torchbnn as bnn

from sklearn.utils import shuffle
import pickle as pkl
import datetime

import matplotlib.pyplot as plt
import pandas as pd
import os


%matplotlib inline

seed = 1
np.random.seed(seed)

In [8]:
# Create fake dimensions (lat, lon, months, time)
n_years = 10
lats = np.arange(30, 50, 0.5)
lons = np.arange(-10,20, 0.5)
# Change time so the in sample runs -1, 1
time = np.arange(-1,1, (2/12)/n_years) * 2 + 1
months = np.tile(np.arange(-1,1,2/12), n_years)

# Create smooth function to serve as the observational truth
def fun(time, lat, lon, month):
    f = (0.5 * (((lat / 90) ** 2) + 0.5 * np.sin(2 * np.pi * lon / 180)) - 0.2 * np.cos(np.pi * month) + 
         0.1 * np.sin(4 * np.pi * lat / 90) * np.cos(4 * np.pi * lon / 180) - 0.05 * np.cos(2 * np.pi * time) + 0.05 * np.sin(2 * np.pi * lat / 45) * np.sin(2 * np.pi * lon / 90)) 
    return f

obs = np.zeros([len(time), len(lats), len(lons)])
for i, t in enumerate(time):
    for j, lat in enumerate(lats):
        for k, lon in enumerate(lons):
            obs[i, j, k] = fun(t, lat, lon, months[i])
            
# Normalise obs
obs = 2 * (obs - obs.min())/(obs.max() - obs.min()) - 1


# Format data so it is suitable for NN input
data_len = obs.size
lon_data = np.tile(lons, int(data_len / len(lons))).reshape(-1,)
lat_data = np.tile(np.repeat(lats, len(lons)), len(time)).reshape(-1,)
time_data = np.repeat(time, int(data_len / len(time))).reshape(-1,)
mon_data = np.repeat(months, int(data_len / len(time))).reshape(-1,)

# model 1  True in the north with a -0.03 bias
mdl1 = obs.copy() - 0.03 + np.random.normal(size=[len(time), len(lats), len(lons)]) * 0.005
mdl1[:,:-16, :] = np.random.random([len(time), len(lats) - 16, len(lons)]) * 2 - 1

# model 2 True around equator with no bias
mdl2 = obs.copy() + np.random.normal(size=[len(time),len(lats), len(lons)]) * 0.005 
mdl2[:,:16, :] = np.random.random([len(time), 16, len(lons)]) * 2 - 1
mdl2[:,-16:, :] = np.random.random([len(time), 16, len(lons)]) * 2 - 1

# model 3 True around equator with no bias (as per model 2)
mdl3 = obs.copy() + np.random.normal(size=[len(time), len(lats), len(lons)]) * 0.005
mdl3[:,:16, :] = np.random.random([len(time), 16, len(lons)]) * 2 - 1
mdl3[:,-16:, :] = np.random.random([len(time), 16, len(lons)]) * 2 - 1

# model 4 True in the south with 0.03 bias. True only for months 1-6
mdl4 = obs.copy() + 0.03 + np.random.normal(size=[len(time), len(lats), len(lons)]) * 0.005
mdl4[:,16:, :] = np.random.random([len(time), len(lats) - 16, len(lons)]) * 2 - 1

# Add noise
# In the north we have 0.01 noise
obs[:,-16:, :] = obs[:,-16:, :] + np.random.normal(size=obs[:,-16:, :].shape) * 0.01
# Around the equator we have 0.02 noise
obs[:,16:-16, :] = obs[:,16:-16, :] + np.random.normal(size=obs[:,16:-16, :].shape) * 0.02
# In the south we have 0.03 noise
obs[:,:16, :] = obs[:,:16, :] + np.random.normal(size=obs[:,:16, :].shape) * 0.03

df = pd.DataFrame()
df['mdl1'] = mdl1.ravel()
df['mdl2'] = mdl2.ravel()
df['mdl3'] = mdl3.ravel()
df['mdl4'] = mdl4.ravel()

# Convert the coordinates to have a 1:1 mapping
x = np.cos(lat_data * np.pi / 180)  * np.cos(lon_data * np.pi / 180)
y = np.cos(lat_data * np.pi / 180)  * np.sin(lon_data * np.pi / 180)
z = np.sin(lat_data * np.pi / 180)

rads = (mon_data * 360) * (np.pi / 180)
x_mon = np.sin(rads)
y_mon = np.cos(rads)

# Coordinate scaling
df['x'] = x * 2
df['y'] = y * 2
df['z'] = z * 2
df['x_mon'] = x_mon
df['y_mon'] = y_mon
df['time'] = time_data

df['obs'] = obs.ravel()

# Remove last 10 years to see extrapolation
df_in = df[:int(10 * len(df)/20)]
df_out = df[int(10 * len(df)/20):]

df_shuffled = df_in.sample(frac=1, random_state=seed)
split_idx = round(len(df_shuffled) * 0.85)
df_train = df_shuffled[:split_idx]
df_test = df_shuffled[split_idx:]

# In sample training
X_train = df_train.drop(['obs'],axis=1).values
y_train = df_train['obs'].values.reshape(-1,1)

# The in sample testing - this is not used for training
X_test = df_test.drop(['obs'],axis=1).values
y_test = df_test['obs'].values.reshape(-1,1)

# For out of sample extraploation
X_out = df_out.drop(['obs'],axis=1).values
y_out = df_out['obs'].values.reshape(-1,1)

In [4]:
df_in

Unnamed: 0,mdl1,mdl2,mdl3,mdl4,x,y,z,x_mon,y_mon,time,obs
0,0.308125,0.593262,0.855573,-0.081189,1.705737,-0.300767,1.000000,2.449294e-16,1.0,-1.000000,-0.152023
1,0.141808,-0.525250,0.467150,-0.084253,1.708297,-0.285871,1.000000,2.449294e-16,1.0,-1.000000,-0.125962
2,-0.003159,-0.219612,-0.537756,-0.078845,1.710726,-0.270952,1.000000,2.449294e-16,1.0,-1.000000,-0.135723
3,0.009407,0.625610,0.825712,-0.083226,1.713026,-0.256013,1.000000,2.449294e-16,1.0,-1.000000,-0.149991
4,-0.235543,0.764452,0.358199,-0.092031,1.715195,-0.241055,1.000000,2.449294e-16,1.0,-1.000000,-0.118857
...,...,...,...,...,...,...,...,...,...,...,...
143995,0.702084,-0.642518,-0.155602,-0.697434,1.238779,0.390586,1.520812,-8.660254e-01,0.5,0.966667,0.708805
143996,0.695548,0.261793,0.124474,-0.513974,1.235324,0.401381,1.520812,-8.660254e-01,0.5,0.966667,0.724898
143997,0.708643,0.068291,0.408332,-0.604038,1.231774,0.412146,1.520812,-8.660254e-01,0.5,0.966667,0.735268
143998,0.704098,-0.078892,-0.080094,-0.873401,1.228130,0.422879,1.520812,-8.660254e-01,0.5,0.966667,0.740708


In [6]:
X_train.shape

torch.Size([122400, 10])

In [9]:
base_dir = "C:/Users/Artgur/Desktop/Uni/MA/BNN/Models/ToyPyModels_BNN_v4/"


num_models = 4
bias_std = 0.01
noise_mean = 0.02
noise_std = 0.004
n_ensembles = 10
hidden_size = 100  # Tune this as needed

# Hyperparameters
n = X_train.shape[0]
x_dim = X_train.shape[1]
alpha_dim = x_dim - num_models
y_dim = y_train.shape[1]
learning_rate = 0.00005
n_epochs = 100  # Adjust as needed
batch_size = 2000

# Initialize standard deviations for weights and biases
init_stddev_1_w = np.sqrt(3 / (x_dim - num_models))
init_stddev_1_b = init_stddev_1_w
init_stddev_2_w = 1.3 / np.sqrt(hidden_size)
init_stddev_2_b = init_stddev_2_w
init_stddev_3_w = (bias_std * 1.3) / np.sqrt(hidden_size)
init_stddev_noise_w = (noise_std * 1.3) / np.sqrt(hidden_size)

lambda_anchor = 1.0 / (np.array([init_stddev_1_w, init_stddev_1_b, init_stddev_2_w, init_stddev_2_b, init_stddev_3_w, init_stddev_noise_w])**2)


X_train = torch.tensor(X_train, dtype=torch.float32).squeeze()
y_train = torch.tensor(y_train, dtype=torch.float32).squeeze()
X_test = torch.tensor(X_test, dtype=torch.float32).squeeze()
X_out  = torch.tensor(X_out, dtype=torch.float32).squeeze()
y_test = torch.tensor(y_test, dtype=torch.float32).squeeze()
y_out= torch.tensor(y_out, dtype=torch.float32).squeeze()

In [11]:
X_train[:, num_models: num_models + alpha_dim]

tensor([[ 1.4263, -0.0249,  1.4018, -0.8660,  0.5000,  0.7667],
        [ 1.7143,  0.0000,  1.0301,  0.8660, -0.5000,  0.6667],
        [ 1.3841,  0.3451,  1.4018,  0.8660,  0.5000,  0.0333],
        ...,
        [ 1.6519,  0.2913,  1.0893, -0.8660, -0.5000,  0.1333],
        [ 1.6852, -0.1920,  1.0598, -0.8660, -0.5000,  0.1333],
        [ 1.7102,  0.1196,  1.0301,  0.8660,  0.5000,  0.2333]])

In [13]:
X_train

tensor([[-0.4844, -0.5857, -0.9773,  ..., -0.8660,  0.5000,  0.7667],
        [ 0.6078,  0.2427, -0.8042,  ...,  0.8660, -0.5000,  0.6667],
        [-0.3018, -0.8175, -0.5805,  ...,  0.8660,  0.5000,  0.0333],
        ...,
        [-0.9893,  0.7060,  0.1687,  ..., -0.8660, -0.5000,  0.1333],
        [-0.5335, -0.2287,  0.0188,  ..., -0.8660, -0.5000,  0.1333],
        [-0.7548, -0.5539,  0.1439,  ...,  0.8660,  0.5000,  0.2333]])

In [None]:
class NN(nn.Module):
    def __init__(self, x_dim, y_dim, hidden_size, learning_rate, lambda_anchor, init_stddev_1_w, init_stddev_1_b, init_stddev_2_w, init_stddev_2_b, init_stddev_3_w, init_stddev_noise_w):
        super(NN, self).__init__()
        self.spacetime_dim = alpha_dim
        self.modelpred_dim = num_models
        self.y_dim = y_dim
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate
        self.lambda_anchor = torch.tensor(lambda_anchor, dtype=torch.float32)
        


        self.layer_1 = bnn.BayesLinear(prior_mu=0.0, prior_sigma=init_stddev_1_w, in_features=self.spacetime_dim, out_features=self.hidden_size, bias=True)
        # ?? Bayesian batch normalization
        self.layer_2 = bnn.BayesLinear(prior_mu=0.0, prior_sigma=init_stddev_2_w, in_features=self.hidden_size, out_features=self.modelpred_dim, bias=True)
        self.modelbias = bnn.BayesLinear(prior_mu=0.0, prior_sigma=init_stddev_3_w, in_features=self.hidden_size, out_features=self.y_dim, bias=False)
        self.noise_pred_layer = bnn.BayesLinear(prior_mu=0.0, prior_sigma=init_stddev_noise_w, in_features=self.hidden_size, out_features=self.y_dim, bias=False)


        # Activation Function
        self.tanh = nn.Tanh()
        
        # Optimizer
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

        # Save initial weights for anchoring
        self.initial_weights = [param.data.clone() for param in self.parameters()]




    def forward(self, x):
        modelpred = x[:, :self.modelpred_dim]
        spacetime = x[:, self.modelpred_dim:]
        #x = self.tanh(self.layer_1_bn(self.layer_1(spacetime)))
        x = self.tanh(self.layer_1(spacetime))

        modelpred = 

        model_coeff = torch.softmax(self.layer_2(x), dim=1)
        modelbias = self.modelbias(x).squeeze(-1)
        output = torch.sum(model_coeff * modelpred, dim=1) + modelbias
        noise_pred = self.noise_pred_layer(x).squeeze(-1)
        
        return output, noise_pred, model_coeff, modelbias


    
    def calculate_loss(self, x, y_target):
        output, noise_pred, _, _ = self.forward(x)
        
        noise_sq = torch.square(noise_pred + noise_mean) + 1e-6
        err_sq = torch.square(y_target.squeeze(-1) - output)
        mse_ = torch.sum(err_sq)/ x.shape[0]
        loss_ = (torch.sum(err_sq / noise_sq) + torch.sum(torch.log(noise_sq)))/ x.shape[0]

        # Anchoring Loss
        loss_anchor = 0
        for lambda_val, param, initial_param in zip(self.lambda_anchor, self.parameters(), self.initial_weights):
            loss_anchor += lambda_val * torch.sum((param - initial_param)**2) / x.shape[0]


        
        total_loss =  loss_ + loss_anchor

        return mse_, loss_, loss_anchor, noise_sq, err_sq, total_loss


    def train_step(self, x, y_target):
        self.train()
        self.optimizer.zero_grad()
        mse_loss, noise_loss, loss_anchor, noise_sq, err_sq, total_loss  = self.calculate_loss(x, y_target)
        #total_loss = noise_loss + loss_anchor
        total_loss.backward()
        self.optimizer.step()
        return total_loss.item(), mse_loss.item(), noise_loss.item(), loss_anchor.item() ,noise_sq.mean().item(), err_sq.mean().item()
     
    def get_noise_sq(self, x):
        self.eval()
        with torch.no_grad():
            _, noise_pred, _, _ = self.forward(x)
            noise_sq = torch.square(noise_pred + noise_mean) + 1e-6
        return noise_sq

    def get_alphas(self, x):
        self.eval()
        with torch.no_grad():
            _, _, model_coeff, _ = self.forward(x)
        return model_coeff

    def get_betas(self, x):
        self.eval()
        with torch.no_grad():
            _, _, _, modelbias = self.forward(x)
        return modelbias

    def get_alpha_w(self, x):
        self.eval()
        with torch.no_grad():
            x = self.tanh(self.layer_1(x[:, self.modelpred_dim:]))
            alpha_w = self.layer_2(x)
        return alpha_w

    def get_w1(self, x):
        self.eval()
        with torch.no_grad():
            x = self.tanh(self.layer_1(x[:, self.modelpred_dim:]))
        return x

def predict_ensemble(models, X):
    with torch.no_grad():
        all_preds = torch.stack([model(X)[0] for model in models])
        all_noise_sq = torch.stack([model.get_noise_sq(X) for model in models])
        all_alphas = torch.stack([model.get_alphas(X) for model in models])
        all_betas = torch.stack([model.get_betas(X) for model in models])
        all_alpha_w = torch.stack([model.get_alpha_w(X) for model in models])
        all_w1 = torch.stack([model.get_w1(X) for model in models])

    preds_mu = torch.mean(all_preds, dim=0)
    preds_std = torch.std(all_preds, dim=0)

    return all_preds, preds_mu, preds_std, all_noise_sq, all_alphas, all_betas, all_alpha_w, all_w1


def train_ensemble(X_train, y_train, n_ensembles, n_epochs, batch_size):
    models = []
    num_batches = X_train.shape[0] // batch_size

    for ensemble_index in range(n_ensembles):
        model = NN(x_dim, y_dim, hidden_size, learning_rate, lambda_anchor, init_stddev_1_w, init_stddev_1_b, init_stddev_2_w, init_stddev_2_b, init_stddev_3_w, init_stddev_noise_w)

        losses = {
            'total_losses': [],
            'mse_losses': [],
            'noise_losses': [],
            'anchor_losses': [],
        }

        for epoch in range(n_epochs):
            # Shuffle data before each epoch
            permutation = torch.randperm(X_train.shape[0])
            X_train_shuffled = X_train[permutation]
            y_train_shuffled = y_train[permutation]

            epoch_losses = {
                'total_loss': 0,
                'mse_loss': 0,
                'noise_loss': 0,
                'anchor_loss': 0,
                'count': 0
            }

            for batch_idx in range(num_batches):
                start_idx = batch_idx * batch_size
                end_idx = start_idx + batch_size
                batch_x = X_train_shuffled[start_idx:end_idx]
                batch_y = y_train_shuffled[start_idx:end_idx]

                loss, mse_loss, noise_loss, anchor_loss, noise_sq, err_sq = model.train_step(batch_x, batch_y)

                epoch_losses['total_loss'] += loss
                epoch_losses['mse_loss'] += mse_loss
                epoch_losses['noise_loss'] += noise_loss
                epoch_losses['anchor_loss'] += anchor_loss
                epoch_losses['count'] += 1

            for key in epoch_losses:
                if key != 'count':
                    corrected_key = key + 'es' if key.endswith('loss') else key + 's'
                    losses[corrected_key].append(epoch_losses[key] / epoch_losses['count'])

            if epoch % 10 == 0:
                print(f"Ensemble {ensemble_index + 1}, Epoch {epoch}, Total Loss: {losses['total_losses'][-1]}, MSE Loss: {np.round(np.sqrt(losses['mse_losses'][-1]),5)}, Noise Loss: {losses['noise_losses'][-1]}, Anchoring Loss: {losses['anchor_losses'][-1]}")
                print(f"noise_sq {noise_sq}, err_sq {err_sq}")

        save_path = f"{base_dir}model_{ensemble_index}_checkpoint.pth"
        save_model(model, model.optimizer, losses, save_path)

        models.append(model)

    return models


def save_model(model, optimizer, losses, save_path="model_checkpoint.pth"):
    save_dict = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': losses,  
    }
    torch.save(save_dict, save_path)
    print(f"Model, optimizer, and losses saved to {save_path}")

    
def recube(in_tensor, lat_len, lon_len, time_len):
    """
    Reshape a flat tensor to a 3D tensor based on dimensions of time, latitude, and longitude.
    """
    output = in_tensor.reshape(time_len, lat_len, lon_len)
    return output

def report_on_percentiles_tensor(y, y_pred, y_std):
    """
    Report the percentage of data points within 1, 2, and 3 standard deviations.
    """
    y, y_pred, y_std = y.flatten(), y_pred.flatten(), y_std.flatten()
    diffs = torch.abs(y_pred - y)
    within_1_std = torch.sum(diffs <= y_std * 1).item() / y.shape[0]
    within_2_std = torch.sum(diffs <= y_std * 2).item() / y.shape[0]
    within_3_std = torch.sum(diffs <= y_std * 3).item() / y.shape[0]

    print(f'Using {y.shape[0]} data points')
    print(f'{within_1_std * 100:.2f}% within 1 std')
    print(f'{within_2_std * 100:.2f}% within 2 std')
    print(f'{within_3_std * 100:.2f}% within 3 std')

def report_on_percentiles(y, y_pred, y_std):

    n = len(y.ravel())

    n1 = np.sum(np.abs(y_pred.ravel() - y.ravel()) <= y_std.ravel() * 1)
    n2 = np.sum(np.abs(y_pred.ravel() - y.ravel()) <= y_std.ravel() * 2)
    n3 = np.sum(np.abs(y_pred.ravel() - y.ravel()) <= y_std.ravel() * 3)
    print('Using {} data points'.format(n))
    print('{} within 1 std'.format(100 * n1 / n))
    print('{} within 2 std'.format(100 * n2 / n))
    print('{} within 3 std'.format(100 * n3 / n))

    return

def get_betas(models, X):
    with torch.no_grad():  
        betas = torch.stack([model(X)[3] for model in models]) 
    return betas


def get_alphas(models, X):
    with torch.no_grad():
        coeffs = [model(X)[2] for model in models]  
    coeffs_tensor = torch.stack(coeffs)
    return coeffs_tensor

def calculate_rmse(true, pred_mu):
    rmse = torch.sqrt(torch.mean(torch.square(pred_mu - true)))
    return rmse

def calculate_nll(true, pred_mu, pred_std):
    nll = 0.5 * (((pred_mu - true) ** 2) / (pred_std ** 2)) + torch.log(pred_std ** 2) + torch.log(2 * torch.tensor(np.pi))
    return torch.mean(nll)

