# Setup

TPU setup:

* Log into blank TPU VM, forwarding ports:

```gcloud compute tpus tpu-vm ssh tpu-v3-8-two --project=tpu-research-cloud-project --zone=europe-west4-a -- -L 8080:localhost:8080 -L 9000:localhost:9000```

```screen -U```

```sudo docker run -p 127.0.0.1:9000:8080 --privileged europe-docker.pkg.dev/colab-images/public/runtime```

* Copy-paste token URL from the docker run, using it to Connect => local runtime

E.g. http://127.0.0.1:9000/?token=95decc7a09a6f477347b83b866aa44a69c3572fa804731fb

* Upload data to `/content/generated-easy-attention-dataset.txt` (from ./github.com/houeland/tpu-ml/easy-attention/model/data/generated-easy-attention-dataset.txt)

In [None]:
!pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html optax==0.1.7 flax==0.7.4 ml-collections==0.1.1
# !pip install jax==0.4.16 jaxlib==0.4.16 optax==0.1.7 flax==0.7.4 ml-collections==0.1.1
# !pip install optax==0.1.7 flax==0.7.4 ml-collections==0.1.1


Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Collecting jax[tpu]==0.4.16
  Downloading jax-0.4.16-py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Collecting flax==0.7.4
  Downloading flax-0.7.4-py3-none-any.whl (233 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.5/233.5 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ml-collections==0.1.1
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jaxlib==0.4.16 (from jax[tpu]==0.4.16)
  Downloading jaxlib-0.4.16-cp310-cp310-manylinux2014_x86_64.whl (84.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.5/84.5 MB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hColle

In [None]:
import jax
print(jax.__version__)
print(jax.devices())

from jax import config
config.update("jax_numpy_rank_promotion", "raise")

0.4.16
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]


In [None]:
import datetime

class LoggingModule:
  def __init__(self):
    ts = lambda: datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    self.info =    lambda *args: print(ts(), "INFO   ", *args)
    self.warning = lambda *args: print(ts(), "WARNING", *args)

# Config

In [None]:
import ml_collections


# Probably want a higher learning rate now that batches are bigger?

def get_config():
    # return big_config()
    # return normal_config()
    return small_config()
    # return tiny_config()


def big_config():
    config = ml_collections.ConfigDict()
    config.dataset_batch_size = 16
    config.dataset_sequence_length = 256
    config.learning_rate = 0.01
    config.momentum = 0.9
    config.model_embed_dim = 512
    config.transformer_num_layers = 12
    config.transformer_num_attention_heads = 16
    config.transformer_attention_size_per_head = 8
    config.transformer_dropout_rate = 0.1
    return ml_collections.FrozenConfigDict(config)


def normal_config():
    config = ml_collections.ConfigDict()
    config.dataset_batch_size = 16
    config.dataset_sequence_length = 256
    config.learning_rate = 0.01
    config.momentum = 0.9
    config.model_embed_dim = 128
    config.transformer_num_layers = 6
    config.transformer_num_attention_heads = 8
    config.transformer_attention_size_per_head = 4
    config.transformer_dropout_rate = 0.1
    return ml_collections.FrozenConfigDict(config)


def small_config():
    config = ml_collections.ConfigDict()
    config.dataset_batch_size = 16
    config.dataset_sequence_length = 256
    config.learning_rate = 0.05
    config.momentum = 0.0
    config.model_embed_dim = 64
    config.transformer_num_layers = 2
    config.transformer_num_attention_heads = 2
    config.transformer_attention_size_per_head = 16
    config.transformer_dropout_rate = 0.1
    return ml_collections.FrozenConfigDict(config)


def tiny_config():
    config = ml_collections.ConfigDict()
    config.dataset_batch_size = 16
    config.dataset_sequence_length = 256
    config.learning_rate = 0.4
    config.momentum = 0.0
    config.model_embed_dim = 4
    config.transformer_num_layers = 2
    config.transformer_num_attention_heads = 2
    config.transformer_attention_size_per_head = 2
    config.transformer_dropout_rate = 0.1
    return ml_collections.FrozenConfigDict(config)


class ConfigModule:
  get_config = get_config

config_lib = ConfigModule()
config_lib.get_config = get_config


# Dataset

In [None]:
# Load a text file from disk and create mini-batches for training a byte-based Transformer model.

import jax
import jax.numpy as jnp
from typing import Iterator, NamedTuple, List, Any
import ml_collections
import re

# byte-level model, so === 8 bits
# not inherently meaningful, but for current computers and UTF-8, it's pretty good
# TODO: consider predicting individual bits instead? though probably need advancements in self-auto-chunking first,
# just like e.g. sentencepiece tokenization is currently quite a bit more effective than byte-level models like this one
VOCAB_SIZE = 256


class Batch(NamedTuple):
    inputs: jax.Array
    targets: jax.Array
    attn_datas: List[Any]

# input example: {"question": "Show biggest of these numbers: 1 4 5", "solution": "5;X@#Y"}
# attn_data placeholders are already included in the solution string, but need to be extracted

# skip anything that doesn't have a complete solution, or doesn't show all the numbers before the solution

# SKIP: first solution doesn't have associated numbers (so can't say which number was most impactful since they're not considered)
# ", "solution": "9"}
# {"question": "Determine the maximum number among: 8 2 2 9", "solution": "9"}
# {"question": "Find the maximum n

# SKIP: present but incomplete solution field (would need extra conditional checks about whether to replace X and Y or not)
# minimum number in the list: 3 6 5 3 5", "solution": "3"}
# {"question": "Return the maximum number among: 8 1 7 1 9 9", "solution": "5;

# algo: foreach `"s`
# * make sure it's a complete `"solution": "?`
# * make sure it has a complete directly preceding list of numbers starting from the embedded `:`

# * simplified: find each "s, and associate with it: all chars until+including a `}` [requires two more additional chars but that's fine]
#                             and: all chars since the previous `}`

# * so let's split by `}`

def determine_attention_for_row(row):
  s = bytes(row.tolist()).decode("utf-8")
  # print(s)
  parts = re.findall(r'(.*?}\n)', s, re.DOTALL)
  # print()
  # print(parts)
  ok = True
  leftover = s[sum(len(p) for p in parts):]
  if '"solution": "' in leftover:
    # Too much in left-over part, don't want to deal with messy part-attention-y cases
    ok = False
  for p in parts:
    if len(re.findall(r'(: [0-9])', p, re.DOTALL)) == 0:
      # Missing clear index of numbers
      ok = False
  if not ok:
    return None
  idx = 0
  attn_data = []
  for p in parts:
    sol_idx = idx+len(p)-9
    sol = s[sol_idx:sol_idx+1]
    nums = re.findall(r': ([0-9][0-9 ]+)', p, re.DOTALL)[0]
    nums_idx = idx+p.index(nums)
    nums_cnt = (len(nums)+1)//2
    # print("p", p, "s[idx:]", s[idx:idx+len(p)], "sol_idx", sol_idx, "sol", sol, "nums", nums, "nums_idx", nums_idx, "nums_cnt", nums_cnt, "sol_recon", s[sol_idx], "nums_recon", s[nums_idx:nums_idx+2*nums_cnt-1], "done")
    idx += len(p)
    attn_data.append(dict(sol_idx=sol_idx, nums_idx=nums_idx, nums_cnt=nums_cnt))
  return attn_data
  # print(ok)
  # print()
  # print()

def load_from_file(filename: str, config: ml_collections.ConfigDict) -> Iterator[Batch]:
    with open(filename, "rb") as file:
        array = jnp.frombuffer(file.read(), dtype=jnp.uint8)
    print(f"{array.shape=}")

    crop_len = config.dataset_sequence_length + 1  # type: ignore
    _num_batches, remainder = jnp.divmod(
        array.shape[0], config.dataset_batch_size * crop_len
    )
    if remainder:
        array = array[:-remainder]
    ds = array.reshape([-1, crop_len])
    print(f"{ds.shape=}")
    it = iter(ds)

    inputs = []
    targets = []
    attn_datas = []
    for row in it:
      # print("row", array_to_text(row).replace('\n', '_'))
      attn_data = determine_attention_for_row(row)
      if attn_data is None:
        continue
      inputs.append(row[:-1])
      targets.append(row[1:])
      attn_datas.append(attn_data)
      if len(inputs) == config.dataset_batch_size:
            yield Batch(inputs=jnp.stack(inputs), targets=jnp.stack(targets), attn_datas=attn_datas)
            inputs = []
            targets = []
            attn_datas = []

class DatasetModule:
  Batch = Batch
  load_from_file = load_from_file
  VOCAB_SIZE = VOCAB_SIZE

dataset = DatasetModule()
dataset.Batch=Batch
dataset.load_from_file=load_from_file
dataset.VOCAB_SIZE = VOCAB_SIZE

In [None]:
import time

def test_load_from_file():
  z_cntr = 0
  z_config = ml_collections.ConfigDict(config_lib.get_config())
  z_config.dataset_sequence_length = 128
  z_config = ml_collections.FrozenConfigDict(z_config)
  t_last = time.time()
  for batch in dataset.load_from_file("/content/generated-easy-attention-dataset_with_attn_data_10M.txt", z_config):
    # print("batch:")
    # for i in range(len(batch.attn_datas)):
    #   print("  inp:", array_to_text(batch.inputs[i]).replace('\n', '_'))
    #   print("  trg:", array_to_text(batch.targets[i]).replace('\n', '_'))
    #   print("  attn_data:", batch.attn_datas[i])
    #   print()
    # print()
    # print()
    # print()
    z_cntr += 1
    if z_cntr % 100 == 0:
      t_now = time.time()
      print(t_now - t_last, z_cntr)
      t_last = t_now
    # print(batch)
    # if (z_cntr > 100): break

# test_load_from_file()

In [None]:
# !time bunzip2 'with_attn_data_batches_bs=16_seqlen=128.pickle.bz2'


real	0m26.494s
user	0m24.952s
sys	0m0.917s


In [None]:
import time
import pickle

__cache = {}

def iter_raw_batches(config):
  if __cache.get("config") != config:
    logging = LoggingModule()
    old = __cache.get("config")
    if not old or old.dataset_batch_size != config.dataset_batch_size or old.dataset_sequence_length != config.dataset_sequence_length:
      pickle_filename = f'with_attn_data_batches_bs={config.dataset_batch_size}_seqlen={config.dataset_sequence_length}.pickle'
      try:
        with open(pickle_filename, 'rb') as handle:
          t_before = time.time()
          print("opened pickle file...")
          batches = pickle.load(handle)
          t_after = time.time()
          print(f"loading pickled batches...done in {t_after - t_before}!")
      except:
        print("recomputing batches...")
        t_before = time.time()
        batches = []
        for batch in dataset.load_from_file("/content/generated-easy-attention-dataset_with_attn_data_10M.txt", config):
        # for batch in dataset.load_from_file("/content/generated-easy-attention-dataset.txt", config):
        # for batch in dataset.load_from_file("/content/parity-dataset.txt", config):
          batches.append(batch)
          if len(batches) % 1000 == 0: logging.info(len(batches))
          # if len(batches) >= 20000: break
        t_after = time.time()
        print(f"recomputing batches...done in {t_after - t_before}!")
        with open(pickle_filename, 'wb') as handle:
          pickle.dump(batches, handle, protocol=pickle.HIGHEST_PROTOCOL)

      __cache["config"] = config
      __cache["batches"] = batches
  for batch in __cache["batches"]:
    yield batch

In [None]:
def make_all_batches():
  z_config = ml_collections.ConfigDict(config_lib.get_config())
  z_config.dataset_sequence_length = 128
  z_config = ml_collections.FrozenConfigDict(z_config)
  z_batches = list(iter_raw_batches(z_config))
  print(len(z_batches))

make_all_batches()


198667


# Model

In [None]:
from typing import Optional
import flax.linen as nn
import jax
import jax.numpy as jnp
import pprint


class Embedder(nn.Module):
    embed_dim: int
    vocab_size: int

    @nn.compact
    def __call__(self, in_tokens: jax.Array) -> jax.Array:
        # tokens.shape = (BATCH_SIZE, SEQUENCE_LENGTH)
        tokens = in_tokens
        # print(f"{tokens=}")
        num_batches, seq_len = tokens.shape
        # print(f"{num_batches=}, {seq_len=}")

        # token_embeddings.shape = (BATCH_SIZE, SEQUENCE_LENGTH, EMBEDDING_SIZE)
        token_embeddings = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.embed_dim,
            embedding_init=jax.nn.initializers.truncated_normal(stddev=0.02),
            name="embed__model_tokens_input",
        )(tokens)

        # positional_embeddings.shape = (SEQUENCE_LENGTH, EMBEDDING_SIZE)
        # alternatively, fixed sinusoidal embeddings, or rotary embeddings?
        positional_embeddings = self.param(
            "positional_embeddings",
            lambda key: jax.nn.initializers.truncated_normal(stddev=0.02)(
                key, [seq_len, self.embed_dim]
            ),
        )
        # print(f"{token_embeddings=}, {positional_embeddings=}")

        # input_embeddings.shape = (BATCH_SIZE, SEQUENCE_LENGTH, EMBEDDING_SIZE)
        # add dropout???
        input_embeddings = token_embeddings + jnp.broadcast_to(positional_embeddings, token_embeddings.shape)

        return input_embeddings

