<a href="https://colab.research.google.com/github/khanmhmdi/Moe-llm-edge-computing/blob/main/BEYSIAN_CIFAR_FROM_SCRATCH.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.distributions import Normal, kl
import numpy as np

# ----------------------------------------------------------------------
# 1. Define Hyperparameters
# ----------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1024
learning_rate = 0.001
num_epochs = 100
num_samples = 10 # Number of samples from the posterior for ELBO estimation
kl_weight = 0.01 # Weight for the KL divergence term
# ----------------------------------------------------------------------
# 2. Load and Preprocess CIFAR-10 Data
# ----------------------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# ----------------------------------------------------------------------
# 3. Define the Bayesian Linear Layer
# ----------------------------------------------------------------------
class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_mu=0, prior_sigma=1):
        super(BayesianLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Prior distribution parameters
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma
        self.prior = Normal(torch.tensor([self.prior_mu]).to(device), torch.tensor([self.prior_sigma]).to(device)) # Create a Normal distribution object

        # Variational posterior parameters (learnable)
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-5, -4)) # Initialize rho to get a small initial std
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-5, -4))

    def forward(self, x):
        # Sample weights from the variational posterior
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        weight_dist = Normal(self.weight_mu, weight_sigma)
        weight = weight_dist.rsample() # Reparameterization trick

        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        bias_dist = Normal(self.bias_mu, bias_sigma)
        bias = bias_dist.rsample()  # Reparameterization trick

        self.weight_log_prob = weight_dist.log_prob(weight).sum()
        self.bias_log_prob = bias_dist.log_prob(bias).sum()

        self.weight_prior_log_prob = self.prior.log_prob(weight).sum()
        self.bias_prior_log_prob = self.prior.log_prob(bias).sum()

        return nn.functional.linear(x, weight, bias)

    def kl_loss(self):
        """Calculates the KL divergence between the variational posterior and the prior."""
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        weight_dist = Normal(self.weight_mu, weight_sigma)
        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        bias_dist = Normal(self.bias_mu, bias_sigma)

        kl_weight = kl.kl_divergence(weight_dist, self.prior).sum()
        kl_bias = kl.kl_divergence(bias_dist, self.prior).sum()
        return kl_weight + kl_bias


