In [4]:
import math, copy

base_config = {
        "dataset": "c4_subset",
        "batch_size": 32,  # physical batch size 256
        "learning_rate": 0.001 * math.sqrt(4),
        "min_lr": 1e-5,
        "lr_schedule": "cosine",
        "warmup_epochs": 1,
        "warmup_epochs_frac": 0.1,
        "weight_decay": 0.1,
        "hidden_dim": 64,  # Base hidden dimension
        "num_layers": 4,  # Base number of layers
        "num_heads": 4,
        "dropout": 0.0,
        "seq_length": 128,
        "wikitext_limit": 5 * 10**7,
        "pos_encoding": "rotary",
        "init_scheme": "transformer_scaled",
        "stride": 64,
        "pin_memory": True,
        "compile": False,
        "prefetch_factor": 8,
        "min_epochs": 2,
        "max_epochs": 2,
        "use_gradient_clipping": True,
        "gradient_clip_val": 1.0,
        "label_smoothing": 0.0,
        "gradient_accumulation_steps": 16,
        "optimizer": "adamw",
        "activation": "gelu",
        "norm_type": "layer",
        "results_folder": "Former_Experiments_Folder",
        "csv_log_interval": 50,
        "seed": 789,
    }

def chinchilla_scale(base_cfg, hidden_dims):
    """
    Return a list of configs that satisfy:
      • tokens ≈ 20 × parameters
      • per-step compute budget unchanged vs. baseline
      • depth/width ratio fixed (layers ∝ hidden_dim)
    """

    def param_count(d, L):
        # crude but width-dominant: 12·L·d²  (ignores embeddings/out-proj)
        return 12 * L * d**2

    base_d = base_cfg["hidden_dim"]
    base_L = base_cfg["num_layers"]
    base_bsz = base_cfg["batch_size"]
    base_lr = base_cfg["learning_rate"]
    base_clip = base_cfg["gradient_clip_val"]
    seq_len = base_cfg["seq_length"]

    out = []
    for d in hidden_dims:
        width_scale = d / base_d

        # 1) Depth: keep L ∝ d   (so aspect-ratio is preserved)
        L = max(1, int(round(base_L * width_scale)))

        # 2) Keep per-step FLOPs ≈ const ⇒ batch ∝ 1 / (width² · depth/base_depth)
        flops_scale = (width_scale**2) * (L / base_L)
        bsz = max(1, int(round(base_bsz / flops_scale)))

        # 3) LR & grad-clip heuristics
        lr = base_lr * (base_d / d) ** 0.5
        clip = base_clip * math.sqrt(width_scale)

        # 4) Chinchilla target tokens  (≈ 20 × parameters)
        params = param_count(d, L)
        tgt_tok = int(20 * params)

        # 5) Convert token target into epochs
        tokens_per_step = bsz * seq_len
        est_steps = math.ceil(tgt_tok / tokens_per_step)
        max_epochs = math.ceil(
            est_steps / (len(base_cfg.get("dataset", [])) or 1)
        )  # adjust as needed

        cfg = copy.deepcopy(base_cfg)
        cfg.update(
            {
                "hidden_dim": d,
                "num_layers": L,
                "num_heads": max(1, d // 16),
                "batch_size": bsz,
                "learning_rate": lr,
                "gradient_clip_val": clip,
                "target_tokens": tgt_tok,
                "max_epochs": max(max_epochs, cfg.get("min_epochs", 1)),
            }
        )
        out.append(cfg)
    return out


print(chinchilla_scale(base_config, [256]))

[{'dataset': 'c4_subset', 'batch_size': 1, 'learning_rate': 0.001, 'min_lr': 1e-05, 'lr_schedule': 'cosine', 'warmup_epochs': 1, 'warmup_epochs_frac': 0.1, 'weight_decay': 0.1, 'hidden_dim': 256, 'num_layers': 16, 'num_heads': 16, 'dropout': 0.0, 'seq_length': 128, 'wikitext_limit': 50000000, 'pos_encoding': 'rotary', 'init_scheme': 'transformer_scaled', 'stride': 64, 'pin_memory': True, 'compile': False, 'prefetch_factor': 8, 'min_epochs': 2, 'max_epochs': 218454, 'use_gradient_clipping': True, 'gradient_clip_val': 2.0, 'label_smoothing': 0.0, 'gradient_accumulation_steps': 16, 'optimizer': 'adamw', 'activation': 'gelu', 'norm_type': 'layer', 'results_folder': 'Former_Experiments_Folder', 'csv_log_interval': 50, 'seed': 789, 'target_tokens': 251658240}]


In [9]:
import math

# ---------- common settings ----------
COMMON = dict(
    dataset             = "c4_subset",
    lr_schedule         = "cosine",
    warmup_epochs       = 1,
    warmup_epochs_frac  = 0.10,
    weight_decay        = 0.10,
    dropout             = 0.0,          # bump to 0.1-0.2 for >100 M tokens if needed
    seq_length          = 128,
    pos_encoding        = "rotary",
    init_scheme         = "transformer_scaled",
    stride              = 64,
    pin_memory          = True,
    compile             = False,
    prefetch_factor     = 8,
    min_epochs          = 1,
    max_epochs          = 1,
    use_gradient_clipping = True,
    gradient_clip_val   = 1.0,
    label_smoothing     = 0.0,
    optimizer           = "adamw",
    activation          = "gelu",
    norm_type           = "layer",
    results_folder      = "Former_Experiments_Folder",
    csv_log_interval    = 50,
    seed                = 789,
)

GPT2_VOCAB_SIZE = 50257

def make_cfg(d_model, n_layers, vocab:int = GPT2_VOCAB_SIZE):
    heads = max(1, d_model // 16)
    lr    = 0.001 * math.sqrt(d_model / 16)
    params = 12 * d_model * d_model * n_layers + vocab * d_model    # rough GPT-style count
    tokens = 20 * params                           # Chin. optimal compute
    eff_bs = 256                                   # keep effective batch ~constant
    phys_bs = 32     
    
    print("Overrides:")     
    print(dict(
        hidden_dim  = d_model,
        num_layers  = n_layers,
        num_heads   = heads,
        learning_rate = lr,
        batch_size  = phys_bs,
        gradient_accumulation_steps = eff_bs // phys_bs,
        train_tokens = tokens,         # ← 1 epoch budget
    ))        
    print(params, "for ", d_model, "d model")                 # fits typical 40 GB A100 w/ acc-16

    return dict(
        COMMON,
        hidden_dim  = d_model,
        num_layers  = n_layers,
        num_heads   = heads,
        learning_rate = lr,
        batch_size  = phys_bs,
        gradient_accumulation_steps = eff_bs // phys_bs,
        train_tokens = tokens,         # ← 1 epoch budget
    )


#choose layer depth so that it is roughly proportional to the hidden dimension cubed
CONFIGS = {
    "dim16"  : make_cfg(16,  2),
    "dim32"  : make_cfg(32,  3),
    "dim64"  : make_cfg(64,  4),   # original width/depth
    "dim128" : make_cfg(128, 8),
}

# Pretty-print if you run this file directly
if __name__ == "__main__":
    from pprint import pprint
    pprint(CONFIGS, width=120, sort_dicts=False)


Overrides:
{'hidden_dim': 16, 'num_layers': 2, 'num_heads': 1, 'learning_rate': 0.001, 'batch_size': 32, 'gradient_accumulation_steps': 8, 'train_tokens': 16205120}
810256 for  16 d model
Overrides:
{'hidden_dim': 32, 'num_layers': 3, 'num_heads': 2, 'learning_rate': 0.0014142135623730952, 'batch_size': 32, 'gradient_accumulation_steps': 8, 'train_tokens': 32901760}
1645088 for  32 d model
Overrides:
{'hidden_dim': 64, 'num_layers': 4, 'num_heads': 4, 'learning_rate': 0.002, 'batch_size': 32, 'gradient_accumulation_steps': 8, 'train_tokens': 68261120}
3413056 for  64 d model
Overrides:
{'hidden_dim': 128, 'num_layers': 8, 'num_heads': 8, 'learning_rate': 0.0028284271247461905, 'batch_size': 32, 'gradient_accumulation_steps': 8, 'train_tokens': 160115200}
8005760 for  128 d model
{'dim16': {'dataset': 'c4_subset',
           'lr_schedule': 'cosine',
           'warmup_epochs': 1,
           'warmup_epochs_frac': 0.1,
           'weight_decay': 0.1,
           'dropout': 0.0,
           