class Decoder(nn.Module):
    vocab_size: int
    @nn.compact
    def __call__(self, output: jax.Array) -> jax.Array:
        decoded = nn.Dense(self.vocab_size)(output)
        return decoded

class Transformer(nn.Module):
    num_attention_heads: int
    num_layers: int  # each layer has attention + MLP
    attention_size_per_head: int
    dropout_rate: float
    is_training: bool
    widening_factor: int = 4  # factor for widening MLP hidden layer

    @nn.compact
    def __call__(self, embeddings: jax.Array) -> jax.Array:
        initializer = nn.initializers.variance_scaling(
            2 / self.num_layers, "fan_in", "truncated_normal"
        )
        num_batches, seq_len, embedding_size = embeddings.shape
        # print(f"{num_batches=} {seq_len=} {embedding_size=}")

        # Compute causal mask for auto-regressive model
        # this (standard) approach produces inputs like:
        #   hello w0000
        #   hello wo000
        #   hello wor00
        # which doesn't sound right, something like this seems more reasonable for decoders that iterate:
        #   0000hello w
        #   000hello wo
        #   00hello wor
        # though given the first attention layer, maybe this doesn't matter?
        # depends on the position embedding, and does seem important for the one used here. ¯\_(ツ)_/¯
        # model can learn to work around it anyway, just seems odd
        causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len)))
        # print(f"{causal_mask=}")

        h = embeddings
        for _ in range(self.num_layers):
            # First the attention block.
            attn_block = nn.SelfAttention(
                num_heads=self.num_attention_heads,
                qkv_features=self.attention_size_per_head * self.num_attention_heads,
                dropout_rate=self.dropout_rate,
                kernel_init=nn.initializers.variance_scaling(
                    scale=2 / self.num_layers,
                    mode="fan_in",
                    distribution="truncated_normal",
                ),
                # use built-in normalization?
                # use_bias False???
                # broadcast_dropout False???
                # use decode:true for speedup!
            )
            h_norm = nn.LayerNorm()(h)
            h_attn = attn_block(
                inputs_q=h_norm, mask=causal_mask, deterministic=not self.is_training
            )

            # add another dropout layer here??? instead of the built-in one?
            h = h + h_attn

            # Then the MLP block
            h_norm = nn.LayerNorm()(h)
            h_dense = nn.Dense(
                self.widening_factor * embedding_size, kernel_init=initializer
            )(h_norm)
            # relu+dropout instead of gelu??
            h_dense = jax.nn.gelu(h_dense)
            h_dense = nn.Dense(embedding_size, kernel_init=initializer)(h_dense)
            h_dense = nn.Dropout(
                rate=self.dropout_rate, deterministic=not self.is_training
            )(h_dense)
            h = h + h_dense

        return nn.LayerNorm()(h)


class AutoregressiveTransformerModel(nn.Module):
    embedder: Embedder
    transformer: Transformer
    decoder: Decoder

    @nn.compact
    def __call__(self, in_tokens: jax.Array) -> jax.Array:
        """Forward pass, producing a sequence of logits."""
        input_embeddings = self.embedder(in_tokens)

        # output.shape = ???
        output = self.transformer(input_embeddings)
        # print(f"{input_embeddings=}, {output=}")

        decoded = self.decoder(output)

        # print(f"{decoded=}")
        return decoded, input_embeddings


class ModelModule:
  AutoregressiveTransformerModel = AutoregressiveTransformerModel
  Transformer = Transformer
  Embedder = Embedder
  Decoder = Decoder

model = ModelModule()


# Train

In [None]:
import jax
import jax.experimental
import jax.numpy as jnp
import ml_collections
import optax
from flax.training import train_state
from flax import struct
import typing
import functools

if False:
    from flax.metrics import tensorboard
else:
    tensorboard = None

PRNGKeyArray = jax.Array

class TrainState(train_state.TrainState):
    num_examples_trained_on: int
    transformer_apply_fn: typing.Callable = struct.field(pytree_node=False)
    embedder_apply_fn: typing.Callable = struct.field(pytree_node=False)
    decoder_apply_fn: typing.Callable = struct.field(pytree_node=False)
    predict_apply_fn: typing.Callable = struct.field(pytree_node=False)


def create_model(config: ml_collections.ConfigDict, is_training):
    transformer = model.Transformer(
        num_attention_heads=config.transformer_num_attention_heads,  # type: ignore
        num_layers=config.transformer_num_layers,  # type: ignore
        attention_size_per_head=config.transformer_attention_size_per_head,  # type: ignore
        dropout_rate=config.transformer_dropout_rate,  # type: ignore
        is_training=is_training,
    )
    embedder = model.Embedder(
        embed_dim=config.model_embed_dim,
        vocab_size=dataset.VOCAB_SIZE,
    )
    decoder = model.Decoder(
        vocab_size=dataset.VOCAB_SIZE,
    )
    atm = model.AutoregressiveTransformerModel(
        transformer=transformer,
        embedder=embedder,
        decoder=decoder,
    )
    return atm


def create_train_state(
    init_rng: PRNGKeyArray, config: ml_collections.ConfigDict
):
    m = create_model(config, is_training=True)
    shape = (config.dataset_batch_size, config.dataset_sequence_length)
    data = dataset.Batch(
        inputs=jnp.zeros(shape, dtype=jnp.uint8),
        targets=jnp.ones(shape, dtype=jnp.uint8),
        attn_datas=[],
    )
    params_rng_key, dropout_rng_key = jax.random.split(init_rng)
    variables = jax.jit(m.init)(
        {"params": params_rng_key, "dropout": dropout_rng_key}, data.inputs
    )
    # TODO: Use a learning rate schedule!!
    tx = optax.sgd(config.learning_rate, config.momentum)  # type: ignore
    pm = create_model(config, is_training=False)
    return TrainState.create(
        apply_fn=m.apply,
        params=variables["params"],
        tx=tx,
        num_examples_trained_on=0,
        transformer_apply_fn=m.transformer.apply,
        embedder_apply_fn=m.embedder.apply,
        decoder_apply_fn=m.decoder.apply,
        predict_apply_fn=pm.apply,
    )


logging_module = LoggingModule()

def log_progress(num_examples_trained_on, loss, accuracy):
    # if num_examples_trained_on % 128 == 0:
    if num_examples_trained_on % 8192 == 0:
      logging_module.info(f"[{num_examples_trained_on}] loss: {loss}, accuracy: {accuracy}")

@jax.jit
def apply_model(state, batch, apply_rng_key):
    print(f"apply_model() {batch.inputs.shape=}")
    dropout_rng_key = apply_rng_key

    def loss_fn(params):
        logits, input_embeddings = state.apply_fn(
            {"params": params}, batch.inputs, rngs={"dropout": dropout_rng_key}
        )
        log_probs = jax.nn.log_softmax(logits)  # [B, T, V]
        onehot_targets = jax.nn.one_hot(batch.targets, dataset.VOCAB_SIZE)
        log_likelihood = jnp.sum(onehot_targets * log_probs, axis=-1)  # [B, T]

        # Loss is the average negative log-likelihood per (non-masked) token.
        return -jnp.mean(log_likelihood), {"logits": logits, "input_embeddings": input_embeddings}

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, aux), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(aux["logits"], -1) == batch.targets)
    # jax.experimental.io_callback(
    #     log_progress, None, state.num_examples_trained_on, loss, accuracy
    # )
    return grads, loss, accuracy

@jax.jit
def update_model(state, grads):
    print(f"update_model()")
    return state.apply_gradients(grads=grads)


def run_training(
    dataset_filename: str,
    rng_key: PRNGKeyArray, config: ml_collections.ConfigDict, workdir: str
):
    if tensorboard:
        summary_writer = tensorboard.SummaryWriter(workdir)
        summary_writer.hparams(dict(config))
    init_rng, rng_key = jax.random.split(rng_key)
    state = create_train_state(init_rng, config)
    # print(f"{state=}")
    logging.info("starting training...")

    for batch in dataset.load_from_file(dataset_filename):
        rng_key, apply_rng_key = jax.random.split(rng_key)
        grads, loss, accuracy = apply_model(state, batch, apply_rng_key)
        state = update_model(state, grads)
        state = state.replace(
            num_examples_trained_on=state.num_examples_trained_on
            + batch.inputs.shape[0]
        )

class TrainModule:
  run_training = run_training

train = TrainModule()
train.run_training = run_training


# Evals

In [None]:
def array_to_text(b):
  return bytes(b.tolist()).decode("utf-8", errors='replace').replace("\n", "_")

eval_cases_str = [
    'rn biggest number in this list: 1 4 5 9", "solution": "9',
    'rn smallest number in this list: 1 4 5 9", "solution": "1',
    'number in this list: 1 4 5 9", "solution": "9',
    'rn biggest number in this list: 8 3 4 7", "solution": "8',
    'rn smallest number in this list: 8 3 4 7", "solution": "3',
    'number in this list: 8 3 4 7", "solution": "3',
    # 'ber among: 3 1 5", "solution": "5"}\n{"question": "Return the minimum number in this list: 7 1 9 8", "solution": "1"}\n{"question',
]

def preprocess_case(c):
  arr = jnp.array(list(c[:-1].encode('utf-8')))
  arr = jnp.pad(arr, [(0,128-arr.shape[0])])
  answer = ord(c[-1])
  return (arr, answer, c)

eval_cases = [preprocess_case(c) for c in eval_cases_str]

def eval_predictions(state, *, print_details=False):
  test_array = jnp.stack([arr for (arr, answer, c) in eval_cases])
  all_logits, input_embeddings = state.predict_apply_fn({"params": state.params}, test_array)
  probs = []
  for idx, c in enumerate(eval_cases, 1):
    (test_array, test_answer, test_str) = c
    pred_idx = len(test_str)-1-1 # test case includes the solution as the last digit
    logits = jnp.stack([all_logits[idx-1]])
    log_probs = jax.nn.log_softmax(logits)[0]
    # # print(log_probs.shape)
    prob = jnp.exp(log_probs[pred_idx][test_answer])
    probs.append(float(prob))
    if print_details:
      guesses = array_to_text(jnp.argmax(log_probs, axis=-1))
      predicted_solution = guesses[pred_idx]
      print("                     ", test_str.replace("\n", "_"))
      print(predicted_solution, '%.5f' % prob, len(guesses), f"{guesses=}")
      print()
  return probs

# eval_predictions(last_round['state'], print_details=True)


# Interactive

In [None]:
all_results = []

In [None]:
import random
import time

def initial_apply_model(state):
  shape = (config.dataset_batch_size, config.dataset_sequence_length)
  batch = dataset.Batch(
      inputs=jnp.zeros(shape, dtype=jnp.uint8),
      targets=jnp.ones(shape, dtype=jnp.uint8),
  )
  rng_key = jax.random.key(12345)
  apply_model(state, batch, rng_key)

