<a href="https://colab.research.google.com/github/khanmhmdi/Moe-llm-edge-computing/blob/main/BNN_Gate_MOE_Simple_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# -------------------------
# Bayesian Linear Layer
# -------------------------
class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_mu=0.0, prior_sigma=1.0):
        super(BayesianLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Learnable variational parameters for weights and biases.
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias_mu = nn.Parameter(torch.Tensor(out_features))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features))

        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / self.in_features ** 0.5
        self.weight_mu.data.uniform_(-stdv, stdv)
        # Initialize rho to negative values to start with small sigma (via softplus)
        self.weight_rho.data.fill_(-3)
        self.bias_mu.data.uniform_(-stdv, stdv)
        self.bias_rho.data.fill_(-3)

    def forward(self, input):
        # Transform rho to sigma using softplus
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        # Sample epsilon from standard normal distribution
        weight_eps = torch.randn_like(self.weight_mu)
        bias_eps = torch.randn_like(self.bias_mu)
        # Reparameterization trick: sample = mu + sigma * epsilon
        weight = self.weight_mu + weight_sigma * weight_eps
        bias = self.bias_mu + bias_sigma * bias_eps

        # Compute output and KL divergence for this layer.
        output = F.linear(input, weight, bias)
        self.kl = self.kl_divergence(self.weight_mu, weight_sigma) + self.kl_divergence(self.bias_mu, bias_sigma)
        return output

    def kl_divergence(self, mu, sigma):
        """
        Closed-form KL divergence between q(w) = N(mu, sigma^2)
        and p(w) = N(prior_mu, prior_sigma^2):
            KL = log(prior_sigma / sigma) + (sigma^2 + (mu - prior_mu)^2) / (2 * prior_sigma^2) - 1/2
        """
        prior_sigma = self.prior_sigma
        prior_mu = self.prior_mu
        kl = torch.log(prior_sigma / sigma) + (sigma.pow(2) + (mu - prior_mu).pow(2)) / (2 * prior_sigma**2) - 0.5
        return kl.sum()


# -------------------------
# Bayesian Gate Network
# -------------------------
class BayesianGate(nn.Module):
    def __init__(self, input_dim, num_experts, hidden_dim=128):
        super(BayesianGate, self).__init__()
        # Two BayesianLinear layers with a non-linearity.
        self.fc1 = BayesianLinear(input_dim, hidden_dim)
        self.fc2 = BayesianLinear(hidden_dim, num_experts)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)
        # Apply softmax to get a probability distribution over experts.
        gate_weights = F.softmax(logits, dim=-1)
        # Sum the KL divergence from both Bayesian layers.
        kl = self.fc1.kl + self.fc2.kl
        return gate_weights, kl


