In [2]:
from pathlib import Path

import torch

from olmo.config import TrainConfig
from olmo.model import OLMo

from importlib import reload

from exults.tok_utils import pretty_print_logits

from transformers import AutoModelForCausalLM, AutoTokenizer

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1.7-7B-hf")

tokenizer_config.json:   0%|          | 0.00/5.37k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.12M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

In [4]:
local_cache = Path('/scratch7/users/fjiahai/olmo_artifacts/olmo_1_7_preanneal/')

In [5]:
cfg = TrainConfig.load(local_cache / 'config.yaml')

In [8]:
cfg.model

ModelConfig(d_model=4096, n_heads=32, n_kv_heads=None, clip_qkv=8.0, n_layers=32, mlp_ratio=4, mlp_hidden_size=22016, activation_type='swiglu', block_type='sequential', block_group_size=1, alibi=False, alibi_bias_max=8.0, rope=True, rope_full_precision=True, flash_attention=True, attention_dropout=0.0, multi_query_attention=None, attention_layer_norm=False, residual_dropout=0.0, embedding_dropout=0.0, layer_norm_type='default', layer_norm_with_affine=False, attention_layer_norm_with_affine=False, max_sequence_length=4096, include_bias=False, bias_for_layer_norm=False, scale_logits=False, vocab_size=50280, embedding_size=50304, weight_tying=False, eos_token_id=50279, pad_token_id=1, init_device='meta', init_fn='mitchell', init_std=0.02, init_cutoff_factor=None, precision='amp_bf16')

In [9]:
model = OLMo(cfg.model)

model.load_state_dict(torch.load(str(local_cache / 'model.pt'), mmap=True), assign=True)

<All keys matched successfully>

In [10]:
ret = tokenizer('Stephen Hawking was born in', return_tensors='pt')

In [12]:
model.cuda()

OLMo(
  (transformer): ModuleDict(
    (wte): Embedding(50304, 4096)
    (emb_drop): Dropout(p=0.0, inplace=False)
    (ln_f): LayerNorm()
    (blocks): ModuleList(
      (0-31): 32 x OLMoSequentialBlock(
        (dropout): Dropout(p=0.0, inplace=False)
        (act): SwiGLU()
        (attn_out): Linear(in_features=4096, out_features=4096, bias=False)
        (ff_out): Linear(in_features=11008, out_features=4096, bias=False)
        (rotary_emb): RotaryEmbedding()
        (attn_norm): LayerNorm()
        (ff_norm): LayerNorm()
        (att_proj): Linear(in_features=4096, out_features=12288, bias=False)
        (ff_proj): Linear(in_features=4096, out_features=22016, bias=False)
      )
    )
    (ff_out): Linear(in_features=4096, out_features=50304, bias=False)
  )
)

In [13]:
logits = model(input_ids=ret.input_ids.cuda()).logits

In [15]:
len(tokenizer)

50280

In [16]:
logits.shape

torch.Size([1, 6, 50304])

In [22]:
pretty_print_logits(tokenizer, logits[0, -1, :])

Top 0th token. Logit: 14.05 Prob: 27.01% Token: | Oxford|
Top 1th token. Logit: 14.02 Prob: 26.18% Token: | 1942|
Top 2th token. Logit: 12.73 Prob:  7.23% Token: | the|
Top 3th token. Logit: 12.67 Prob:  6.78% Token: | October|
Top 4th token. Logit: 12.26 Prob:  4.50% Token: | London|
Top 5th token. Logit: 12.07 Prob:  3.70% Token: | Cambridge|
Top 6th token. Logit: 11.80 Prob:  2.84% Token: | England|
Top 7th token. Logit: 10.59 Prob:  0.85% Token: | January|
Top 8th token. Logit: 10.41 Prob:  0.70% Token: | 1948|
Top 9th token. Logit: 10.31 Prob:  0.64% Token: | February|


In [24]:
def get_olmo_mlp(layers):
    return [name for name, p in model.named_parameters() if 'ff' in name and any(f'blocks.{l}.' in name for l in layers)]

