In [2]:
from typing import List
from collections import defaultdict
import os
import pickle

import numpy as np
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import scipy.linalg
import torch
import torch.nn.functional as F
import sklearn.cluster

import datasets
from datasets import load_dataset
from transformers import AutoTokenizer, GPTNeoXForCausalLM

In [3]:
pile_canonical = "/om/user/ericjm/the_pile/the_pile_test_canonical_200k"
# ----- load the_pile test set -----
dataset = datasets.load_from_disk(pile_canonical)

def tokenize_sample(sample):
    tokens = tokenizer(sample["text"], return_tensors='pt', 
                        max_length=1024, truncation=True)["input_ids"]
    return {"input_ids": tokens}

starting_indexes = np.array([0] + list(np.cumsum(dataset["preds_len"])))

def loss_idx_to_dataset_idx(idx):
    """given an idx in range(0, 10658635), return
    a sample index in range(0, 20000) and pred-in-sample
    index in range(0, 1023). Note token-in-sample idx is
    exactly pred-in-sample + 1"""
    sample_index = np.searchsorted(starting_indexes, idx, side="right") - 1
    pred_in_sample_index = idx - starting_indexes[sample_index]
    return int(sample_index), int(pred_in_sample_index)

def get_context(idx):
    """given idx in range(0, 10658635), return dataset sample
    and predicted token index within sample, in range(1, 1024)."""
    sample_index, pred_index = loss_idx_to_dataset_idx(idx)
    return dataset[sample_index], pred_index+1

def print_context(idx, context_length=-1):
    """
    given idx in range(0, 10658635), print prompt preceding the corresponding
    prediction, and highlight the predicted token.
    """
    sample, token_idx = get_context(idx)
    prompt = sample["split_by_token"][:token_idx]
    if context_length > 0:
        prompt = prompt[-context_length:]
    prompt = "".join(prompt)
    token = sample["split_by_token"][token_idx]
    print(prompt + "\033[41m" + token + "\033[0m")



In [4]:
# print_context(106_396_003)

In [5]:
with open("/om/user/ericjm/results/the-everything-machine/clustering-0/clusters_full_more.pkl", 'rb') as f:
    clusters = pickle.load(f)

In [6]:
idxs, C, C_abs = torch.load("/om/user/ericjm/results/the-everything-machine/clustering-0/full_more.pt", map_location=torch.device('cpu'))

In [7]:
len(idxs) # indexes into The Pile test set which were clustered

10000

In [8]:
clusters_labels, _ = clusters[400]
label_frequencies = defaultdict(int)
for l in clusters_labels:
    label_frequencies[l] += 1

labels_sorted_by_freq = sorted(label_frequencies.keys(), key=lambda k: label_frequencies[k], reverse=True)
# label_permutation = [labels_sorted_by_freq.index(i) for i in labels_sorted_by_freq]
permutation = []
indices = defaultdict(list)
for i, cls in enumerate(clusters_labels):
    indices[cls].append(i)
for cls in labels_sorted_by_freq:
    permutation.extend(indices[cls])

In [100]:
def context_tokens(idx, context_length=100):
    sample, token_idx = get_context(idx)
    prompt = sample["split_by_token"][:token_idx]
    token = sample["split_by_token"][token_idx]
    return prompt[-context_length:] + [token]

def contains_unicode(text: str) -> bool:
    return any(ord(char) > 127 for char in text)

def tokens_to_latex(tokens: List[str], highlight_index=-1) -> str:
    latex_code = ""
    for i, token in enumerate(tokens):
        # choose the text that will go inside the \tok command after {\strut}
        if token == "\n":
            latex_text = r"{\textbackslash}n" # some text that represents a newline
            # latex_text = "↲" # sadly these aren't working
        elif all([c == " " for c in token]):
            latex_text = r"\phantom{" + "a"*len(token) + r"}" # some invisible text that represents a space
        elif token == "\t":
            latex_text = r"\phantom{aaaa}" # some invisible text that represents a tab
        else:
            latex_text = token.replace("_", r"\_").replace("#", r"\#").replace("$", r"\$").replace("%", r"\%").replace("{", r"\{").replace("}", r"\}")
        background_color = "white" if i != highlight_index % len(tokens) else "lightred"
        latex_code += r'\tok[{}]'.format(background_color) + r'{{\strut}' + latex_text + '}'
        latex_code += r'\allowbreak '  # Allow line breaks between tokens
        if token == "\n":
            latex_code += r"\\"
    return latex_code

## Cluster 50

In [141]:
# choose diverse samples from each cluster
i_choices = [
    0,
    3,
    4,
    6,
    # 7, save page space
    20
]
context_lengths = {
    0: 99,
    3: 100,
    4: 102,
    6: 100,
    # 7: 100, 
    20: 71
}

