# Testing data pipeline

Get a dataloader based on given experiment config

In [1]:
from pathlib import Path

from hydra import initialize_config_dir, compose
from omegaconf import OmegaConf
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

from arithmetic_lm.dataset import DATASET_CLASSES
from arithmetic_lm.tokenizer import CharTokenizer

## Get config

In [2]:
EXPERIMENT = "15/exp15_trans_dec.yaml"
DATASET = "add_generalize_to_longer_19"
abs_config_dir = Path("..").resolve() / "arithmetic_lm/conf"

# add overrides if needed
overrides = [
    "training.only_answer_loss=true",
    "data.format.operand_random_spaces_amount=0.3",
]

if EXPERIMENT:
    overrides.append(f"+experiment={EXPERIMENT}")
if DATASET:
    overrides.append(f"data={DATASET}")
print(overrides)

with initialize_config_dir(version_base=None, config_dir=str(abs_config_dir)):
    cfg = compose(config_name="train.yaml", overrides=overrides)
    print(OmegaConf.to_yaml(cfg))

['training.only_answer_loss=true', 'data.format.operand_random_spaces_amount=0.3', '+experiment=15/exp15_trans_dec.yaml', 'data=add_generalize_to_longer_19']
data:
  train_ds_class: ArithmeticLMSequenceDataset
  format:
    pad: $
    pad_ops_zero: null
    pad_ans_zero: null
    reverse_ops: false
    reverse_ans: false
    encdec: false
    filler_tokens_prompt: 0
    filler_tokens_ans: 0
    operand_random_spaces_amount: 0.3
    answer_random_spaces_amount: 0
    scratchpad: false
  train: data/addition/generalize_to_longer_19/train_add_1-19_except18_1M.txt
  test:
    in_dist: data/addition/generalize_to_longer_19/test_add_in_distribution_2000.txt
    ood_18digit: data/addition/generalize_to_longer_19/test_add_ood_18digit_100.txt
    ood_20digit: data/addition/generalize_to_longer_19/test_add_ood_20digit_100.txt
    ood_21digit: data/addition/generalize_to_longer_19/test_add_ood_21digit_100.txt
    ood_22digit: data/addition/generalize_to_longer_19/test_add_ood_22digit_100.txt
    

## Initialize dataset and dataloader

In [3]:
tokenizer = CharTokenizer()

# dataset
ds_kwargs = {
    "tokenizer": tokenizer,
    "seq_len": cfg.model.args.context_len,
    "pad": cfg.data.format.pad,
    "pad_ops_zero": cfg.data.format.pad_ops_zero,
    "pad_ans_zero": cfg.data.format.pad_ans_zero,
    "reverse_ops": cfg.data.format.reverse_ops,
    "reverse_ans": cfg.data.format.reverse_ans,
    "filler_tokens_prompt": cfg.data.format.filler_tokens_prompt,
    "filler_tokens_ans": cfg.data.format.filler_tokens_ans,
    "equal_in_prompt": not cfg.data.format.encdec,
    "scratchpad": cfg.data.format.get("scratchpad", False),
    "operand_random_spaces_amount": cfg.data.format.get(
        "operand_random_spaces_amount", 0
    ),
}
ds_class = DATASET_CLASSES[cfg.data.train_ds_class]
ds = ds_class(Path("..").resolve() / cfg.data.train, **ds_kwargs)
print(ds, len(ds), ds[0])

<arithmetic_lm.dataset.arithmetic_dataset.ArithmeticLMSequenceDataset object at 0x130d26e40> 896200 (tensor([65,  0, 72,  0, 80,  0]), tensor([ 0, 72,  0, 80,  0, 65]))


In [4]:
# dataloader
dl = torch.utils.data.DataLoader(
    ds, batch_size=4, shuffle=True, num_workers=0, collate_fn=ds.collate_fn
)

## Get a batch of data and do whatever

In [18]:
from arithmetic_lm.model.utils import answer_mask

for i, batch in enumerate(dl):
    tgt = batch[1]
    print(tgt.shape)
    # decode and print tgt before
    for t in tgt:
        print(repr(tokenizer.decode(t)))

    answer_masked_tgt = answer_mask(
        tgt,
        pad_token_id=tokenizer.pad_token_id,
        equal_token_id=tokenizer.encode("=")[0],
    )

    print("AFTER")
    print(answer_masked_tgt.shape)

    # decode and print tgt after
    for t in answer_masked_tgt:
        print(repr(tokenizer.decode(t).replace("\x0c", ".")))

    break

torch.Size([4, 62])
'6536776816 8+3733  8784 039=102706552207$\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c'
'1  38824+301959=440783$\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c'
'2425+2887=5312$\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c\x0c'
'944 60 9009978 41453+516100  87449437 93 2=146070988447279385$'
AFTER
torch.Size([4, 62])
'...........................=102706552207$.....................'
'...............=440783$.......................................'
'.........=5312$...............................................'
'..........................................=146070988447279385$'


In [12]:
# for GPT2
context_len = 1024
n_tokens = 50257
hidden_dim = 768
n_layers = 12
mlp_mult = 4

n_params = transformer_decoder_param_count(
    context_len, n_tokens, hidden_dim, n_layers, mlp_mult
)

print(f"GPT2 params: {n_params / 1e6:.2f}M")

GPT2 params: 163.09M
