In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from fatransformer.exchangeable_layer import PoolLayer, ExchangeableLayer
import torch

In [3]:
layer = ExchangeableLayer(in_channels=1, out_channels=2, pool_config={'row': ['mean'], 'column': ['mean',], 'global': 'mean'})

In [4]:
x = torch.tensor([[[1., 2., 3.], [5., 6., 7.]]])
# x = x.unsqueeze(1)
print(x.shape)
y = layer(x)
y

torch.Size([1, 2, 3])


tensor([[[[-0.0120, -0.0173, -0.0245],
          [-0.0073, -0.0109, -0.0159]],

         [[-0.1372, -0.1049, -0.0735],
          [-0.1596, -0.1695, -0.1536]]]], grad_fn=<PermuteBackward0>)

In [5]:
x = torch.tensor([[[5., 6., 7.], [1., 2., 3.]]])
x = x.unsqueeze(1)  # Changed to BCHW format
print(x.shape)
y = layer(x)
print(y.shape)
y

torch.Size([1, 1, 2, 3])
torch.Size([1, 2, 2, 3])


tensor([[[[-0.0073, -0.0109, -0.0159],
          [-0.0120, -0.0173, -0.0245]],

         [[-0.1596, -0.1695, -0.1536],
          [-0.1372, -0.1049, -0.0735]]]], grad_fn=<PermuteBackward0>)

In [6]:
x = torch.tensor([[[2., 1., 3.], [6., 5., 7.]]])
x = x.unsqueeze(1)  # Changed to BCHW format
print(x.shape)
y = layer(x)
print(y)

torch.Size([1, 1, 2, 3])
tensor([[[[-0.0173, -0.0120, -0.0245],
          [-0.0109, -0.0073, -0.0159]],

         [[-0.1049, -0.1372, -0.0735],
          [-0.1695, -0.1596, -0.1536]]]], grad_fn=<PermuteBackward0>)


In [19]:
from fatransformer.fatransformer_exchangeable import FATransformer

In [20]:
n = 10
m = 20
d_model = 768
num_heads = 12
num_output_layers = 4
dropout = 0.0
initial_temperature = 1.0
final_temperature = 0.01
model = FATransformer(n, m, d_model, num_heads, num_output_layers, dropout, initial_temperature, final_temperature)

In [21]:
x = torch.rand(1, n, m)
model(x).shape

torch.Size([1, 20, 10])

In [22]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import wandb
import numpy as np
from fatransformer.helpers import get_nash_welfare

