In [1]:
from fatransformer.model_components import GLU, MHA
from fatransformer.fatransformer import FATransformer
import torch.nn as nn
import torch.nn.functional as F
import torch
from fatransformer.helpers import get_nash_welfare
from typing import Optional

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ModuleNotFoundError: No module named 'fatransformer'

In [2]:
%%time

m = 14
n = 10

trans = FATransformer(n, m, 768, 12, 4, 0.0).to(device)

x = torch.rand(2048, n, m).to(device)

out = trans(x)

print(out.shape)

torch.Size([2048, 14, 10])
CPU times: user 399 ms, sys: 112 ms, total: 511 ms
Wall time: 637 ms


In [16]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import wandb
import numpy as np

wandb.init(
    project="fa-transformer-temp",   
    config={
        "n": 10,
        "m": 20,
        "d_model": 768,
        "num_heads": 12,
        "dropout": 0.0,
        "lr": 1e-4,
        "weight_decay": 1e-2,
        "steps": 20000,
        "batch_size": 512,
        "num_output_layers": 4,
        "initial_temperature": 1.0,
        "final_temperature": 0.001,
    }
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n = wandb.config["n"]
m = wandb.config["m"]

model = FATransformer(n, m, 
                      wandb.config["d_model"], wandb.config["num_heads"], wandb.config["num_output_layers"], wandb.config["dropout"],
                      initial_temperature=wandb.config["initial_temperature"], final_temperature=wandb.config["final_temperature"]
                      ).to(device)

optimizer = optim.AdamW(model.parameters(), lr=wandb.config["lr"], weight_decay=wandb.config["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=wandb.config["steps"])

# Temperature annealing schedule
initial_temp = wandb.config["initial_temperature"]
final_temp = wandb.config["final_temperature"]

for step in range(wandb.config["steps"]):
    # Calculate current temperature using exponential decay
    progress = step / wandb.config["steps"]
    current_temp = initial_temp * (final_temp / initial_temp) ** progress
    model.update_temperature(current_temp)
    
    u = torch.rand(wandb.config["batch_size"], n, m, device=device)

    allocation = model(u)
    nash_welfare = get_nash_welfare(u, allocation, reduction="mean")

    loss = -nash_welfare

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()

    wandb.log({
        "step": step,
        "loss": loss.item(),
        "nash_welfare": nash_welfare.item(),
        "lr": scheduler.get_last_lr()[0],
        "temperature": current_temp
    })

wandb.finish()

0,1
loss,█▅▂▁▂▂▃▁▁▃▂▂▄▂▃▄▃▂▄▂▅▄▄▃▂▄▄▅▆▃▄▃▄▆▂▃▄▄▄▄
lr,█████████▇▇▇▇▇▆▆▆▆▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▁▁▁▁▁▁▁
nash_welfare,▁▄▆█▇▇▇█▇▆▇▇██▅▇▇▇█▆▇▆█▆▆▇▇▆▆▆▅▆▆▄▆▆▆▆▆▇
step,▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇█████
temperature,██▇▇▇▆▅▄▄▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,-1.76276
lr,0.0
nash_welfare,1.76276
step,19999.0
temperature,0.001


In [17]:
# Save the model's state_dict to a file
model_path = "fatransformer_model_20_10-new_001.pt"
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

# To give the model to someone else, share the 'fatransformer_model.pt' file.
# They can load it with:
# model = FATransformer(n, m, wandb.config["d_model"], wandb.config["num_heads"], wandb.config["num_output_layers"], wandb.config["dropout"])
# model.load_state_dict(torch.load("fatransformer_model.pt", map_location=device))
# model.to(device)
# model.eval()


Model saved to fatransformer_model_20_10-new_001.pt


In [11]:
model.eval()
model.temperature = .001

In [15]:
sample = torch.rand(10, 10, 14).to(device)

model(sample).shape

torch.Size([10, 14, 10])

In [4]:
model

FATransformer(
  (agent_proj): Linear(in_features=10, out_features=768, bias=True)
  (item_proj): Linear(in_features=14, out_features=768, bias=True)
  (output_proj): Linear(in_features=768, out_features=14, bias=True)
  (agent_transformer): FASelfAttentionBlock(
    (attn): MHA(
      (q_proj): Linear(in_features=768, out_features=768, bias=False)
      (k_proj): Linear(in_features=768, out_features=768, bias=False)
      (v_proj): Linear(in_features=768, out_features=768, bias=False)
      (o_proj): Linear(in_features=768, out_features=768, bias=True)
      (proj_drop): Dropout(p=0, inplace=False)
    )
    (glu): GLU(
      (activation): SiLU()
      (gate_proj): Linear(in_features=768, out_features=2048, bias=False)
      (up_proj): Linear(in_features=768, out_features=2048, bias=False)
      (down_proj): Linear(in_features=2048, out_features=768, bias=False)
    )
    (attn_norm): RMSNorm((768,), eps=None, elementwise_affine=True)
    (glu_norm): RMSNorm((768,), eps=None, elementw

In [None]:
# Updated training loop with temperature annealing
import torch
import torch.nn.functional as F
import torch.optim as optim
import wandb
import numpy as np

wandb.init(
    project="fa-transformer",   
    config={
        "n": 10,
        "m": 14,
        "d_model": 768,
        "num_heads": 12,
        "dropout": 0.0,
        "lr": 1e-4,
        "weight_decay": 1e-2,
        "steps": 20000,
        "batch_size": 100,
        "initial_temperature": 2.0,
        "final_temperature": 0.1,
    }
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n = wandb.config["n"]
m = wandb.config["m"]

model = FATransformer(n, m, wandb.config["d_model"], wandb.config["num_heads"], wandb.config["dropout"], initial_temperature=wandb.config["initial_temperature"]).to(device)

optimizer = optim.AdamW(model.parameters(), lr=wandb.config["lr"], weight_decay=wandb.config["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=wandb.config["steps"])

# Temperature annealing schedule
initial_temp = wandb.config["initial_temperature"]
final_temp = wandb.config["final_temperature"]

for step in range(wandb.config["steps"]):
    # Calculate current temperature using exponential decay
    progress = step / wandb.config["steps"]
    current_temp = initial_temp * (final_temp / initial_temp) ** progress
    model.update_temperature(current_temp)
    
    u = torch.rand(wandb.config["batch_size"], n, m, device=device)

    allocation = model(u)
    nash_welfare = get_nash_welfare(u, allocation, reduction="mean")

    loss = -nash_welfare

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()

    wandb.log({
        "step": step,
        "loss": loss.item(),
        "nash_welfare": nash_welfare.item(),
        "lr": scheduler.get_last_lr()[0],
        "temperature": current_temp
    })

wandb.finish()


In [None]:
# Test the updated FATransformer with final_temperature parameter
import torch

# Create model with final_temperature parameter
model = FATransformer(n=3, m=4, d_model=64, num_heads=4, final_temperature=0.01)
x = torch.rand(1, 3, 4)

print("Model parameters:")
print(f"Initial temperature: {model.initial_temperature}")
print(f"Final temperature: {model.final_temperature}")
print(f"Current temperature: {model.temperature}")

# Test training mode (uses current temperature)
print("\nTraining mode:")
output_train = model(x)
print(f"Temperature used: {model.temperature}")
print(f"Output shape: {output_train.shape}")
print(f"Sample output (first item): {output_train[0, 0, :]}")

# Test evaluation mode (automatically uses final temperature)
print("\nEvaluation mode:")
model.eval()
output_eval = model(x)
print(f"Temperature used: {model.temperature}")
print(f"Sample output (first item): {output_eval[0, 0, :]}")

# Test explicit evaluation temperature setting
print("\nExplicit eval temperature:")
model.train()  # Reset to training mode
model.set_eval_temperature()  # Set to final temperature
output_explicit = model(x)
print(f"Temperature used: {model.temperature}")
print(f"Sample output (first item): {output_explicit[0, 0, :]}")

print("\nNotice how the evaluation outputs are more binary-like (closer to 0 or 1)")


In [None]:
# Test temperature annealing implementation
import torch
import torch.nn.functional as F

# Create a simple test model
class TestModel(nn.Module):
    def __init__(self, initial_temperature=1.0):
        super().__init__()
        self.temperature = initial_temperature
        
    def update_temperature(self, temperature):
        self.temperature = temperature
        
    def forward(self, x):
        # Simulate logits
        logits = torch.randn(2, 3)  # 2 samples, 3 classes
        return F.softmax(logits / self.temperature, dim=-1)

# Test different temperatures
model = TestModel(initial_temperature=2.0)
x = torch.randn(2, 3)

print("Temperature = 2.0 (high, more uniform):")
model.update_temperature(2.0)
output1 = model(x)
print(output1)
print(f"Entropy: {-torch.sum(output1 * torch.log(output1 + 1e-8), dim=-1).mean():.4f}")

print("\nTemperature = 0.1 (low, more binary):")
model.update_temperature(0.1)
output2 = model(x)
print(output2)
print(f"Entropy: {-torch.sum(output2 * torch.log(output2 + 1e-8), dim=-1).mean():.4f}")

print("\nTemperature annealing schedule:")
for step in range(0, 100, 20):
    progress = step / 100
    temp = 2.0 * (0.1 / 2.0) ** progress
    model.update_temperature(temp)
    output = model(x)
    entropy = -torch.sum(output * torch.log(output + 1e-8), dim=-1).mean()
    print(f"Step {step:3d}: Temp={temp:.3f}, Entropy={entropy:.4f}")
