In [1]:
repo_id = "littlePanic99/nanochat"

base_filename = "base/d20/model_021400.pt"
mid_filename = "mid/d20/model_000809.pt"
sft_filename = "sft/d20/model_000700.pt"
tokenizer_filename = "tokenizer/latest/tokenizer.pkl"


In [2]:
import pickle
from huggingface_hub import hf_hub_download

with open(hf_hub_download(repo_id=repo_id, filename=tokenizer_filename, local_files_only=True), "rb") as f:
    tokenizer = pickle.load(f)

def token_str(enc, tid: int):
    return enc.decode_bytes([tid]).decode("utf-8", errors="replace")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import numpy as np
import torch
from sklearn.cluster import MiniBatchKMeans

def to_f32_cpu(x):
    return x.detach().cpu().float()

def center_inplace(x):
    x -= x.mean(dim=0, keepdim=True)
    return x

def l2_normalize_inplace(x, eps=1e-12):
    x /= (x.norm(dim=1, keepdim=True) + eps)
    return x

def fast_cluster_embedding_ids(X, k=256, batch_size=2048, seed=0):
    """
    X: (n, d) L2-normalized torch tensor on CPU
    returns: dict {cluster_id: [indices]}
    """
    if torch.is_tensor(X):
        X = X.numpy()

    km = MiniBatchKMeans(
        n_clusters=k,
        batch_size=batch_size,
        n_init="auto",
        random_state=seed,
    )

    labels = km.fit_predict(X)

    clusters = {}
    for i, c in enumerate(labels):
        clusters.setdefault(c, []).append(i)

    return clusters


In [5]:
from pprint import pprint
import torch
from huggingface_hub import hf_hub_download


def get_E(repo_id="littlePanic99/nanochat", filename="base/d20/model_021400.pt"):
    ckpt = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True)
    emb = torch.load(ckpt, map_location="cpu")["transformer.wte.weight"]
    return emb.detach()

E = get_E()

In [25]:
E

tensor([[ 0.1943,  2.1562, -0.1719,  ...,  0.9688,  1.1016,  0.4277],
        [-1.0938,  0.5820, -0.2734,  ...,  0.4004,  0.2402,  1.1328],
        [ 0.4902,  0.5078,  0.6445,  ...,  1.3047, -1.5547,  0.5508],
        ...,
        [-0.1289, -0.7031,  0.0728,  ..., -0.7383,  0.7812, -0.0806],
        [ 0.7812, -0.6992,  0.4043,  ...,  0.1426, -0.7656, -0.2656],
        [-0.7812,  0.9258,  0.9727,  ..., -0.4609, -0.3008, -1.3125]],
       dtype=torch.bfloat16)

In [6]:
Xu = l2_normalize_inplace(center_inplace(to_f32_cpu(E)))
clusters = fast_cluster_embedding_ids(Xu, k=256)

In [18]:
cluster_titles = {
  195: "titles_honorifics_degrees",
  248: "abbreviations_acronyms_misc",
  253: "name_fragments_foreign",
  230: "academic_historical_disciplines",
  99: "latin_prefixes_processes",
  201: "name_prefixes_particles",
  252: "proper_nouns_places_ethnic",
  193: "prepositions_function_words",
  82: "syllables_interjections",
  165: "latin_roots_scientific_terms",
  172: "adjectives_general_descriptors",
  146: "countries_world_regions",
  243: "institutional_process_terms",
  211: "discourse_adverbs",
  7: "institutions_organizations",
  32: "given_names_short_forms",
  123: "historical_religious_adjectives",
  202: "generic_nouns_collections",
  242: "common_verbs_actions",
  62: "common_nouns_objects",
  4: "surnames_family_names",
  182: "nationalities_languages",
  160: "acronyms_universities_agencies",
  36: "given_names_formal",
  28: "short_nouns_onomatopoeia",
  113: "gerunds_process_verbs",
  207: "cities_states_places",
  143: "subword_suffixes_fragments",
  209: "ancient_historical_figures",
  84: "south_asia_india_culture",
  149: "ethnic_groups_regions_global",
  184: "uk_aus_places_regions",
  222: "milk",
  128: "given_names_modern",
  78: "religions_islam_judaism",
  192: "internet",
  138: "international_organizations",
  185: "female_given_names_titles",
  45: "astronomy_disasters_events",
  116: "mesopotamia_ancient_near_east",
  237: "technology_brands_internet",
  34: "famous_people_figures",
  93: "surnames_place_based",
  71: "surname_singleton",
  157: "female_given_names_modern",
  47: "virtual_reality",
  80: "ethnic_conflict_groups"
}

In [35]:

import torch
import torch.nn.functional as F

def mean_embedding(E: torch.Tensor, ids) -> torch.Tensor:
    ids = torch.as_tensor(ids, device=E.device)
    return E.index_select(0, ids).mean(dim=0)


