In [1]:
import os

import haiku as hk
import jax
import jax.numpy as jnp
# sys.path.append(os.path.abspath('../'))
import numpy as np
import optax

import wandb
from losses import AssociativeRecallLoss
from model_rng import CustomTransformer  #, Transformer
from trainer_gd import TrainerGD

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

hk.vmap.require_split_rng = False

# Default CONFIG

| **Hyperparameter**                  | **Value**                                                                                     |
|------------------------------------|-----------------------------------------------------------------------------------------------|
| **Dataset**                        | Randomly generated binary value vectors with d = 5 and corresponding one-hot encoded keys     |
| **Tokenization & RSE**             | One token is the concatenated vector `[vi, ki, ri]`, where `ri` is the same random binary vector for all `i` with kd_r = 10` |
| **Context size**                   | Variable size from 8–20 (see Figure 1)                                                        |
| **Optimizer**                      | Adam (ε = 1e−5, β₁ = 0.9, β₂ = 0.95)                                                           |
| **Hyperparameters of our objective** | `m = 30` and `p = 100`                                                                        |
| **Batch size**                     | 512                                                                                           |
| **Gradient clipping**              | Global norm of 1                                                                              |
| **Positional encodings**           | Standard positional encodings                                                                 |
| **Architecture details**           | 2 transformer blocks with linear self-attention, 1 head, key size 5, token size 15, no input- but output-embedding |
| **Attention mask details**         | Causal mask                                                                                   |
| **Weight initialization**          | Truncated normal with variance from fan-in; biases zero; scale weight matrices before skip connection with `1 / (2√N)` where N = number of layers |
| **Learning rate scheduler**        | Linear warm-up from 0 to 0.003 for 2000 steps, then annealed to 0.0003                         |
| **Standard deviation / robustness**| Results averaged over 5 random seeds; standard deviation omitted due to negligible differences |


In [2]:
N = 5

cfg = {
    "qmc": False,
    "seed": 42,
    "num_jit_batches": 100,
    "num_steps": 10000,  #30000,

    # Env    
    "num_token": 1,
    "target_size": N,
    "batch_size": 512,
    "num_train_seed": 0,  # total: N * 2**B

    "probabilistic": "random",
    "data_pooling": "lp",  #"lp" "mean"
    "p": 1,
    "num_seed": 30,
    "loss": "bce",  # "contrastive_ce" "contrastive_hinge" "bce" "bce_mse" "mse"

    "hardcoded_randomness": False,
    "mlp": False,
    "widening_factor": 1,
    "first_mlp": True,
    "seed_size": 2 * N,
    'reverse_block': True,

    # RNG model
    "num_layers": 2,
    "num_heads": 1,
    "kq_dim": N,
    "v_dim": N,
    "embed_dim": 8 * N,  # 32,
    "softmax": "none",  #"all",
    "positional_embedding": False,
    "first_embedding_init_var": 1,
    "w_init_var": 1,

    "optim_algo": "gd",
    "GD_PARAM": {
        "lr": 3e-3,
        "betas": (0.9, 0.95),
        "eps": 1e-5,
        "grad_norm_clip": 1,
        "weight_decay": 0.01 * 0,
        "scheduler": "cosine",  # "warmup", "cosine"
        "lr_alpha": 0.1,
        "warmup_steps": 333, #1000,
    },

}


In [3]:
from itertools import product


def generate_one_hot_combinations(K, T):
    # Generate all possible sequences of indices with T elements where each element ranges from 0 to K-1
    indices = list(product(range(K), repeat=T))

    # Convert these indices to one-hot vectors
    def one_hot_encode(index):
        # Create a one-hot vector for each index in the sequence
        return jnp.eye(K)[jnp.array(index)]

    return jnp.array(indices)


def save_advanced_log(cfg, loss_fn, params):
    all_y_target = generate_one_hot_combinations(2, cfg["target_size"])
    all_query_idx = jnp.arange(cfg["num_token"])[:, None]
    rng_static, rng_dynamic = jax.random.split(jax.random.PRNGKey(0))

    all_log_dict = {}
    if cfg["probabilistic"] == "single_seed":
        num_seed = 100
    elif cfg["probabilistic"] == "deterministic":
        num_seed = 1
        rng_dynamic = jax.random.PRNGKey(cfg["seed"])
    else:
        num_seed = 100

    rng_list = jax.random.split(rng_dynamic, num_seed)
    for rng_seed in rng_list:
        all_log_dict_tmp = {}
        for rng_Y in jax.random.split(rng_static, 16):
            rng_Y = jax.vmap(lambda r: jax.random.split(r, all_y_target.shape[0]))(
                jax.random.split(rng_Y, all_query_idx.shape[0])
            )
            input, target = jax.vmap(
                jax.vmap(loss_fn.data_generator.build, in_axes=(0, None, 0, None)),
                in_axes=(0, None, None, 0),
            )(rng_Y, rng_seed, all_y_target, all_query_idx)
            loss, log_dict, prediction = jax.vmap(
                jax.vmap(loss_fn.get_loss_from_input, in_axes=(None, 0, 0)),
                in_axes=(None, 0, 0),
            )(params, input, target)
            log_dict["loss"] = loss
            log_dict["prediction"] = prediction
            log_dict["target"] = target[-1]

            for k, v in log_dict.items():
                if k not in all_log_dict_tmp:
                    all_log_dict_tmp[k] = []
                all_log_dict_tmp[k].append(v)

        for k, v in all_log_dict_tmp.items():
            if k not in all_log_dict:
                all_log_dict[k] = []
            all_log_dict[k].append(jnp.stack(v))

    for k, v in all_log_dict.items():
        all_log_dict[k] = jnp.stack(v)

    return all_log_dict


def train():
    run = wandb.init()
    cfg.update({k: v for (k, v) in wandb.config.items() if type(v) != dict})
    wandb.config.update(cfg)

    # create the transformer model
    model = hk.without_apply_rng(
        hk.transform(
            lambda x: CustomTransformer(
                out_dim=cfg["target_size"],
                num_layers=cfg["num_layers"],
                num_heads=cfg["num_heads"],
                kq_dim=cfg["kq_dim"],
                v_dim=cfg["v_dim"],
                embed_dim=cfg["embed_dim"],
                softmax=cfg["softmax"],
                positional_embedding=cfg["positional_embedding"],
                first_embedding_init_var=cfg["first_embedding_init_var"],
                w_init_var=cfg["w_init_var"],
                mlp=cfg["mlp"],
                widening_factor=cfg["widening_factor"],
                first_mlp=cfg["first_mlp"],
                reverse_block=cfg["reverse_block"],
            )(x)
        )
    )

    train_param = cfg["GD_PARAM"]
    if train_param["scheduler"] == "cosine":
        learning_rate = optax.join_schedules(
            [
                optax.linear_schedule(
                    0, train_param["lr"], train_param["warmup_steps"]
                ),
                optax.cosine_decay_schedule(
                    train_param["lr"],
                    cfg["num_steps"] - train_param["warmup_steps"],
                    alpha=train_param["lr"] * train_param["lr_alpha"],
                ),
            ],
            boundaries=[train_param["warmup_steps"]],
        )
    elif train_param["scheduler"] == "warmup":
        learning_rate = optax.linear_schedule(
            0, train_param["lr"], train_param["warmup_steps"]
        )
    else:
        learning_rate = train_param["lr"]

    gd_optimizer = optax.inject_hyperparams(
        lambda lr: optax.chain(
            optax.clip_by_global_norm(train_param["grad_norm_clip"]),
            optax.adamw(
                lr,
                weight_decay=train_param["weight_decay"],
                b1=train_param["betas"][0],
                b2=train_param["betas"][1],
                eps=train_param["eps"],
            ),
        )
    )(lr=learning_rate)

    loss_fn = AssociativeRecallLoss(model, cfg)
    trainer = TrainerGD(model, gd_optimizer, loss_fn, cfg)

    t = 0
    while t < cfg["num_steps"] // cfg["num_jit_batches"]:
        log_metric = trainer.train_iter(cfg["num_jit_batches"])
        eval_metric = loss_fn.eval_fn(trainer.get_params(), 10)

        log_dict = {}
        # log_dict.update({k: v.mean().item() for k, v in log_metric.items()})
        log_dict.update({k + "_eval": v.mean().item() for k, v in eval_metric.items()})
        wandb.log(log_dict)
        print(f"Step {t}: {log_dict}")
        if log_dict.get("data_loss_eval", 1) < 0.001:
            print("Early stopping due to low eval loss")
            break
        t += 1

    log_dict = {}
    train_metric = loss_fn.eval_fn(trainer.get_params(), 1000, eval_on_train=True)
    eval_metric = loss_fn.eval_fn(trainer.get_params(), 1000, eval_on_train=False)
    log_dict.update(train_metric)
    log_dict.update({k + "_eval": v for k, v in eval_metric.items()})
    wandb.log({"final_" + k: v for k, v in log_dict.items()})

    print(f"Saving run {run.id}")
    np.save(f"checkpoints/associative_recall_{run.id}_{cfg['probabilistic']}_q{cfg['p']}_c{cfg['num_token']}", (cfg, trainer.get_params()))
    adv_log_dict = save_advanced_log(cfg, loss_fn, trainer.get_params())
    np.save(f"checkpoints/associative_recall_{run.id}_log_dict", adv_log_dict)

In [10]:
train()

0,1
data_loss_eval,▁
data_loss_max_eval,▁
data_loss_median_eval,▁
inacc_eval,▁
inacc_harsh_eval,▁
inacc_harsh_max_eval,▁
inacc_harsh_median_eval,▁
inacc_lenient_eval,▁
inacc_lenient_max_eval,▁
inacc_lenient_median_eval,▁

0,1
data_loss_eval,0.4705
data_loss_max_eval,0.85633
data_loss_median_eval,0.48148
inacc_eval,0.64131
inacc_harsh_eval,0.98824
inacc_harsh_max_eval,1.0
inacc_harsh_median_eval,1.0
inacc_lenient_eval,0.08727
inacc_lenient_max_eval,1.0
inacc_lenient_median_eval,0.0


Number of parameters: 4.12k
Step 0: {'data_loss_eval': 0.47049883008003235, 'data_loss_max_eval': 0.8563265800476074, 'data_loss_median_eval': 0.48147857189178467, 'inacc_eval': 0.6413143873214722, 'inacc_harsh_eval': 0.9882421493530273, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.08726562559604645, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 0.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 0.6651670932769775}
Step 1: {'data_loss_eval': 0.0015417593531310558, 'data_loss_max_eval': 0.08072300255298615, 'data_loss_median_eval': 6.43824678263627e-05, 'inacc_eval': 0.00044726557098329067, 'inacc_harsh_eval': 0.013417968526482582, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 0.0, 'inacc_lenient_eval': 0.0, 'inacc_lenient_max_eval': 0.0, 'inacc_lenient_median_eval': 0.0, 'inacc_max_eval': 0.03333333507180214, 'inacc_median_eval': 0.0}
Step 2: {'data_loss_eval': 0.0007402288028970361, 'data_loss_max_eval': 0.02694686502218246

KeyboardInterrupt: 

In [None]:
sweep_configuration = {
    'method': 'grid',
    'name': 'sweep_AR_token',
    'parameters': dict(
        probabilistic={'values': ["single_seed", "deterministic", "random"]},
        num_token={'values': [12, 14, 16, 18, 20]},
        p={'values': [1, 16, 32, 100]},
        # seed={'values': [10, 11, 12, 13, 14]},
    )
}
sweep_id = wandb.sweep(sweep=sweep_configuration)
wandb.agent(sweep_id=sweep_id, function=train, count=60)

Create sweep with ID: 87jv6nbf
Sweep URL: https://wandb.ai/team-epoch-iv/uncategorized/sweeps/87jv6nbf


wandb: Agent Starting Run: swey6mwp with config:
wandb: 	num_token: 12
wandb: 	p: 1
wandb: 	probabilistic: single_seed




Number of parameters: 4.57k
Step 0: {'data_loss_eval': 0.695654571056366, 'data_loss_max_eval': 0.9721782803535461, 'data_loss_median_eval': 0.6952534317970276, 'inacc_eval': 0.96791011095047, 'inacc_harsh_eval': 0.96791011095047, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.96791011095047, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 1: {'data_loss_eval': 0.6929435729980469, 'data_loss_max_eval': 0.9733137488365173, 'data_loss_median_eval': 0.6929234862327576, 'inacc_eval': 0.9653124809265137, 'inacc_harsh_eval': 0.9653124809265137, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.9653124809265137, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 2: {'data_loss_eval': 0.6915267109870911, 'data_loss_max_eval': 1.0888051986694336, 'data_loss_median_eval': 0.6915991902351

0,1
data_loss_eval,███▅▅▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
data_loss_max_eval,▂▃▆▅█▆▅▄▃▆▄▅▇▇██▆▆▅▄▆▄▅▃▅▃▃▃▃▃▃▂▂▁▃▂▁▁▁▁
data_loss_median_eval,██▇▆▆▅▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
final_data_loss,▁
final_data_loss_eval,▁
final_data_loss_max,▁
final_data_loss_max_eval,▁
final_data_loss_median,▁
final_data_loss_median_eval,▁
final_inacc,▁

0,1
data_loss_eval,0.0239
data_loss_max_eval,0.67185
data_loss_median_eval,0.00076
final_data_loss,0.02391
final_data_loss_eval,0.02391
final_data_loss_max,1.15958
final_data_loss_max_eval,1.15958
final_data_loss_median,0.00076
final_data_loss_median_eval,0.00076
final_inacc,0.07323


wandb: Agent Starting Run: gskyrpcx with config:
wandb: 	num_token: 12
wandb: 	p: 1
wandb: 	probabilistic: deterministic




Number of parameters: 4.57k
Step 0: {'data_loss_eval': 0.681771993637085, 'data_loss_max_eval': 1.102307677268982, 'data_loss_median_eval': 0.6826086044311523, 'inacc_eval': 0.9495898485183716, 'inacc_harsh_eval': 0.9495898485183716, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.9495898485183716, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 1: {'data_loss_eval': 0.662310004234314, 'data_loss_max_eval': 1.2588285207748413, 'data_loss_median_eval': 0.6604387760162354, 'inacc_eval': 0.9226366877555847, 'inacc_harsh_eval': 0.9226366877555847, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.9226366877555847, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 2: {'data_loss_eval': 0.6071664094924927, 'data_loss_max_eval': 1.9593490362167358, 'data_loss_median_eval': 0.599267542

0,1
data_loss_eval,█▅▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
data_loss_max_eval,▃█▇▅▆█▄▇▅▆▇▆▅▄▅▅▄▆▆▇▄▄▄▃▄▃▄▃▂▂▃▃▂▄▂▁▂▁▁▁
data_loss_median_eval,██▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
final_data_loss,▁
final_data_loss_eval,▁
final_data_loss_max,▁
final_data_loss_max_eval,▁
final_data_loss_median,▁
final_data_loss_median_eval,▁
final_inacc,▁

0,1
data_loss_eval,0.01498
data_loss_max_eval,0.74943
data_loss_median_eval,0.0001
final_data_loss,0.01507
final_data_loss_eval,0.01507
final_data_loss_max,1.11914
final_data_loss_max_eval,1.11914
final_data_loss_median,0.0001
final_data_loss_median_eval,0.0001
final_inacc,0.04536


wandb: Agent Starting Run: n44x6f1b with config:
wandb: 	num_token: 12
wandb: 	p: 1
wandb: 	probabilistic: random




Number of parameters: 4.57k
Step 0: {'data_loss_eval': 0.6947222948074341, 'data_loss_max_eval': 0.7439357042312622, 'data_loss_median_eval': 0.6946430802345276, 'inacc_eval': 0.9676685929298401, 'inacc_harsh_eval': 1.0, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.5380859375, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 0.8799999952316284, 'inacc_max_eval': 1.0, 'inacc_median_eval': 0.9960000514984131}
Step 1: {'data_loss_eval': 0.6784694790840149, 'data_loss_max_eval': 1.0439389944076538, 'data_loss_median_eval': 0.6795017719268799, 'inacc_eval': 0.9395832419395447, 'inacc_harsh_eval': 0.9993945360183716, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.6862499713897705, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 2: {'data_loss_eval': 0.6040310859680176, 'data_loss_max_eval': 1.4325844049453735, 'data_loss_median_eval':

0,1
data_loss_eval,█▇▆▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
data_loss_max_eval,▁▅▄▄▄▅▅▅▅▄▄▅▅▄▅▄▄▄▄▃▃▄▄▄█▃▃▃▂▁▂▂▁▁▂▁▁▁▁▁
data_loss_median_eval,█▆▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
final_data_loss,▁
final_data_loss_eval,▁
final_data_loss_max,▁
final_data_loss_max_eval,▁
final_data_loss_median,▁
final_data_loss_median_eval,▁
final_inacc,▁

0,1
data_loss_eval,0.02357
data_loss_max_eval,0.82631
data_loss_median_eval,0.00067
final_data_loss,0.02352
final_data_loss_eval,0.02352
final_data_loss_max,1.07826
final_data_loss_max_eval,1.07826
final_data_loss_median,0.00065
final_data_loss_median_eval,0.00065
final_inacc,0.06937


wandb: Agent Starting Run: ci78c0kp with config:
wandb: 	num_token: 12
wandb: 	p: 16
wandb: 	probabilistic: single_seed




Number of parameters: 4.57k
Step 0: {'data_loss_eval': 0.695116400718689, 'data_loss_max_eval': 0.8646141290664673, 'data_loss_median_eval': 0.6947752237319946, 'inacc_eval': 0.967578113079071, 'inacc_harsh_eval': 0.967578113079071, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.967578113079071, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 1: {'data_loss_eval': 0.6933043599128723, 'data_loss_max_eval': 0.7486452460289001, 'data_loss_median_eval': 0.6931825876235962, 'inacc_eval': 0.9670116901397705, 'inacc_harsh_eval': 0.9670116901397705, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.9670116901397705, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 2: {'data_loss_eval': 0.6907561421394348, 'data_loss_max_eval': 0.8367813229560852, 'data_loss_median_eval': 0.6915228366

0,1
data_loss_eval,██▇▇▆▅▅▅▅▄▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
data_loss_max_eval,▅▅▅▇▆▅▄▆▅█▆▅▆▆▅▅▆▅▆▄▅▄▅▆▃▄▄▃▃▃▂▃▂▃▃▂▁▁▁▁
data_loss_median_eval,███▇▆▄▄▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
final_data_loss,▁
final_data_loss_eval,▁
final_data_loss_max,▁
final_data_loss_max_eval,▁
final_data_loss_median,▁
final_data_loss_median_eval,▁
final_inacc,▁

0,1
data_loss_eval,0.07519
data_loss_max_eval,0.52305
data_loss_median_eval,0.05704
final_data_loss,0.07481
final_data_loss_eval,0.07481
final_data_loss_max,0.69708
final_data_loss_max_eval,0.69708
final_data_loss_median,0.05681
final_data_loss_median_eval,0.05681
final_inacc,0.07513


wandb: Agent Starting Run: 1w2vjwbu with config:
wandb: 	num_token: 12
wandb: 	p: 16
wandb: 	probabilistic: deterministic




Number of parameters: 4.57k
Step 0: {'data_loss_eval': 0.6883750557899475, 'data_loss_max_eval': 0.7647943496704102, 'data_loss_median_eval': 0.6888426542282104, 'inacc_eval': 0.9477148056030273, 'inacc_harsh_eval': 0.9477148056030273, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.9477148056030273, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 1: {'data_loss_eval': 0.6811529397964478, 'data_loss_max_eval': 0.8053514361381531, 'data_loss_median_eval': 0.6816191077232361, 'inacc_eval': 0.924121081829071, 'inacc_harsh_eval': 0.924121081829071, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.924121081829071, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 2: {'data_loss_eval': 0.6250389218330383, 'data_loss_max_eval': 0.9337337613105774, 'data_loss_median_eval': 0.626499414

0,1
data_loss_eval,█▇▇▆▅▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
data_loss_max_eval,▃▆▅▆▆▄█▅▄▄▆▇▄▅▅▃▂▂▂▃▄▅▄▄▅▃▆▂▅▄▃▅▃▁▁▁▁▁▁▁
data_loss_median_eval,██▇▆▅▅▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
final_data_loss,▁
final_data_loss_eval,▁
final_data_loss_max,▁
final_data_loss_max_eval,▁
final_data_loss_median,▁
final_data_loss_median_eval,▁
final_inacc,▁

0,1
data_loss_eval,0.07125
data_loss_max_eval,0.6734
data_loss_median_eval,0.05413
final_data_loss,0.07149
final_data_loss_eval,0.07149
final_data_loss_max,0.6734
final_data_loss_max_eval,0.6734
final_data_loss_median,0.05425
final_data_loss_median_eval,0.05425
final_inacc,0.07489


wandb: Agent Starting Run: muuopk9g with config:
wandb: 	num_token: 12
wandb: 	p: 16
wandb: 	probabilistic: random




Number of parameters: 4.57k
Step 0: {'data_loss_eval': 0.6946938633918762, 'data_loss_max_eval': 0.7392885088920593, 'data_loss_median_eval': 0.6947845816612244, 'inacc_eval': 0.9681358933448792, 'inacc_harsh_eval': 1.0, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.5093945264816284, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 0.5699999928474426, 'inacc_max_eval': 1.0, 'inacc_median_eval': 0.9856668710708618}
Step 1: {'data_loss_eval': 0.6884190440177917, 'data_loss_max_eval': 0.7547599077224731, 'data_loss_median_eval': 0.6890026926994324, 'inacc_eval': 0.9504683613777161, 'inacc_harsh_eval': 1.0, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.4812109172344208, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 0.3349999785423279, 'inacc_max_eval': 1.0, 'inacc_median_eval': 0.97783362865448}
Step 2: {'data_loss_eval': 0.6522217988967896, 'data_loss_max_eval': 0.8507391810417175, 'data

wandb: Network error (ConnectionError), entering retry loop.


Step 41: {'data_loss_eval': 0.21565498411655426, 'data_loss_max_eval': 1.1550358533859253, 'data_loss_median_eval': 0.20217901468276978, 'inacc_eval': 0.1758965104818344, 'inacc_harsh_eval': 0.35712888836860657, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 0.0, 'inacc_lenient_eval': 0.06246093660593033, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 0.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 0.0}


wandb: Network error (ConnectionError), entering retry loop.


Step 42: {'data_loss_eval': 0.20100079476833344, 'data_loss_max_eval': 0.8583104610443115, 'data_loss_median_eval': 0.1876661777496338, 'inacc_eval': 0.15807361900806427, 'inacc_harsh_eval': 0.3283007740974426, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 0.0, 'inacc_lenient_eval': 0.056328125298023224, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 0.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 0.0}
Step 43: {'data_loss_eval': 0.19296611845493317, 'data_loss_max_eval': 0.7065832018852234, 'data_loss_median_eval': 0.18050368130207062, 'inacc_eval': 0.13206776976585388, 'inacc_harsh_eval': 0.2913867235183716, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 0.0, 'inacc_lenient_eval': 0.041386716067790985, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 0.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 0.0}
Step 44: {'data_loss_eval': 0.20580463111400604, 'data_loss_max_eval': 0.8485163450241089, 'data_loss_median_eval': 0.19261400401592255, 'i

0,1
data_loss_eval,█▇▇▅▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
data_loss_max_eval,▄▆▇▆▇▆▅▆▆▇▆▆▇▆▅█▄▆▄▅▅▅▄▄▄▅▅▅▄▃▃▃▂▂▂▁▁▂▁▁
data_loss_median_eval,█▇▆▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
final_data_loss,▁
final_data_loss_eval,▁
final_data_loss_max,▁
final_data_loss_max_eval,▁
final_data_loss_median,▁
final_data_loss_median_eval,▁
final_inacc,▁

0,1
data_loss_eval,0.07701
data_loss_max_eval,0.52878
data_loss_median_eval,0.06014
final_data_loss,0.07659
final_data_loss_eval,0.07659
final_data_loss_max,0.69064
final_data_loss_max_eval,0.69064
final_data_loss_median,0.05977
final_data_loss_median_eval,0.05977
final_inacc,0.07582


wandb: Agent Starting Run: rzj5nb6b with config:
wandb: 	num_token: 12
wandb: 	p: 32
wandb: 	probabilistic: single_seed




Number of parameters: 4.57k
Step 0: {'data_loss_eval': 0.695035994052887, 'data_loss_max_eval': 0.8318287134170532, 'data_loss_median_eval': 0.6947414875030518, 'inacc_eval': 0.9683398008346558, 'inacc_harsh_eval': 0.9683398008346558, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.9683398008346558, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 1: {'data_loss_eval': 0.6934254765510559, 'data_loss_max_eval': 0.7470057606697083, 'data_loss_median_eval': 0.6933587193489075, 'inacc_eval': 0.9669726490974426, 'inacc_harsh_eval': 0.9669726490974426, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.9669726490974426, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 2: {'data_loss_eval': 0.6928434371948242, 'data_loss_max_eval': 0.7237527370452881, 'data_loss_median_eval': 0.6928899

0,1
data_loss_eval,████▇▇▆▆▆▆▅▅▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
data_loss_max_eval,▆▅▄▅▅▆▅▆▇█▆▇▇▅▇▆▇▆▇▆▅▅▆▄▄▅▄▄▃▂▃▄▂▃▂▂▂▂▁▁
data_loss_median_eval,████▆▆▆▅▅▄▄▄▄▃▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
final_data_loss,▁
final_data_loss_eval,▁
final_data_loss_max,▁
final_data_loss_max_eval,▁
final_data_loss_median,▁
final_data_loss_median_eval,▁
final_inacc,▁

0,1
data_loss_eval,0.09332
data_loss_max_eval,0.53412
data_loss_median_eval,0.07767
final_data_loss,0.09324
final_data_loss_eval,0.09324
final_data_loss_max,0.74068
final_data_loss_max_eval,0.74068
final_data_loss_median,0.07698
final_data_loss_median_eval,0.07698
final_inacc,0.07537


wandb: Agent Starting Run: g0famop6 with config:
wandb: 	num_token: 12
wandb: 	p: 32
wandb: 	probabilistic: deterministic




Number of parameters: 4.57k
Step 0: {'data_loss_eval': 0.6914695501327515, 'data_loss_max_eval': 0.7289711236953735, 'data_loss_median_eval': 0.6915869116783142, 'inacc_eval': 0.9555078148841858, 'inacc_harsh_eval': 0.9555078148841858, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.9555078148841858, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 1: {'data_loss_eval': 0.6851370334625244, 'data_loss_max_eval': 0.7383896708488464, 'data_loss_median_eval': 0.6858466863632202, 'inacc_eval': 0.9114062190055847, 'inacc_harsh_eval': 0.9114062190055847, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.9114062190055847, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 1.0, 'inacc_max_eval': 1.0, 'inacc_median_eval': 1.0}
Step 2: {'data_loss_eval': 0.6562421321868896, 'data_loss_max_eval': 0.8152233958244324, 'data_loss_median_eval': 0.657315

0,1
data_loss_eval,█▇▇▆▆▆▆▅▅▅▅▅▅▅▅▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
data_loss_max_eval,▄▆▅▅▅▆▆▅▆▆▅▇▇▆▅▄▇▆▆▅▄▅▄▇▄▅▄▄█▃▂▅▃▂▃▃▁▂▂▂
data_loss_median_eval,█▇▇▇▆▆▆▅▅▅▅▅▅▅▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
final_data_loss,▁
final_data_loss_eval,▁
final_data_loss_max,▁
final_data_loss_max_eval,▁
final_data_loss_median,▁
final_data_loss_median_eval,▁
final_inacc,▁

0,1
data_loss_eval,0.1016
data_loss_max_eval,0.60134
data_loss_median_eval,0.08601
final_data_loss,0.10116
final_data_loss_eval,0.10116
final_data_loss_max,0.69525
final_data_loss_max_eval,0.69525
final_data_loss_median,0.0855
final_data_loss_median_eval,0.0855
final_inacc,0.06688


wandb: Agent Starting Run: cfl1y4cr with config:
wandb: 	num_token: 12
wandb: 	p: 32
wandb: 	probabilistic: random




Number of parameters: 4.57k
Step 0: {'data_loss_eval': 0.6947854161262512, 'data_loss_max_eval': 0.7366889715194702, 'data_loss_median_eval': 0.6949082016944885, 'inacc_eval': 0.9683920741081238, 'inacc_harsh_eval': 1.0, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.5001171827316284, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 0.45499998331069946, 'inacc_max_eval': 1.0, 'inacc_median_eval': 0.9818335771560669}
Step 1: {'data_loss_eval': 0.6904545426368713, 'data_loss_max_eval': 0.7316902279853821, 'data_loss_median_eval': 0.690800130367279, 'inacc_eval': 0.955399215221405, 'inacc_harsh_eval': 1.0, 'inacc_harsh_max_eval': 1.0, 'inacc_harsh_median_eval': 1.0, 'inacc_lenient_eval': 0.4202539026737213, 'inacc_lenient_max_eval': 1.0, 'inacc_lenient_median_eval': 0.029999999329447746, 'inacc_max_eval': 1.0, 'inacc_median_eval': 0.9676669836044312}
Step 2: {'data_loss_eval': 0.6709215641021729, 'data_loss_max_eval': 0.763923168182373, 'da

In [None]:
sweep_configuration = {
    'method': 'grid',
    'name': 'sweep_AR_P',
    'parameters': dict(
        probabilistic={'values': ["single_seed", "deterministic", "random"]},
        num_token={'values': [15]},
        p={'values': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, -1]},
        # seed={'values': [10, 11, 12, 13, 14]},
    )
}
sweep_id = wandb.sweep(sweep=sweep_configuration)

In [None]:
sweep_configuration = {
    'method': 'grid',
    'name': 'sweep_AR_BS',
    'parameters': dict(
        probabilistic={'values': ["single_seed", "deterministic", "random"]},
        num_token={'values': [15]},
        batch_size={'values': [20, 30, 40, 60, 90, 120, 180, 270, 360, 540, 810, 1080]},
        p={'values': [1, 10, 100]},
        data_pooling={'values': ["lp"]},
        seed={'values': [10, 11, 12, 13, 14]},
    )
}
sweep_id = wandb.sweep(sweep=sweep_configuration)

In [None]:
sweep_configuration = {
    'method': 'grid',
    'name': 'sweep_AR_M',
    'parameters': dict(
        probabilistic={'values': ["single_seed", "deterministic", "random"]},
        num_token={'values': [15]},
        num_seed={'values': [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]},
        p={'values': [1, 10, 100]},
        data_pooling={'values': ["lp"]},
        seed={'values': [10, 11, 12, 13, 14]},
    )
}
sweep_id = wandb.sweep(sweep=sweep_configuration)

In [None]:
sweep_configuration = {
    'method': 'grid',
    'name': 'sweep_AR_numseed',
    'parameters': dict(
        probabilistic={'values': ["single_seed", "deterministic", "random"]},
        num_token={'values': [15]},
        seed_size={'values': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]},
        p={'values': [1, 10, 100]},
        data_pooling={'values': ["lp"]},
        seed={'values': [10, 11, 12, 13, 14]},
    )
}
sweep_id = wandb.sweep(sweep=sweep_configuration)