In [1]:
from collections import defaultdict
import pathlib
from pathlib import Path
import os
import sys
import argparse

from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import sklearn.cluster

from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets

In [2]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M", cache_dir="data/")
dataset = datasets.load_from_disk("data/tinystories_tokenized")



In [3]:
starting_indexes = np.array([0] + list(np.cumsum(dataset["preds_len"])))
def loss_idx_to_dataset_idx(idx):
    """given an idx in range(0, 10658635), return
    a sample index in range(0, 20000) and pred-in-sample
    index in range(0, 1023). Note token-in-sample idx is
    exactly pred-in-sample + 1"""
    sample_index = np.searchsorted(starting_indexes, idx, side="right") - 1
    pred_in_sample_index = idx - starting_indexes[sample_index]
    return int(sample_index), int(pred_in_sample_index)

def get_context(idx):
    """given idx in range(0, 10658635), return dataset sample
    and predicted token index within sample, in range(1, 1024)."""
    sample_index, pred_index = loss_idx_to_dataset_idx(idx)
    return dataset[sample_index], pred_index+1

def print_context(idx):
    """
    given idx in range(0, 10658635), print prompt preceding the corresponding
    prediction, and highlight the predicted token.
    """
    sample, token_idx = get_context(idx)
    prompt = sample["split_by_token"][:token_idx]
    prompt = "".join(prompt)
    token = sample["split_by_token"][token_idx]
    print(prompt + "\033[41m" + token + "\033[0m")


In [4]:
matrix_path = "data/C-2.pt"
num_clusters = 400
eigen_tol = "auto"
n_init = 30
random_state = 0

In [5]:
token_idxs, C = torch.load(matrix_path)
C = C.cpu().numpy()
C = np.clip(C, -1, 1)
C = 1 - np.arccos(C) / np.pi

In [6]:
np.random.seed(random_state)
clusters_labels = sklearn.cluster.SpectralClustering(n_clusters=num_clusters, 
                                                    affinity='precomputed',
                                                    eigen_tol=eigen_tol,
                                                    n_init=n_init,
                                                    random_state=random_state,
                                                    assign_labels='kmeans').fit_predict(C)

In [7]:
label_frequencies = defaultdict(int)
for l in clusters_labels:
    label_frequencies[l] += 1

labels_sorted_by_freq = sorted(label_frequencies.keys(), key=lambda k: label_frequencies[k], reverse=True)
# label_permutation = [labels_sorted_by_freq.index(i) for i in labels_sorted_by_freq]
permutation = []
indices = defaultdict(list)
for i, cls in enumerate(clusters_labels):
    indices[cls].append(i)
for cls in labels_sorted_by_freq:
    permutation.extend(indices[cls])

clusters_data = defaultdict(list)
for i, label in tqdm(list(enumerate(labels_sorted_by_freq)), desc="Finding contexts"):
    for idx_i in indices[label]:
        idx = token_idxs[idx_i]
        doc, token_idx_within_doc = get_context(idx)
        tokens = doc["split_by_token"]
        clusters_data[i].append((tokens, token_idx_within_doc))
torch.save((clusters_data, clusters_labels), f"data/{num_clusters}clusters-2.pt")

Finding contexts:   0%|          | 0/400 [00:00<?, ?it/s]