In [1]:
import random
import pandas as pd

import torch
from common_net.moe import TopKMoE, TopKMoEConfig, MoEGateLossManager
from common_net.attentions import MultiHeadAttention, LinearAttention
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

In [2]:
seed = 42
set_size = 4
n_samples=10000

In [3]:
def generate_set_attention_dataset(n_samples=1000, set_size=5,
                                   task="sum", threshold=5, seed=None):
    """
    Generate a toy dataset for set attention tasks.

    Args:
        n_samples (int): Number of samples to generate.
        set_size (int): Number of elements in each set.
        task (str): Task type: 'sum', 'max', 'count_above', 'mean'.
        threshold (int): Used if task == 'count_above'.
        seed (int or None): Random seed for reproducibility.

    Returns:
        pd.DataFrame: Dataset with columns ['set', 'label'].
    """
    if seed is not None:
        random.seed(seed)

    data = []
    for _ in range(n_samples):
        s = [random.randint(0, 10) for _ in range(set_size)]

        if task == "sum":
            label = sum(s)
        elif task == "max":
            label = max(s)
        elif task == "mean":
            label = sum(s) / len(s)
        elif task == "count_above":
            label = sum(1 for x in s if x > threshold)
        else:
            raise ValueError("Unsupported task")

        data.append({"set": s, "label": label})

    return pd.DataFrame(data)


In [4]:
# Example usage
df = generate_set_attention_dataset(n_samples=n_samples, set_size=set_size, task="mean", seed=seed)
df


Unnamed: 0,set,label
0,"[10, 1, 0, 4]",3.75
1,"[3, 3, 2, 1]",2.25
2,"[10, 8, 1, 9]",7.00
3,"[6, 0, 0, 1]",1.75
4,"[3, 3, 8, 9]",5.75
...,...,...
9995,"[3, 1, 7, 5]",4.00
9996,"[5, 2, 9, 7]",5.75
9997,"[7, 6, 0, 1]",3.50
9998,"[4, 5, 8, 4]",5.25