__latest = None

def runone(config):
  logging = LoggingModule()
  # logging.info("config", config)
  # logging.info("runit()")
  rng_key = jax.random.key(12345)
  init_rng, rng_key = jax.random.split(rng_key)
  state = create_train_state(init_rng, config)
  # logging.info("create_train_state()")
  # initial_apply_model(state)
  # logging.info("initial_apply_model()")
  # for batch in dataset.load_from_file("/content/generated-easy-attention-dataset.txt", config):
  print(f"{state.params.keys()=}")
  last_round = None
  for epoch in range(1,2):
   for (i,batch) in enumerate(iter_raw_batches(config),1):
    if i > 100000: break
    rng_key, apply_rng_key = jax.random.split(rng_key)
    last_round = dict(apply_rng_key=apply_rng_key, state=state, batch=batch)
    # workaround so jax.jit doesn't recompile due to attn_datas changing, when it doesn't matter
    grads, loss, accuracy = apply_model(state, Batch(inputs=batch.inputs,targets=batch.targets,attn_datas=None), apply_rng_key)
    if i % 1000 == 0:
      eval = eval_predictions(state)
      loss = float(loss)
      print(f'{epoch=} {i=} {loss=} {eval=}')
    # all_imp_val, all_imp_grads, latest_logits = analyze_model(state, batch, apply_rng_key)
    # prev_state = state
    state = update_model(state, grads)
    state = state.replace(
        num_examples_trained_on=state.num_examples_trained_on
        + batch.inputs.shape[0]
    )
  # all_imp_val, all_imp_grads, latest_logits = analyze_model(last_round["state"], last_round["batch"], last_round["apply_rng_key"])
  # print(f"{batch=}")
  # print(f"{imp_val=}")
  # print(f"{imp_aux=}")
  # print(f"{imp_grads=}")
  global __latest
  # __latest = dict(state=state, batch=batch, all_imp_val=all_imp_val, all_imp_grads=all_imp_grads, latest_logits=latest_logits, last_round=last_round)
  __latest = dict(state=state, last_round=last_round, config=config)
  return dict(loss=float(loss), accuracy=float(accuracy))

def runit():
  orig_config = config_lib.get_config()
  results = []
  # for learning_rate in [0.005, 0.0075, 0.01, 0.015, 0.02, 0.03]:
  # for learning_rate in [0.005, 0.01, 0.02, 0.03, 0.04, 0.05, 0.075, 0.1]:
  # for learning_rate in [0.0025, 0.005, 0.01, 0.02, 0.03, 0.04, 0.05]:
  # for learning_rate in [0.035, 0.045, 0.055, 0.06, 0.07, 0.08]:
  # for momentum in [0.0, 0.5, 0.75, 0.9, 0.95, 0.99]:
  if True:
      config = ml_collections.ConfigDict(orig_config)
      # config.dataset_batch_size = 16
      config.dataset_sequence_length = 128
      # config.model_embed_dim = 8
      # learning_rate = 0.01
      # config.dataset_batch_size = 16
      # config.learning_rate = learning_rate
      # config.momentum = momentum
      config = ml_collections.FrozenConfigDict(config)
      t_before = time.time()
      result = runone(config)
      t_after = time.time()
      result = dict(config=config, result=result, time_taken=t_after - t_before)
      all_results.append(result)
      results.append(result)
      print(result)
  # for idx,r in enumerate(results):
  #   print(r)

runit()

__latest_raw = __latest

# == expected (example) ==
# state.params.keys()=dict_keys(['decoder', 'embedder', 'transformer'])
# apply_model() batch.inputs.shape=(16, 128)
# update_model()
# {'config': dataset_batch_size: 16
# dataset_sequence_length: 128
# learning_rate: 0.05
# model_embed_dim: 64
# momentum: 0.0
# transformer_attention_size_per_head: 16
# transformer_dropout_rate: 0.1
# transformer_num_attention_heads: 2
# transformer_num_layers: 2
# , 'result': {'loss': 0.3165973126888275, 'accuracy': 0.8720703125}, 'time_taken': 24.62476372718811}


# == 5-epoch with evals ==
# state.params.keys()=dict_keys(['decoder', 'embedder', 'transformer'])
# apply_model() batch.inputs.shape=(16, 128)
# update_model()
# epoch=0 i=1000 loss=0.5739200115203857 eval=[0.25576072931289673, 0.1909084916114807, 0.2073844075202942, 0.04420793801546097, 0.08059445768594742, 0.07843567430973053, 0.9624700546264648]
# epoch=0 i=2000 loss=0.4136265814304352 eval=[0.263443261384964, 0.19029481709003448, 0.24889640510082245, 0.12895368039608002, 0.06614433228969574, 0.0901700034737587, 0.9990789294242859]
# epoch=0 i=3000 loss=0.3505646288394928 eval=[0.23537684977054596, 0.22554592788219452, 0.21004092693328857, 0.1295713633298874, 0.07193823903799057, 0.069587841629982, 0.9996753931045532]
# epoch=0 i=4000 loss=0.3638574779033661 eval=[0.21900257468223572, 0.2750113904476166, 0.22962553799152374, 0.12180472165346146, 0.08235418796539307, 0.08164555579423904, 0.9996033310890198]
# epoch=0 i=5000 loss=0.3442617356777191 eval=[0.3437851369380951, 0.2896570861339569, 0.28338637948036194, 0.11324393004179001, 0.09359100461006165, 0.09555242955684662, 0.9997578859329224]
# epoch=0 i=6000 loss=0.3612757623195648 eval=[0.3321145474910736, 0.3115253150463104, 0.388074666261673, 0.16379046440124512, 0.10965581983327866, 0.09905251860618591, 0.9998824000358582]
# epoch=0 i=7000 loss=0.3155438303947449 eval=[0.519768238067627, 0.2946617901325226, 0.5052418112754822, 0.2036275863647461, 0.10580804944038391, 0.0836510956287384, 0.9998419284820557]
# epoch=0 i=8000 loss=0.33309099078178406 eval=[0.6334126591682434, 0.45864757895469666, 0.5498705506324768, 0.24986563622951508, 0.12717555463314056, 0.09224288910627365, 0.999918520450592]
# epoch=0 i=9000 loss=0.3243502080440521 eval=[0.6138125061988831, 0.37781932950019836, 0.47996675968170166, 0.2924855947494507, 0.09212594479322433, 0.0735761970281601, 0.9999344944953918]
# epoch=1 i=1000 loss=0.35357117652893066 eval=[0.7369169592857361, 0.5232179164886475, 0.5388957262039185, 0.2395087033510208, 0.14448748528957367, 0.09491223841905594, 0.9997111558914185]
# epoch=1 i=2000 loss=0.32078713178634644 eval=[0.8391317129135132, 0.4492957890033722, 0.5058247447013855, 0.348644495010376, 0.17749656736850739, 0.09224411100149155, 0.9999372363090515]
# epoch=1 i=3000 loss=0.29471877217292786 eval=[0.7215474247932434, 0.5509695410728455, 0.4120620787143707, 0.4432094991207123, 0.219400092959404, 0.0976300910115242, 0.9999505877494812]
# epoch=1 i=4000 loss=0.3263701796531677 eval=[0.6753780245780945, 0.6447336077690125, 0.3632242977619171, 0.4117605686187744, 0.24879731237888336, 0.11036688834428787, 0.9999779462814331]
# epoch=1 i=5000 loss=0.32446375489234924 eval=[0.6963419914245605, 0.8688401579856873, 0.31226545572280884, 0.5491999387741089, 0.20782560110092163, 0.10119607299566269, 0.9999755620956421]
# epoch=1 i=6000 loss=0.33036187291145325 eval=[0.6837486028671265, 0.813098669052124, 0.2693575918674469, 0.4401657283306122, 0.33529070019721985, 0.1376340389251709, 0.9999758005142212]
# epoch=1 i=7000 loss=0.2911761999130249 eval=[0.7587365508079529, 0.8477703928947449, 0.3281989097595215, 0.6431964635848999, 0.3842546045780182, 0.08905424922704697, 0.9999850392341614]
# epoch=1 i=8000 loss=0.31904229521751404 eval=[0.9006826281547546, 0.7769214510917664, 0.5496879816055298, 0.5099847316741943, 0.41413450241088867, 0.09657419472932816, 0.9999873042106628]
# epoch=1 i=9000 loss=0.3193903863430023 eval=[0.8368915915489197, 0.6778384447097778, 0.4698925018310547, 0.6070364117622375, 0.35666853189468384, 0.10728566348552704, 0.999973475933075]
# epoch=2 i=1000 loss=0.3231153190135956 eval=[0.9382563233375549, 0.712011992931366, 0.8130528926849365, 0.7326958179473877, 0.42362159490585327, 0.10162405669689178, 0.9999815821647644]
# epoch=2 i=2000 loss=0.3130776584148407 eval=[0.8540672659873962, 0.7809162139892578, 0.6200969815254211, 0.8193367719650269, 0.5511485934257507, 0.09916981309652328, 0.9999854564666748]
# epoch=2 i=3000 loss=0.2797461152076721 eval=[0.9398724436759949, 0.7283543348312378, 0.6749342679977417, 0.8329630494117737, 0.4646602272987366, 0.09345698356628418, 0.9999879002571106]
# epoch=2 i=4000 loss=0.319455087184906 eval=[0.9454924464225769, 0.9568725228309631, 0.504018247127533, 0.7465532422065735, 0.6363449096679688, 0.2609730064868927, 0.9999868869781494]
# epoch=2 i=5000 loss=0.31180888414382935 eval=[0.9875187277793884, 0.9856971502304077, 0.5815438628196716, 0.8782399296760559, 0.5585669279098511, 0.17823170125484467, 0.9999873042106628]
# epoch=2 i=6000 loss=0.3286677896976471 eval=[0.8597803115844727, 0.9745884537696838, 0.46322929859161377, 0.8213324546813965, 0.6781769394874573, 0.22114898264408112, 0.9999845623970032]
# epoch=2 i=7000 loss=0.2852470874786377 eval=[0.9865315556526184, 0.9705560803413391, 0.8133450150489807, 0.9582337737083435, 0.6071891188621521, 0.07707253098487854, 0.9999844431877136]
# epoch=2 i=8000 loss=0.3009129762649536 eval=[0.9868437051773071, 0.992186427116394, 0.7376627922058105, 0.9323488473892212, 0.6530531048774719, 0.09685704112052917, 0.9999840259552002]
# epoch=2 i=9000 loss=0.31120601296424866 eval=[0.988959789276123, 0.9872342348098755, 0.8608003258705139, 0.8953272104263306, 0.5900635719299316, 0.2272649109363556, 0.9999893307685852]
# epoch=3 i=1000 loss=0.3076467216014862 eval=[0.9913015961647034, 0.9783923625946045, 0.8200646042823792, 0.8432161211967468, 0.633089542388916, 0.18574026226997375, 0.9999899864196777]
# epoch=3 i=2000 loss=0.3079196512699127 eval=[0.9921027421951294, 0.9884726405143738, 0.7194697260856628, 0.8260654807090759, 0.6831430196762085, 0.13233381509780884, 0.9999898076057434]
# epoch=3 i=3000 loss=0.2817702293395996 eval=[0.9803481101989746, 0.9955952167510986, 0.7335648536682129, 0.982134997844696, 0.6524721384048462, 0.09374842792749405, 0.9999921321868896]
# epoch=3 i=4000 loss=0.313987135887146 eval=[0.9878073930740356, 0.9962072968482971, 0.6073234677314758, 0.9252861142158508, 0.8108097314834595, 0.4767504930496216, 0.9999916553497314]
# epoch=3 i=5000 loss=0.30439266562461853 eval=[0.9941520094871521, 0.9970490336418152, 0.5010340213775635, 0.9628649950027466, 0.7323124408721924, 0.2326633781194687, 0.9999921321868896]
# epoch=3 i=6000 loss=0.3341415822505951 eval=[0.9891661405563354, 0.9961928129196167, 0.491127073764801, 0.9249422550201416, 0.7883840799331665, 0.3188803493976593, 0.9999815821647644]
# epoch=3 i=7000 loss=0.28101611137390137 eval=[0.9927098751068115, 0.9921664595603943, 0.6426379084587097, 0.961493194103241, 0.7969250679016113, 0.24041303992271423, 0.9999803304672241]
# epoch=3 i=8000 loss=0.297041654586792 eval=[0.9884660840034485, 0.9963101148605347, 0.3774213492870331, 0.9664911031723022, 0.8411203622817993, 0.16945911943912506, 0.9999827742576599]
# epoch=3 i=9000 loss=0.3006150424480438 eval=[0.9895147085189819, 0.9962372183799744, 0.5557937622070312, 0.9701530933380127, 0.8128714561462402, 0.42682838439941406, 0.999988317489624]
# epoch=4 i=1000 loss=0.30540186166763306 eval=[0.992435872554779, 0.9964560270309448, 0.7079945206642151, 0.8108586072921753, 0.8836319446563721, 0.33212652802467346, 0.9999902248382568]
# epoch=4 i=2000 loss=0.2983374297618866 eval=[0.9948731660842896, 0.9955849647521973, 0.6550081968307495, 0.9731466770172119, 0.8628264665603638, 0.29651397466659546, 0.999985933303833]
# epoch=4 i=3000 loss=0.27236470580101013 eval=[0.9930411577224731, 0.9979089498519897, 0.5148472189903259, 0.9926022291183472, 0.8458859324455261, 0.17175333201885223, 0.9999879002571106]
# epoch=4 i=4000 loss=0.3109138011932373 eval=[0.992407500743866, 0.9982712268829346, 0.3773348927497864, 0.9817141890525818, 0.8689809441566467, 0.5541554689407349, 0.9999892711639404]
# epoch=4 i=5000 loss=0.30217716097831726 eval=[0.9965239763259888, 0.9975792765617371, 0.4456343948841095, 0.9880919456481934, 0.8514565229415894, 0.4222062826156616, 0.9999893307685852]
# epoch=4 i=6000 loss=0.3221867084503174 eval=[0.994474470615387, 0.99793940782547, 0.2711160182952881, 0.973802924156189, 0.8948591947555542, 0.5141062140464783, 0.9999846816062927]
# epoch=4 i=7000 loss=0.27481919527053833 eval=[0.9955040812492371, 0.9986401200294495, 0.47307905554771423, 0.9931105971336365, 0.8934237957000732, 0.19525732100009918, 0.9999902248382568]
# epoch=4 i=8000 loss=0.2961166203022003 eval=[0.9964576363563538, 0.9977209568023682, 0.5212557911872864, 0.9892817139625549, 0.8412736654281616, 0.1789352148771286, 0.9999916553497314]
# epoch=4 i=9000 loss=0.30616694688796997 eval=[0.995288610458374, 0.9985306859016418, 0.45903435349464417, 0.9967986941337585, 0.8838385343551636, 0.3318592309951782, 0.9999925494194031]
# {'config': dataset_batch_size: 16
# dataset_sequence_length: 128
# learning_rate: 0.05
# model_embed_dim: 64
# momentum: 0.0
# transformer_attention_size_per_head: 16
# transformer_dropout_rate: 0.1
# transformer_num_attention_heads: 2
# transformer_num_layers: 2
# , 'result': {'loss': 0.2778962552547455, 'accuracy': 0.88134765625}, 'time_taken': 102.8743155002594}