# -------------------------
# Deterministic Expert Network
# -------------------------
class Expert(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super(Expert, self).__init__()
        # A simple two-layer MLP for classification.
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# -------------------------
# Mixture of Experts (MoE) Model
# -------------------------
class MixtureOfExperts(nn.Module):
    def __init__(self, input_dim, num_experts, output_dim,
                 expert_hidden_dim=256, gate_hidden_dim=128):
        super(MixtureOfExperts, self).__init__()
        self.num_experts = num_experts
        # Create a list (ModuleList) of deterministic expert networks.
        self.experts = nn.ModuleList([
            Expert(input_dim, output_dim, hidden_dim=expert_hidden_dim)
            for _ in range(num_experts)
        ])
        # Create the Bayesian gating network.
        self.gate = BayesianGate(input_dim, num_experts, hidden_dim=gate_hidden_dim)

    def forward(self, x):
        # Compute gate weights and its KL divergence.
        gate_weights, gate_kl = self.gate(x)  # gate_weights: (batch_size, num_experts)

        # Get output from each expert.
        expert_outputs = [expert(x) for expert in self.experts]  # Each: (batch_size, output_dim)
        # Stack expert outputs into shape: (batch_size, num_experts, output_dim)
        expert_outputs = torch.stack(expert_outputs, dim=1)

        # Reshape gate weights to match the expert outputs.
        gate_weights = gate_weights.unsqueeze(-1)  # (batch_size, num_experts, 1)
        # Weighted sum over experts.
        output = torch.sum(gate_weights * expert_outputs, dim=1)
        return output, gate_kl


# -------------------------
# Training and Testing Functions
# -------------------------
def train(model, device, train_loader, optimizer, epoch, kl_weight):
    model.train()
    total_loss = 0
    total_correct = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        # Flatten the MNIST images into vectors of size 28*28 = 784.
        data = data.view(data.size(0), -1).to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output, gate_kl = model(data)
        # Compute cross-entropy loss on predictions.
        loss_pred = F.cross_entropy(output, target)
        # Total loss includes the KL divergence weighted by kl_weight.
        loss = loss_pred + kl_weight * gate_kl
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.size(0)
        pred = output.argmax(dim=1)
        total_correct += pred.eq(target).sum().item()

        if batch_idx % 100 == 0:
            print(f"Train Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] "
                  f"Loss: {loss.item():.4f} (Pred Loss: {loss_pred.item():.4f}, KL: {gate_kl.item():.4f})")

    avg_loss = total_loss / len(train_loader.dataset)
    accuracy = total_correct / len(train_loader.dataset)
    print(f"Train Epoch {epoch} Average Loss: {avg_loss:.4f}, Accuracy: {accuracy * 100:.2f}%")


def test(model, device, test_loader, kl_weight):
    model.eval()
    total_loss = 0
    total_correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data = data.view(data.size(0), -1).to(device)
            target = target.to(device)
            output, gate_kl = model(data)
            loss_pred = F.cross_entropy(output, target)
            loss = loss_pred + kl_weight * gate_kl
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            total_correct += pred.eq(target).sum().item()

    avg_loss = total_loss / len(test_loader.dataset)
    accuracy = total_correct / len(test_loader.dataset)
    print(f"\nTest set: Average loss: {avg_loss:.4f}, Accuracy: {accuracy * 100:.2f}%\n")


# -------------------------
# Main Function
# -------------------------

# Hyperparameters:
batch_size = 64
test_batch_size = 1000
epochs = 10         # You may increase to get better performance.
learning_rate = 1e-3
kl_weight = 0.001   # Adjust weight for the KL divergence regularization.

# Define device: use CUDA if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare MNIST dataset and DataLoader with standard normalization.
transform = transforms.Compose([
    transforms.ToTensor(),
    # Normalize MNIST dataset with mean and std. You can also flatten later.
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

# Input dimension is 28*28 = 784; number of classes is 10.
input_dim = 28 * 28
output_dim = 10
num_experts = 2  # You can experiment with more experts.

# Initialize the Mixture of Experts model.
model = MixtureOfExperts(input_dim, num_experts, output_dim,
                          expert_hidden_dim=256, gate_hidden_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop.
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, kl_weight)
    test(model, device, test_loader, kl_weight)





Train Epoch 1 [0/60000] Loss: 256.7855 (Pred Loss: 2.3425, KL: 254442.9375)
Train Epoch 1 [6400/60000] Loss: 244.9105 (Pred Loss: 0.2305, KL: 244680.0312)
Train Epoch 1 [12800/60000] Loss: 235.1261 (Pred Loss: 0.1693, KL: 234956.7656)
Train Epoch 1 [19200/60000] Loss: 225.5084 (Pred Loss: 0.2246, KL: 225283.8438)
Train Epoch 1 [25600/60000] Loss: 215.9463 (Pred Loss: 0.2817, KL: 215664.5625)
Train Epoch 1 [32000/60000] Loss: 206.2173 (Pred Loss: 0.1092, KL: 206108.0938)
Train Epoch 1 [38400/60000] Loss: 196.8170 (Pred Loss: 0.1949, KL: 196622.0938)
Train Epoch 1 [44800/60000] Loss: 187.3491 (Pred Loss: 0.1336, KL: 187215.5000)
Train Epoch 1 [51200/60000] Loss: 178.1880 (Pred Loss: 0.2917, KL: 177896.3438)
Train Epoch 1 [57600/60000] Loss: 168.7834 (Pred Loss: 0.1087, KL: 168674.6562)
Train Epoch 1 Average Loss: 209.6124, Accuracy: 93.14%

Test set: Average loss: 165.3168, Accuracy: 96.30%

Train Epoch 2 [0/60000] Loss: 165.2966 (Pred Loss: 0.1009, KL: 165195.7344)
Train Epoch 2 [6400/6

In [9]:
def predict_with_uncertainty(model, data, num_samples=30):
    """
    Runs multiple forward passes on the data using the Bayesian gate's stochasticity
    and returns the mean predictions and uncertainty (e.g., standard deviation).

    Args:
        model (nn.Module): The trained MoE model.
        data (Tensor): Input tensor (e.g., batch of flattened MNIST images).
        num_samples (int): Number of Monte Carlo samples.

    Returns:
        avg_output (Tensor): Averaged predictions across samples.
        uncertainty (Tensor): Standard deviation across predictions.
    """
    model.eval()  # Ensure model is in evaluation mode
    preds = []

    with torch.no_grad():
        for _ in range(num_samples):
            # Each forward pass samples new weights for the Bayesian layers.
            output, _ = model(data)
            preds.append(output)

    # Stack the predictions to shape (num_samples, batch_size, num_classes)
    preds = torch.stack(preds)
    # Average predictions along the sample axis.
    avg_output = preds.mean(dim=0)
    # Compute the uncertainty (e.g., standard deviation) across the samples.
    uncertainty = preds.std(dim=0)
    return avg_output, uncertainty

# Suppose you have a batch of MNIST test data:
data, target = next(iter(test_loader))
data = data.view(data.size(0), -1).to(device)

# Get the predictions and uncertainty using Monte Carlo sampling.
mean_predictions, prediction_uncertainty = predict_with_uncertainty(model, data, num_samples=50)

# You can now use mean_predictions for decision-making and prediction_uncertainty
# for risk evaluation or further analysis.
print("Mean Predictions:", mean_predictions)
print("Prediction Uncertainty:", prediction_uncertainty)


Mean Predictions: tensor([[ -9.5471, -13.9054,  -7.6679,  ...,  13.6950, -10.8544,  -5.2255],
        [-12.1189,  -2.1980,  18.3954,  ..., -23.8359,  -6.2128, -32.9885],
        [-17.3906,   7.7076,  -0.2277,  ...,  -4.0017,  -6.7653, -18.7809],
        ...,
        [ 29.8338, -31.1928,  -0.6697,  ...,  -6.5207, -30.6078, -13.4928],
        [-10.5167, -18.5244,   0.5654,  ...,  -7.8078,  10.4129, -10.8197],
        [ -8.3470, -15.1543,  -8.8521,  ...,   6.5622, -12.0683,   9.1366]])
Prediction Uncertainty: tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [1.3681, 1.7487, 0.3635,  ..., 0.4994, 0.4111, 1.3095],
        [0.4162, 1.6816, 0.8238,  ..., 0.6839, 1.7301, 0.9185]])


In [13]:
(prediction_uncertainty[2])

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])