def cluster_quality(E: torch.Tensor, ids):
    X = E[ids]
    X = F.normalize(X, dim=1)
    c = F.normalize(X.mean(dim=0), dim=0)

    cos = X @ c
    return {
        "n": X.shape[0],
        "cos_mean": cos.mean().item(),
        "cos_min": cos.min().item(),
        "cos_max": cos.max().item(),
        "cos_std": cos.std(unbiased=False).item(),
    }

In [102]:
reference_groups = []


for cluster_id, ids in clusters.items():
    tokens = [token_str(tokenizer, tid).strip() for tid in ids]
    pct_start_upper = len([t for t in  tokens if len(t) > 1 and t[0].isupper()])/len(tokens)
    
    reference_groups.append({
        "cluster_id": cluster_id,
        "ids": ids,
        "tokens": tokens,
        "title": cluster_titles.get(cluster_id),
        "embedding": mean_embedding(E, ids),
        "cluster_quality": cluster_quality(E, ids),
        "probable_reference": len(ids) > 3 and pct_start_upper > 0.5,
        "pct_start_upper": pct_start_upper
    })        

# reference_groups = sorted(reference_groups, key=lambda g: g["cluster_quality"]["cos_max"])
reference_groups = sorted(reference_groups, key=lambda g: -g["pct_start_upper"])

In [103]:
for reference_group in reference_groups:
    cos_max = reference_group["cluster_quality"]["cos_max"]
    cos_std = reference_group["cluster_quality"]["cos_std"]
    title = str(reference_group["title"])
    tokens = reference_group["tokens"]
    pct_start_upper = reference_group["pct_start_upper"]
    
    print(f"{cos_max:0.3f}  {cos_std:0.3f}  {pct_start_upper:0.3f}  {title:<30}   {', '.join([str(v) for v in tokens[0:10]])}")


0.887  0.000  1.000  surname_singleton                Hart, Hart
0.891  0.000  1.000  virtual_reality                  VR, VR
1.000  0.000  1.000  None                             Roma
0.906  0.002  1.000  None                             Sounds, Sounds
1.000  0.000  1.000  None                             Disclaimer
1.000  0.000  1.000  None                             Shap
1.000  0.000  1.000  None                             Raja
1.000  0.000  1.000  None                             Yoruba
0.996  0.000  1.000  None                             Engl
1.000  0.000  1.000  None                             Matilda
1.000  0.000  1.000  None                             Experienced
1.000  0.000  1.000  None                             Charon
1.000  0.000  1.000  None                             IMS
1.000  0.000  1.000  None                             Simplified
0.410  0.039  0.992  subword_suffixes_fragments       IN, IS, ON, AS, HE, RE, IS, RO, VID, ION
0.379  0.037  0.991  common_nouns_ob

In [106]:
for reference_group in reference_groups:
    cos_max = reference_group["cluster_quality"]["cos_max"]
    cos_std = reference_group["cluster_quality"]["cos_std"]
    title = str(reference_group["title"])
    tokens = reference_group["tokens"]
    if 'name' in title:        
        print(f"{cos_max:0.3f}   {cos_std:0.3f}   {title:<30}   {', '.join(tokens[0:10])}")

0.887   0.000   surname_singleton                Hart, Hart
0.523   0.053   given_names_modern               Pat, David, Mark, Carol, Sam, Ken, Alex, Rob, Dan, Peter
0.463   0.043   given_names_short_forms          jor, Ed, Rober, Jer, Ab, Hen, Franc, Scot, Ed, Jul
0.408   0.043   surnames_place_based             Lee, Franklin, Lewis, ilton, Grant, Rand, Ford, Ross, Hamilton, Newton
0.555   0.048   female_given_names_modern        Susan, leen, Barbara, Lisa, Laura, Karen, Jennifer, Nancy, Linda, nda
0.453   0.047   given_names_formal               John, William, James, George, Paul, Thomas, Robert, Louis, Jose, Charles
0.457   0.048   name_prefixes_particles          �, �, �, �, de, le, cl, Th, St, qu
0.480   0.055   female_given_names_titles        Mary, Ann, abeth, izabeth, Elizabeth, Mrs, Rose, Anne, Jane, Maria
0.471   0.049   surnames_family_names            ohn, ley, augh, erson, son, berg, ston, Smith, inson, stein
0.490   0.039   name_fragments_foreign           �, esc, adel, F

# Token Cluster Usage

In [84]:
def get_token_cluster_map(clusters):
    inv = {}
    for k, vs in clusters.items():
        for v in vs:
            inv[v] = int(k)
    return inv

token_cluster_map = get_token_cluster_map(clusters)

cluster_lookup = {c["cluster_id"]: c for c in reference_groups}