# ----------------------------------------------------------------------
# 4. Define the BNN Model
# ----------------------------------------------------------------------
class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.flatten = nn.Flatten()

        # Use BayesianLinear layers
        self.fc1 = BayesianLinear(32 * 8 * 8, 128)  # Adjusted input size after pooling
        self.fc2 = BayesianLinear(128, 10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
        x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
        x = self.dropout(x)
        x = self.flatten(x)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def kl_loss(self):
        """Calculates the total KL divergence for the network."""
        kl_loss = self.fc1.kl_loss() + self.fc2.kl_loss()
        return kl_loss

# ----------------------------------------------------------------------
# 5. Training Loop (MODIFIED)
# ----------------------------------------------------------------------
def train(model, optimizer, trainloader, num_epochs, kl_weight):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_ce_loss = 0.0  # Running Cross-Entropy Loss
        running_kl_loss = 0.0  # Running KL Divergence Loss
        correct = 0
        total = 0

        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            # Forward pass and calculate losses
            outputs = model(inputs)
            log_likelihood = nn.CrossEntropyLoss()(outputs, labels) # Cross-Entropy Loss
            kl_loss_val = model.kl_loss() # KL Divergence Loss
            elbo = log_likelihood + kl_weight * kl_loss_val # Minimize the negative ELBO

            elbo.backward()
            optimizer.step()

            running_loss += elbo.item()
            running_ce_loss += log_likelihood.item()
            running_kl_loss += kl_loss_val.item()

            _, predicted = torch.max(outputs.data, 1) # Calculate training accuracy
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # if i % 40 == 0:    # print every 200 mini-batches
            #     avg_elbo_loss = running_loss / 200
            #     avg_ce_loss = running_ce_loss / 200
            #     avg_kl_loss = running_kl_loss / 200
            #     accuracy = 100 * correct / total
            #     print(f'[{epoch + 1}, {i + 1:5d}] ELBO Loss: {avg_elbo_loss:.3f} CE Loss: {avg_ce_loss:.3f} KL Loss: {avg_kl_loss:.3f} Train Accuracy: {accuracy:.2f}%')
            #     running_loss = 0.0
            #     running_ce_loss = 0.0
            #     running_kl_loss = 0.0
            #     correct = 0
            #     total = 0


        # Print epoch-level s/ummary (optional, after each epoch)
        avg_epoch_elbo_loss = running_loss / len(trainloader)
        avg_epoch_ce_loss = running_ce_loss / len(trainloader)
        avg_epoch_kl_loss = running_kl_loss / len(trainloader)
        epoch_accuracy = 100 * correct / len(trainloader.dataset)
        print(f'Epoch [{epoch + 1}] Summary - ELBO Loss: {avg_epoch_elbo_loss:.3f} CE Loss: {avg_epoch_ce_loss:.3f} KL Loss: {avg_epoch_kl_loss:.3f} Train Accuracy: {epoch_accuracy:.2f}%')


    print('Finished Training')
# ----------------------------------------------------------------------
# 6. Evaluation Function
# ----------------------------------------------------------------------

def evaluate(model, testloader, num_samples):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            # Sample multiple times from the posterior
            predictions = torch.zeros(num_samples, images.size(0), 10).to(device) # (num_samples, batch_size, num_classes)

            for i in range(num_samples):
                outputs = model(images)
                predictions[i] = nn.functional.softmax(outputs, dim=1)

            # Average the predictions
            mean_predictions = torch.mean(predictions, dim=0)
            _, predicted = torch.max(mean_predictions.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')


# ----------------------------------------------------------------------
# 7.  Instantiate Model, Optimizer, and Train
# ----------------------------------------------------------------------

model = BNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train(model, optimizer, trainloader, num_epochs, kl_weight)

# ----------------------------------------------------------------------
# 8. Evaluate the Model
# ----------------------------------------------------------------------
evaluate(model, testloader, num_samples)


Files already downloaded and verified
Files already downloaded and verified
[1,     1] ELBO Loss: 52.926 CE Loss: 0.041 KL Loss: 5288.494 Train Accuracy: 9.28%
[1,    41] ELBO Loss: 2105.028 CE Loss: 0.606 KL Loss: 210442.245 Train Accuracy: 21.22%
[2,     1] ELBO Loss: 52.237 CE Loss: 0.010 KL Loss: 5222.644 Train Accuracy: 27.34%
[2,    41] ELBO Loss: 2078.369 CE Loss: 0.379 KL Loss: 207798.972 Train Accuracy: 31.93%
[3,     1] ELBO Loss: 51.575 CE Loss: 0.009 KL Loss: 5156.583 Train Accuracy: 34.86%
[3,    41] ELBO Loss: 2051.970 CE Loss: 0.338 KL Loss: 205163.152 Train Accuracy: 38.98%
[4,     1] ELBO Loss: 50.917 CE Loss: 0.008 KL Loss: 5090.945 Train Accuracy: 42.68%
[4,    41] ELBO Loss: 2025.761 CE Loss: 0.308 KL Loss: 202545.337 Train Accuracy: 43.85%
[5,     1] ELBO Loss: 50.265 CE Loss: 0.007 KL Loss: 5025.753 Train Accuracy: 47.17%
[5,    41] ELBO Loss: 1999.725 CE Loss: 0.287 KL Loss: 199943.899 Train Accuracy: 47.94%
[6,     1] ELBO Loss: 49.616 CE Loss: 0.007 KL Loss: 49

KeyboardInterrupt: 

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.distributions import Normal, kl
import numpy as np

# ----------------------------------------------------------------------
# 1. Define Hyperparameters (MODIFIED)
# ----------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1024
learning_rate = 0.001
num_epochs = 100
num_samples = 10 # Number of samples from the posterior for ELBO estimation
kl_weight = 1  # Reduced KL weight to balance the loss terms
num_train = 50000  # Number of training examples in CIFAR-10

# ----------------------------------------------------------------------
# 2. Load and Preprocess CIFAR-10 Data (Unchanged)
# ----------------------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# ----------------------------------------------------------------------
# 3. Define the Bayesian Linear Layer (MODIFIED)
# ----------------------------------------------------------------------
class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_mu=0, prior_sigma=0.1):  # Smaller prior_sigma
        super(BayesianLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Prior distribution parameters
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma
        self.prior = Normal(torch.tensor([self.prior_mu]).to(device),
                           torch.tensor([self.prior_sigma]).to(device))

        # Variational posterior parameters (learnable) with higher initial rho
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-3, -2))  # Higher rho
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-3, -2))  # Higher rho

    def forward(self, x):
        # Reparameterization
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        weight_dist = Normal(self.weight_mu, weight_sigma)
        weight = weight_dist.rsample()

        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        bias_dist = Normal(self.bias_mu, bias_sigma)
        bias = bias_dist.rsample()

        self.weight_log_prob = weight_dist.log_prob(weight).sum()
        self.bias_log_prob = bias_dist.log_prob(bias).sum()

        return nn.functional.linear(x, weight, bias)

    def kl_loss(self):
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        weight_dist = Normal(self.weight_mu, weight_sigma)
        kl_weight = kl.kl_divergence(weight_dist, self.prior).sum()

        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        bias_dist = Normal(self.bias_mu, bias_sigma)
        kl_bias = kl.kl_divergence(bias_dist, self.prior).sum()

        return kl_weight + kl_bias

