In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

class VariationalLinear(nn.Module):
    """A variational linear layer with mean field approximation"""
    def __init__(self, input_features, output_features, prior_mean=0, prior_var=1):
        super(VariationalLinear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        
        # Parameters for the means of the weights and biases
        self.weight_mu = Parameter(torch.Tensor(output_features, input_features))
        self.bias_mu = Parameter(torch.Tensor(output_features))
        
        # Parameters for the log variance of the weights and biases
        self.weight_logvar = Parameter(torch.Tensor(output_features, input_features))
        self.bias_logvar = Parameter(torch.Tensor(output_features))
        
        # Prior distributions for the weights and biases
        self.prior_mean = prior_mean
        self.prior_var = prior_var
        
        self.reset_parameters()
    
    def reset_parameters(self):
        self.weight_mu.data.normal_(0, 0.1)
        self.bias_mu.data.zero_()
        self.weight_logvar.data.fill_(-10)
        self.bias_logvar.data.fill_(-10)
    
    def forward(self, x, sample=True):
        if sample:
            # Sample weights and biases from their distributions
            weight = self.weight_mu + torch.randn_like(self.weight_mu) * torch.exp(0.5 * self.weight_logvar)
            bias = self.bias_mu + torch.randn_like(self.bias_mu) * torch.exp(0.5 * self.bias_logvar)
        else:
            weight = self.weight_mu
            bias = self.bias_mu
        return F.linear(x, weight, bias)
    
    def kl_divergence(self):
        # KL divergence between the posterior and the prior distribution
        weight_var = torch.exp(self.weight_logvar)
        bias_var = torch.exp(self.bias_logvar)
        
        kl_weight = 0.5 * torch.sum(weight_var / self.prior_var + (self.weight_mu - self.prior_mean)**2 / self.prior_var - 1. - self.weight_logvar + torch.log(self.prior_var))
        kl_bias = 0.5 * torch.sum(bias_var / self.prior_var + (self.bias_mu - self.prior_mean)**2 / self.prior_var - 1. - self.bias_logvar + torch.log(self.prior_var))
        return kl_weight + kl_bias

class MFVI_NN(nn.Module):
    """A complex Bayesian Neural Network with flexible architecture"""
    def __init__(self, input_size, hidden_sizes, output_size, no_train_samples=10, no_pred_samples=100, prior_mean=0, prior_var=1):
        super(MFVI_NN, self).__init__()
        self.layers = nn.ModuleList()
        self.no_train_samples = no_train_samples
        self.no_pred_samples = no_pred_samples
        
        # Creating variational layers
        all_sizes = [input_size] + hidden_sizes + [output_size]
        for i in range(len(all_sizes) - 1):
            self.layers.append(VariationalLinear(all_sizes[i], all_sizes[i+1], prior_mean, prior_var))
    
    def forward(self, x, sample=True):
        for layer in self.layers[:-1]:
            x = F.relu(layer(x, sample))
        x = self.layers[-1](x, sample)
        return x
    
    def kl_divergence(self):
        kl = 0
        for layer in self.layers:
            kl += layer.kl_divergence()
        return kl

# Assuming the model setup:
input_size = 28 * 28  # Example for MNIST
hidden_sizes = [512, 256]
output_size = 10  # Example for MNIST classification

model = MFVI_NN(input_size, hidden_sizes, output_size)

# Define a dummy dataset
x = torch.randn(64, input_features)
y = torch.randn(64, output_features)

# Loss function and optimizer
optimizer = torch.optim.Adam(bnn.parameters(), lr=0.01)

def train():
    bnn.train()
    optimizer.zero_grad()
    output = bnn(x)
    reconstruction_loss = F.mse_loss(output, y)
    kl_divergence = bnn.kl_divergence()
    # The final loss combines the reconstruction loss and the KL divergence term
    loss = reconstruction_loss + kl_divergence
    loss.backward()
    optimizer.step()
    return loss.item()

epochs = 100
for epoch in range(epochs):
    loss = train()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}")


Epoch 0, Loss: 245.03265380859375
Epoch 10, Loss: 241.9197540283203
Epoch 20, Loss: 238.85108947753906
Epoch 30, Loss: 235.80706787109375
Epoch 40, Loss: 232.75558471679688
Epoch 50, Loss: 229.70718383789062
Epoch 60, Loss: 226.655517578125
Epoch 70, Loss: 223.60711669921875
Epoch 80, Loss: 220.5594482421875
Epoch 90, Loss: 217.5092315673828


In [33]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# Example dataset (replace these with real data)
X_train = torch.randn(100, 10)  # 100 samples, 10 features each
Y_train = torch.randint(0, 2, (100,))  # Binary targets for this example

# Hyperparameters
learning_rate = 0.001
epochs = 20
batch_size = 32
no_train_samples = 10
no_pred_samples = 100

# Prior hyperparameters
prior_mean = 0.0
prior_var = 1.0

# Create a dataset and dataloader
dataset = TensorDataset(X_train, Y_train)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model initialization
model = MFVI_NN(input_size=10, hidden_sizes=[50, 50], output_size=2,
                no_train_samples=no_train_samples, no_pred_samples=no_pred_samples,
                prior_mean=prior_mean, prior_var=prior_var)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Training loop
model.train()
for epoch in range(epochs):
    total_loss = 0.0
    for X_batch, Y_batch in dataloader:
        optimizer.zero_grad()
        
        # Forward pass with a sample from the variational distribution
        outputs = model(X_batch, samples=no_train_samples)
        
        # Predictive loss
        loss = criterion(outputs, Y_batch)
        
        # KL divergence (regularization term)
        kl_div = model.kl_divergence()
        
        # Total loss
        total_loss = loss + kl_div / X_train.size(0)  # Normalize KL divergence by the dataset size
        
        # Backpropagation and optimization
        total_loss.backward()
        optimizer.step()
        
        total_loss += total_loss.item()
    
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(dataloader):.4f}')


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [32, 10] but got: [50, 10].