<a href="https://colab.research.google.com/github/nncliff/qwen-32B/blob/main/chapter-1/moe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mixture of Experts (MoE) Classifier on CIFAR-10

This notebook implements a Mixture of Experts (MoE) classifier using PyTorch on the CIFAR-10 dataset. It is migrated from a standalone Python script.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import random

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


## Mixture of Experts (MoE) Module

The `MixtureOfExperts` module is the core component. Here's a breakdown of the tensor shapes during the forward pass:

1.  **Input `x`**: `(batch_size, input_dim)`
2.  **Gating Network (`self.classifier`)**:
    *   Produces logits for each expert.
    *   Shape: `(batch_size, num_experts)`
3.  **Top-k Selection**:
    *   We select the top `k` experts for each sample.
    *   `topk_scores`: `(batch_size, k)` - The probabilities of the selected experts.
    *   `topk_indices`: `(batch_size, k)` - The indices of the selected experts.
4.  **Routing**:
    *   We iterate through the `k` selected experts.
    *   For each rank `i` (from 0 to k-1), we identify which samples are assigned to which expert.
    *   `expert_mask`: `(batch_size, num_experts)` - A boolean mask indicating assignment.
5.  **Expert Computation**:
    *   For each expert `j`, we select the inputs assigned to it: `expert_input` shape `(num_assigned_samples, input_dim)`.
    *   The expert processes these inputs: `expert_out` shape `(num_assigned_samples, input_dim)`.
    *   We weight the output by the gating score and add it to the final output.
6.  **Output**: `(batch_size, input_dim)` - Same shape as input.

