<a href="https://colab.research.google.com/github/nncliff/qwen-32B/blob/main/chapter-1/moe-top2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Top-2 Mixture of Experts (MoE) Audio Classifier

This notebook implements a Mixture of Experts (MoE) model for audio classification using a Top-2 gating mechanism. It is migrated from `moe-top2.py`.

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import os
import random

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [2]:
class DummyAudioDataset(Dataset):
    def __init__(self, num_samples=300, sample_length=16000, num_classes=10):
        self.num_samples = num_samples
        self.sample_length = sample_length
        self.num_classes = num_classes
        self.data = []

        for _ in range(num_samples):
            label = random.randint(0, num_classes - 1)
            waveform = torch.randn(sample_length) + label * 0.1  # Simple pattern based on label
            self.data.append((waveform, label))

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]

In [3]:
class Top2Router(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(Top2Router, self).__init__()
        self.num_experts = num_experts
        self.fc = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        # x: (batch_size, input_dim) -> logits: (batch_size, num_experts)
        logits = self.fc(x)
        probabilities = F.softmax(logits, dim=-1)
        top2_vals, top2_indices = torch.topk(probabilities, 2, dim=-1)
        return top2_indices, top2_vals

In [4]:
class Top2MoE(nn.Module):
    def __init__(self, input_dim, expert_hidden_dim, num_experts=6):
        super(Top2MoE, self).__init__()
        self.num_experts = num_experts
        self.router = Top2Router(input_dim, num_experts)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, expert_hidden_dim),
                nn.ReLU(),
                nn.Linear(expert_hidden_dim, input_dim)
            ) for _ in range(num_experts)
        ])

    def forward(self, x):
        # x: (batch_size, input_dim)
        top2_indices, top2_vals = self.router(x)
        output = torch.zeros_like(x)

        for i in range(2):  # For top-2 experts
            expert_indices = top2_indices[:, i]
            expert_weights = top2_vals[:, i].unsqueeze(-1)

            for j in range(self.num_experts):
                mask = (expert_indices == j)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[j](expert_input)
                    output[mask] += expert_output * expert_weights[mask]

        return output

In [5]:
class MoEAudioClassifier(nn.Module):
    def __init__(self, in_channels=1, conv_dim=64, moe_dim=128, num_classes=10):
        super(MoEAudioClassifier, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels, conv_dim, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv1d(conv_dim, conv_dim, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
        )
        self.proj = nn.Linear(conv_dim, moe_dim)
        self.moe = Top2MoE(moe_dim, moe_dim*2, num_experts=6)
        self.classifier = nn.Linear(moe_dim, num_classes)

    def forward(self, x):
        # x: (batch_size, time_steps) -> (batch_size, 1, time_steps)
        if x.dim() == 2:
            x = x.unsqueeze(1)  # Add channel dimension

        x = self.conv(x) # (batch_size, conv_dim, reduced_time_steps)
        x = x.mean(dim=-1)  # Global average pooling
        x = self.proj(x)
        x = self.moe(x)
        return self.classifier(x)

In [7]:
dataset = DummyAudioDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model = MoEAudioClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(30):
    total_loss = 0
    correct = 0
    for batch in dataloader:
        waveforms, labels = batch
        waveforms, labels = waveforms.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(waveforms)  # Add channel dimension
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}, Accuracy: {correct/len(dataset):.4f}")

Epoch 1, Loss: 2.3003, Accuracy: 0.1067
Epoch 2, Loss: 2.1015, Accuracy: 0.1900
Epoch 3, Loss: 1.6462, Accuracy: 0.3000
Epoch 4, Loss: 1.1824, Accuracy: 0.5167
Epoch 5, Loss: 0.8417, Accuracy: 0.6633
Epoch 6, Loss: 0.5448, Accuracy: 0.8633
Epoch 7, Loss: 0.3512, Accuracy: 0.8967
Epoch 8, Loss: 0.4237, Accuracy: 0.8033
Epoch 9, Loss: 0.2632, Accuracy: 0.8900
Epoch 10, Loss: 0.1676, Accuracy: 0.9600
Epoch 11, Loss: 0.0904, Accuracy: 0.9933
Epoch 12, Loss: 0.0555, Accuracy: 1.0000
Epoch 13, Loss: 0.0365, Accuracy: 1.0000
Epoch 14, Loss: 0.0293, Accuracy: 1.0000
Epoch 15, Loss: 0.0275, Accuracy: 1.0000
Epoch 16, Loss: 0.0351, Accuracy: 0.9933
Epoch 17, Loss: 0.0322, Accuracy: 1.0000
Epoch 18, Loss: 0.0206, Accuracy: 1.0000
Epoch 19, Loss: 0.0145, Accuracy: 1.0000
Epoch 20, Loss: 0.0101, Accuracy: 1.0000
Epoch 21, Loss: 0.0099, Accuracy: 1.0000
Epoch 22, Loss: 0.0088, Accuracy: 1.0000
Epoch 23, Loss: 0.0062, Accuracy: 1.0000
Epoch 24, Loss: 0.0052, Accuracy: 1.0000
Epoch 25, Loss: 0.0056, A