In [1]:
import torch

A = torch.tensor([ 512, 768, 1024, 1280, 1600])
B = (A/768) ** -0.5
# Python
C = {int(a): float(b) for a, b in zip(A.tolist(), B.tolist())}
print(f"dim by lr scale ratio: {C}")

dim by lr scale ratio: {512: 1.2247447967529297, 768: 1.0, 1024: 0.866025447845459, 1280: 0.7745966911315918, 1600: 0.6928203701972961}


In [1]:
from daisy.daisy_core import next_multiple_of_n

def estimate_params(L: int, d: int, V: int, tied_head: bool = False) -> dict:
    """
    Returns dict with per-layer, embeddings, and total params (in integers and millions).
    Formula:
      - Attention (Q,K,V,O): 4 * d * d
      - d_fc: 4 * d
      - MLP: 2 * d * d_fc
      - Per-layer â‰ˆ 4d^2 + 3 d d_fc  (defaults to 16 d^2 when d_fc=4d)
      - Embeddings: V * d
      - Output head: + V * d if untied
    """
    d_fc = 4 * d
    mlp = 2 * d * d_fc
    attn = 4 * d * d
    per_layer = mlp + attn
    layers_total = L * per_layer
    embed = V * d
    head = 0 if tied_head else next_multiple_of_n(V, n=128) * d
    ve = 3 * V * d
    scalars = 5 * L
    total = layers_total + embed + head + ve + scalars
    to_m = lambda x: round(x / 1e6, 3)
    return {
        "per_layer": per_layer,
        "per_layer_M": to_m(per_layer),
        "layers_total": layers_total,
        "layers_total_M": to_m(layers_total),
        "embeddings": embed,
        "embeddings_M": to_m(embed),
        "output_head": head,
        "output_head_M": to_m(head),
        "total": total,
        "total_M": to_m(total),
    }

In [4]:
import json
from training.hparams import load_hparams_from_yaml
hparams = load_hparams_from_yaml("../config/pretrain_pico.yml")
L = hparams.num_layers
d = hparams.model_dim
V = hparams.vocab_size
res = estimate_params(L=L, d=d, V=V, tied_head=False)
print(json.dumps(res, indent=2, sort_keys=True))

{
  "embeddings": 25731584,
  "embeddings_M": 25.732,
  "layers_total": 18874368,
  "layers_total_M": 18.874,
  "output_head": 25755648,
  "output_head_M": 25.756,
  "per_layer": 3145728,
  "per_layer_M": 3.146,
  "total": 147556382,
  "total_M": 147.556
}


In [11]:
target_param_data_ratio = 20
num_params = 350
target_tokens = target_param_data_ratio * num_params
print(f"Target tokens (M): {target_tokens}")

Target tokens (M): 7000


In [12]:
scale = 0.866025447845459
head_lr = 0.004 * scale
embed_lr = 0.2 * scale
scalar_lr = 0.015 * scale

print(f"Head LR: {head_lr}")
print(f"Embed LR: {embed_lr}")
print(f"Scalar LR: {scalar_lr}")


Head LR: 0.003464101791381836
Embed LR: 0.1732050895690918
Scalar LR: 0.012990381717681885


In [13]:
sft_scale = 1/50
print(f"Head LR: {head_lr*sft_scale}")
print(f"Embed LR: {embed_lr*sft_scale}")
print(f"Scalar LR: {scalar_lr*sft_scale}")

Head LR: 6.928203582763672e-05
Embed LR: 0.003464101791381836
Scalar LR: 0.0002598076343536377


In [16]:
0.015 * sft_scale

0.0003