In [5]:
class MeanAttentionNet(nn.Module):
    def __init__(self, embed_dim=16, num_heads=4):
        super().__init__()
        # Embed each scalar number into a vector
        self.embedding = nn.Linear(1, embed_dim)

        # Multihead Attention
        self.attn = MultiHeadAttention(
            embed_dim,
            num_heads,
            attn_cls=LinearAttention,
            gated=True,
            num_k_heads=num_heads//2,
            num_v_heads=num_heads//2,
        )

        # Output head (predict scalar mean)
        self.ff = TopKMoE(
            embed_dim=embed_dim,
            d_ff=32,
        )

        self.fc = nn.Linear(embed_dim, 1)

        self.gate_criterion = MoEGateLossManager()
        self.gate_criterion.register_moe(self.ff)

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, set_size)
        """
        # Add embedding dimension
        x = x.unsqueeze(-1).float()                # (batch, set_size, 1)
        x = self.embedding(x)                      # (batch, set_size, embed_dim)

        # Attention: use same tensor as query, key, value
        attn_out, _ = self.attn(x, x, x)           # (batch, set_size, embed_dim)

        # Pool across the set (order-invariant)
        pooled = attn_out.mean(dim=1)              # (batch, embed_dim)

        expert_out = self.ff(pooled)                       # (batch, embed_dim)
        # Predict mean
        out = self.fc(expert_out).squeeze(-1)          # (batch,)
        return out

In [6]:
class SetDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        x = torch.tensor(row['set'], dtype=torch.float32)
        y = torch.tensor(row['label'], dtype=torch.float32)
        return x, y

In [7]:
def acc(preds, targets, tol=0.01):
    return ((preds - targets).abs() < tol).float().mean().item()

def train(model: MeanAttentionNet, train_ds: Dataset, eval_ds: Dataset, epochs=20, batch_size=32, lr=1e-3, device="cpu"):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    # Prepare data
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

    eval_loader = DataLoader(eval_ds, batch_size=batch_size, shuffle=False)


    for epoch in range(1, epochs + 1):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x)
            loss = criterion(preds, y) + model.gate_criterion()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % 5 == 0 or epoch == 1:
            print(f"Epoch {epoch:02d} | Loss: {loss.item():.4f}")

        # Evaluation
        model.eval()
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for x_val, y_val in eval_loader:
                x_val, y_val = x_val.to(device), y_val.to(device)
                val_preds = model(x_val)
                all_preds.append(val_preds.cpu())
                all_targets.append(y_val.cpu())
        all_preds = torch.cat(all_preds)
        all_targets = torch.cat(all_targets)
        evaluation_acc = acc(all_preds, all_targets)
        print(f"Evaluation Accuracy: {evaluation_acc:.4f}")
        if evaluation_acc >= 0.99:
            break

    return model

In [8]:
batch_size = 64
model = MeanAttentionNet()

In [9]:
train_ds = SetDataset(df[:int(0.8*len(df))])
eval_ds = SetDataset(df[int(0.8*len(df)):])

In [10]:
train(model, train_ds=train_ds, eval_ds=eval_ds, epochs=1000, batch_size=batch_size, lr=1e-3, device="cpu")

Epoch 01 | Loss: 16.3603
Evaluation Accuracy: 0.0000
Evaluation Accuracy: 0.0985
Evaluation Accuracy: 0.1885
Evaluation Accuracy: 0.2355
Epoch 05 | Loss: 0.0014
Evaluation Accuracy: 0.2650
Evaluation Accuracy: 0.3750
Evaluation Accuracy: 0.4580
Evaluation Accuracy: 0.3900
Evaluation Accuracy: 0.5930
Epoch 10 | Loss: 0.0001
Evaluation Accuracy: 0.6730
Evaluation Accuracy: 0.7655
Evaluation Accuracy: 0.7990
Evaluation Accuracy: 0.8300
Evaluation Accuracy: 0.7705
Epoch 15 | Loss: 0.0000
Evaluation Accuracy: 0.8880
Evaluation Accuracy: 0.8940
Evaluation Accuracy: 0.8780
Evaluation Accuracy: 0.9425
Evaluation Accuracy: 0.9205
Epoch 20 | Loss: 0.0001
Evaluation Accuracy: 0.9605
Evaluation Accuracy: 0.9640
Evaluation Accuracy: 0.9670
Evaluation Accuracy: 0.9670
Evaluation Accuracy: 0.9750
Epoch 25 | Loss: 0.0000
Evaluation Accuracy: 0.9755
Evaluation Accuracy: 0.9750
Evaluation Accuracy: 0.9775
Evaluation Accuracy: 0.9800
Evaluation Accuracy: 0.9875
Epoch 30 | Loss: 0.0000
Evaluation Accuracy

MeanAttentionNet(
  (embedding): Linear(in_features=1, out_features=16, bias=True)
  (attn): MultiHeadAttention(
    (q_proj): Linear(in_features=16, out_features=16, bias=False)
    (k_proj): Linear(in_features=16, out_features=8, bias=False)
    (v_proj): Linear(in_features=16, out_features=8, bias=False)
    (gate): Gated(
      (gate): Sequential(
        (0): Linear(in_features=16, out_features=16, bias=False)
        (1): Sigmoid()
      )
    )
    (out_proj): Linear(in_features=16, out_features=16, bias=False)
    (attn): LinearAttention()
    (q_norm): ZeroCenteredRMSNorm((4,), eps=1e-06)
    (k_norm): ZeroCenteredRMSNorm((4,), eps=1e-06)
  )
  (ff): TopKMoE(
    (experts): ModuleList(
      (0-3): 4 x GatedMLP(
        (up): Linear(in_features=16, out_features=32, bias=False)
        (down): Linear(in_features=32, out_features=16, bias=False)
        (gate): Gated(
          (gate): Sequential(
            (0): Linear(in_features=16, out_features=32, bias=False)
            (

In [13]:
eval_ds[0]

(tensor([ 2., 10.,  0.,  6.]), tensor(4.5000))

In [14]:
x, y = eval_ds[3]

pred = model(x.unsqueeze(0))

pred, y

(tensor([6.2461], grad_fn=<SqueezeBackward1>), tensor(6.2500))