# ----------------------------------------------------------------------
# 4. Define the BNN Model (Unchanged)
# ----------------------------------------------------------------------
class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.fc1 = BayesianLinear(32 * 8 * 8, 128)
        self.fc2 = BayesianLinear(128, 10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
        x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
        x = self.dropout(x)
        x = self.flatten(x)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def kl_loss(self):
        # Scale KL loss by the number of training examples
        return (self.fc1.kl_loss() + self.fc2.kl_loss()) / num_train

# ----------------------------------------------------------------------
# 5. Training Loop (MODIFIED to use correct scaling)
# ----------------------------------------------------------------------
def train(model, optimizer, trainloader, num_epochs, kl_weight):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_ce_loss = 0.0
        running_kl_loss = 0.0
        correct = 0
        total = 0

        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            log_likelihood = nn.CrossEntropyLoss()(outputs, labels)
            kl_loss_val = model.kl_loss()
            elbo = log_likelihood + kl_weight * kl_loss_val

            elbo.backward()
            optimizer.step()

            running_loss += elbo.item()
            running_ce_loss += log_likelihood.item()
            running_kl_loss += kl_loss_val.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_accuracy = 100 * correct / total
        avg_ce = running_ce_loss / len(trainloader)
        avg_kl = running_kl_loss / len(trainloader)
        avg_elbo = running_loss / len(trainloader)
        print(f'Epoch {epoch+1}, ELBO: {avg_elbo:.3f}, CE: {avg_ce:.3f}, KL: {avg_kl:.3f}, Acc: {epoch_accuracy:.2f}%')
    print('Finished Training')

# ----------------------------------------------------------------------
# 6. Evaluation Function (Unchanged)
# ----------------------------------------------------------------------
def evaluate(model, testloader, num_samples):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            predictions = torch.zeros(num_samples, images.size(0), 10).to(device)
            for i in range(num_samples):
                outputs = model(images)
                predictions[i] = nn.functional.softmax(outputs, dim=1)
            mean_predictions = torch.mean(predictions, dim=0)
            _, predicted = torch.max(mean_predictions.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

# ----------------------------------------------------------------------
# 7. Instantiate and Train (Unchanged)
# ----------------------------------------------------------------------
model = BNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train(model, optimizer, trainloader, num_epochs, kl_weight)
evaluate(model, testloader, num_samples)

Files already downloaded and verified
Files already downloaded and verified
Epoch 1, ELBO: 9.860, CE: 5.839, KL: 4.021, Acc: 13.46%
Epoch 2, ELBO: 6.315, CE: 2.364, KL: 3.952, Acc: 13.58%
Epoch 3, ELBO: 6.144, CE: 2.300, KL: 3.844, Acc: 13.97%
Epoch 4, ELBO: 5.991, CE: 2.274, KL: 3.717, Acc: 15.32%
Epoch 5, ELBO: 5.807, CE: 2.228, KL: 3.579, Acc: 16.88%
Epoch 6, ELBO: 5.621, CE: 2.187, KL: 3.434, Acc: 19.28%
Epoch 7, ELBO: 5.437, CE: 2.151, KL: 3.286, Acc: 20.26%
Epoch 8, ELBO: 5.230, CE: 2.093, KL: 3.137, Acc: 22.62%
Epoch 9, ELBO: 5.044, CE: 2.056, KL: 2.988, Acc: 24.05%
Epoch 10, ELBO: 4.863, CE: 2.021, KL: 2.842, Acc: 25.39%
Epoch 11, ELBO: 4.689, CE: 1.991, KL: 2.698, Acc: 26.73%
Epoch 12, ELBO: 4.502, CE: 1.943, KL: 2.559, Acc: 28.71%
Epoch 13, ELBO: 4.358, CE: 1.934, KL: 2.424, Acc: 29.51%
Epoch 14, ELBO: 4.185, CE: 1.890, KL: 2.295, Acc: 30.52%
Epoch 15, ELBO: 4.029, CE: 1.859, KL: 2.170, Acc: 32.19%
Epoch 16, ELBO: 3.874, CE: 1.824, KL: 2.051, Acc: 33.39%
Epoch 17, ELBO: 3.735

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.distributions import Normal, kl
import numpy as np

# ----------------------------------------------------------------------
# 1. Define Hyperparameters (MODIFIED for MoE)
# ----------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1024
learning_rate = 0.001
num_epochs = 100
num_samples = 10  # Number of samples from the posterior for ELBO estimation
kl_weight = 1  # Reduced KL weight to balance the loss terms
num_train = 50000  # Number of training examples in CIFAR-10
num_experts = 3   # Number of experts in MoE

# ----------------------------------------------------------------------
# 2. Load and Preprocess CIFAR-10 Data (Unchanged)
# ----------------------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# ----------------------------------------------------------------------
# 3. Define the Bayesian Linear Layer (Unchanged)
# ----------------------------------------------------------------------
class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_mu=0, prior_sigma=0.1):  # Smaller prior_sigma
        super(BayesianLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Prior distribution parameters
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma
        self.prior = Normal(torch.tensor([self.prior_mu]).to(device),
                           torch.tensor([self.prior_sigma]).to(device))

        # Variational posterior parameters (learnable) with higher initial rho
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-3, -2))  # Higher rho
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-3, -2))  # Higher rho

    def forward(self, x):
        # Reparameterization
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        weight_dist = Normal(self.weight_mu, weight_sigma)
        weight = weight_dist.rsample()

        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        bias_dist = Normal(self.bias_mu, bias_sigma)
        bias = bias_dist.rsample()

        self.weight_log_prob = weight_dist.log_prob(weight).sum()
        self.bias_log_prob = bias_dist.log_prob(bias).sum()

        return nn.functional.linear(x, weight, bias)

    def kl_loss(self):
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        weight_dist = Normal(self.weight_mu, weight_sigma)
        kl_weight = kl.kl_divergence(weight_dist, self.prior).sum()

        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        bias_dist = Normal(self.bias_mu, bias_sigma)
        kl_bias = kl.kl_divergence(bias_dist, self.prior).sum()

        return kl_weight + kl_bias

