In [1]:
from model_components import GLU, MHA
from fatransformer import FATransformer
import torch.nn as nn
import torch.nn.functional as F
import torch
from helpers import get_nash_welfare
from typing import Optional

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import wandb
import torch.optim as optim

# Define the sweep configuration for Bayesian optimization
sweep_config = {
    "method": "bayes",
    "metric": {"goal": "maximize", "name": "nash_welfare"},
    "parameters": {
        "d_model": {
            "values": [768, 1536]
        },
        "num_heads": {
            "values": [4, 8, 12, 16]
        },
        "lr": {
            "min": 1e-5,
            "max": 1e-4,
            "distribution": "log_uniform_values"
        },
        "batch_size": {
            "values": [512, 1024, 2048, 4096]
        },
        "num_output_layers": {
            "values": [1, 2, 3, 4, 5]
        },
        "weight_decay": {
            "min": 1e-4,
            "max": 1e-1,
            "distribution": "log_uniform_values"
        },
        "n": {
            "value": 10
        },
        "m": {
            "value": 14
        },
        "dropout": {
            "value": 0.0
        },
        "steps": {
            "value": 20000
        }
    }
}

sweep_id = wandb.sweep(sweep_config, project="fa-transformer-sweep")

def train():
    wandb.init()
    config = wandb.config

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    n = config.n
    m = config.m

    model = FATransformer(
        n, m, config.d_model, config.num_heads, config.num_output_layers, config.dropout
    ).to(device)

    optimizer = optim.AdamW(
        model.parameters(), lr=config.lr, weight_decay=config.weight_decay
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.steps)

    # Early stopping parameters
    patience = 20  # Number of steps to wait for improvement
    min_delta = 1e-5  # Minimum change to qualify as improvement
    best_nash_welfare = float('-inf')
    steps_without_improvement = 0

    for step in range(config.steps):
        u = torch.rand(config.batch_size, n, m, device=device)

        allocation = model(u)
        nash_welfare = get_nash_welfare(u, allocation, reduction="mean")

        loss = -nash_welfare

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        wandb.log({
            "step": step,
            "loss": loss.item(),
            "nash_welfare": nash_welfare.item(),
            "lr": scheduler.get_last_lr()[0]
        })

        # Early stopping logic
        if nash_welfare.item() > best_nash_welfare + min_delta:
            best_nash_welfare = nash_welfare.item()
            steps_without_improvement = 0
        else:
            steps_without_improvement += 1

        if steps_without_improvement >= patience:
            print(f"Early stopping at step {step} with best nash_welfare {best_nash_welfare:.6f}")
            break

    wandb.finish()

# To launch the sweep agent, run this cell or in a script:
wandb.agent(sweep_id, function=train)


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Create sweep with ID: obmpqm8u
Sweep URL: https://wandb.ai/dieplstks/fa-transformer-sweep/sweeps/obmpqm8u


[34m[1mwandb[0m: Agent Starting Run: fexst8pn with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	d_model: 1536
[34m[1mwandb[0m: 	dropout: 0
[34m[1mwandb[0m: 	lr: 3.429339118090045e-05
[34m[1mwandb[0m: 	m: 14
[34m[1mwandb[0m: 	n: 10
[34m[1mwandb[0m: 	num_heads: 16
[34m[1mwandb[0m: 	num_output_layers: 5
[34m[1mwandb[0m: 	steps: 20000
[34m[1mwandb[0m: 	weight_decay: 0.014630721343048138
[34m[1mwandb[0m: Currently logged in as: [33mdieplstks[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7fb8fc5f0710>> (for post_run_cell), with arguments args (<ExecutionResult object at 7fb90af354c0, execution_count=2 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7fb90b256780, raw_cell="import wandb
import torch.optim as optim

# Define.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://wsl%2Bubuntu/home/dipplestix/Projects/fair-allocation-transformer/fatransformer/faformer_sweep.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

Exception in thread Exception in threading.excepthook:
Exception ignored in thread started by: <bound method Thread._bootstrap of <Thread(Thread-5 (_run_job), stopped 140432411457088)>>
Traceback (most recent call last):
  File "/home/dipplestix/Projects/paper_reimps/.conda/lib/python3.12/threading.py", line 1030, in _bootstrap
    self._bootstrap_inner()
  File "/home/dipplestix/Projects/paper_reimps/.conda/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self._invoke_excepthook(self)
  File "/home/dipplestix/Projects/paper_reimps/.conda/lib/python3.12/threading.py", line 1389, in invoke_excepthook
    local_print("Exception in threading.excepthook:",
  File "/home/dipplestix/Projects/paper_reimps/.conda/lib/python3.12/site-packages/ipykernel/iostream.py", line 604, in flush
    self.pub_thread.schedule(self._flush)
  File "/home/dipplestix/Projects/paper_reimps/.conda/lib/python3.12/site-packages/ipykernel/iostream.py", line 267, in schedule
    self._event_pipe.send(