In [25]:
get_olmo_mlp([2, 3, 4])

['transformer.blocks.2.ff_out.weight',
 'transformer.blocks.2.ff_proj.weight',
 'transformer.blocks.3.ff_out.weight',
 'transformer.blocks.3.ff_proj.weight',
 'transformer.blocks.4.ff_out.weight',
 'transformer.blocks.4.ff_proj.weight']

# Old stuff

In [1]:
load_path = "https://olmo-checkpoints.org/ai2-llm/olmo-medium/p067ktg9/step558223-unsharded/"

In [1]:
# new thing
load_path = "https://olmo-checkpoints.org/ai2-llm/olmo-medium/0rdfxd6d/step477000-unsharded/"

In [2]:
from olmo.checkpoint import load_state_dict

In [3]:
local_cache = '/data/fjiahai/olmo_artifacts/'

In [4]:
optim_state_dict_to_load = load_state_dict(
    load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
)

In [8]:
test = optim_state_dict_to_load['state']['transformer.blocks.0.att_proj.weight']

In [9]:
test

{'exp_avg': tensor([[ 6.5499e-08,  8.0335e-08,  1.5783e-08,  ..., -4.8924e-08,
           8.3826e-08, -4.1052e-08],
         [ 2.4778e-08,  8.3342e-08,  8.9644e-08,  ..., -1.3339e-07,
           3.9266e-08, -8.0391e-08],
         [ 7.0408e-08,  3.9740e-08,  5.3293e-08,  ..., -1.2388e-07,
           1.8799e-08, -3.0253e-09],
         ...,
         [ 2.1911e-06, -2.9228e-06, -2.9721e-07,  ...,  1.3703e-06,
          -1.4783e-06,  9.2868e-07],
         [ 5.4719e-06, -1.4839e-06, -4.9751e-06,  ...,  2.4974e-06,
           1.2438e-06,  1.6049e-07],
         [-2.0320e-06,  3.3357e-06,  2.2788e-06,  ..., -2.5484e-07,
          -4.8298e-06, -2.9821e-06]]),
 'exp_avg_sq': tensor([[2.3876e-13, 1.2379e-13, 8.2509e-14,  ..., 1.1497e-13, 1.3299e-13,
          1.1867e-13],
         [2.1846e-13, 1.4645e-13, 1.6955e-13,  ..., 1.9589e-13, 1.4935e-13,
          1.5493e-13],
         [1.4841e-13, 1.0610e-13, 7.8815e-14,  ..., 7.2453e-14, 8.4843e-14,
          8.8198e-14],
         ...,
         [1.2453e-

In [11]:
noise_ratio = test['exp_avg']**2/test['exp_avg_sq']

In [13]:
noise_ratio.mean()

tensor(0.0523)

In [32]:
def quantile(x, q, num_samples=10000):
    x = x.flatten()
    p = num_samples / x.shape[0]
    samples = x[torch.rand(x.shape) < p]
    return samples.quantile(q)

In [50]:
def get_metrics(metrics_fn):
    num_layers = 32
    weights = [
        'transformer.blocks.{layer}.att_proj.weight',
        'transformer.blocks.{layer}.attn_out.weight',
        'transformer.blocks.{layer}.ff_out.weight',
        'transformer.blocks.{layer}.ff_proj.weight',
    ]
    all_data = []
    for layer in range(num_layers):
        for name, key in zip(['att_proj', 'attn_out', 'ff_out', 'ff_proj'], weights):
            key = key.format(layer=layer)
            metrics = get_quartiles(metrics_fn(optim_state_dict_to_load['state'][key]))
            all_data.extend([
                {
                    'layer': layer,
                    'name': name,
                    'metric': metric,
                    'value': value
                }
                for metric, value in metrics.items()
            ])
    return all_data
            

In [51]:
def get_quartiles(x):
    num_samples = 10_000
    x = x.flatten()
    p = num_samples / x.shape[0]
    samples = x[torch.rand(x.shape) < p]
    return {
        '0.1': samples.quantile(0.1),
        '0.25': samples.quantile(0.25),
        '0.5': samples.quantile(0.5),
        '0.75': samples.quantile(0.75),
        '0.9': samples.quantile(0.9),
    }

In [52]:
import pandas as pd

In [55]:
def var_to_eps(d):
    eps = 1e-5
    return d['exp_avg_sq'].pow(0.5)/eps

In [56]:
def noise_to_var(d):
    return d['exp_avg'].abs()/d['exp_avg_sq'].pow(0.5)

In [57]:
df = pd.DataFrame(get_metrics(var_to_eps))

In [67]:
wider_df = df.pivot(index=['layer', 'name'], values='value', columns=['metric']).reset_index()
wider_df

metric,layer,name,0.1,0.25,0.5,0.75,0.9
0,0,att_proj,0.015577,0.027051,0.065611,0.209084,0.343591
1,0,attn_out,0.188178,0.255414,0.335484,0.457230,0.656300
2,0,ff_out,0.323821,0.351649,0.386895,0.434107,0.497275
3,0,ff_proj,0.238291,0.261760,0.290944,0.326912,0.371479
4,1,att_proj,0.060574,0.092586,0.170461,0.391833,0.540019
...,...,...,...,...,...,...,...
123,30,ff_proj,0.234143,0.259287,0.296859,0.344094,0.414386
124,31,att_proj,0.051172,0.079858,0.142700,0.289525,0.441328
125,31,attn_out,0.242589,0.296128,0.357167,0.425082,0.502822
126,31,ff_out,0.336558,0.381530,0.443886,0.559444,0.870701


In [69]:
low = '0.25'
high = '0.75'

In [86]:
(
    lp.ggplot(wider_df) +
    lp.geom_line(lp.aes(x='layer', y='0.5', color='name'), tooltips=lp.layer_tooltips()) + 
    lp.geom_ribbon(lp.aes(x='layer', ymin=low, color='name', ymax=high), tooltips='none') +
    lp.geom_line(lp.aes(x='layer', y='0.1', color='name'),linetype='dashed') + 
    lp.geom_line(lp.aes(x='layer', y='0.9', color='name'), linetype='dashed') +
    lp.ggsize(1000, 600)
) 

In [63]:
import lets_plot as lp
lp.LetsPlot.setup_html()

In [92]:
df2['value'] = df2['value'].apply(lambda x: x.item())

In [87]:
df2 = pd.DataFrame(get_metrics(noise_to_var))

In [93]:
wider_df2 = df2.pivot(index=['layer', 'name'], values='value', columns=['metric']).reset_index()
wider_df2

metric,layer,name,0.1,0.25,0.5,0.75,0.9
0,0,att_proj,0.030493,0.076938,0.159378,0.270215,0.377729
1,0,attn_out,0.030308,0.075721,0.159070,0.266128,0.377912
2,0,ff_out,0.029085,0.075508,0.158135,0.262823,0.375893
3,0,ff_proj,0.030380,0.074795,0.157721,0.265215,0.369034
4,1,att_proj,0.027331,0.071774,0.155073,0.262610,0.368357
...,...,...,...,...,...,...,...
123,30,ff_proj,0.029988,0.075553,0.160171,0.267904,0.374199
124,31,att_proj,0.030345,0.077297,0.157685,0.262800,0.371267
125,31,attn_out,0.029070,0.076741,0.160625,0.269649,0.375819
126,31,ff_out,0.030253,0.075483,0.160050,0.268123,0.372747


In [94]:
(
    lp.ggplot(wider_df2) +
    lp.geom_line(lp.aes(x='layer', y='0.5', color='name'), tooltips=lp.layer_tooltips()) + 
    lp.geom_ribbon(lp.aes(x='layer', ymin=low, color='name', ymax=high), tooltips='none') +
    lp.geom_line(lp.aes(x='layer', y='0.1', color='name'),linetype='dashed') + 
    lp.geom_line(lp.aes(x='layer', y='0.9', color='name'), linetype='dashed') +
    lp.ggsize(1000, 600)
) 