# ----------------------------------------------------------------------
# 4. Define the Expert Network (Simple CNN Experts)
# ----------------------------------------------------------------------
class CNNExpert(nn.Module):
    def __init__(self):
        super(CNNExpert, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
        x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
        x = self.dropout(x)
        x = self.flatten(x)
        return x

# ----------------------------------------------------------------------
# 5. Define the Bayesian Gating Network (BNN as Gate)
# ----------------------------------------------------------------------
class BayesianGatingNetwork(nn.Module):
    def __init__(self, num_experts):
        super(BayesianGatingNetwork, self).__init__()
        self.expert_cnn = CNNExpert() # Shared CNN feature extractor
        self.fc1 = BayesianLinear(32 * 8 * 8, 128)
        self.fc2 = BayesianLinear(128, num_experts) # Output layer for gating, num_experts outputs

    def forward(self, x):
        x = self.expert_cnn(x) # Use shared CNN feature extractor
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return nn.functional.softmax(x, dim=1) # Softmax to get probabilities

    def kl_loss(self):
        return (self.fc1.kl_loss() + self.fc2.kl_loss())/num_train

# ----------------------------------------------------------------------
# 6. Define the MoE BNN Model
# ----------------------------------------------------------------------
class MoE_BNN(nn.Module):
    def __init__(self, num_experts):
        super(MoE_BNN, self).__init__()
        self.gate = BayesianGatingNetwork(num_experts)
        self.experts = nn.ModuleList([nn.Sequential(
            CNNExpert(), # Separate CNN feature extractor for each expert (can be shared or different)
            BayesianLinear(32 * 8 * 8, 128),
            nn.ReLU(),
            BayesianLinear(128, 10)
        ) for _ in range(num_experts)])

    def forward(self, x):
        gate_weights = self.gate(x) # Output: [batch_size, num_experts]
        expert_outputs = [expert(x) for expert in self.experts] # List of [batch_size, 10]

        # Mixture of Experts: Weighted sum of expert outputs
        Blended_output = torch.stack(expert_outputs, dim=2) # [batch_size, 10, num_experts]
        Blended_output = torch.matmul(Blended_output, gate_weights.unsqueeze(2)) # [batch_size, 10, 1]
        Blended_output = Blended_output.squeeze(2) # [batch_size, 10]
        return Blended_output

    def kl_loss(self):
        gate_kl_loss = self.gate.kl_loss()
        experts_kl_loss = torch.sum(torch.stack([expert[1].kl_loss() + expert[3].kl_loss() for expert in self.experts])) # Sum KL loss from all experts Bayesian layers
        return gate_kl_loss + (experts_kl_loss)/3

# ----------------------------------------------------------------------
# 7. Training Loop (MODIFIED for MoE BNN)
# ----------------------------------------------------------------------
def train(model, optimizer, trainloader, num_epochs, kl_weight):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_ce_loss = 0.0
        running_kl_loss = 0.0
        correct = 0
        total = 0

        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()

            outputs = model(inputs) # Get output from MoE_BNN
            log_likelihood = nn.CrossEntropyLoss()(outputs, labels)
            kl_loss_val = model.kl_loss() / num_train # Scale KL loss by num_train
            elbo = log_likelihood + kl_weight * kl_loss_val

            elbo.backward()
            optimizer.step()

            running_loss += elbo.item()
            running_ce_loss += log_likelihood.item()
            running_kl_loss += kl_loss_val.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_accuracy = 100 * correct / total
        avg_ce = running_ce_loss / len(trainloader)
        avg_kl = running_kl_loss / len(trainloader)
        avg_elbo = running_loss / len(trainloader)
        print(f'Epoch {epoch+1}, ELBO: {avg_elbo:.3f}, CE: {avg_ce:.3f}, KL: {avg_kl:.3f}, Acc: {epoch_accuracy:.2f}%')
    print('Finished Training')

# ----------------------------------------------------------------------
# 8. Evaluation Function (MODIFIED for MoE BNN)
# ----------------------------------------------------------------------
def evaluate(model, testloader, num_samples):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            predictions = torch.zeros(num_samples, images.size(0), 10).to(device)
            for i in range(num_samples):
                outputs = model(images) # Get output from MoE_BNN
                predictions[i] = nn.functional.softmax(outputs, dim=1)
            mean_predictions = torch.mean(predictions, dim=0)
            _, predicted = torch.max(mean_predictions.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

# ----------------------------------------------------------------------
# 9. Instantiate and Train MoE BNN
# ----------------------------------------------------------------------
moe_bnn_model = MoE_BNN(num_experts=num_experts).to(device)
optimizer_moe = optim.Adam(moe_bnn_model.parameters(), lr=learning_rate)
train(moe_bnn_model, optimizer_moe, trainloader, num_epochs, kl_weight)
evaluate(moe_bnn_model, testloader, num_samples)


Files already downloaded and verified
Files already downloaded and verified
Epoch 1, ELBO: 8.852, CE: 4.852, KL: 4.000, Acc: 14.30%
Epoch 2, ELBO: 6.151, CE: 2.251, KL: 3.901, Acc: 15.86%
Epoch 3, ELBO: 5.953, CE: 2.189, KL: 3.764, Acc: 19.02%
Epoch 4, ELBO: 5.730, CE: 2.119, KL: 3.611, Acc: 20.87%
Epoch 5, ELBO: 5.502, CE: 2.051, KL: 3.451, Acc: 24.03%
Epoch 6, ELBO: 5.303, CE: 2.016, KL: 3.287, Acc: 24.94%
Epoch 7, ELBO: 5.086, CE: 1.962, KL: 3.124, Acc: 27.04%
Epoch 8, ELBO: 4.886, CE: 1.923, KL: 2.963, Acc: 28.12%
Epoch 9, ELBO: 4.680, CE: 1.872, KL: 2.807, Acc: 30.27%
Epoch 10, ELBO: 4.494, CE: 1.838, KL: 2.656, Acc: 31.83%
Epoch 11, ELBO: 4.284, CE: 1.772, KL: 2.511, Acc: 34.71%
Epoch 12, ELBO: 4.103, CE: 1.730, KL: 2.373, Acc: 36.14%
Epoch 13, ELBO: 3.954, CE: 1.714, KL: 2.240, Acc: 37.28%
Epoch 14, ELBO: 3.790, CE: 1.676, KL: 2.114, Acc: 39.10%
Epoch 15, ELBO: 3.639, CE: 1.644, KL: 1.995, Acc: 39.85%
Epoch 16, ELBO: 3.495, CE: 1.614, KL: 1.882, Acc: 41.24%
Epoch 17, ELBO: 3.364

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.distributions import Normal, kl
import numpy as np

# ----------------------------------------------------------------------
# 1. Define Hyperparameters
# ----------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1024
learning_rate = 0.001
num_epochs = 1  # Reduced epochs for faster execution in this example
num_samples = 10  # Number of samples from the posterior for ELBO estimation (Bayesian MoE)
kl_weight = 1
num_train = 50000
num_experts = 3
noise_std = 0.1      # Standard deviation for Gaussian noise robustness test
epsilon_fgsm = 0.03  # Epsilon for FGSM adversarial attack

# ----------------------------------------------------------------------
# 2. Load and Preprocess CIFAR-10 Data
# ----------------------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# ----------------------------------------------------------------------
# 3. Define the Bayesian Linear Layer
# ----------------------------------------------------------------------
class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_mu=0, prior_sigma=0.1):
        super(BayesianLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Prior distribution parameters
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma
        self.prior = Normal(torch.tensor([self.prior_mu]).to(device),
                           torch.tensor([self.prior_sigma]).to(device))

        # Variational posterior parameters (learnable)
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-3, -2))
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-3, -2))

    def forward(self, x):
        # Reparameterization
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        weight_dist = Normal(self.weight_mu, weight_sigma)
        weight = weight_dist.rsample()

        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        bias_dist = Normal(self.bias_mu, bias_sigma)
        bias = bias_dist.rsample()

        self.weight_log_prob = weight_dist.log_prob(weight).sum()
        self.bias_log_prob = bias_dist.log_prob(bias).sum()

        return nn.functional.linear(x, weight, bias)

    def kl_loss(self):
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        weight_dist = Normal(self.weight_mu, weight_sigma)
        kl_weight = kl.kl_divergence(weight_dist, self.prior).sum()

        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        bias_dist = Normal(self.bias_mu, bias_sigma)
        kl_bias = kl.kl_divergence(bias_dist, self.prior).sum()

        return kl_weight + kl_bias