# with solution placeholders, larger dataset, 100k batches:
# state.params.keys()=dict_keys(['decoder', 'embedder', 'transformer'])
# apply_model() batch.inputs.shape=(16, 128)
# update_model()
# epoch=1 i=1000 loss=0.46857425570487976 eval=[0.29447174072265625, 0.20361736416816711, 0.27934327721595764, 0.07694032788276672, 0.06546784937381744, 0.0924861952662468, 0.9827739000320435]
# epoch=1 i=5000 loss=0.33055466413497925 eval=[0.35981935262680054, 0.27107465267181396, 0.323690801858902, 0.1450735330581665, 0.08937200903892517, 0.09934253245592117, 0.9998649954795837]
# epoch=1 i=10000 loss=0.3067989945411682 eval=[0.25776877999305725, 0.6053521037101746, 0.27096766233444214, 0.2684767544269562, 0.11884681135416031, 0.12769708037376404, 0.9997788071632385]
# epoch=1 i=20000 loss=0.3014799952507019 eval=[0.9014600515365601, 0.970745325088501, 0.5794517993927002, 0.4935365915298462, 0.4263138771057129, 0.17679323256015778, 0.9998611211776733]
# epoch=1 i=50000 loss=0.2758004367351532 eval=[0.9949303865432739, 0.9985706806182861, 0.6434107422828674, 0.9947646260261536, 0.9805237054824829, 0.1597745567560196, 0.9998638033866882]
# epoch=1 i=100000 loss=0.29574307799339294 eval=[0.9995472431182861, 0.9991129636764526, 0.6704465746879578, 0.9972497820854187, 0.98650723695755, 0.36667314171791077, 0.9999515414237976]
# {'config': dataset_batch_size: 16
# dataset_sequence_length: 128
# learning_rate: 0.05
# model_embed_dim: 64
# momentum: 0.0
# transformer_attention_size_per_head: 16
# transformer_dropout_rate: 0.1
# transformer_num_attention_heads: 2
# transformer_num_layers: 2
# , 'result': {'loss': 0.29574307799339294, 'accuracy': 0.873046875}, 'time_taken': 192.95904302597046}


##### Kaggle TPU VM v3-8 #####
# state.params.keys()=dict_keys(['decoder', 'embedder', 'transformer'])
# apply_model() batch.inputs.shape=(16, 128)
# update_model()
# epoch=1 i=1000 loss=0.46857425570487976 eval=[0.29447174072265625, 0.20361736416816711, 0.27934327721595764, 0.07694032788276672, 0.06546784937381744, 0.0924861952662468]
# epoch=1 i=2000 loss=0.3858732283115387 eval=[0.2718171775341034, 0.22649747133255005, 0.23751406371593475, 0.09634003788232803, 0.09410829842090607, 0.12707847356796265]
# epoch=1 i=5000 loss=0.33055466413497925 eval=[0.35981935262680054, 0.27107465267181396, 0.323690801858902, 0.1450735330581665, 0.08937200903892517, 0.09934253245592117]
# epoch=1 i=10000 loss=0.3067989945411682 eval=[0.25776877999305725, 0.6053521037101746, 0.27096766233444214, 0.2684767544269562, 0.11884681135416031, 0.12769708037376404]
# epoch=1 i=20000 loss=0.3014799952507019 eval=[0.9014600515365601, 0.970745325088501, 0.5794517993927002, 0.4935365915298462, 0.4263138771057129, 0.17679323256015778]
# epoch=1 i=50000 loss=0.2758004367351532 eval=[0.9949303865432739, 0.9985706806182861, 0.6434107422828674, 0.9947646260261536, 0.9805237054824829, 0.1597745567560196]
# epoch=1 i=100000 loss=0.29574307799339294 eval=[0.9995472431182861, 0.9991129636764526, 0.6704465746879578, 0.9972497820854187, 0.98650723695755, 0.36667314171791077]
# {'config': dataset_batch_size: 16
# dataset_sequence_length: 128
# learning_rate: 0.05
# model_embed_dim: 64
# momentum: 0.0
# transformer_attention_size_per_head: 16
# transformer_dropout_rate: 0.1
# transformer_num_attention_heads: 2
# transformer_num_layers: 2
# , 'result': {'loss': 0.29574307799339294, 'accuracy': 0.873046875}, 'time_taken': 178.05552220344543}


##### CPU #####
# state.params.keys()=dict_keys(['decoder', 'embedder', 'transformer'])
# apply_model() batch.inputs.shape=(16, 128)
# update_model()
# [2m]  epoch=1 i=1000 loss=0.4708341956138611 eval=[0.2885418236255646, 0.20252171158790588, 0.2837035655975342, 0.07080753147602081, 0.063657745718956, 0.08937117457389832]
# [4m]  epoch=1 i=2000 loss=0.38457387685775757 eval=[0.269186794757843, 0.22728782892227173, 0.2321525365114212, 0.09635068476200104, 0.09576919674873352, 0.13052456080913544]
# [6m]  epoch=1 i=3000 loss=0.34993505477905273 eval=[0.3772343397140503, 0.2725283205509186, 0.29736313223838806, 0.11464370042085648, 0.06595189869403839, 0.09163758158683777]
# [8m]  epoch=1 i=4000 loss=0.34004104137420654 eval=[0.40140414237976074, 0.25403761863708496, 0.4700442850589752, 0.18818339705467224, 0.07322150468826294, 0.0786784291267395]
# [10m] epoch=1 i=5000 loss=0.3301083445549011 eval=[0.4180288314819336, 0.27140316367149353, 0.4240249991416931, 0.18540892004966736, 0.08827342092990875, 0.09868050366640091]

##### T4 GPU #####
# state.params.keys()=dict_keys(['decoder', 'embedder', 'transformer'])
# apply_model() batch.inputs.shape=(16, 128)
# update_model()
# epoch=1 i=1000 loss=0.4688611626625061 eval=[0.2880088984966278, 0.20080436766147614, 0.2853744328022003, 0.07385311275720596, 0.06807087361812592, 0.0906631126999855]
# epoch=1 i=2000 loss=0.38366883993148804 eval=[0.2699344754219055, 0.23088432848453522, 0.23788470029830933, 0.09666123241186142, 0.09474606066942215, 0.1279018223285675]
# epoch=1 i=5000 loss=0.33229517936706543 eval=[0.4273291528224945, 0.26891598105430603, 0.45222875475883484, 0.18731510639190674, 0.08529336005449295, 0.09875991195440292]
# epoch=1 i=10000 loss=0.3064265251159668 eval=[0.31476539373397827, 0.5089015364646912, 0.3670322895050049, 0.2563089430332184, 0.11243242770433426, 0.11580734699964523]
# epoch=1 i=20000 loss=0.29950302839279175 eval=[0.9533458352088928, 0.9878413081169128, 0.5648325681686401, 0.42324820160865784, 0.5178518891334534, 0.2243034988641739]
# epoch=1 i=50000 loss=0.27319565415382385 eval=[0.9983587265014648, 0.9992724061012268, 0.6187055110931396, 0.9925269484519958, 0.9799752831459045, 0.2606170177459717]
# epoch=1 i=100000 loss=0.299211323261261 eval=[0.9996192455291748, 0.9993351697921753, 0.6472648978233337, 0.9983538389205933, 0.9911203980445862, 0.41760197281837463]
# {'config': dataset_batch_size: 16
# dataset_sequence_length: 128
# learning_rate: 0.05
# model_embed_dim: 64
# momentum: 0.0
# transformer_attention_size_per_head: 16
# transformer_dropout_rate: 0.1
# transformer_num_attention_heads: 2
# transformer_num_layers: 2
# , 'result': {'loss': 0.299211323261261, 'accuracy': 0.87109375}, 'time_taken': 325.2082769870758}