In [142]:
cluster_number = 50
cluster_idxs_is = indices[labels_sorted_by_freq[cluster_number]]
cluster_idxs = [idxs[cluster_idxs_i] for cluster_idxs_i in cluster_idxs_is]

print_context(cluster_idxs[20], context_length=71)

	QCBlockListMsg          = 0x0a
	GetLatestStatusMsg      = 0x0b
	LatestStatusMsg         = 0x0c
	PrepareBlockHashMsg     = 0x0d
	GetViewChangeMsg        = 0x0e
	PingMsg                 = 0x0[41mf[0m


In [143]:
text = ""

for i in i_choices:
    tokens = context_tokens(cluster_idxs[i], context_lengths[i])
    text += tokens_to_latex(tokens)
    if i != i_choices[-1]:
        text += "\n\n{\color{gray}\\rule{0.99\linewidth}{0.5pt}}\n\n"

In [144]:
with open("../texts/pile/cluster50.tex", 'w') as f:
    f.write(text)

## Cluster 100

In [106]:
# choose diverse samples from each cluster
i_choices = [
    13,
    1,
    # 2,
    # 9,
    # 15,
    12,
    3,
    8,
]

context_lengths = {
    13: 97,
    1: 73,
    # 2: 78,
    # 9: 79,
    # 15: 73, # really constrained by width here!
    12: 75,
    3: 46,
    8: 77
}

In [107]:
cluster_number = 100
cluster_idxs_is = indices[labels_sorted_by_freq[cluster_number]]
cluster_idxs = [idxs[cluster_idxs_i] for cluster_idxs_i in cluster_idxs_is]

print_context(cluster_idxs[12], context_length=75)

<!--
/**
 * Copyright (c) 2019, The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.[41m
[0m


In [108]:
text = ""

for i in i_choices:
    tokens = context_tokens(cluster_idxs[i], context_lengths[i])
    text += tokens_to_latex(tokens)
    if i != i_choices[-1]:
        text += "\n\n{\color{gray}\\rule{0.99\linewidth}{0.5pt}}\n\n"

In [109]:
with open("../texts/pile/cluster100.tex", 'w') as f:
    f.write(text)

## Cluster 146

In [110]:
# choose diverse samples from each cluster
i_choices = [
    3,
    5,
    9,
    12,
]

context_lengths = {
    3: 75,
    5: 75,
    9: 75,
    12: 100
}

In [111]:
cluster_number = 146
cluster_idxs_is = indices[labels_sorted_by_freq[cluster_number]]
cluster_idxs = [idxs[cluster_idxs_i] for cluster_idxs_i in cluster_idxs_is]

print_context(cluster_idxs[12], context_length=100)

 United States Patent No. 6,073,124 (issued June 6, 2000) ("the '124 patent"). Microsoft in turn asserted counterclaims against NCI for infringement of three of its patentsUnited States Patent Nos. 5,822,526, 5,999,914 and 5,794,006. Only terms of the '124 patent are presently before the Court; interpretation of claims in Microsoft's patents will be interpreted in a separate Markman hearing to be held on November 15[41m,[0m


In [112]:
text = ""

for i in i_choices:
    tokens = context_tokens(cluster_idxs[i], context_lengths[i])
    text += tokens_to_latex(tokens)
    if i != i_choices[-1]:
        text += "\n\n{\color{gray}\\rule{0.99\linewidth}{0.5pt}}\n\n"

In [113]:
with open("../texts/pile/cluster146.tex", 'w') as f:
    f.write(text)

## Cluster 269

In [137]:
# choose diverse samples from each cluster
i_choices = {
    0: 100,
    1: 40,
    3: 96,
    5: 16,
    6: 24,
    7: 38
}

In [138]:
cluster_number = 269
cluster_idxs_is = indices[labels_sorted_by_freq[cluster_number]]
cluster_idxs = [idxs[cluster_idxs_i] for cluster_idxs_i in cluster_idxs_is]

print_context(cluster_idxs[7], context_length=38)

 In 1954, the couple published Living the Good Life which inspired many young, educated Americans to create simpler, rural lifestyles and the back-to-the-land movement of the 1960[41ms[0m


In [139]:
text = ""

for j, (i, context_length) in enumerate(i_choices.items()):
    tokens = context_tokens(cluster_idxs[i], context_length)
    text += tokens_to_latex(tokens)
    if j != len(i_choices)-1:
        text += "\n\n{\color{gray}\\rule{0.99\linewidth}{0.5pt}}\n\n"

In [140]:
with open("../texts/pile/cluster269.tex", 'w') as f:
    f.write(text)

## Cluster 278

In [118]:
# choose diverse samples from each cluster
i_choices = {
    2: 100,
    3: 100,
    5: 29
}

In [119]:
cluster_number = 278
cluster_idxs_is = indices[labels_sorted_by_freq[cluster_number]]
cluster_idxs = [idxs[cluster_idxs_i] for cluster_idxs_i in cluster_idxs_is]

print_context(cluster_idxs[5], context_length=29)

See: http://jsfiddle.net/mWFGZ/1/
html, body {
    margin: 0;
    padding[41m:[0m


In [120]:
text = ""

for j, (i, context_length) in enumerate(i_choices.items()):
    tokens = context_tokens(cluster_idxs[i], context_length)
    text += tokens_to_latex(tokens)
    if j != len(i_choices)-1:
        text += "\n\n{\color{gray}\\rule{0.99\linewidth}{0.5pt}}\n\n"

In [121]:
with open("../texts/pile/cluster278.tex", 'w') as f:
    f.write(text)

## Cluster 292

In [122]:
# choose diverse samples from each cluster
i_choices = {
    0: 26,
    1: 47,
    2: 100,
    5: 80
}

In [123]:
cluster_number = 292
cluster_idxs_is = indices[labels_sorted_by_freq[cluster_number]]
cluster_idxs = [idxs[cluster_idxs_i] for cluster_idxs_i in cluster_idxs_is]

print_context(cluster_idxs[5], context_length=80)

But as citizens, our responsibility is to look beyond the anecdote. We journalists try to make sure that if a tree falls in the forest, it won’t go unnoticed. Still, if we get so taken by the trees that we don’t see the forest, we’ll all be lost.

Rex Smith is editor of the Times Union. Share your thoughts at http[41m://[0m


In [124]:
text = ""

for j, (i, context_length) in enumerate(i_choices.items()):
    tokens = context_tokens(cluster_idxs[i], context_length)
    text += tokens_to_latex(tokens)
    if j != len(i_choices)-1:
        text += "\n\n{\color{gray}\\rule{0.99\linewidth}{0.5pt}}\n\n"

In [125]:
with open("../texts/pile/cluster292.tex", 'w') as f:
    f.write(text)

In [62]:
'\u2009'

'\u2009'

In [122]:
cluster_number = 100
cluster_idxs_is = indices[labels_sorted_by_freq[cluster_number]]
cluster_idxs = [idxs[cluster_idxs_i] for cluster_idxs_i in cluster_idxs_is]

print_context(cluster_idxs[8])

******************************************************
  The ‘‘officially released’’ date that appears near the
beginning of each opinion is the date the opinion will
be published in the Connecticut Law Journal or the
date it was released as a slip opinion. The operative
date for the beginning of all time periods for filing
postopinion motions and petitions for certification is
the ‘‘officially released’’ date appearing in the opinion.
In no event will any such motions be accepted before
the ‘‘officially released’’ date.
  All opinions are subject to modification and technical
correction prior to official publication in the Connecti-
cut Reports and Connecticut Appellate Reports. In the
event of discrepancies between the electronic version
of an opinion and the print version appearing in the
Connecticut Law Journal and subsequently in the Con-
necticut Reports or Connecticut Appellate Reports, the
latest print version is to be considered authoritative.
  The syllabus and procedural his

In [123]:
text = ""

for i in range(6):
    tokens = context_tokens(cluster_idxs[i], 50)
    text += tokens_to_latex(tokens)
    if i != 6-1:
        text += "\n\n{\color{gray}\\rule{0.99\linewidth}{0.5pt}}\n\n"

In [124]:
with open("/om2/user/ericjm/the-everything-machine/texts/pile/cluster100.tex", 'w') as f:
    f.write(text)

In [None]:
from pathlib import Path
def context_str(idx):
    sample, token_idx = get_context(idx)
    prompt = sample["split_by_token"][:token_idx]
    prompt = "".join(prompt)
    token = sample["split_by_token"][token_idx]
    if len(prompt) > 300:
        prompt = "..." + prompt[-300:]
    return prompt + "<|" + token + "|>"

# create directory at `/om/user/ericjm/results/the-everything-machine/clustering-0/full/00500`
clustersdir = Path("/om/user/ericjm/results/the-everything-machine/clustering-0/full_more/00500")
clustersdir.mkdir(parents=True, exist_ok=True) # override if already exists

for i, label in tqdm(list(enumerate(labels_sorted_by_freq))):
    # create subdirectory of cluster
    clusterdir = clustersdir / Path((5-len(str(i)))*"0"+str(i))
    clusterdir.mkdir(exist_ok=True)
    with open(os.path.join(clusterdir, "prompts.txt"), "w") as f:
        for idx_i in indices[label]:
            idx = idxs[idx_i]
            f.write(context_str(idx))
            f.write("\n"+"-"*40+"\n")
    plt.figure()
    for idx_i in indices[label]:
        idx = idxs[idx_i]
        plt.plot(list(range(1000, 144000, 1000)), timeseries19m[idx], color='black', alpha=0.1)
    plt.yscale('log')
    plt.savefig(os.path.join(clusterdir, "trajectories.png"))
    plt.close()