In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap, jit

from matplotlib import pyplot as plt
import os
import sys

In [2]:
import model
import train

# Copy Task

## Problem

Given a repeating sequence of distinct tokens, continue the pattern. This equates to learning an induction head.

## Dataset Generation

The dataset consists of sequences of varying length that contain a repeating pattern and cut-off abruptly. The goal is to continue the sequence correctly. There is no semantic meaning behind tokens, so they can be randomly generated at init and frozen.

E.g. abcabcabca should be continued with bcabcabc

### Base case

The most simple case will simply have 64-character strings containing repeating character sequences of 4 to 15 distinct characters, so we see 8 to 4 repetitions. To start we can use 32 distinct tokens.

In [3]:
# create dataset; we need to add masking on the loss function!
key_d1 = jax.random.PRNGKey(0)

dataset_name = 'copytask'
token_arr = jnp.arange(32, dtype=jnp.uint16)
sample_len = 64
assert sample_len >= len(token_arr)
n_data = 2**19*3
patt_min = 4
patt_max = 15
assert patt_max >= patt_min
assert n_data % (patt_max-patt_min+1) == 0

pattern_inds = []
for patt_len in range(patt_min, patt_max+1):
  p = jnp.tile(jnp.arange(patt_len), 1+sample_len//patt_len)[:sample_len]
  pattern_inds.append(p)
pattern_inds = jnp.array(pattern_inds)

key_gen, key_shuffle = jax.random.split(key_d1)
key_perms = jax.random.split(key_gen, n_data)

tok_permutations = vmap(lambda k : jax.random.permutation(k, token_arr))(key_perms)
# tok_permutations = jax.random.choice(key_gen, token_arr, (n_data, sample_len))
pattern_inds_expanded = jnp.tile(pattern_inds, (n_data//len(pattern_inds), 1))

data = vmap(lambda i : tok_permutations[i][pattern_inds_expanded[i]])(jnp.arange(n_data))
data = data[jax.random.permutation(key_shuffle, jnp.arange(n_data))] # shuffle the data

val_data_len = n_data//16
train_ids = np.array(data, dtype=np.uint16)[:-val_data_len].flatten()
val_ids = np.array(data, dtype=np.uint16)[-val_data_len:].flatten()
try:
  os.mkdir(dataset_name)
except:
  print(f'dataset {dataset_name} already exists')
train_ids.tofile(os.path.join(dataset_name, 'train.bin'))
val_ids.tofile(os.path.join(dataset_name, 'val.bin'))

sample_len, n_data, val_data_len

dataset copytask already exists


(64, 1572864, 98304)

In [4]:
# !rm -rf /content/logs

In [10]:
# %%capture
config = train.TrainConfig(
    input_bin=f"{dataset_name}/val.bin",
    input_val_bin=f"{dataset_name}/val.bin",
    embd_dim = 128,
    head_dim = 128,
    n_layer = 2,
    block_size = sample_len, # should match the task sequence length so tasks are independently trained on
    batch_size = 64,
    gradient_accumulation_steps = 1,
    max_iters = 10_000,
    eval_iters = 10, # val_data_len // 64, # number of examples // batch_size
    learning_rate = 3e-3,
    warmup_iters = 100,
    lr_decay_iters = 10_000,
    vocab_size = len(token_arr),
    use_masked_loss = True,
    
    log_interval = 500,
    eval_interval = 1_000,
    # rope_base = 10*sample_len,

)
display(config)

AssertionError: use_masked_loss=True requires data with masks (filenames containing 'mask'), but input paths are 'copytask/val.bin' and 'copytask/val.bin'

In [6]:
params = train.train_loop(config)

[wandb] No credentials found. Falling back to offline mode.


[34m[1mwandb[0m: Number of parameters: 0.40M
[34m[1mwandb[0m: Loading training data...
[34m[1mwandb[0m: Process 0/1 prepared dataset from 1 file(s): 6,291,456 tokens, 0.01 GB on disk.
[34m[1mwandb[0m: Process 0/1 prepared loader with 10000 batches.
[34m[1mwandb[0m: Loaded 10000 training batches.
[34m[1mwandb[0m: Loading validation data...
[34m[1mwandb[0m: Process 0/1 prepared dataset from 1 file(s): 6,291,456 tokens, 0.01 GB on disk.
[34m[1mwandb[0m: Process 0/1 prepared loader with 10 batches.
[34m[1mwandb[0m: Loaded 10 validation batches.
[34m[1mwandb[0m: Starting Ahead-of-Time (AOT) compilation...


Number of parameters: 0.40M
Loading training data...
Process 0/1 prepared dataset from 1 file(s): 6,291,456 tokens, 0.01 GB on disk.
Process 0/1 prepared loader with 10000 batches.
Loaded 10000 training batches.
Loading validation data...
Process 0/1 prepared dataset from 1 file(s): 6,291,456 tokens, 0.01 GB on disk.
Process 0/1 prepared loader with 10 batches.
Loaded 10 validation batches.
Starting Ahead-of-Time (AOT) compilation...


[34m[1mwandb[0m: AOT compilation finished.
[34m[1mwandb[0m: Starting training...
[34m[1mwandb[0m: Running validation for step 0...


AOT compilation finished.
Starting training...
Running validation for step 0...
model/total_params: 397696 | model/attn_params: 131072 | model/mlp_params: 261888 | model/embed_params: 4096 | model/vocab_size: 32 | val_loss: 3.646 | step: 0 | lr: 0 | loss: 3.646
step: 500 | lr: 0.002988 | loss: 0.6871


[34m[1mwandb[0m: Running validation for step 1000...


Running validation for step 1000...
val_loss: 0.6234 | step: 1000 | lr: 0.00294 | loss: 0.592


[34m[1mwandb[0m: Cycling dataset...


step: 1500 | lr: 0.002856 | loss: 0.5835
Cycling dataset...


[34m[1mwandb[0m: Running validation for step 2000...


Running validation for step 2000...
val_loss: 0.5898 | step: 2000 | lr: 0.002738 | loss: 0.6058
step: 2500 | lr: 0.00259 | loss: 0.5522


[34m[1mwandb[0m: Running validation for step 3000...
[34m[1mwandb[0m: Cycling dataset...


Running validation for step 3000...
val_loss: 0.5789 | step: 3000 | lr: 0.002414 | loss: 0.6009
Cycling dataset...
step: 3500 | lr: 0.002216 | loss: 0.6025


[34m[1mwandb[0m: Running validation for step 4000...


Running validation for step 4000...
val_loss: 0.5659 | step: 4000 | lr: 0.002001 | loss: 0.588


[34m[1mwandb[0m: Cycling dataset...


step: 4500 | lr: 0.001773 | loss: 0.5663
Cycling dataset...


[34m[1mwandb[0m: Running validation for step 5000...


Running validation for step 5000...
val_loss: 0.5598 | step: 5000 | lr: 0.001539 | loss: 0.528
step: 5500 | lr: 0.001304 | loss: 0.5465


[34m[1mwandb[0m: Running validation for step 6000...
[34m[1mwandb[0m: Cycling dataset...


Running validation for step 6000...
val_loss: 0.5507 | step: 6000 | lr: 0.001074 | loss: 0.4925
Cycling dataset...
step: 6500 | lr: 0.0008556 | loss: 0.5283


[34m[1mwandb[0m: Running validation for step 7000...


Running validation for step 7000...
val_loss: 0.5459 | step: 7000 | lr: 0.0006536 | loss: 0.571
step: 7500 | lr: 0.0004733 | loss: 0.5358


[34m[1mwandb[0m: Cycling dataset...


Cycling dataset...


[34m[1mwandb[0m: Running validation for step 8000...


Running validation for step 8000...
val_loss: 0.5414 | step: 8000 | lr: 0.0003192 | loss: 0.5126
step: 8500 | lr: 0.0001951 | loss: 0.5281


[34m[1mwandb[0m: Running validation for step 9000...


Running validation for step 9000...
val_loss: 0.5409 | step: 9000 | lr: 0.0001041 | loss: 0.5269


[34m[1mwandb[0m: Cycling dataset...


Cycling dataset...
step: 9500 | lr: 4.865e-05 | loss: 0.54


[34m[1mwandb[0m: Final validation...
[34m[1mwandb[0m: Running validation for step 9999...
[34m[1mwandb[0m: Training finished.
[34m[1mwandb[0m: Saved checkpoint to logs/osc6wywd//state_step009999.pkl


Final validation...
Running validation for step 9999...
step: 9999 | val_loss: 0.5403
Training finished.
Saved checkpoint to logs/osc6wywd//state_step009999.pkl


0,1
loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁███▇▇▇▆▆▅▅▄▄▃▃▂▂▁▁▁
model/attn_params,▁
model/embed_params,▁
model/mlp_params,▁
model/total_params,▁
model/vocab_size,▁
step,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
val_loss,█▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.53996
lr,5e-05
model/attn_params,131072.0
model/embed_params,4096.0
model/mlp_params,261888.0
model/total_params,397696.0
model/vocab_size,32.0
step,9999.0
val_loss,0.54034


In [7]:
test_input = [1,2,3,4,5,6,7,8]*1

start_len = len(test_input)
for i in range(start_len, sample_len+1):
  padded_test_input = test_input + [0]*(sample_len - len(test_input))
  rope_params = model.precompute_rope(config.get_model_config(), None)
  preds = model.gpt_forward(params, rope_params,jnp.array(test_input)[None,:], config.get_model_config())
  new_ind = jnp.argmax(preds[0][i])
  test_input.append(new_ind.item())

In [8]:
print("Testing input:")
print(test_input[:start_len])

print("Predicted output:")
print(test_input[start_len:])

Testing input:
[1, 2, 3, 4, 5, 6, 7, 8]
Predicted output:
[1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 21]


# Path Graph

## Problem

Given a 'goal' token, identify which unique 'path' in context it comes from and return all tokens in the path up to and including the goal.

Concretely, suppose your context has two paths: A,B,C and P,Q,R,S. Given a goal R, we would return P,Q,R. This tests the model's ability to build pointers between tokens based on context.

## Dataset Generation

The dataset consists of a list of sequences. Sequences can be one of three types:
- context
- input
- output

No nesting of sequences is allowed.

#### Tokenization

Overall will use 128 tokens.

Special Tokens (with index):
- Context Start: 0
- Input Start: 1
- Output start: 2

Each special token implicitly ends the previous sequence and starts a new one. These embeddings can be learned.

All remaining tokens are exchangeable, i.e. only act as pointers and contain no semantic meaning. Their embeddings will be fixed and initialized randomly.

1. Minimal example
- 0, 3,4,5,6, 1, 5, 2, 3,4,5

2. Two contexts
- 0, 3,4,5,6, 0, 7,8,9, 1, 4, 2, 3,4

3. Stream of problems (context grows and problems arrive independently)
- 0, 3,4,5,6, 0, 7,8,9, 1, 5, 2, 3,4,5, 0, 10,11,12,13,14, 1, 4, 2, 3,4, 1, 12, 2, 10,11,12

For position encoding we'll use RoPE which works well with QK-norm attention.

In [9]:
# TBD, need to add masking to loss_fn()!