In [None]:
class MixtureOfExperts(nn.Module):
    def __init__(self, input_dim, expert_dim, num_experts):
        super(MixtureOfExperts, self).__init__()
        self.num_experts = num_experts
        self.expert_dim = expert_dim
        self.k = max(1, num_experts // 4)  # Select top-k experts, if k=1, it becomes Top-1 MoE

        self.experts = nn.ModuleList([nn.Sequential(
            nn.Linear(input_dim, expert_dim),
            nn.ReLU(),
            nn.Linear(expert_dim, input_dim),
        ) for _ in range(num_experts)])

        self.classifier = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        # x shape: (batch_size, input_dim)

        # 1. Gating Network: Predict expert weights
        logits = self.classifier(x)
        # logits shape: (batch_size, num_experts)

        topk_scores = F.softmax(logits, dim=1)
        # topk_scores shape: (batch_size, num_experts)

        # 2. Select Top-k Experts
        topk_scores, topk_indices = torch.topk(topk_scores, k=self.k, dim=1)
        # topk_scores shape: (batch_size, k)
        # topk_indices shape: (batch_size, k)

        output = torch.zeros_like(x)
        # output shape: (batch_size, input_dim)

        # 3. Route inputs to experts
        for i in range(self.k):
            # Get the i-th selected expert for each sample
            expert_idx = topk_indices[:, i]
            # expert_idx shape: (batch_size,)

            expert_weight = topk_scores[:, i].unsqueeze(1)
            # expert_weight shape: (batch_size, 1)

            # Create a mask for which expert is selected for which sample at this rank
            expert_mask = torch.zeros(x.shape[0], self.num_experts, dtype=torch.bool, device=x.device)
            # expert_mask shape: (batch_size, num_experts)

            # Set the mask for the selected expert
            # The scatter_(dim, index, src) method writes the value src (in this case, True) into the tensor at the indices specified by index along dimension dim.
            expert_mask.scatter_(1, expert_idx.unsqueeze(1), True)
            # expert_mask is now True at [b, e] if sample b selected expert e as its i-th choice

            for j, expert in enumerate(self.experts):
                mask = expert_mask[:, j]
                # mask shape: (batch_size,) - True if sample assigned to expert j

                if mask.any():
                    expert_input = x[mask]
                    # expert_input shape: (num_assigned_samples, input_dim)

                    expert_out = expert(expert_input)
                    # expert_out shape: (num_assigned_samples, input_dim)

                    # Add weighted expert output to final output
                    # We use the mask to place the results back into the correct batch positions
                    output[mask] += expert_out * expert_weight[mask]

        return output

### Example: `scatter_` Operation

Let's trace `expert_mask.scatter_(1, expert_idx.unsqueeze(1), True)` with a concrete example.

**Setup:**
*   **Batch Size**: 3
*   **Num Experts**: 5
*   **Current Rank**: 1st choice (`i=0`)

**1. Initial State:**
`expert_mask` is initialized to all zeros (False). Shape `(3, 5)`.
```python
expert_mask = [
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0]
]
```

**2. Expert Indices (`expert_idx`):**
Let's say for this rank, the samples selected the following experts:
*   Sample 0 -> Expert **1**
*   Sample 1 -> Expert **4**
*   Sample 2 -> Expert **0**

`expert_idx` = `[1, 4, 0]` (Shape: `(3,)`)

**3. Unsqueeze:**
`expert_idx.unsqueeze(1)` becomes a column vector.
```python
index = [
    [1],
    [4],
    [0]
]
```
(Shape: `(3, 1)`)

**4. Scatter Operation:**
`expert_mask.scatter_(dim=1, index=index, src=True)`

This tells PyTorch: "For each row, go to the column specified by `index` and set the value to `True`."

*   **Row 0**: Index is `1`. Set `expert_mask[0, 1] = True`.
*   **Row 1**: Index is `4`. Set `expert_mask[1, 4] = True`.
*   **Row 2**: Index is `0`. Set `expert_mask[2, 0] = True`.

**5. Final Result:**
```python
expert_mask = [
    [0, 1, 0, 0, 0],  # Sample 0 assigned to Expert 1
    [0, 0, 0, 0, 1],  # Sample 1 assigned to Expert 4
    [1, 0, 0, 0, 0]   # Sample 2 assigned to Expert 0
]
```
This boolean mask is then used to efficiently select the correct input rows for each expert later in the loop.

### Example: `expert_input = x[mask]`

Continuing from the previous example, let's see how we select inputs for a specific expert.

**Setup:**
*   **Input `x`** (Batch Size 3, Input Dim 4):
    ```python
    x = [
        [0.1, 0.1, 0.1, 0.1], # Sample 0
        [0.2, 0.2, 0.2, 0.2], # Sample 1
        [0.3, 0.3, 0.3, 0.3]  # Sample 2
    ]
    ```
*   **Expert Mask** (from previous step):
    ```python
    expert_mask = [
        [0, 1, 0, 0, 0],  # Sample 0 -> Expert 1
        [0, 0, 0, 0, 1],  # Sample 1 -> Expert 4
        [1, 0, 0, 0, 0]   # Sample 2 -> Expert 0
    ]
    ```

**Scenario: Loop iteration for Expert 1 (`j=1`)**

1.  **Extract Mask Column:**
    We look at the column for Expert 1: `mask = expert_mask[:, 1]`.
    *   Result: `[True, False, False]` (Sample 0 is True).

2.  **Boolean Indexing (`x[mask]`):**
    We select the rows from `x` where `mask` is True.
    *   Since only the first element is True, we pick the first row of `x`.

3.  **Result (`expert_input`):**
    ```python
    expert_input = [
        [0.1, 0.1, 0.1, 0.1]
    ]
    ```
    *   Shape: `(1, 4)` (1 sample assigned, 4 features).
    *   This tensor is then passed to Expert 1's neural network.

**Scenario: Loop iteration for Expert 2 (`j=2`)**

1.  **Extract Mask Column:**
    `mask = expert_mask[:, 2]`.
    *   Result: `[False, False, False]`.

2.  **Check `mask.any()`:**
    Since there are no True values, `mask.any()` is False.
    *   We **skip** computation for Expert 2 to save resources. No samples need this expert.

### Explanation: ResNet18 Backbone Modification

In the `MoEClassifier` below, you will see these lines:

```python
self.backbone = torchvision.models.resnet18(pretrained=False)
self.backbone.fc = nn.Identity()
```

**Why do we do this?**

1.  **Standard ResNet18**: By default, ResNet18 ends with a fully connected layer (`fc`) that maps features to 1000 classes (for ImageNet).
2.  **Feature Extraction**: We want to use ResNet18 only to extract high-level features from the images, not to classify them immediately.
3.  **`nn.Identity()`**: By replacing `self.backbone.fc` with `nn.Identity()`, we effectively "delete" the classification layer. The data passes through this layer unchanged.
4.  **Output Dimension**: The layer immediately preceding the `fc` layer in ResNet18 outputs a vector of size **512**. This 512-dimensional vector serves as the input to our Mixture of Experts module.

In [3]:
class MoEClassifier(nn.Module):
    def __init__(self, input_dim=512, moe_hidden=1024, num_experts=8, num_classes=10):
        super(MoEClassifier, self).__init__()
        # ResNet18 backbone
        self.backbone = torchvision.models.resnet18(pretrained=False)
        self.backbone.fc = nn.Identity()  # Remove the final classification layer
        # Output of ResNet18 (without fc) is 512 for standard implementation

        self.moe = MixtureOfExperts(input_dim, moe_hidden, num_experts)
        self.classifier = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # x shape: (batch_size, 3, 224, 224) - Input images

        x = self.backbone(x)
        # x shape: (batch_size, 512) - Feature vectors from ResNet

        x = self.moe(x)
        # x shape: (batch_size, 512) - Refined features from MoE

        out = self.classifier(x)
        # out shape: (batch_size, num_classes) - Class logits

        return out

### Explanation: Data Transformations

The `transform` object defines a pipeline of operations to preprocess the images before they are fed into the model.

```python
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
```

1.  **`transforms.Compose([...])`**: This chains multiple transform steps together. The output of one step becomes the input of the next.
2.  **`transforms.Resize((224, 224))`**:
    *   **What it does**: Resizes the input image to 224x224 pixels.
    *   **Why**: The ResNet18 backbone expects inputs of this size (standard ImageNet size). CIFAR-10 images are originally small (32x32), so we upscale them to match the network's architecture.
3.  **`transforms.ToTensor()`**:
    *   **What it does**: Converts the image (which is usually a PIL Image or NumPy array with values 0-255) into a PyTorch Tensor.
    *   **Normalization**: It also scales the pixel values from the range [0, 255] to [0.0, 1.0].
    *   **Channel Ordering**: It changes the dimension order from (Height, Width, Channels) to (Channels, Height, Width), which is the format PyTorch expects.

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
    ])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

100%|██████████| 170M/170M [00:04<00:00, 34.7MB/s]


In [5]:
model = MoEClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()



In [6]:
model.train()
for epoch in range(5):
    total_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        predicted = torch.argmax(outputs, dim=1)
        correct += predicted.eq(labels).sum().item()

    print(f'Epoch [{epoch+1}/5], Loss: {total_loss/len(train_loader):.4f}, Accuracy: {100.*correct/len(train_loader.dataset):.2f}%')

Epoch [1/5], Loss: 1.6121, Accuracy: 39.45%
Epoch [2/5], Loss: 1.0392, Accuracy: 62.81%
Epoch [3/5], Loss: 0.7854, Accuracy: 72.53%
Epoch [4/5], Loss: 0.6135, Accuracy: 78.92%
Epoch [5/5], Loss: 0.4960, Accuracy: 82.94%