##### TPU VM v2-8 #####
# state.params.keys()=dict_keys(['decoder', 'embedder', 'transformer'])
# apply_model() batch.inputs.shape=(16, 128)
# update_model()
# epoch=1 i=1000 loss=0.469958633184433 eval=[0.2897143065929413, 0.19931700825691223, 0.2833636403083801, 0.07226584851741791, 0.06572941690683365, 0.09199415892362595]
# epoch=1 i=2000 loss=0.38573798537254333 eval=[0.27043935656547546, 0.22066356241703033, 0.23920319974422455, 0.10247620195150375, 0.09362270683050156, 0.12571851909160614]
# epoch=1 i=5000 loss=0.33506259322166443 eval=[0.3878052830696106, 0.25854751467704773, 0.3601313829421997, 0.155623197555542, 0.09345501661300659, 0.095614954829216]
# epoch=1 i=10000 loss=0.30599725246429443 eval=[0.2533847987651825, 0.6219271421432495, 0.3063637912273407, 0.2897632122039795, 0.10406502336263657, 0.12918716669082642]
# epoch=1 i=20000 loss=0.300803542137146 eval=[0.9575787782669067, 0.9791516065597534, 0.5493866205215454, 0.5160269141197205, 0.5214440822601318, 0.1960601955652237]
# epoch=1 i=50000 loss=0.2731204032897949 eval=[0.9976479411125183, 0.9988945126533508, 0.6397044062614441, 0.9945805072784424, 0.9584788680076599, 0.32771873474121094]
# epoch=1 i=100000 loss=0.2970423996448517 eval=[0.9997036457061768, 0.9993190765380859, 0.6756388545036316, 0.9987111687660217, 0.9908581972122192, 0.3020014762878418]
# {'config': dataset_batch_size: 16
# dataset_sequence_length: 128
# learning_rate: 0.05
# model_embed_dim: 64
# momentum: 0.0
# transformer_attention_size_per_head: 16
# transformer_dropout_rate: 0.1
# transformer_num_attention_heads: 2
# transformer_num_layers: 2
# , 'result': {'loss': 0.2970423996448517, 'accuracy': 0.87548828125}, 'time_taken': 183.3177788257599}


state.params.keys()=dict_keys(['decoder', 'embedder', 'transformer'])
apply_model() batch.inputs.shape=(16, 128)
update_model()
epoch=1 i=1000 loss=0.469958633184433 eval=[0.2897143065929413, 0.19931700825691223, 0.2833636403083801, 0.07226584851741791, 0.06572941690683365, 0.09199415892362595]
epoch=1 i=2000 loss=0.38573798537254333 eval=[0.27043935656547546, 0.22066356241703033, 0.23920319974422455, 0.10247620195150375, 0.09362270683050156, 0.12571851909160614]
epoch=1 i=3000 loss=0.3461489677429199 eval=[0.3826325237751007, 0.2683303654193878, 0.29979297518730164, 0.11167490482330322, 0.06427467614412308, 0.09038758277893066]
epoch=1 i=4000 loss=0.34199845790863037 eval=[0.4121222198009491, 0.2559955418109894, 0.455400675535202, 0.1866670548915863, 0.0758262500166893, 0.07691669464111328]
epoch=1 i=5000 loss=0.33506259322166443 eval=[0.3878052830696106, 0.25854751467704773, 0.3601313829421997, 0.155623197555542, 0.09345501661300659, 0.095614954829216]
epoch=1 i=6000 loss=0.321956425

# Attention-processed dataset

In [None]:
last_round = __latest["last_round"]
print(last_round['batch'])
print(list(filter(lambda s: '__' not in s, dir(__latest['state']))))
print("num_examples_trained_on", __latest['state'].num_examples_trained_on)

__latest_raw = __latest

