<a href="https://colab.research.google.com/github/jyotidabass/Mixture-of-Nested-Experts-MoNE-model/blob/main/Mixture_of_Nested_Experts_(MoNE)%C2%A0model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Define the MoNE model
class MoNE(nn.Module):
    def __init__(self, num_experts, input_dim, output_dim):
        super(MoNE, self).__init__()
        self.num_experts = num_experts
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Define the experts
        self.experts = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_experts)])

        # Define the router
        self.router = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        # Compute the router output
        router_output = torch.softmax(self.router(x), dim=1)

        # Compute the expert outputs
        expert_outputs = []
        for i in range(self.num_experts):
            expert_output = self.experts[i](x)
            expert_outputs.append(expert_output)

        # Compute the final output
        final_output = 0
        for i in range(self.num_experts):
            # Reshape router_output[:, i] to (100, 1) for broadcasting
            final_output += router_output[:, i].unsqueeze(1) * expert_outputs[i]
            # unsqueeze(1) adds a dimension of size 1 at dimension 1,
            # effectively changing the shape from (100,) to (100, 1).
            # This allows for proper broadcasting during the multiplication.

        return final_output

# Set the hyperparameters
num_experts = 3
input_dim = 784
output_dim = 10

# Initialize the MoNE model
model = MoNE(num_experts, input_dim, output_dim)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
for epoch in range(10):
    optimizer.zero_grad()
    inputs = torch.randn(100, input_dim)
    labels = torch.randint(0, output_dim, (100,))
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

Epoch 1, Loss: 2.398869514465332
Epoch 2, Loss: 2.396008014678955
Epoch 3, Loss: 2.3138465881347656
Epoch 4, Loss: 2.440911054611206
Epoch 5, Loss: 2.3770956993103027
Epoch 6, Loss: 2.4285318851470947
Epoch 7, Loss: 2.380089521408081
Epoch 8, Loss: 2.371851921081543
Epoch 9, Loss: 2.3508713245391846
Epoch 10, Loss: 2.3464956283569336