In [75]:
from glob import glob
import pyarrow.parquet as pq

def iter_parquet_text(glob_pattern: str, field: str = "text"):
    for path in glob(glob_pattern):
        table = pq.read_table(path, columns=[field])
        for value in table[field].to_pylist():
            if value is not None:
                yield value



In [76]:
!ls ../../../../data/fineweb-edu-100b-shuffle

shard_00000.parquet


In [89]:
for i, text in enumerate(iter_parquet_text("../../../../data/fineweb-edu-100b-shuffle/*.parquet")):
    sample_text =  text
    break

In [90]:
from collections import Counter

sample_text_clusters = [token_cluster_map[idx] for idx in tokenizer.encode_ordinary(sample_text)]

for cluster_id, count in Counter(sample_text_clusters).most_common(20):
    reference_group = cluster_lookup[cluster_id]
    title = str(reference_group["title"])
    tokens = reference_group["tokens"]
    
    print(f"{cluster_id:<5}   {count:<5}   {title:<30}   {', '.join([str(v) for v in tokens[0:10]])}")
    


58      465     None                             , !, ", ', (, *, +, ,, -, /
193     214     prepositions_function_words      of, in, to, for, on, with, by, at, from, In
95      145     None                             is, be, are, can, was, have, will, has, do, were
127     95      None                             , ), ., ., ?, )., ),, :, .”, )
131     84      None                             0, 1, 2, 3, 4, 5, 6, 7, 8, 9
11      74      None                             man, work, time, act, years, form, way, effect, count, sign
27      42      None                             ation, ulation, development, ization, addition, treatment, olution, ruction, growth, production
137     40      None                             important, interest, impact, success, results, future, ways, changes, symptoms, types
220     35      None                             dam, road, traffic, rail, concrete, mining, roads, grid, pit, mine
155     34      None                             import, $, pay, mone

In [123]:
from collections import Counter

cluster_count = Counter()
prev_pair_count = {c: Counter() for c in cluster_lookup.keys()}
next_pair_count = {c: Counter() for c in cluster_lookup.keys()}

total = 0

for text in iter_parquet_text("../../../../data/fineweb-edu-100b-shuffle/*.parquet"):
    ids = tokenizer.encode_ordinary(text)
    clusters = [token_cluster_map[i] for i in ids]

    for j, c in enumerate(clusters):
        cluster_count[c] += 1

        if j > 0:
            prev_pair_count[c][clusters[j - 1]] += 1
        if j < len(clusters) - 1:
            next_pair_count[c][clusters[j + 1]] += 1

        total += 1
        if total >= 1_000_000:
            break
    if total >= 1_000_000:
        break


In [125]:
for cluster_id, count in Counter(token_cluster_counts).most_common():
    reference_group = cluster_lookup[cluster_id]
    if reference_group["probable_reference"]:
        title = str(reference_group["title"])
        if not "name" in title:
            continue
        tokens = reference_group["tokens"]
        
        print(f"{cluster_id:<5}   {count:<10}   {title:<30}   {', '.join([str(v) for v in tokens[0:10]])}")
        for prev_cluster_id, prev_count in prev_pair_count[cluster_id].most_common(3):
            prev_reference_group = cluster_lookup[prev_cluster_id]
            prev_title = str(prev_reference_group["title"])
            prev_tokens = prev_reference_group["tokens"]
            print(f"   {prev_cluster_id:<5}   {prev_count:<10}   {prev_title:<30}   {', '.join([str(v) for v in prev_tokens[0:10]])}")
        print()
        for next_cluster_id, next_count in next_pair_count[cluster_id].most_common(3):
            next_reference_group = cluster_lookup[next_cluster_id]
            next_title = str(next_reference_group["title"])
            next_tokens = next_reference_group["tokens"]
            print(f"   {next_cluster_id:<5}   {next_count:<10}   {next_title:<30}   {', '.join([str(v) for v in next_tokens[0:10]])}")
        print("-"*100)

201     4482         name_prefixes_particles          �, �, �, �, de, le, cl, Th, St, qu
   58      1469         None                             , !, ", ', (, *, +, ,, -, /
   127     810          None                             , ), ., ., ?, )., ),, :, .”, )
   193     529          prepositions_function_words      of, in, to, for, on, with, by, at, from, In

   174     900          None                             U, X, `, j, q, u, v, w, �, �
   114     529          None                             E, Y, a, b, c, e, f, g, h, i
   187     286          None                             ople, ween, iron, omen, reen, gin, ledge, chie, oura, ison
----------------------------------------------------------------------------------------------------
4       2113         surnames_family_names            ohn, ley, augh, erson, son, berg, ston, Smith, inson, stein
   58      547          None                             , !, ", ', (, *, +, ,, -, /
   127     251          None                    