wandb.init(
    project="fa-transformer-temp",   
    config={
        "n": 10,
        "m": 20,
        "d_model": 768,
        "num_heads": 12,
        "dropout": 0.0,
        "lr": 1e-4,
        "weight_decay": 1e-2,
        "steps": 500,
        "batch_size": 256,
        "num_output_layers": 4,
        "initial_temperature": 1.0,
        "final_temperature": 0.01,
    }
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n = wandb.config["n"]
m = wandb.config["m"]

model = FATransformer(n, m, 
                      wandb.config["d_model"], wandb.config["num_heads"], wandb.config["num_output_layers"], wandb.config["dropout"],
                      initial_temperature=wandb.config["initial_temperature"], final_temperature=wandb.config["final_temperature"]
                      ).to(device)

optimizer = optim.AdamW(model.parameters(), lr=wandb.config["lr"], weight_decay=wandb.config["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=wandb.config["steps"])

# Temperature annealing schedule
initial_temp = wandb.config["initial_temperature"]
final_temp = wandb.config["final_temperature"]

for step in range(wandb.config["steps"]):
    # Calculate current temperature using exponential decay
    progress = step / wandb.config["steps"]
    current_temp = initial_temp * (final_temp / initial_temp) ** progress
    model.update_temperature(current_temp)
    
    u = torch.rand(wandb.config["batch_size"], n, m, device=device)

    allocation = model(u)
    nash_welfare = get_nash_welfare(u, allocation, reduction="mean")

    loss = -nash_welfare

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()

    wandb.log({
        "step": step,
        "loss": loss.item(),
        "nash_welfare": nash_welfare.item(),
        "lr": scheduler.get_last_lr()[0],
        "temperature": current_temp
    })

wandb.finish()

0,1
loss,█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,██████████▇▇▇▇▇▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▁▁▁▁▁
nash_welfare,▁▁▅▆▇▇█▇█▇███████████████████████████▇▇▇
step,▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
temperature,█▆▆▆▆▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,-1.64003
lr,0.0
nash_welfare,1.64003
step,499.0
temperature,0.01009


In [19]:
from fatransformer.fatransformer_exchangeable import FATransformer

n = 10
m = 20
d_model = 768
num_heads = 12
num_output_layers = 4
dropout = 0.0
initial_temperature = 1.0
final_temperature = 0.01
model = FATransformer(n, m, d_model, num_heads, num_output_layers, dropout, initial_temperature, final_temperature)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params:,}")

Total number of parameters in the model: 66,092,545


In [21]:
from fatransformer.fatransformer import FATransformer

n = 10
m = 20
d_model = 768
num_heads = 12
num_output_layers = 4
dropout = 0.0
initial_temperature = 1.0
final_temperature = 0.01
model = FATransformer(n, m, d_model, num_heads, num_output_layers, dropout, initial_temperature, final_temperature)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params:,}")


Total number of parameters in the model: 49,595,146


In [17]:
print(595146 - 581322)
print(595146 - 578250)
print(768*20+768*10-768*8)

13824
16896
16896


In [6]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import wandb
import numpy as np
from fatransformer.helpers import get_nash_welfare

wandb.init(
    project="fa-transformer-temp",   
    config={
        "n": 10,
        "m": 20,
        "d_model": 768,
        "num_heads": 12,
        "dropout": 0.0,
        "lr": 1e-3,
        "weight_decay": 1e-2,
        "steps": 500,
        "batch_size": 512,
        "num_output_layers": 4,
        "initial_temperature": 1.0,
        "final_temperature": 0.01,
    }
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n = wandb.config["n"]
m = wandb.config["m"]

model = FATransformer(n, m, 
                      wandb.config["d_model"], wandb.config["num_heads"], wandb.config["num_output_layers"], wandb.config["dropout"],
                      initial_temperature=wandb.config["initial_temperature"], final_temperature=wandb.config["final_temperature"]
                      ).to(device)

optimizer = optim.AdamW(model.parameters(), lr=wandb.config["lr"], weight_decay=wandb.config["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=wandb.config["steps"])

# Temperature annealing schedule
initial_temp = wandb.config["initial_temperature"]
final_temp = wandb.config["final_temperature"]

for step in range(wandb.config["steps"]):
    # Calculate current temperature using exponential decay
    progress = step / wandb.config["steps"]
    current_temp = initial_temp * (final_temp / initial_temp) ** progress
    model.update_temperature(current_temp)
    
    u = torch.rand(wandb.config["batch_size"], n, m, device=device)

    allocation = model(u)
    nash_welfare = get_nash_welfare(u, allocation, reduction="mean")

    loss = -nash_welfare

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()

    wandb.log({
        "step": step,
        "loss": loss.item(),
        "nash_welfare": nash_welfare.item(),
        "lr": scheduler.get_last_lr()[0],
        "temperature": current_temp
    })

wandb.finish()

0,1
loss,█▃▂▂▂▂▁▂▁▁▂▂▂▂▂▂▂▂▂▄▄▃▄▃▃▄▄▄▃▃▂▃▂▂▂▁▁▁▁▁
lr,█████▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
nash_welfare,▇███▇█▇▇▇▇█▆▇▇█▇▆▇▇▇▅▅▇▄▁▆▅▅▇▄▇▇▇▇▇▆▇███
step,▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇██
temperature,█▇▆▆▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,-0.99161
lr,0.0
nash_welfare,0.99161
step,499.0
temperature,0.01009


In [28]:
import torch.nn as nn
n = AxisAttnPool1D(768)
x = torch.rand(10, 768, 10, 20)
n(x).shape

torch.Size([10, 10, 768])

In [26]:
class AxisAttnPool1D(nn.Module):
    """
    Reduces the width axis (W) of a (B, D, H, W) tensor to (B, H, D) via learned attention.
    Permutation-equivariant over the reduced axis and length-agnostic.
    """
    def __init__(self, d_model: int):
        super().__init__()
        self.norm  = nn.RMSNorm(d_model)
        self.score = nn.Linear(d_model, 1, bias=False)  # per-element score

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, D, H, W)
        B, D, H, W = x.shape
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        s = self.score(x).squeeze(-1)             # (B, H, W)
        a = s.softmax(dim=2)                      # softmax over width axis
        pooled = (a.unsqueeze(-1) * x).sum(dim=2) # (B, H, D)
        return pooled
