Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
!pip install gcsfs
!pip install sentencepiece

In [None]:
from typing import Sequence

from collections import defaultdict
import itertools
import numpy as np
import pandas as pd

import gcsfs
import sentencepiece as spm

fs = gcsfs.GCSFileSystem('transformer-ngrams')

TOKENIZER_PATH = 'gs://transformer-ngrams/32768.model'
VOCAB_SIZE = 32768
BOS_TOKEN = 1

See all files

In [None]:
fs.ls('gs://transformer-ngrams')

['transformer-ngrams/32768.model',
 'transformer-ngrams/TinyStories',
 'transformer-ngrams/Wikipedia']

For TinyStories, we provide


*   rules data (both train and eval via `train_data_rules` and `eval_data_rules`) obtained from 100 random train and eval stories (corresponding to 100 parquet files in each directory). See Rules section below on how to interpret this data. Creating this file involved a computationally expensive aggregation of ngram statistics.
*   training data (via `training_data`) such that training on such chunks yields the exact n-gram statistics used in *Understanding Transformers via Simple N-Gram Statistics* and in the rule formation of the preceding rules data. Note that different chunking of the training data will affect the n-gram statistics and thus rules appearing rules data. Hence making use of the rules data for trained models requires training on the training data provided.

For eval data, one can acquire that from Huggingface and process it as desired (it does not contribute to the training dataset statistics).



In [None]:
fs.ls('gs://transformer-ngrams/TinyStories')

['transformer-ngrams/TinyStories/',
 'transformer-ngrams/TinyStories/eval_data_rules',
 'transformer-ngrams/TinyStories/train_data_rules',
 'transformer-ngrams/TinyStories/training_data']

For Wikipedia, we provide

*   rules data (via `eval_data_rules`) obtained from 10 random eval chunks (corresponding to 10 parquet files that are sharded into 20 files via `[chunk_number]_[shard_number].parquet`). See Rules section below on how to interpret this data. Creating this file involved a computationally expensive aggregation of ngram statistics.
*   training data (via `training_data`) such that training on such chunks yields the exact n-gram statistics used in *Understanding Transformers via Simple N-Gram Statistics* and in the rule formation of the preceding rules data. Note that different chunking of the training data will affect the n-gram statistics and thus rules appearing rules data. Hence making use of the rules data for trained models requires training on the training data provided
*   since our train/eval split of the Wikipedia MTV-5 is non-canonical, we alos provide eval chunks for validation data (via `eval_data`). The particular way that the eval data is chunked plays no special role.

In [None]:
fs.ls('gs://transformer-ngrams/Wikipedia')

['transformer-ngrams/Wikipedia/',
 'transformer-ngrams/Wikipedia/eval_data',
 'transformer-ngrams/Wikipedia/eval_rules_data',
 'transformer-ngrams/Wikipedia/train_data']

Load Tokenizer

In [None]:
with fs.open(TOKENIZER_PATH) as f:
  tokenizer = spm.SentencePieceProcessor(model_proto=f.read())

# encode: text => id
print(tokenizer.encode_as_pieces('This is a test'))
print(tokenizer.encode_as_ids('This is a test'))