# ----------------------------------------------------------------------
# 4. Define the CNN Expert Network
# ----------------------------------------------------------------------
class CNNExpert(nn.Module):
    def __init__(self):
        super(CNNExpert, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
        x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
        x = self.dropout(x)
        x = self.flatten(x)
        return x

# ----------------------------------------------------------------------
# 5. Define the Bayesian Gating Network (BNN as Gate)
# ----------------------------------------------------------------------
class BayesianGatingNetwork(nn.Module):
    def __init__(self, num_experts):
        super(BayesianGatingNetwork, self).__init__()
        self.expert_cnn = CNNExpert() # Shared CNN feature extractor
        self.fc1 = BayesianLinear(32 * 8 * 8, 128)
        self.fc2 = BayesianLinear(128, num_experts) # Output layer for gating, num_experts outputs

    def forward(self, x):
        x = self.expert_cnn(x) # Use shared CNN feature extractor
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return nn.functional.softmax(x, dim=1) # Softmax to get probabilities

    def kl_loss(self):
        return (self.fc1.kl_loss() + self.fc2.kl_loss())/num_train

# ----------------------------------------------------------------------
# 6. Define the Simple Gating Network (Non-Bayesian Gate)
# ----------------------------------------------------------------------
class SimpleGatingNetwork(nn.Module):
    def __init__(self, num_experts):
        super(SimpleGatingNetwork, self).__init__()
        self.expert_cnn = CNNExpert() # Shared CNN feature extractor
        self.fc1 = nn.Linear(32 * 8 * 8, 128) # Standard Linear Layer
        self.fc2 = nn.Linear(128, num_experts) # Standard Linear Layer for gating outputs

    def forward(self, x):
        x = self.expert_cnn(x) # Use shared CNN feature extractor
        x = nn.functional.relu(self.fc1(x)) # ReLU activation
        x = self.fc2(x)
        return nn.functional.softmax(x, dim=1) # Softmax to get probabilities

# ----------------------------------------------------------------------
# 7. Define the MoE BNN Model (Bayesian Gate)
# ----------------------------------------------------------------------
class MoE_BNN(nn.Module):
    def __init__(self, num_experts):
        super(MoE_BNN, self).__init__()
        self.gate = BayesianGatingNetwork(num_experts)
        self.experts = nn.ModuleList([nn.Sequential(
            CNNExpert(), # Separate CNN feature extractor for each expert
            BayesianLinear(32 * 8 * 8, 128),
            nn.ReLU(),
            BayesianLinear(128, 10)
        ) for _ in range(num_experts)])

    def forward(self, x):
        gate_weights = self.gate(x) # Output: [batch_size, num_experts]
        expert_outputs = [expert(x) for expert in self.experts] # List of [batch_size, 10]

        # Mixture of Experts: Weighted sum of expert outputs
        Blended_output = torch.stack(expert_outputs, dim=2) # [batch_size, 10, num_experts]
        Blended_output = torch.matmul(Blended_output, gate_weights.unsqueeze(2)) # [batch_size, 10, 1]
        Blended_output = Blended_output.squeeze(2) # [batch_size, 10]
        return Blended_output

    def kl_loss(self):
        gate_kl_loss = self.gate.kl_loss()
        experts_kl_loss = torch.sum(torch.stack([expert[1].kl_loss() + expert[3].kl_loss() for expert in self.experts])) # Sum KL loss from all experts Bayesian layers
        return gate_kl_loss + (experts_kl_loss)/3

# ----------------------------------------------------------------------
# 8. Define the MoE with Simple Gate Model
# ----------------------------------------------------------------------
class MoE_SimpleGate(nn.Module):
    def __init__(self, num_experts):
        super(MoE_SimpleGate, self).__init__()
        self.gate = SimpleGatingNetwork(num_experts) # Use SimpleGatingNetwork
        self.experts = nn.ModuleList([nn.Sequential(
            CNNExpert(), # Separate CNN feature extractor for each expert
            nn.Linear(32 * 8 * 8, 128), # Standard Linear Layer
            nn.ReLU(),
            nn.Linear(128, 10) # Standard Linear Layer
        ) for _ in range(num_experts)])

    def forward(self, x):
        gate_weights = self.gate(x) # Output: [batch_size, num_experts]
        expert_outputs = [expert(x) for expert in self.experts] # List of [batch_size, 10]

        # Mixture of Experts: Weighted sum of expert outputs
        Blended_output = torch.stack(expert_outputs, dim=2) # [batch_size, 10, num_experts]
        Blended_output = torch.matmul(Blended_output, gate_weights.unsqueeze(2)) # [batch_size, 10, 1]
        Blended_output = Blended_output.squeeze(2) # [batch_size, 10]
        return Blended_output


# ----------------------------------------------------------------------
# 9. Training Function (General for both MoE types)
# ----------------------------------------------------------------------
def train_moe(model, optimizer, trainloader, num_epochs, kl_weight=None, is_bayesian=True): # Added is_bayesian flag
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_ce_loss = 0.0
        running_kl_loss = 0.0
        correct = 0
        total = 0

        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()

            outputs = model(inputs) # Get output from MoE
            log_likelihood = nn.CrossEntropyLoss()(outputs, labels)
            loss = log_likelihood # Initialize loss with CE

            if is_bayesian: # Add KL loss only for Bayesian MoE
                kl_loss_val = model.kl_loss() / num_train # Scale KL loss by num_train
                loss = log_likelihood + kl_weight * kl_loss_val
                running_kl_loss += kl_loss_val.item()

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_ce_loss += log_likelihood.item()


            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_accuracy = 100 * correct / total
        avg_ce = running_ce_loss / len(trainloader)
        avg_loss = running_loss / len(trainloader) # Avg Loss (ELBO or CE)
        if is_bayesian:
            avg_kl = running_kl_loss / len(trainloader)
            print(f'Epoch {epoch+1}, Loss: {avg_loss:.3f} (ELBO), CE: {avg_ce:.3f}, KL: {avg_kl:.3f}, Acc: {epoch_accuracy:.2f}%')
        else:
            print(f'Epoch {epoch+1}, Loss: {avg_loss:.3f} (CE), CE: {avg_ce:.3f}, Acc: {epoch_accuracy:.2f}%')
    print('Finished Training')


# ----------------------------------------------------------------------
# 10. Evaluation Function (General for both MoE types)
# ----------------------------------------------------------------------
def evaluate_moe(model, testloader, num_samples=1, is_bayesian=True): # Modified for num_samples and is_bayesian
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            if is_bayesian:
                predictions = torch.zeros(num_samples, images.size(0), 10).to(device)
                for i in range(num_samples):
                    outputs = model(images) # Get output from MoE
                    predictions[i] = nn.functional.softmax(outputs, dim=1)
                mean_predictions = torch.mean(predictions, dim=0)
                _, predicted = torch.max(mean_predictions.data, 1)
            else: # For Simple Gate MoE, no sampling needed
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Test Accuracy: {100 * correct / total:.2f}%')


# ----------------------------------------------------------------------
# 11. Function to Add Gaussian Noise to Images
# ----------------------------------------------------------------------
def add_gaussian_noise(images, std=0.1):
    noise = torch.randn_like(images) * std
    noisy_images = images + noise
    return torch.clip(noisy_images, -1, 1) # Clip to the valid range [-1, 1] for normalized CIFAR-10


# ----------------------------------------------------------------------
# 12. Function for FGSM Adversarial Attack
# ----------------------------------------------------------------------
def fgsm_attack(model, images, labels, epsilon=0.03):
    images_clone = images.clone().requires_grad_(True) # Clone and enable grad
    outputs = model(images_clone)
    loss = nn.CrossEntropyLoss()(outputs, labels)
    model.zero_grad()
    loss.backward()
    grad_sign = images_clone.grad.data.sign()
    adversarial_images = images + epsilon * grad_sign
    adversarial_images = torch.clip(adversarial_images, -1, 1) # Clip to valid range
    return adversarial_images


# ----------------------------------------------------------------------
# 13. Evaluation for Robustness and Adversarial Attacks
# ----------------------------------------------------------------------
def evaluate_robustness(model, testloader, num_samples=1, is_bayesian=True, noise_std=0.1, epsilon_fgsm=0.03):
    model.eval()
    correct_clean = 0
    correct_noisy = 0
    correct_adversarial = 0
    total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)

            # --- Clean Data Evaluation ---
            if is_bayesian:
                predictions_clean = torch.zeros(num_samples, images.size(0), 10).to(device)
                for i in range(num_samples):
                    outputs_clean = model(images)
                    predictions_clean[i] = nn.functional.softmax(outputs_clean, dim=1)
                mean_predictions_clean = torch.mean(predictions_clean, dim=0)
                _, predicted_clean = torch.max(mean_predictions_clean.data, 1)
            else:
                outputs_clean = model(images)
                _, predicted_clean = torch.max(outputs_clean.data, 1)
            correct_clean += (predicted_clean == labels).sum().item()


            # --- Noisy Data Evaluation ---
            noisy_images = add_gaussian_noise(images, std=noise_std)
            if is_bayesian:
                predictions_noisy = torch.zeros(num_samples, noisy_images.size(0), 10).to(device)
                for i in range(num_samples):
                    outputs_noisy = model(noisy_images)
                    predictions_noisy[i] = nn.functional.softmax(outputs_noisy, dim=1)
                mean_predictions_noisy = torch.mean(predictions_noisy, dim=0)
                _, predicted_noisy = torch.max(mean_predictions_noisy.data, 1)
            else:
                outputs_noisy = model(noisy_images)
                _, predicted_noisy = torch.max(outputs_noisy.data, 1)
            correct_noisy += (predicted_noisy == labels).sum().item()


            # --- Adversarial Attack Evaluation (FGSM) ---
            adversarial_images = fgsm_attack(model, images, labels, epsilon=epsilon_fgsm)
            if is_bayesian:
                predictions_adv = torch.zeros(num_samples, adversarial_images.size(0), 10).to(device)
                for i in range(num_samples):
                    outputs_adv = model(adversarial_images)
                    predictions_adv[i] = nn.functional.softmax(outputs_adv, dim=1)
                mean_predictions_adv = torch.mean(predictions_adv, dim=0)
                _, predicted_adv = torch.max(mean_predictions_adv.data, 1)
            else:
                outputs_adv = model(adversarial_images)
                _, predicted_adv = torch.max(outputs_adv.data, 1)
            correct_adversarial += (predicted_adv == labels).sum().item()

            total += labels.size(0)

    print(f'--- Robustness Evaluation ---')
    print(f'Clean Accuracy:      {100 * correct_clean / total:.2f}%')
    print(f'Noisy Accuracy (std={noise_std}): {100 * correct_noisy / total:.2f}%')
    print(f'Adv Accuracy (FGSM, eps={epsilon_fgsm}): {100 * correct_adversarial / total:.2f}%')