Batch(inputs=Array([[117, 109,  98, ..., 113, 117, 101],
       [116, 105, 111, ..., 114,  32, 105],
       [ 32, 116, 104, ..., 108, 117, 116],
       ...,
       [104, 101,  32, ...,  89,  34, 125],
       [123,  34, 113, ...,  32, 116, 104],
       [ 32, 108, 105, ...,  34, 125,  10]], dtype=uint8), targets=Array([[109,  98, 101, ..., 117, 101, 115],
       [105, 111, 110, ...,  32, 105, 110],
       [116, 104, 105, ..., 117, 116, 105],
       ...,
       [101,  32, 108, ...,  34, 125,  10],
       [ 34, 113, 117, ..., 116, 104, 101],
       [108, 105, 115, ..., 125,  10, 123]], dtype=uint8), attn_datas=[[{'sol_idx': 28, 'nums_idx': 7, 'nums_cnt': 3}, {'sol_idx': 114, 'nums_idx': 83, 'nums_cnt': 8}], [{'sol_idx': 66, 'nums_idx': 43, 'nums_cnt': 4}], [{'sol_idx': 41, 'nums_idx': 12, 'nums_cnt': 7}], [{'sol_idx': 57, 'nums_idx': 30, 'nums_cnt': 6}], [{'sol_idx': 60, 'nums_idx': 29, 'nums_cnt': 8}], [{'sol_idx': 54, 'nums_idx': 23, 'nums_cnt': 8}], [{'sol_idx': 57, 'nums_idx': 26, 'num

In [None]:
# %%time

import math

def explore_determining_importance(input_batch, state):
  for idx in range(len(input_batch.attn_datas)):
    inp = input_batch.inputs[idx]
    trg = input_batch.targets[idx]
    s_inp = array_to_text(inp)
    s_trg = array_to_text(trg)
    summaries = []
    p_ctr = 0
    for p in input_batch.attn_datas[idx]:
      p_ctr += 1
      print(idx, p, f'{s_inp=} {s_trg=}')
      inp_sol_idx = p['sol_idx']
      trg_sol_idx = inp_sol_idx-1
      for n_ctr in range(p['nums_cnt']):
        # print(f"{p_ctr=} {n_ctr=}")
        # if p_ctr != 2: continue
        # if n_ctr != 1: continue
        n_idx = p['nums_idx'] + n_ctr*2
        print("n_idx", n_idx, "s@n_idx", s_inp[n_idx], "s@sol_idx", s_inp[inp_sol_idx])
        sub_inp = []
        sub_trg = []
        sub_inp.append(inp)
        base_idx = 0
        for c in range(ord('1'), ord('9')+1):
          sub_inp.append(inp.at[n_idx].set(c))
        for c in range(ord('1'), ord('9')+1):
          sub_trg.append(trg.at[n_idx-1].set(c))
        sub_inp = jnp.stack(sub_inp)
        sub_trg = jnp.stack(sub_trg)
        logits, input_embeddings = state.predict_apply_fn({"params": state.params}, sub_inp)
        log_probs = jax.nn.log_softmax(logits)
        # visualize the replacements with num==>5, which should often be impactful for both smallest and largest
        # s_subinp = "".join(chr(sub_inp[4][idx]) for idx in range(128))
        # print("s_subinp5:", s_subinp.replace("\n", "_"))
        # guesses = "".join(chr(jnp.argmax(log_probs[4][idx])) for idx in range(128))
        # print("  guesses:", "_" + guesses.replace("\n", "_"))
        # print("prediction", chr(jnp.argmax(log_probs[4][sol_idx-1])))
        # print(logits.shape, jnp.exp(log_probs)[:,sol_idx-1,inp[sol_idx]])
        print()
        sol = inp[inp_sol_idx]
        base_prob_sol = jnp.exp(log_probs[base_idx][trg_sol_idx][sol])
        sum_diff_abs = jnp.array(0.0)
        sum_diff_sqr = jnp.array(0.0)
        sum_diff_sol_entropy = jnp.array(0.0)
        sum_diff_cross_entropy = jnp.array(0.0)
        all_probs = []
        for num in range(1,10):
          p_sol = jnp.argmax(log_probs[num][trg_sol_idx])

          # s_subinp = "".join(chr(sub_inp[num][idx]) for idx in range(128))
          # print(f"s_subinp{num}:", s_subinp.replace("\n", "_"))
          # guesses = "".join(chr(jnp.argmax(log_probs[num][idx])) for idx in range(128))
          # print("  guesses:", "_" + guesses.replace("\n", "_"))
          # print("prediction", chr(p_sol))

          num_prob = jnp.exp(log_probs[num,trg_sol_idx,sol])
          all_probs.append(num_prob)

          diff_abs = jnp.abs(base_prob_sol - num_prob)
          diff_sqr = (base_prob_sol - num_prob) ** 2
          diff_sol_entropy = -(base_prob_sol*jnp.log(num_prob)+(1-base_prob_sol)*jnp.log(1-num_prob))
          diff_cross_entropy = optax.softmax_cross_entropy(
              logits[num,trg_sol_idx,:],
              jax.nn.softmax(logits[base_idx,trg_sol_idx,:]),
          )

          sum_diff_abs += diff_abs
          sum_diff_sqr += diff_sqr
          sum_diff_sol_entropy += diff_sol_entropy
          sum_diff_cross_entropy += diff_cross_entropy

          print(f"real answer {s_inp[inp_sol_idx]} at {num_prob:.10f} ==> now {chr(p_sol)}{'' if s_inp[inp_sol_idx] == chr(p_sol) else ' (diff)'}")
          # print(f"  {diff_abs=:.10f}")
          # print(f"  {diff_sqr=:.10f}")
          # print(f"  {diff_sol_entropy=:.10f}")
          # print(f"  {diff_cross_entropy=:.10f}")
          # which = jnp.array([ord(c) for c in "123456789"])
          # print(f"{num=}", jnp.exp(log_probs)[num,trg_sol_idx,which])
          # print()
        all_probs = [float(x) for x in all_probs]
        # print(f"{all_probs=}")
        print(f"{sum_diff_abs=:.10f}")
        print(f"{sum_diff_sqr=:.10f}")
        print(f"{sum_diff_sol_entropy=:.10f}")
        print(f"{sum_diff_cross_entropy=:.10f}")
        # summaries.append(dict(all_probs=all_probs, diff_abs=diff_abs, diff_sqr=diff_sqr, diff_sol_entropy=diff_sol_entropy))
      print()
    for s in summaries:
      print(s)
    break

# explore_determining_importance(last_round['batch'], last_round['state'])


## EXAMPLE OUTPUT
# 0 {'sol_idx': 32, 'nums_idx': 11, 'nums_cnt': 3} s_inp='ber among: 3 4 5", "solution": "5"}\n{"question": "Return the minimum number in this list: 7 1 9 8", "solution": "1"}\n{"question"' s_trg='er among: 3 4 5", "solution": "5"}\n{"question": "Return the minimum number in this list: 7 1 9 8", "solution": "1"}\n{"question":'
# n_idx 11 s@n_idx 3 s@sol_idx 5

# real answer 5 at 0.1603630334 ==> now 1 (diff)
# real answer 5 at 0.1405527741 ==> now 2 (diff)
# real answer 5 at 0.1478881836 ==> now 3 (diff)
# real answer 5 at 0.2092338949 ==> now 4 (diff)
# real answer 5 at 0.2722366154 ==> now 4 (diff)
# real answer 5 at 0.2194128782 ==> now 6 (diff)
# real answer 5 at 0.1721600443 ==> now 7 (diff)
# real answer 5 at 0.0946281105 ==> now 8 (diff)
# real answer 5 at 0.1337747872 ==> now 9 (diff)
# sum_diff_abs=0.3686744273
# sum_diff_sqr=0.0281759873
# sum_diff_sol_entropy=3.8621878624
# sum_diff_cross_entropy=20.1720008850
# n_idx 13 s@n_idx 4 s@sol_idx 5

# real answer 5 at 0.1429803520 ==> now 1 (diff)
# real answer 5 at 0.1935782433 ==> now 2 (diff)
# real answer 5 at 0.0796194971 ==> now 3 (diff)
# real answer 5 at 0.1478881836 ==> now 3 (diff)
# real answer 5 at 0.1425110847 ==> now 3 (diff)
# real answer 5 at 0.1408456713 ==> now 3 (diff)
# real answer 5 at 0.0484608710 ==> now 7 (diff)
# real answer 5 at 0.0436881743 ==> now 3 (diff)
# real answer 5 at 0.0417443700 ==> now 3 (diff)
# sum_diff_abs=0.4410573244
# sum_diff_sqr=0.0388607346
# sum_diff_sol_entropy=4.0446205139
# sum_diff_cross_entropy=15.4104194641
# n_idx 15 s@n_idx 5 s@sol_idx 5

# real answer 5 at 0.1345619857 ==> now 1 (diff)
# real answer 5 at 0.0836355463 ==> now 2 (diff)
# real answer 5 at 0.0687709898 ==> now 3 (diff)
# real answer 5 at 0.1442581862 ==> now 3 (diff)
# real answer 5 at 0.1478881836 ==> now 3 (diff)
# real answer 5 at 0.1463320106 ==> now 3 (diff)
# real answer 5 at 0.0457798168 ==> now 7 (diff)
# real answer 5 at 0.0374081060 ==> now 3 (diff)
# real answer 5 at 0.0282506160 ==> now 3 (diff)
# sum_diff_abs=0.4941082001
# sum_diff_sqr=0.0475262292
# sum_diff_sol_entropy=4.1410827637
# sum_diff_cross_entropy=16.8980541229

# 0 {'sol_idx': 113, 'nums_idx': 90, 'nums_cnt': 4} s_inp='ber among: 3 4 5", "solution": "5"}\n{"question": "Return the minimum number in this list: 7 1 9 8", "solution": "1"}\n{"question"' s_trg='er among: 3 4 5", "solution": "5"}\n{"question": "Return the minimum number in this list: 7 1 9 8", "solution": "1"}\n{"question":'
# n_idx 90 s@n_idx 7 s@sol_idx 1

# real answer 1 at 0.9995328188 ==> now 1
# real answer 1 at 0.9992763996 ==> now 1
# real answer 1 at 0.9994804859 ==> now 1
# real answer 1 at 0.9995002151 ==> now 1
# real answer 1 at 0.9994872808 ==> now 1
# real answer 1 at 0.9994792938 ==> now 1
# real answer 1 at 0.9991625547 ==> now 1
# real answer 1 at 0.9995072484 ==> now 1
# real answer 1 at 0.9993562102 ==> now 1
# sum_diff_abs=0.0023195148
# sum_diff_sqr=0.0000007272
# sum_diff_sol_entropy=0.0615321547
# sum_diff_cross_entropy=0.0892656520
# n_idx 92 s@n_idx 1 s@sol_idx 1

# real answer 1 at 0.9991625547 ==> now 1
# real answer 1 at 0.0011199772 ==> now 2 (diff)
# real answer 1 at 0.0008807029 ==> now 3 (diff)
# real answer 1 at 0.0004429024 ==> now 4 (diff)
# real answer 1 at 0.0005722365 ==> now 4 (diff)
# real answer 1 at 0.0005210863 ==> now 4 (diff)
# real answer 1 at 0.0007700303 ==> now 4 (diff)
# real answer 1 at 0.0006012401 ==> now 4 (diff)
# real answer 1 at 0.0010839972 ==> now 4 (diff)
# sum_diff_abs=7.9873085022
# sum_diff_sqr=7.9746370316
# sum_diff_sol_entropy=57.9478569031
# sum_diff_cross_entropy=58.0061531067
# n_idx 94 s@n_idx 9 s@sol_idx 1

# real answer 1 at 0.9992839098 ==> now 1
# real answer 1 at 0.9984142780 ==> now 1
# real answer 1 at 0.9991236329 ==> now 1
# real answer 1 at 0.9990704656 ==> now 1
# real answer 1 at 0.9989698529 ==> now 1
# real answer 1 at 0.9989487529 ==> now 1
# real answer 1 at 0.9983630776 ==> now 1
# real answer 1 at 0.9991226196 ==> now 1
# real answer 1 at 0.9991625547 ==> now 1
# sum_diff_abs=0.0022465587
# sum_diff_sqr=0.0000013082
# sum_diff_sol_entropy=0.0614459515
# sum_diff_cross_entropy=0.0882511139
# n_idx 96 s@n_idx 8 s@sol_idx 1

# real answer 1 at 0.9992237687 ==> now 1
# real answer 1 at 0.9985373616 ==> now 1
# real answer 1 at 0.9990318418 ==> now 1
# real answer 1 at 0.9992146492 ==> now 1
# real answer 1 at 0.9990714788 ==> now 1
# real answer 1 at 0.9990147352 ==> now 1
# real answer 1 at 0.9987911582 ==> now 1
# real answer 1 at 0.9991625547 ==> now 1
# real answer 1 at 0.9989220500 ==> now 1
# sum_diff_abs=0.0017200112
# sum_diff_sqr=0.0000006403
# sum_diff_sol_entropy=0.0612159483
# sum_diff_cross_entropy=0.0878705680


In [None]:
# %%time

import math

class TrgIdxTransformer(nn.Module):
    num_attention_heads: int
    num_layers: int  # each layer has attention + MLP
    attention_size_per_head: int
    dropout_rate: float
    is_training: bool
    widening_factor: int = 4  # factor for widening MLP hidden layer

    @nn.compact
    def __call__(self, embeddings: jax.Array, whichidx) -> jax.Array:
        initializer = nn.initializers.variance_scaling(
            2 / self.num_layers, "fan_in", "truncated_normal"
        )
        # print(f"{embeddings.shape=}")
        num_batches, seq_len, embedding_size = embeddings.shape
        causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len)))
        # print(f"{causal_mask.shape=}")
        # causal_mask = causal_mask[:,:,whichidx:whichidx+1,:]
        # print(f"{causal_mask.shape=}")
        # print(f"{causal_mask=}")
        # cm2 = jnp.zeros((1,1,1,seq_len))

        h = embeddings
        for _ in range(self.num_layers):
            # First the attention block.
            attn_block = nn.SelfAttention(
                num_heads=self.num_attention_heads,
                qkv_features=self.attention_size_per_head * self.num_attention_heads,
                dropout_rate=self.dropout_rate,
                kernel_init=nn.initializers.variance_scaling(
                    scale=2 / self.num_layers,
                    mode="fan_in",
                    distribution="truncated_normal",
                ),
                # use built-in normalization?
                # use_bias False???
                # broadcast_dropout False???
                # use decode:true for speedup!
            )
            h_norm = nn.LayerNorm()(h)
            h_attn = attn_block(
                inputs_q=h_norm, mask=causal_mask, deterministic=not self.is_training
            )

            # add another dropout layer here??? instead of the built-in one?
            h = h + h_attn

            # Then the MLP block
            h_norm = nn.LayerNorm()(h)
            h_dense = nn.Dense(
                self.widening_factor * embedding_size, kernel_init=initializer
            )(h_norm)
            # relu+dropout instead of gelu??
            h_dense = jax.nn.gelu(h_dense)
            h_dense = nn.Dense(embedding_size, kernel_init=initializer)(h_dense)
            h_dense = nn.Dropout(
                rate=self.dropout_rate, deterministic=not self.is_training
            )(h_dense)
            h = h + h_dense

        return nn.LayerNorm()(h)

class FastIdxAutoregressiveTransformerModel(nn.Module):
    embedder: Embedder
    transformer: TrgIdxTransformer
    # trgIdxTransformer: TrgIdxTransformer
    decoder: Decoder

    @nn.compact
    def __call__(self, in_tokens: jax.Array, whichidx) -> jax.Array:
        """Forward pass, producing a sequence of logits."""
        input_embeddings = self.embedder(in_tokens)
        output = self.transformer(input_embeddings, whichidx)
        decoded = self.decoder(output)
        return decoded, input_embeddings

def mkfast(config):
  transformer = TrgIdxTransformer(
      num_attention_heads=config.transformer_num_attention_heads,  # type: ignore
      num_layers=config.transformer_num_layers,  # type: ignore
      attention_size_per_head=config.transformer_attention_size_per_head,  # type: ignore
      dropout_rate=config.transformer_dropout_rate,  # type: ignore
      is_training=False,
  )
  embedder = model.Embedder(
      embed_dim=config.model_embed_dim,
      vocab_size=dataset.VOCAB_SIZE,
  )
  decoder = model.Decoder(
      vocab_size=dataset.VOCAB_SIZE,
  )
  fastatm = FastIdxAutoregressiveTransformerModel(
      embedder=embedder,
      transformer=transformer,
      decoder=decoder,
  )

  @jax.jit
  def fastapply(state, numbatch, whichidx):
    logits, input_embeddings = fastatm.apply(state, numbatch, whichidx)
    return logits[:,whichidx,:]

  return fastapply


dofastapply = mkfast(__latest_raw["config"])

def generate_introspective_batch(input_batch, state):
  new_inputs = []
  new_targets = []
  for idx in range(len(input_batch.attn_datas)):
    inp = input_batch.inputs[idx]
    trg = input_batch.targets[idx]
    s_inp = array_to_text(inp)
    s_trg = array_to_text(trg)
    p_ctr = 0
    new_inp = input_batch.inputs[idx]
    new_trg = input_batch.targets[idx]
    # print('new_inp', array_to_text(new_inp).replace('\n', '_'))
    # print('new_trg', array_to_text(new_trg).replace('\n', '_'))
    # print()
    for p in input_batch.attn_datas[idx]:
      p_ctr += 1
      # print(idx, p, f'{s_inp=} {s_trg=}')
      trg_sol_idx = p['sol_idx']-1
      num_impacts_list = []
      for n_ctr in range(p['nums_cnt']):
        # print(f"{p_ctr=} {n_ctr=}")
        # if p_ctr != 2: continue
        # if n_ctr != 1: continue
        n_idx = p['nums_idx'] + n_ctr*2
        sub_inp = []
        sub_inp.append(inp)
        base_idx = 0
        for c in range(ord('1'), ord('9')+1):
          sub_inp.append(inp.at[n_idx].set(c))
        # note: inefficient, currently computes the complete target when we only need trg_sol_idx, so at least SEQLEN times too much work, maybe up to SEQLEN squared
        # logits, input_embeddings = state.predict_apply_fn({"params": state.params}, jnp.stack(sub_inp))
        logits = dofastapply({"params": state.params}, jnp.stack(sub_inp), trg_sol_idx)
        # print(logits.shape, jnp.sum(logits[:,:]))
        impact = optax.softmax_cross_entropy(
            logits[1:,:],
            jnp.broadcast_to(jax.nn.softmax(logits[base_idx,:]), [9, 256]),
        )
        # print()
        # print('  ', array_to_text(inp.at[n_idx].set(ord('$'))).replace('\n', '_'))
        # print('  ', f"{jnp.sum(impact)=}")
        num_impacts_list.append(jnp.sum(impact))
      num_impacts = jnp.stack(num_impacts_list)
      # print()
      # print(f"{num_impacts=}")
      highest_impact_idx = jnp.argmax(num_impacts)
      highest_impact = num_impacts[highest_impact_idx]
      # print(f"{highest_impact_idx=} {highest_impact=}")
      sol_idx = p['sol_idx']
      impact_num_idx = sol_idx + 2
      impact_idx_idx = impact_num_idx + 3
      new_inp = new_inp.at[impact_num_idx].set(inp[p['nums_idx'] + highest_impact_idx * 2])
      new_inp = new_inp.at[impact_idx_idx].set(ord("%d" % (highest_impact_idx+1)))

      new_trg = new_trg.at[impact_num_idx-1].set(inp[p['nums_idx'] + highest_impact_idx * 2])
      new_trg = new_trg.at[impact_idx_idx-1].set(ord("%d" % (highest_impact_idx+1)))
      # print('new_inp', array_to_text(new_inp).replace('\n', '_'))
      # print()
    # print('new_inp', array_to_text(new_inp).replace('\n', '_'))
    # print('new_trg', array_to_text(new_trg).replace('\n', '_'))
    new_inputs.append(new_inp)
    new_targets.append(new_trg)
    # break
  return Batch(inputs=jnp.stack(new_inputs), targets=jnp.stack(new_targets), attn_datas=input_batch.attn_datas)

generate_introspective_batch(last_round['batch'], __latest['state'])


Batch(inputs=Array([[117, 109,  98, ..., 113, 117, 101],
       [116, 105, 111, ..., 114,  32, 105],
       [ 32, 116, 104, ..., 108, 117, 116],
       ...,
       [104, 101,  32, ...,  50,  34, 125],
       [123,  34, 113, ...,  32, 116, 104],
       [ 32, 108, 105, ...,  34, 125,  10]], dtype=uint8), targets=Array([[109,  98, 101, ..., 117, 101, 115],
       [105, 111, 110, ...,  32, 105, 110],
       [116, 104, 105, ..., 117, 116, 105],
       ...,
       [101,  32, 108, ...,  34, 125,  10],
       [ 34, 113, 117, ..., 116, 104, 101],
       [108, 105, 115, ..., 125,  10, 123]], dtype=uint8), attn_datas=[[{'sol_idx': 28, 'nums_idx': 7, 'nums_cnt': 3}, {'sol_idx': 114, 'nums_idx': 83, 'nums_cnt': 8}], [{'sol_idx': 66, 'nums_idx': 43, 'nums_cnt': 4}], [{'sol_idx': 41, 'nums_idx': 12, 'nums_cnt': 7}], [{'sol_idx': 57, 'nums_idx': 30, 'nums_cnt': 6}], [{'sol_idx': 60, 'nums_idx': 29, 'nums_cnt': 8}], [{'sol_idx': 54, 'nums_idx': 23, 'nums_cnt': 8}], [{'sol_idx': 57, 'nums_idx': 26, 'num

# Self-reflection training

In [None]:
all_results = []

In [None]:
import random
import time
from flax.training import orbax_utils
import orbax

def initial_apply_model(state):
  shape = (config.dataset_batch_size, config.dataset_sequence_length)
  batch = dataset.Batch(
      inputs=jnp.zeros(shape, dtype=jnp.uint8),
      targets=jnp.ones(shape, dtype=jnp.uint8),
  )
  rng_key = jax.random.key(12345)
  apply_model(state, batch, rng_key)

__latest = None

def runone(config):
  logging = LoggingModule()
  # logging.info("config", config)
  # logging.info("runit()")
  rng_key = jax.random.key(12345)
  init_rng, rng_key = jax.random.split(rng_key)
  state = create_train_state(init_rng, config)
  @jax.jit
  def pred(params, batch):
    logits, input_embeddings = state.predict_apply_fn({"params": params}, batch)
    log_probs = jax.nn.log_softmax(logits)
    return jnp.argmax(log_probs, axis=-1)
  # logging.info("create_train_state()")
  # initial_apply_model(state)
  # logging.info("initial_apply_model()")
  # for batch in dataset.load_from_file("/content/generated-easy-attention-dataset.txt", config):
  print(f"{state.params.keys()=}")
  last_round = None
  for epoch in range(1,2):
   for (i,raw_batch) in enumerate(iter_raw_batches(config),1):
    if i > 100000: break
    # if i > 5000: break
    # if i > 250: break
    rng_key, apply_rng_key = jax.random.split(rng_key)
    # workaround so jax.jit doesn't recompile due to attn_datas changing, when it doesn't matter
    attn_batch = generate_introspective_batch(raw_batch, state)
    last_round = dict(apply_rng_key=apply_rng_key, state=state, raw_batch=raw_batch, attn_batch=attn_batch)
    grads, loss, accuracy = apply_model(state, Batch(inputs=attn_batch.inputs,targets=attn_batch.targets,attn_datas=None), apply_rng_key)
    # if i % 1000 == 0:
    if i % 100 == 0:
    # if i % 10 == 0:
    # if i % 1 == 0:
      eval = eval_predictions(state)
      loss = float(loss)
      print(f'{epoch=} {i=} {loss=} {eval=}')
      print('   trg', array_to_text(attn_batch.targets[0]).replace('\n','_'))
      p_out = pred(state.params, attn_batch.inputs)
      # print(p_out.shape)
      print('  pred', array_to_text(p_out[0]).replace('\n','_'))
      print()
    # all_imp_val, all_imp_grads, latest_logits = analyze_model(state, batch, apply_rng_key)
    # prev_state = state
    state = update_model(state, grads)
    state = state.replace(
        num_examples_trained_on=state.num_examples_trained_on
        + attn_batch.inputs.shape[0]
    )
    if i % 1000 == 0:
    # if i % 10 == 0:
      save_name = f"/content/flax_ckpt/orbax/save_i={i}"
      ckpt = {'trainstate': state, 'config': config, 'rng_key_start': 12345, 'i': i}
      orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
      save_args = orbax_utils.save_args_from_target(ckpt)
      orbax_checkpointer.save(save_name, ckpt, save_args=save_args)

  # all_imp_val, all_imp_grads, latest_logits = analyze_model(last_round["state"], last_round["batch"], last_round["apply_rng_key"])
  # print(f"{batch=}")
  # print(f"{imp_val=}")
  # print(f"{imp_aux=}")
  # print(f"{imp_grads=}")
  global __latest
  # __latest = dict(state=state, batch=batch, all_imp_val=all_imp_val, all_imp_grads=all_imp_grads, latest_logits=latest_logits, last_round=last_round)
  __latest = dict(state=state, last_round=last_round)
  return dict(loss=float(loss), accuracy=float(accuracy))

def runit():
  orig_config = config_lib.get_config()
  results = []
  # for learning_rate in [0.005, 0.0075, 0.01, 0.015, 0.02, 0.03]:
  # for learning_rate in [0.005, 0.01, 0.02, 0.03, 0.04, 0.05, 0.075, 0.1]:
  # for learning_rate in [0.0025, 0.005, 0.01, 0.02, 0.03, 0.04, 0.05]:
  # for learning_rate in [0.035, 0.045, 0.055, 0.06, 0.07, 0.08]:
  # for momentum in [0.0, 0.5, 0.75, 0.9, 0.95, 0.99]:
  if True:
      config = ml_collections.ConfigDict(orig_config)
      # config.dataset_batch_size = 16
      config.dataset_sequence_length = 128
      # config.model_embed_dim = 8
      # learning_rate = 0.01
      # config.dataset_batch_size = 16
      # config.learning_rate = learning_rate
      # config.momentum = momentum
      config = ml_collections.FrozenConfigDict(config)
      t_before = time.time()
      result = runone(config)
      t_after = time.time()
      result = dict(config=config, result=result, time_taken=t_after - t_before)
      all_results.append(result)
      results.append(result)
      print(result)
  # for idx,r in enumerate(results):
  #   print(r)

runit()

__latest_attn = __latest


state.params.keys()=dict_keys(['decoder', 'embedder', 'transformer'])
apply_model() batch.inputs.shape=(16, 128)
update_model()
epoch=1 i=100 loss=1.8924028873443604 eval=[0.028565993532538414, 0.01532052457332611, 0.02274836041033268, 0.013358487747609615, 0.012247969396412373, 0.014840974472463131]
   trg  which is the minimum number in this list: 1 9 6 1", "solution": "1;1@#1"}_{"question": "Show the minimum number in the list: 4 
  pred    emeemmt ie lbs sbmbt m e  5s 5ies 5isti 5 5 5 5 s ",tlutinn", ", 5"#1",_{":uestinn": ":ien "ie "bo"obmb" mbe  "ou"ie "uoti " 

epoch=1 i=200 loss=1.3659120798110962 eval=[0.04807756096124649, 0.04171866178512573, 0.039141349494457245, 0.013522585853934288, 0.02492089569568634, 0.03819461911916733]
   trg {"question": "Determine the largest number: 5 2 6 3", "solution": "6;2@#2"}_{"question": "Return the maximum number among: 3 4 2
  pred e"iuestinn": ":e h mine n e nim e t:n mbe   2 7 7 7 , "sonution": ":@1"#1"}_{"querthon": ":e  mm nhe nbminbmbn

In [None]:
def ttegetet():
  state = __latest_raw['state']
  @jax.jit
  def pred(params, batch):
    logits, input_embeddings = state.predict_apply_fn({"params": params}, batch)
    log_probs = jax.nn.log_softmax(logits)
    return jnp.argmax(log_probs, axis=-1)

  logits = pred(__latest_raw['state'].params, last_round['batch'].inputs)
  print(logits.shape)
  p = jnp.argmax(log_probs, axis=-1)
  for x in range(16):
    print(x, array_to_text(last_round['batch'].targets[x]))
    print(x, array_to_text(p[x]))
  print(p.shape)

ttegetet()


(16, 128, 256)
0 mber: 1 7 2", "solution": "7;X@#Y"}_{"question": "Determine biggest number among: 4 3 7 4 9 2 8 8", "solution": "9;X@#Y"}_{"ques
0 mber  3 2 7 , "solution": "7;X@#Y"}_{"question": "Determine tiggest number imong: 1 6 6 4 9 7 7 1", "solution": "9;X@#Y"}_{"ques
1 ion": "Find the largest of these numbers: 5 3 4 1", "solution": "5;X@#Y"}_{"question": "Determine which is the largest number in
1 ion": "Dind ahe lorgest nf these numbers: 1 7 9 9 , "solution": "5;X@#Y"}_{"question": "Determine thich is the lorgest number in
2 this list: 4 7 8 1 2 2 9", "solution": "9;X@#Y"}_{"question": "Find and show largest of these numbers: 5 6 4 8 6 6 4 9", "soluti
2 thes list: 7 7 6 7 7 6 1 , "solution": "9;X@#Y"}_{"question": "Dind and show torgest nf these numbers: 3 7 7 7 1 1 7"6", "soluti
3  largest number in the list: 1 1 7 3 1 9", "solution": "9;X@#Y"}_{"question": "Find the biggest of these numbers: 9 5 3 6 6 2 2"
3 slirgest number in thi list: 3 1 1 1 1 1 , "solution": "9;X@#Y"}_{

In [None]:
!ls -l

# Unused stuff

In [None]:
def checklatest():
  inputs = __latest["batch"].inputs
  targets = __latest["batch"].targets
  state = __latest["state"]
  print(state.params.keys())
  decoded = state.decoder_apply_fn({"params":state.params["decoder"]}, latest_logits)
  print(f"{inputs=}")
  print(f"{targets=}")
  print(f"{latest_logits.shape=}")
  print(f"{decoded[0].shape=}")
  inputs_dec0 = "".join([chr(x) for x in inputs[0]])
  # inputs_dec1 = "".join([chr(x) for x in inputs[1]])
  targets_dec0 = "".join([chr(x) for x in targets[0]])
  # targets_dec1 = "".join([chr(x) for x in targets[1]])
  guesses = "".join(chr(jnp.argmax(decoded[0][idx])) for idx in range(128))
  print(f"{inputs_dec0=}")
  # print(f"{inputs_dec1=}")
  print(f"{targets_dec0=}")
  # print(f"{targets_dec1=}")
  print(f"     {guesses=}")
  print(f"{all_imp_grads.shape=}")
  input_embeddings = state.embedder_apply_fn({"params": state.params['embedder']}, jnp.stack([inputs[0]]))[0]
  print(f"{input_embeddings.shape=}")
  for idx in range(START_IDX,START_IDX+16):
    guess = chr(jnp.argmax(decoded[0][idx]))
    print(f"{idx=} {repr(targets_dec0[:idx+1])} <- {repr(guess)}")
    all_imp_grads_idx = all_imp_grads[idx-START_IDX]
    # print(f"{all_imp_grads_idx[0].shape=} {all_imp_grads_idx[0]=}")
    norms = {}
    sortasaliency = {}
    for j in range(START_IDX+16):
      norms[j] = float(jnp.linalg.norm(all_imp_grads_idx[0][j]))
      # so-called "saliency" but it's not very good
      sortasaliency[j] = float(jnp.dot(all_imp_grads_idx[0][j], input_embeddings[j]))
      # sortasaliency[j] = float(jnp.sum(jnp.abs(all_imp_grads_idx[0][j] * input_embeddings[j])))
    normsum = sum(norms.values())
    sortasaliencysum = sum(sortasaliency.values())
    for j in range(START_IDX+16):
      scalednorm = norms[j] / normsum
      scaledsal = sortasaliency[j] / sortasaliencysum
      # norm = all_imp_grads_idx[0][j]
      print(f"{j=} {repr(inputs_dec0[j])} {scalednorm=} {scaledsal=}")
    # print(f"{all_imp_grads_idx[1].shape=} {all_imp_grads_idx[1]=}")

checklatest()

In [None]:
@jax.jit
def analyze_model_2(state, inputs, apply_rng_key):
    dropout_rng_key = apply_rng_key

    def doit(batch):
      output = state.transformer_apply_fn(
          {"params": state.params['transformer']}, batch, rngs={"dropout": dropout_rng_key}
      )
      decoded = state.decoder_apply_fn(
          {"params": state.params['decoder']}, output
      )
      return decoded

    input_embeddings = state.embedder_apply_fn({"params": state.params['embedder']}, jnp.stack([inputs]))
    N = input_embeddings.shape[1]
    print(f"{input_embeddings.shape=}")
    base = doit(input_embeddings)
    print(f"{base.shape=}")
    mult = (1 - jnp.reshape(jnp.identity(N), (N,N,1)))
    print(f"{mult.shape=}")
    batch = input_embeddings * mult
    print(f"{batch.shape=}")
    stuff = []
    for idx in range(128):
      stuff.append(doit(batch[idx:idx+1])[0])
    # xnx = doit(batch)
    xnx = jnp.stack(stuff)
    print(f"{xnx.shape=}")
    diff = (base - xnx) ** 2
    print(f"{diff.shape=}")
    return base, diff, xnx, batch, input_embeddings

foo = None
def doityo():
  last_round = __latest["last_round"]
  base, diff, xnx, batch, input_embeddings = analyze_model_2(last_round["state"], last_round["batch"].inputs[0], last_round["apply_rng_key"])
  global foo
  foo = dict(base=base, diff=diff, xnx=xnx, batch=batch, input_embeddings=input_embeddings)
  print(f"{jnp.sum(input_embeddings[0,:16], axis=-1)=}")
  print(f"{jnp.sum(batch[:4,:16], axis=-1)=}")
  print(f"{jnp.sum(base[:,:8], axis=-1)=}")
  print(f"{jnp.sum(xnx[:8,:8], axis=-1)=}")
  # START_IDX, START_IDX+4
  print(f"{base.shape=}")
  print(f"{diff.shape=}")
  chr_inputs = "".join([chr(x) for x in last_round["batch"].inputs[0]])
  chr_targets = "".join([chr(x) for x in last_round["batch"].targets[0]])
  guesses = []
  for idx in range(128):
    guess = chr(jnp.argmax(base[0][idx]))
    guesses.append(guess)
  chr_guess = "".join(guesses)
  print(f" {chr_inputs=}")
  print(f"{chr_targets=}")
  print(f"  {chr_guess=}")
  # for idx in range(128):
  # for idx in range(95,105):
  for idx in range(8):
    t = chr_targets[idx]
    g = chr_guess[idx]
    print(f"{idx=} {t=} {g=}")
    # for frmidx in range(128):
    for frmidx in range(10):
      f = chr_inputs[frmidx]
      print(f"{frmidx=} {f=} {jnp.mean(diff[frmidx,idx,:])}")
      print(f"  base:{jnp.mean(base[0,idx])}")
      print(f"   xnx:{jnp.mean(xnx[frmidx,idx,:])}")

doityo()


In [None]:
def doitlo():
  last_round = __latest["last_round"]
  base, diff, xnx, batch, input_embeddings = foo["base"], foo["diff"], foo["xnx"], foo["batch"], foo["input_embeddings"]

  print(f"{base.shape=}")
  print(f"{diff.shape=}")
  print(f"{xnx.shape=}")
  print(f"{batch.shape=}")
  print(f"{input_embeddings.shape=}")


  print(f"{jnp.sum(input_embeddings[0,:16], axis=-1)=}")
  print(f"{jnp.sum(batch[:4,:16], axis=-1)=}")
  print(f"{jnp.sum(base[:,:8], axis=-1)=}")
  print(f"{jnp.sum(xnx[:8,:8], axis=-1)=}")
  # START_IDX, START_IDX+4
  print(f"{base.shape=}")
  print(f"{diff.shape=}")
  chr_inputs = "".join([chr(x) for x in last_round["batch"].inputs[0]])
  chr_targets = "".join([chr(x) for x in last_round["batch"].targets[0]])
  guesses = []
  for idx in range(128):
    guess = chr(jnp.argmax(base[0][idx]))
    guesses.append(guess)
  chr_guess = "".join(guesses)
  print(f" {chr_inputs=}")
  print(f"{chr_targets=}")
  print(f"  {chr_guess=}")
  for idx in range(95,105):
    g_idx = jnp.argmax(base[0][idx])
    t = chr_targets[idx]
    g = chr(g_idx)
    print(f"{idx=} {t=} {g=} after: {repr(chr_targets[:idx])}")
    # print(f"  base:{jnp.mean(base[0,idx])}")
    for frmidx in range(110):
      f = chr_inputs[frmidx]
      print(f"  {frmidx=} {f=} {diff[frmidx,idx,g_idx]}")
      # print(f"   xnx:{jnp.mean(xnx[frmidx,idx,:])}")

doitlo()


# Unused stuff

In [None]:
def run_forever():
  orig_config = config_lib.get_config()
  results = []
  while True:
      config = ml_collections.ConfigDict(orig_config)
      config.dataset_batch_size = 16
      config.learning_rate = random.random()
      # config.momentum = random.choice([0.0,0.25,0.5,0.75,0.9,0.925,0.95,0.975,0.99])
      config.momentum = 0.0
      # config.model_embed_dim = random.choice([16,32,64,128,256,512,1024])
      config.model_embed_dim = 64
      # config.transformer_num_layers = random.choice([2,4,6,8,10])
      config.transformer_num_layers = 2
      config.transformer_num_attention_heads = random.choice([2,4,6,8,12,16])
      config.transformer_attention_size_per_head = random.choice([2,4,6,8,12,16])
      # config.transformer_dropout_rate = random.choice([0.0,0.05,0.1,0.15,0.2,0.25,0.5])
      config.transformer_dropout_rate = 0.1
      config = ml_collections.FrozenConfigDict(config)
      t_before = time.time()
      result = runone(config)
      t_after = time.time()
      result = dict(config=config, result=result, time_taken=t_after - t_before)
      all_results.append(result)
      print(result)

# run_forever()


In [None]:
def checkresult():
  lookup = {}
  all_lrs = set()
  all_ms = set()
  for x in all_results:
    lr = x['config'].learning_rate
    all_lrs.add(lr)
    m = x['config'].momentum
    all_ms.add(m)
    key = f"lr={lr} m={m}"
    lookup[key] = x['result']
  for learning_rate in sorted(all_lrs):
    for momentum in sorted(all_ms):
      key = f"lr={learning_rate} m={momentum}"
      if lookup.get(key):
        print(learning_rate, momentum, lookup[key])
  print(len(all_results))


checkresult()

In [None]:
from google.colab import data_table
data_table.enable_dataframe_formatter()
import pandas as pd

def mkpandas():
  c = {}
  keys = all_results[0]['config'].keys()
  for k in keys:
    c[f"c_{k}"] = []
  loss = []
  accuracy = []
  time_taken = []
  for r in all_results:
    tt = r.get('time_taken')
    if tt is None: continue
    for k in keys:
      c[f"c_{k}"].append(r['config'][k])
    loss.append(r['result']['loss'])
    accuracy.append(r['result']['accuracy'])
    time_taken.append(r.get('time_taken'))
  alldict = dict(loss=loss, accuracy=accuracy, time_taken=time_taken)
  alldict.update(c)
  return pd.DataFrame(alldict)


mkpandas()


Unnamed: 0,loss,accuracy,time_taken,c_dataset_batch_size,c_dataset_sequence_length,c_learning_rate,c_model_embed_dim,c_momentum,c_transformer_attention_size_per_head,c_transformer_dropout_rate,c_transformer_num_attention_heads,c_transformer_num_layers
0,,0.000000,136.508780,16,256,0.466423,512,0.950,12,0.25,6,4
1,,0.000000,70.807991,16,256,0.549519,256,0.925,2,0.50,16,2
2,,0.000000,77.242211,16,256,0.610740,512,0.950,6,0.25,8,2
3,0.263198,0.892334,141.738806,16,256,0.259931,256,0.900,16,0.50,8,6
4,,0.000000,731.489069,16,256,0.442092,1024,0.250,6,0.25,6,10
...,...,...,...,...,...,...,...,...,...,...,...,...
1007,0.273123,0.886230,51.125581,16,256,0.916712,64,0.000,2,0.10,16,2
1008,,0.000000,43.224882,16,256,0.941589,64,0.000,2,0.10,12,2
1009,0.262928,0.887939,52.436880,16,256,0.342335,64,0.000,8,0.10,16,2
1010,0.267322,0.885986,30.867207,16,256,0.852922,64,0.000,12,0.10,6,2


In [None]:
from google.colab import data_table
data_table.enable_dataframe_formatter()
import pandas as pd

def mk2dpandas():
  lookup = {}
  all_lr = set()
  all_m = set()
  for x in all_results:
    lr = x['config'].learning_rate
    m = x['config'].momentum
    if x['config'].model_embed_dim != 128: continue
    if x['config'].dataset_batch_size != 16: continue
    all_lr.add(lr)
    all_m.add(m)
    key = f"lr={lr} m={m}"
    lookup[key] = x['result']
  data = {}
  data['lr'] = []
  for m in all_m:
    data[m] = []
  for lr in sorted(all_lr):
    data['lr'].append(lr)
    for m in sorted(all_m):
      key = f"lr={lr} m={m}"
      v = lookup.get(key)
      if v:
        data[m].append("%.5f" % v['accuracy'])
      else:
        data[m].append("")
  return pd.DataFrame(data)


mk2dpandas()

Unnamed: 0,lr,0.9,0.75,0.925,0.99,0.0,0.95,0.25,0.5,0.975
0,0.005,0.88623,,,,,,,,
1,0.007058,,,0.87695,,,,,,
2,0.0075,0.88794,,,,,,,,
3,0.01,0.88867,,,,,,,,
4,0.015,0.89038,,,,,,,,
5,0.02,0.8894,,,,,,,,
6,0.023018,0.89111,,,,,,,,
7,0.03,0.88965,,,,,,,,
8,0.044894,,,,0.0,,,,,
9,0.059026,,,,,,0.89233,,,


In [None]:
import json

def showasjson():
  for x in all_results:
    d = x.copy()
    d['config'] = d['config'].to_dict()
    print(json.dumps(d))

showasjson()

{"config": {"dataset_batch_size": 256, "dataset_sequence_length": 256, "learning_rate": 0.005, "model_embed_dim": 32, "momentum": 0.0, "transformer_attention_size_per_head": 4, "transformer_dropout_rate": 0.1, "transformer_num_attention_heads": 4, "transformer_num_layers": 4}, "result": {"loss": 2.4015908241271973, "accuracy": 0.448394775390625}}
{"config": {"dataset_batch_size": 256, "dataset_sequence_length": 256, "learning_rate": 0.005, "model_embed_dim": 32, "momentum": 0.5, "transformer_attention_size_per_head": 4, "transformer_dropout_rate": 0.1, "transformer_num_attention_heads": 4, "transformer_num_layers": 4}, "result": {"loss": 1.938565731048584, "accuracy": 0.4912872314453125}}
{"config": {"dataset_batch_size": 256, "dataset_sequence_length": 256, "learning_rate": 0.005, "model_embed_dim": 32, "momentum": 0.75, "transformer_attention_size_per_head": 4, "transformer_dropout_rate": 0.1, "transformer_num_attention_heads": 4, "transformer_num_layers": 4}, "result": {"loss": 1.58

In [None]:
import logging

logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)
logging.warning("start")

import jax
import sys

config = config_lib.get_config()


def main():
  return
  # print("hello, world!")
  # ds = dataset.load_from_file("generated-easy-attention-dataset.txt")
  # for idx, batch in enumerate(ds):
  #     print(idx, batch)
  # print("super down")
  # train_rng = jax.random.key(12345)
  # workdir = "./training_workdir/"
  # train.run_training("/generated-easy-attention-dataset.txt", train_rng, config, workdir)


# if __name__ == "__main__":
#     sys.exit(main())