# decode: id => text
print(tokenizer.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']))
print(tokenizer.decode_ids([209, 31, 9, 375, 586]))

# Rules

In [None]:
MAX_CONTEXT_SIZE = 7

RULE_TYPES = {0: [()]}
RULE_TYPES.update({i+1: [('+',) + x for x in itertools.product(['-', '*', '+'], repeat=i)] for i in range(MAX_CONTEXT_SIZE)})

def get_rules_index(max_context_size):
  rules = {i: rule_list for i, rule_list in RULE_TYPES.items() if i <= max_context_size}
  return {i: r for i, r in enumerate(sum(rules.values(), []))}

RULES_INDEX = get_rules_index(MAX_CONTEXT_SIZE)
RULE_TO_INDEX = {v: k for k, v in RULES_INDEX.items()}
# 1094 (note: there are redundancies)
NUM_RULES = len(RULES_INDEX)

RULES_INDEX

{0: (),
 1: ('+',),
 2: ('+', '-'),
 3: ('+', '*'),
 4: ('+', '+'),
 5: ('+', '-', '-'),
 6: ('+', '-', '*'),
 7: ('+', '-', '+'),
 8: ('+', '*', '-'),
 9: ('+', '*', '*'),
 10: ('+', '*', '+'),
 11: ('+', '+', '-'),
 12: ('+', '+', '*'),
 13: ('+', '+', '+'),
 14: ('+', '-', '-', '-'),
 15: ('+', '-', '-', '*'),
 16: ('+', '-', '-', '+'),
 17: ('+', '-', '*', '-'),
 18: ('+', '-', '*', '*'),
 19: ('+', '-', '*', '+'),
 20: ('+', '-', '+', '-'),
 21: ('+', '-', '+', '*'),
 22: ('+', '-', '+', '+'),
 23: ('+', '*', '-', '-'),
 24: ('+', '*', '-', '*'),
 25: ('+', '*', '-', '+'),
 26: ('+', '*', '*', '-'),
 27: ('+', '*', '*', '*'),
 28: ('+', '*', '*', '+'),
 29: ('+', '*', '+', '-'),
 30: ('+', '*', '+', '*'),
 31: ('+', '*', '+', '+'),
 32: ('+', '+', '-', '-'),
 33: ('+', '+', '-', '*'),
 34: ('+', '+', '-', '+'),
 35: ('+', '+', '*', '-'),
 36: ('+', '+', '*', '*'),
 37: ('+', '+', '*', '+'),
 38: ('+', '+', '+', '-'),
 39: ('+', '+', '+', '*'),
 40: ('+', '+', '+', '+'),
 41: ('+',

In [None]:
RULES_SUFFIX = {i: rule for i, rule in RULES_INDEX.items() if all([r == '+' for r in rule])}
RULES_SUFFIX.update({0: ()})

# Rules which only use the suffix
RULES_SUFFIX

{0: (),
 1: ('+',),
 4: ('+', '+'),
 13: ('+', '+', '+'),
 40: ('+', '+', '+', '+'),
 121: ('+', '+', '+', '+', '+'),
 364: ('+', '+', '+', '+', '+', '+'),
 1093: ('+', '+', '+', '+', '+', '+', '+')}

In [None]:
RULES_SUBGRAM = {i: rule for i, rule in RULES_INDEX.items() if all([r in ['-','+'] for r in rule])}
RULES_SUBGRAM.update({0: ()})

# Rules which only keep or drop tokens
RULES_SUBGRAM

{0: (),
 1: ('+',),
 2: ('+', '-'),
 4: ('+', '+'),
 5: ('+', '-', '-'),
 7: ('+', '-', '+'),
 11: ('+', '+', '-'),
 13: ('+', '+', '+'),
 14: ('+', '-', '-', '-'),
 16: ('+', '-', '-', '+'),
 20: ('+', '-', '+', '-'),
 22: ('+', '-', '+', '+'),
 32: ('+', '+', '-', '-'),
 34: ('+', '+', '-', '+'),
 38: ('+', '+', '+', '-'),
 40: ('+', '+', '+', '+'),
 41: ('+', '-', '-', '-', '-'),
 43: ('+', '-', '-', '-', '+'),
 47: ('+', '-', '-', '+', '-'),
 49: ('+', '-', '-', '+', '+'),
 59: ('+', '-', '+', '-', '-'),
 61: ('+', '-', '+', '-', '+'),
 65: ('+', '-', '+', '+', '-'),
 67: ('+', '-', '+', '+', '+'),
 95: ('+', '+', '-', '-', '-'),
 97: ('+', '+', '-', '-', '+'),
 101: ('+', '+', '-', '+', '-'),
 103: ('+', '+', '-', '+', '+'),
 113: ('+', '+', '+', '-', '-'),
 115: ('+', '+', '+', '-', '+'),
 119: ('+', '+', '+', '+', '-'),
 121: ('+', '+', '+', '+', '+'),
 122: ('+', '-', '-', '-', '-', '-'),
 124: ('+', '-', '-', '-', '-', '+'),
 128: ('+', '-', '-', '-', '+', '-'),
 130: ('+', '-

In [None]:
def _get_possible_modified_contexts_for_rules(base_context: list[int]):
  """
  Given (t1, ..., tn) each token can be kept (+), marginalized (*), or dropped (-).
  Returns unique contexts from all possible such operations (which start with a +) and corresponding rule from RULE_INDEX
  """
  ret = []
  n = len(base_context)
  context_to_rule = defaultdict(list)
  for i in range(n): # create modified contexts where the ith token is untouched
    choices = itertools.product(['-', '*', '+'], repeat=n-i-1) # '-': drop token, '*': marginalize, '+': keep
    remaining_context = base_context[i+1:]
    for choice in choices:
      # prefix is always a '+'
      context = [base_context[i]]
      assert len(choice) == len(remaining_context)
      for c, token in zip(choice,remaining_context):
        if c == '-':
          continue
        elif c == '*':
          context.append(0)
        elif c == '+':
          context.append(token)
      context = tuple(context)
      ret.append(context)
      context_to_rule[context].append(RULE_TO_INDEX[('+',) + choice])
  unique_contexts = sorted(list(set(ret)), key=lambda x: (len(x), x), reverse=True)
  return unique_contexts, context_to_rule

def _get_contexts_and_rules(tokens, max_context_size) -> tuple[list[list[tuple[int]]], list[list[int]]]:
  """
  Goes through tokens and returns all possible modified contexts up to length max_len
  Also filter out any contexts in which 1 is not initial.
  """
  contexts_per_token = []
  context_to_rules_per_token = []
  for current_idx, t in enumerate(tokens[1:], 1):
    min_idx = max(0, current_idx-max_context_size)
    raw_context = list(tokens[min_idx:current_idx])
    unique_contexts, context_to_rule = _get_possible_modified_contexts_for_rules(raw_context,)
    contexts_per_token.append(unique_contexts)
    context_to_rules_per_token.append(context_to_rule)
  filter_fn = lambda x: (len(x) == 1) or (BOS_TOKEN not in x[1:])
  contexts_per_token = [list(filter(filter_fn, ctxs)) for ctxs in contexts_per_token]
  return contexts_per_token, context_to_rules_per_token

def get_df_ctx_rules_data(tokens: Sequence[int], max_context_size: int):
  """
  tokens: A sequence of nonzero integers. We assume BOS_TOKEN = 1.
  max_context_size: The maximum context to consider for rules. Assumed at most 7 given our RULES_INDEX has this bound.

  Returns a dataframe with 3 columns:
    index: the index of the current position
    context: a rule context (a tuple of ints) obtained from keeping, dropping, or marginalizing tokens; a 0 token corresponds to a marginalized token
    rule_index: the list of integers, corresponding to rules yielding that context (as obtained from RULES_INDEX)
  """
  assert max_context_size <= MAX_CONTEXT_SIZE
  contexts_data = []
  rules_data = []
  index_data = []
  ctxs_per_token, ctx_to_rule_per_token = _get_contexts_and_rules(tokens, max_context_size)
  for i, (current_ctxs, current_ctx_to_rules) in enumerate(zip(ctxs_per_token, ctx_to_rule_per_token)):
    contexts_data.extend(current_ctxs)
    index_data.extend([i for _ in range(len(current_ctxs))])
    for ctx in current_ctxs:
      rules_data.append(current_ctx_to_rules[ctx])
  df = pd.DataFrame({'index': index_data, 'context': contexts_data, 'rules': rules_data})
  df = df.sort_values(by=['index', 'rules'], key=lambda col: col if col.name == 'index' else col.apply(lambda x: x[0])).reset_index(drop=True)
  return df

**Example:** Suppose we have `tokens = [2, 3, 4, 5, 1, 6]` and `max_context_size=3`.

Then for the first token (index = 0), we have a single rule, RULES_INDEX[1], which keeps the entire context (2,).

For the second token (index = 1), the base context is (2,3) since we look at most max_context_size=3 tokens. We have four possible rules (corresonding to RULES_INDEX[i], i=1,2,3,4), corresponding respectively to keeping only the last token (3,), only the first token (2,), marginalizing the last token (2,0), or using all previous tokens (2,3).

We repeat for the third token (index = 2), now there are 12 rules corresponding to all valid keep, drop, and marginalize operations on the three tokens of context (2,3,4).

Likewise for the fourth token (index = 3), there are also 12 rules, since we have three tokens of context (3,4,5) given the current token 5.

For the fifth token (index = 4) we have a BOS_TOKEN=1, so there is only a single rule, since our context cannot extend to tokens from a different document.

Finally for the sixth token (index = 5), it is analogous to index = 1 and we have a base context of (1,7) and four rules.

In [None]:
get_df_ctx_rules_data([2,3,4,6,1,7], max_context_size=3)

Unnamed: 0,index,context,rules
0,0,"(2,)",[1]
1,1,"(3,)",[1]
2,1,"(2,)",[2]
3,1,"(2, 0)",[3]
4,1,"(2, 3)",[4]
5,2,"(4,)",[1]
6,2,"(3,)",[2]
7,2,"(3, 0)",[3]
8,2,"(3, 4)",[4]
9,2,"(2,)",[5]


## Load Rules Data

We have 100 random train and random eval stories for TinyStories at

`gs://transformer-ngrams./TinyStories/train_data_rules/{i}.parquet`

`gs://transformer-ngrams./TinyStories/eval_data_rules/{i}.parquet`

for `i = "001", "002",..., "100"`

The following discussion on how to interpret the rules data also applies to the corresponding Wikipedia rules data.

In [None]:
SAMPLE_RULES_PATH = 'gs://transformer-ngrams./TinyStories/eval_data_rules/001.parquet'

with fs.open(SAMPLE_RULES_PATH, 'rb') as f:
  df_rules = pd.read_parquet(f)

Load a dataframe of one rules correpsonding to a story from TinyStories. Columns are

**record_num:** The story index (using the row number of the corresponding train/eval Huggingface parquet file)

**index:** Token index within the story

**token:** Current token

**target:** Target token (ground truth)

**context_sized_used:** How many tokens were considered before selecting a subcontext

**context:** The rule context selected (tokens are kept, dropped, or marginalized, with 0 denoting marginalization)

**rules:** The rules (from RULES_INDEX) obtaining the context

**next_token_counter:** The next-token statistics of the context in [k1 v1 ....] format, where k_i is next token and v_i is the number of occurrences

**rule prediction:** A list of the k_i's from next_token_counter whose v_i's are maximal

**text:** Detokenized token

**target_text:** Detokenized target

In [None]:
df_rules

Unnamed: 0,record_num,index,token,target,context_size_used,context,rules,next_token_counter,rule_prediction,text,target_text
0,0,0,1,31402,1,[1],[1],"[32606, 38614, 32624, 132017, 32318, 40357, 41...",[4146],,Spot
3,0,1,31402,32599,1,[31402],[1],"[1358, 56, 2508, 453, 305, 935, 11655, 2, 3260...",[32600],Spot,.
4,0,1,31402,32599,2,[1],[2],"[32606, 38614, 32624, 132017, 32318, 40357, 41...",[4146],Spot,.
2,0,1,31402,32599,2,"[1, 0]",[3],"[2904, 36075, 7097, 16560, 305, 151603, 1726, ...",[1726],Spot,.
1,0,1,31402,32599,2,"[1, 31402]",[4],"[383, 40, 8130, 1, 328, 9, 5130, 1]",[383],Spot,.
...,...,...,...,...,...,...,...,...,...,...,...
27660,0,84,1076,2086,7,"[3240, 1724, 476, 1003, 305, 0, 0]",[1089],"[1076, 38, 1397, 48, 922, 173, 996, 6, 301, 36...",[301],best,friends
27659,0,84,1076,2086,7,"[3240, 1724, 476, 1003, 305, 0, 1076]",[1090],"[2086, 84]",[2086],best,friends
27722,0,84,1076,2086,7,"[3240, 1724, 476, 1003, 305, 2315]",[1091],"[280, 33, 1076, 83, 841, 73, 2086, 1, 2850, 1]",[1076],best,friends
27658,0,84,1076,2086,7,"[3240, 1724, 476, 1003, 305, 2315, 0]",[1092],"[1076, 33, 2086, 157, 32642, 1]",[2086],best,friends


In [None]:
def convert_counter_to_probs(counter: list[int]):
  "counter = [k1 v1 ...] a sequence of key values, key = next token, value = count of next token"
  assert len(counter) % 2 == 0
  probs = np.zeros(VOCAB_SIZE)
  ks = counter[::2]
  assert BOS_TOKEN not in ks
  vs = counter[1::2]
  mass = sum(vs)
  for k, v in zip(ks, vs):
    probs[k] = v
  probs = probs / mass
  return probs

def dist(counter: Sequence[int], model_probs: np.ndarray):
  """
  Computes variational distance between the probability distribution from the counter and the given model_probs.
  """
  probs = convert_counter_to_probs(counter)
  return 0.5 * np.sum(np.abs(probs - model_probs))

A dummy `df_model_preds` with random next token probability distribution is defined below for illustrative purposes. Provide your own from running inference on your model on the set of tokens from df_rules.

In [None]:
n = len(set(df_rules['index']))
index = list(range(n))

random_probs = np.random.uniform(size=(n, VOCAB_SIZE))
random_probs = random_probs/np.sum(random_probs, axis=1)[:,None]

# Replace df_model_preds['model_probs'] with your model evaluated on the set of tokens given by df_rules['token'].
# The nth index should correspond to the predictive probability distribution of the mode evaluated on the sequence of tokens corresponding to index 0, ..., n-1 from df_rules.
df_model_preds = pd.DataFrame({'index': list(range(n)), 'model_probs': random_probs.tolist()})

In [None]:
df_joined = df_rules.merge(df_model_preds, on='index')
df_joined['distance'] = df_joined.apply(lambda x: dist(x.next_token_counter, x.model_probs), axis=1)


# Computes optimal rule and the associated top_1_acc per token
df_optimal = df_joined.loc[df_joined.groupby('index')['distance'].idxmin()].reset_index(drop=True)
df_optimal['top_1_acc'] = df_optimal.apply(lambda x: (np.argmax(x['model_probs']) in x['rule_prediction']) / len(x['rule_prediction']), axis=1)

# Load Training Data

Load a shard for training data or eval data.

For `TinyStories`, a model should be trained such that the input consists of `tokens[:-1]` and the targets consist of `tokens[1:]`. Tokens that are BOS (equal to `1`) should not be made targets however (they should be masked out as targets).

For `Wikipedia`, a model a model should be trained such that the input consists of `observation` and the targets consist of `target`. The difference is that for TinyStories there is no EOS token (equal to `2`) whereas Wikipedia does (an oversight in the TinyStories dataset preparation process, which were it to be redone, should have been introduced). The EOS token appears in the `target` field of Wikipedia but not in the `observation` field.

In [None]:
TINYSTORES_TRAINING_DATA_PATH = (
    'gs://transformer-ngrams/TinyStories/training_data/'
)
WIKIPEDIA_TRAINING_DATA_PATH = (
    'gs://transformer-ngrams/Wikipedia/training_data/'
)
WIKIPEDIA_EVAL_DATA_PATH = 'gs://transformer-ngrams/Wikipedia/eval_data/'

In [None]:
with fs.open('gs://transformer-ngrams/TinyStories/training_data/001.parquet', 'rb') as f:
  df = pd.read_parquet(f)

sample_text = tokenizer.decode_ids(df.tokens.iloc[0].tolist())

In [None]:
df

Unnamed: 0,tokens
0,"[2178, 769, 280, 4922, 32600, 3746, 3031, 351,..."
1,"[603, 275, 13556, 1071, 5199, 360, 16875, 305,..."
2,"[340, 280, 7094, 32599, 6027, 305, 26360, 1792..."
3,"[22565, 11111, 360, 280, 7320, 2888, 4550, 200..."
4,"[1220, 383, 529, 3031, 32642, 4, 4, 13666, 305..."
...,...
995,"[12205, 32599, 12682, 32600, 26360, 21886, 275..."
996,"[611, 383, 275, 1244, 2561, 3965, 26360, 32599..."
997,"[1656, 481, 326, 2955, 1181, 32599, 2204, 481,..."
998,"[32599, 7086, 908, 275, 15613, 305, 275, 1994,..."