# ----------------------------------------------------------------------
# 14. Instantiate, Train and Evaluate both MoE Models
# ----------------------------------------------------------------------

# --- MoE with Bayesian Gate ---
print("--- Training MoE with Bayesian Gate ---")
moe_bnn_model = MoE_BNN(num_experts=num_experts).to(device)
optimizer_bnn_moe = optim.Adam(moe_bnn_model.parameters(), lr=learning_rate)
train_moe(moe_bnn_model, optimizer_bnn_moe, trainloader, num_epochs, kl_weight, is_bayesian=True) # is_bayesian=True
print("\n--- Evaluation MoE with Bayesian Gate ---")
evaluate_robustness(moe_bnn_model, testloader, num_samples, is_bayesian=True, noise_std=noise_std, epsilon_fgsm=epsilon_fgsm) # is_bayesian=True


# --- MoE with Simple Gate ---
print("\n--- Training MoE with Simple Gate ---")
moe_simple_gate_model = MoE_SimpleGate(num_experts=num_experts).to(device)
optimizer_simple_moe = optim.Adam(moe_simple_gate_model.parameters(), lr=learning_rate)
train_moe(moe_simple_gate_model, optimizer_simple_moe, trainloader, num_epochs, is_bayesian=False) # is_bayesian=False, no kl_weight
print("\n--- Evaluation MoE with Simple Gate ---")
evaluate_robustness(moe_simple_gate_model, testloader, is_bayesian=False, noise_std=noise_std, epsilon_fgsm=epsilon_fgsm) # is_bayesian=False


Files already downloaded and verified
Files already downloaded and verified
--- Training MoE with Bayesian Gate ---


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79a1d1368220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79a1d1368220>^
^^^Traceback (most recent call last):
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
^^    ^self._shutdown_workers()^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
^^    ^^if w.is_alive():^
^ ^^^^ ^  ^^ ^ ^

Epoch 1, Loss: 9.224 (ELBO), CE: 5.233, KL: 3.992, Acc: 12.94%
Finished Training

--- Evaluation MoE with Bayesian Gate ---


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn