In [1]:
import random
import pandas as pd

import torch
from common_net.moe import TopKMoE, TopKMoEConfig, MoEGateLossManager
from common_net.attentions import MultiHeadAttention, LinearAttention, MHAConfig
from common_net.att_block import MAB
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

In [2]:
seed = 42
set_size = 4
n_samples=10000

In [3]:
def generate_set_attention_dataset(n_samples=1000, set_size=5,
                                   task="sum", threshold=5, seed=None):
    """
    Generate a toy dataset for set attention tasks.

    Args:
        n_samples (int): Number of samples to generate.
        set_size (int): Number of elements in each set.
        task (str): Task type: 'sum', 'max', 'count_above', 'mean'.
        threshold (int): Used if task == 'count_above'.
        seed (int or None): Random seed for reproducibility.

    Returns:
        pd.DataFrame: Dataset with columns ['set', 'label'].
    """
    if seed is not None:
        random.seed(seed)

    data = []
    for _ in range(n_samples):
        s = [random.randint(0, 10) for _ in range(set_size)]

        if task == "sum":
            label = sum(s)
        elif task == "max":
            label = max(s)
        elif task == "mean":
            label = sum(s) / len(s)
        elif task == "count_above":
            label = sum(1 for x in s if x > threshold)
        else:
            raise ValueError("Unsupported task")

        data.append({"set": s, "label": label})

    return pd.DataFrame(data)


In [4]:
# Example usage
df = generate_set_attention_dataset(n_samples=n_samples, set_size=set_size, task="mean", seed=seed)
df


Unnamed: 0,set,label
0,"[10, 1, 0, 4]",3.75
1,"[3, 3, 2, 1]",2.25
2,"[10, 8, 1, 9]",7.00
3,"[6, 0, 0, 1]",1.75
4,"[3, 3, 8, 9]",5.75
...,...,...
9995,"[3, 1, 7, 5]",4.00
9996,"[5, 2, 9, 7]",5.75
9997,"[7, 6, 0, 1]",3.50
9998,"[4, 5, 8, 4]",5.25


In [6]:
class SetDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        x = torch.tensor(row['set'], dtype=torch.float32).unsqueeze(-1)
        y = torch.tensor(row['label'], dtype=torch.float32)
        return x, y

In [7]:
def acc(preds, targets, tol=0.01):
    return ((preds - targets).abs() < tol).float().mean().item()

def train(model: MAB, train_ds: Dataset, eval_ds: Dataset, epochs=20, batch_size=32, lr=1e-3, device="cpu"):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    gate_criterion = MoEGateLossManager(model)

    # Prepare data
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

    eval_loader = DataLoader(eval_ds, batch_size=batch_size, shuffle=False)


    for epoch in range(1, epochs + 1):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).squeeze(-1)
            loss = criterion(preds, y) + gate_criterion()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % 5 == 0 or epoch == 1:
            print(f"Epoch {epoch:02d} | Loss: {loss.item():.4f}")

        # Evaluation
        model.eval()
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for x_val, y_val in eval_loader:
                x_val, y_val = x_val.to(device), y_val.to(device)
                val_preds = model(x_val).squeeze(-1)
                all_preds.append(val_preds.cpu())
                all_targets.append(y_val.cpu())
        all_preds = torch.cat(all_preds)
        all_targets = torch.cat(all_targets)
        evaluation_acc = acc(all_preds, all_targets)
        print(f"Evaluation Accuracy: {evaluation_acc:.4f}")
        if evaluation_acc >= 0.99:
            break

    return model

In [8]:
class PoolSet(nn.Module):
    def __init__(self, method="mean"):
        super().__init__()
        assert method in ["mean", "sum", "max"], "Unsupported pooling method"
        self.method = method

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.method == "mean":
            return x.mean(dim=1)
        elif self.method == "sum":
            return x.sum(dim=1)
        elif self.method == "max":
            return x.max(dim=1).values

In [9]:
batch_size = 64

model = nn.Sequential(
    nn.Linear(1, 16),
    MAB(
        16,
        d_ff=32,
        mha_config=MHAConfig(
            num_heads=4,
            attn_cls=LinearAttention,
            gated=True,
            num_k_heads=4//2,
            num_v_heads=4//2
        ),
        moe_cls=TopKMoE,
    ),
    PoolSet(method="mean"),
    nn.Linear(16, 1)
)

In [10]:
model

Sequential(
  (0): Linear(in_features=1, out_features=16, bias=True)
  (1): MAB(
    (inp_norm): ZeroCenteredRMSNorm((16,), eps=1e-06)
    (attn): MultiHeadAttention(
      (q_proj): Linear(in_features=16, out_features=16, bias=False)
      (k_proj): Linear(in_features=16, out_features=8, bias=False)
      (v_proj): Linear(in_features=16, out_features=8, bias=False)
      (gate): Gated(
        (gate): Sequential(
          (0): Linear(in_features=16, out_features=16, bias=False)
          (1): Sigmoid()
        )
      )
      (out_proj): Linear(in_features=16, out_features=16, bias=False)
      (attn): LinearAttention()
      (q_norm): ZeroCenteredRMSNorm((4,), eps=1e-06)
      (k_norm): ZeroCenteredRMSNorm((4,), eps=1e-06)
    )
    (ff_norm): ZeroCenteredRMSNorm((16,), eps=1e-06)
    (ff): TopKMoE(
      (experts): ModuleList(
        (0-3): 4 x GatedMLP(
          (up): Linear(in_features=16, out_features=32, bias=False)
          (down): Linear(in_features=32, out_features=16, bi

In [11]:
# model = MeanAttentionNet()

In [12]:
train_ds = SetDataset(df[:int(0.8*len(df))])
eval_ds = SetDataset(df[int(0.8*len(df)):])

In [None]:
train(model, train_ds=train_ds, eval_ds=eval_ds, epochs=1000, batch_size=batch_size, lr=1e-3, device="cpu")

Epoch 01 | Loss: 0.1060
Evaluation Accuracy: 0.0320
Evaluation Accuracy: 0.0500
Evaluation Accuracy: 0.0395
Evaluation Accuracy: 0.0620
Epoch 05 | Loss: 0.0279
Evaluation Accuracy: 0.0570
Evaluation Accuracy: 0.1035
Evaluation Accuracy: 0.1725
Evaluation Accuracy: 0.1670
Evaluation Accuracy: 0.1825
Epoch 10 | Loss: 0.0006
Evaluation Accuracy: 0.3230
Evaluation Accuracy: 0.3475
Evaluation Accuracy: 0.4540
Evaluation Accuracy: 0.4950
Evaluation Accuracy: 0.5320
Epoch 15 | Loss: 0.0003
Evaluation Accuracy: 0.5270
Evaluation Accuracy: 0.5130
Evaluation Accuracy: 0.6325
Evaluation Accuracy: 0.6605
Evaluation Accuracy: 0.6815
Epoch 20 | Loss: 0.0001
Evaluation Accuracy: 0.7175
Evaluation Accuracy: 0.6720
Evaluation Accuracy: 0.7545
Evaluation Accuracy: 0.6315
Evaluation Accuracy: 0.8250
Epoch 25 | Loss: 0.0000
Evaluation Accuracy: 0.8405
Evaluation Accuracy: 0.8670
Evaluation Accuracy: 0.8820
Evaluation Accuracy: 0.8820
Evaluation Accuracy: 0.8820
Epoch 30 | Loss: 0.0001
Evaluation Accuracy:

In [None]:
eval_ds[0]

(tensor([[ 2.],
         [10.],
         [ 0.],
         [ 6.]]),
 tensor(4.5000))

In [None]:
x, y = eval_ds[0]

pred = model(x.unsqueeze(0))

pred, y

(tensor([[4.5047]], grad_fn=<AddmmBackward0>), tensor(4.5000))