In [1]:
from collections import defaultdict
import pathlib
import os
import sys
import argparse

import numpy as np
# from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
# from evaluate_pile_losses import evaluate_pile_losses
# from evaluate_pile_induction_criterias import evaluate_pile_induction_criterias

# import scipy.linalg
import torch
import torch.nn.functional as F
# import sklearn.cluster

import datasets
from transformers import AutoTokenizer, GPTNeoXForCausalLM

In [2]:
cache_dir = "/om/user/ericjm/quanta-discovery/cache/"
pile_canonical = "/om/user/ericjm/the_pile/the_pile_test_canonical_200k"
model_name = "pythia-70m-v0"
step = 143000
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda:0


In [3]:
# ----- load model and tokenizer -----
assert "pythia" in model_name, "must be a Pythia model"
model = GPTNeoXForCausalLM.from_pretrained(
    f"EleutherAI/{model_name}",
    revision=f"step{step}",
    cache_dir=os.path.join(cache_dir, model_name, f"step{step}"),
).to(device)

tokenizer = AutoTokenizer.from_pretrained(
    f"EleutherAI/{model_name}",
    revision=f"step{step}",
    cache_dir=os.path.join(cache_dir, model_name, f"step{step}"),
)

In [4]:
# ----- load the_pile test set -----
dataset = datasets.load_from_disk(pile_canonical)

def tokenize_sample(sample):
    tokens = tokenizer(sample["text"], return_tensors='pt', 
                        max_length=1024, truncation=True)["input_ids"]
    return {"input_ids": tokens}

starting_indexes = np.array([0] + list(np.cumsum(dataset["preds_len"])))

def loss_idx_to_dataset_idx(idx):
    """given an idx in range(0, 10658635), return
    a sample index in range(0, 20000) and pred-in-sample
    index in range(0, 1023). Note token-in-sample idx is
    exactly pred-in-sample + 1"""
    sample_index = np.searchsorted(starting_indexes, idx, side="right") - 1
    pred_in_sample_index = idx - starting_indexes[sample_index]
    return int(sample_index), int(pred_in_sample_index)

def get_context(idx):
    """given idx in range(0, 10658635), return dataset sample
    and predicted token index within sample, in range(1, 1024)."""
    sample_index, pred_index = loss_idx_to_dataset_idx(idx)
    return dataset[sample_index], pred_index+1

def print_context(idx):
    """
    given idx in range(0, 10658635), print prompt preceding the corresponding
    prediction, and highlight the predicted token.
    """
    sample, token_idx = get_context(idx)
    prompt = sample["split_by_token"][:token_idx]
    prompt = "".join(prompt)
    token = sample["split_by_token"][token_idx]
    print(prompt + "\033[41m" + token + "\033[0m")




In [5]:
losses = torch.load(f"/om/user/ericjm/quanta-discovery/cache/{model_name}/step{step}/20000_docs_10658635_tokens_losses.pt")
filter = torch.load(f"/om/user/ericjm/quanta-discovery/cache/{model_name}/step{step}/20000_docs_10658635_tokens_present_trigram_filter.pt")

In [6]:
# let's choose some tokens to compute gradients for. choose relatively low-loss tokens with value False in filter
low_loss_nontrigram_token_idxs = ((losses < 0.2) & ~filter).nonzero().flatten().tolist()

In [7]:
# but let's actually use the tokens from a couple of clusters -- that way, we can ensure that the similarity
# and dissimilarity of the tokens is preserved
token_idxs, C = torch.load("../results/paper-replication/pythia-70m-v0_143000_0.14426950408889636_50_10000_v1.pt")
_, cluster_labels = torch.load("../results/paper-replication/400_auto_pythia-70m-v0_143000_0.14426950408889636_50_10000_v1.pt")

label_frequencies = defaultdict(int)
for l in cluster_labels:
    label_frequencies[l] += 1

labels_sorted_by_freq = sorted(label_frequencies.keys(), key=lambda k: label_frequencies[k], reverse=True)
# label_permutation = [labels_sorted_by_freq.index(i) for i in labels_sorted_by_freq]
permutation = []
indices = defaultdict(list)
for i, cls in enumerate(cluster_labels):
    indices[cls].append(i)
for cls in labels_sorted_by_freq:
    permutation.extend(indices[cls])

In [None]:
sim_idxs = indices[labels_sorted_by_freq[200]] + indices[labels_sorted_by_freq[201]]
C_part = C[sim_idxs, :][:, sim_idxs]
plt.imshow(C_part, cmap="CMRmap", vmin=-1, vmax=1)

In [9]:
subset_token_idxs = [token_idxs[i] for i in sim_idxs]

In [10]:
def get_flattened_gradient(model, param_subset):
    grads = []
    for name, p in model.named_parameters():
        if name in param_subset:
            grads.append(p.grad)
    return torch.cat([g.flatten() for g in grads])
param_names = [n for n, _ in model.named_parameters()]

highsignal_names = [name for name in param_names if 
                        ('layernorm' not in name) and 
                        ('embed' not in name)]

len_g = sum(model.state_dict()[name].numel() for name in highsignal_names)
S = len(subset_token_idxs)

C = torch.empty((S, S), device=device)

In [11]:
Gs = torch.zeros((S, len_g), device=device)
for i, idx in enumerate(subset_token_idxs):
    model.zero_grad()
    document, l = get_context(idx)
    prompt = document['text']
    tokens = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
    logits = model(**tokens).logits
    targets = tokens.input_ids
    ls = torch.nn.functional.cross_entropy(logits[0, :-1, :], targets[0, 1:], reduction='none')
    ls_l = ls[l-1]
    ls_l.backward()
    g = get_flattened_gradient(model, highsignal_names)
    Gs[i] = g

with torch.no_grad():
    Gs = F.normalize(Gs, p=2, dim=1)
    C = torch.matmul(Gs, Gs.T)
del Gs

In [None]:
plt.imshow(C.detach().cpu().numpy(), cmap='CMRmap', vmin=-1, vmax=1)

##### Now let's try random projections of varying dimension and sparsity

In [13]:
class SparseProjectionOperator:
    """
    Note: I think the sparsity is off by a factor of two here.
    """
    def __init__(self, original_dim, projection_dim, sparsity, seed=0, device='cpu'):
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed) if 'cuda' in device else None
        self.device = torch.device(device)
        self.original_dim = original_dim
        self.lambda_ = original_dim * (1 - sparsity)
        num_entries = torch.poisson(self.lambda_ * torch.ones(projection_dim, device=device)).int()
        max_entries = num_entries.max()
        self.positives = torch.randint(0, original_dim, (projection_dim, max_entries), device=device)
        self.negatives = torch.randint(0, original_dim, (projection_dim, max_entries), device=device)
        masks = torch.arange(max_entries, device=device).expand(projection_dim, max_entries) < num_entries.unsqueeze(-1)
        self.positives = self.positives * masks
        self.negatives = self.negatives * masks
    
    def __call__(self, x):
        assert x.device == self.device, "device mismatch between projection and input"
        assert x.shape[-1] == self.original_dim, "input dimension mismatch"
        y = x[self.positives].sum(-1) - x[self.negatives].sum(-1)
        return y

In [None]:
ds = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000]
sparsities = [0.999999, 0.99999, 0.9999, 0.999, 0.99, 0.9]

Cs = dict()

# d_proj = 1000
# sparsity = 0.999
for sparsity in tqdm(sparsities):
    for d_proj in ds:
        R = SparseProjectionOperator(len_g, d_proj, sparsity, seed=0, device='cuda:0')
        Gs = torch.zeros((S, d_proj), device=device)
        for i, idx in enumerate(subset_token_idxs):
            model.zero_grad()
            document, l = get_context(idx)
            prompt = document['text']
            tokens = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
            logits = model(**tokens).logits
            targets = tokens.input_ids
            ls = torch.nn.functional.cross_entropy(logits[0, :-1, :], targets[0, 1:], reduction='none')
            ls_l = ls[l-1]
            ls_l.backward()
            g = get_flattened_gradient(model, highsignal_names)
            with torch.no_grad():
                Gs[i] = R(g)
        with torch.no_grad():
            Gs = F.normalize(Gs, p=2, dim=1)
            C = torch.matmul(Gs, Gs.T)
        Cs[(d_proj, sparsity)] = C.detach().cpu().numpy()


In [None]:
# not sure if there's a memory leak somewhere
# let's just plot the results we got before the memory error
# 6 x 8
plt.figure(figsize=(15, 12))
i = 0
for sparsity in sparsities:
    for d_proj in ds:
        plt.subplot(6, 9, i+1)
        C = Cs[(d_proj, sparsity)]
        plt.imshow(C, cmap='CMRmap', vmin=-1, vmax=1)
        plt.title(f"d_proj={d_proj}, sparsity={sparsity}", fontsize=6)
        # remove all ticks and labels from x and y axes
        plt.xticks([])
        plt.yticks([])
        i += 1

In [20]:
token_idxs, C = torch.load("../results/paper-replication/pythia-70m-v0_143000_0.14426950408889636_50_10000_v1.pt")

In [None]:
# 126, 60
# 71 is number sequence continuation 

In [None]:
sim_idxs = indices[labels_sorted_by_freq[70]] + indices[labels_sorted_by_freq[59]]
C_part = C[sim_idxs, :][:, sim_idxs]
plt.imshow(C_part, cmap="CMRmap", vmin=-0.2, vmax=0.2)

In [39]:
subset_token_idxs = [token_idxs[i] for i in sim_idxs]
S = len(subset_token_idxs)

In [None]:
ds = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000]
sparsities = [0.999999, 0.99999, 0.9999, 0.999, 0.99, 0.9]

Cs = dict()

for sparsity in tqdm(sparsities):
    for d_proj in ds:
        R = SparseProjectionOperator(len_g, d_proj, sparsity, seed=0, device='cuda:0')
        Gs = torch.zeros((S, d_proj), device=device)
        for i, idx in enumerate(subset_token_idxs):
            model.zero_grad()
            document, l = get_context(idx)
            prompt = document['text']
            tokens = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
            logits = model(**tokens).logits
            targets = tokens.input_ids
            ls = torch.nn.functional.cross_entropy(logits[0, :-1, :], targets[0, 1:], reduction='none')
            ls_l = ls[l-1]
            ls_l.backward()
            g = get_flattened_gradient(model, highsignal_names)
            with torch.no_grad():
                Gs[i] = R(g)
        with torch.no_grad():
            Gs = F.normalize(Gs, p=2, dim=1)
            C = torch.matmul(Gs, Gs.T)
        Cs[(d_proj, sparsity)] = C.detach().cpu().numpy()


In [None]:
# not sure if there's a memory leak somewhere
# let's just plot the results we got before the memory error
# 6 x 8
plt.figure(figsize=(15, 12))
i = 0
for sparsity in sparsities:
    for d_proj in ds:
        plt.subplot(6, 9, i+1)
        C = Cs[(d_proj, sparsity)]
        plt.imshow(C, cmap='CMRmap', vmin=-0.2, vmax=0.2)
        plt.title(f"d_proj={d_proj}, sparsity={sparsity}", fontsize=6)
        # remove all ticks and labels from x and y axes
        plt.xticks([])
        plt.yticks([])
        i += 1

In [43]:
token_idxs, C = torch.load("../results/paper-replication/pythia-70m-v0_143000_0.14426950408889636_50_10000_v1.pt")
C_part = C[sim_idxs, :][:, sim_idxs]

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(C_part, cmap="CMRmap", vmin=-0.2, vmax=0.2)
plt.title("full gradient")

plt.subplot(1, 2, 2)
d_proj = 2000
sparsity = 0.999
C_sparse = Cs[(d_proj, sparsity)]
plt.imshow(C_sparse, cmap="CMRmap", vmin=-0.2, vmax=0.2)
plt.title(f"d_proj={d_proj}, sparsity={sparsity}")

In [48]:
idx

5854332

In [57]:
idxs = list(range(10000, 20000, 100))
histograms = []
for idx in tqdm(idxs):
    model.zero_grad()
    document, l = get_context(idx)
    prompt = document['text']
    tokens = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
    logits = model(**tokens).logits
    targets = tokens.input_ids
    ls = torch.nn.functional.cross_entropy(logits[0, :-1, :], targets[0, 1:], reduction='none')
    ls_l = ls[l-1]
    ls_l.backward()
    g = get_flattened_gradient(model, highsignal_names)
    g_np = g.detach().cpu().numpy()
    hist, bin_edges = np.histogram(g_np, bins=10000)
    centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    histograms.append((centers, hist))

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
for i in range(40):
    centers, hist = histograms[i]
    plt.plot(centers, hist, alpha=0.3)
    plt.yscale('log')
    plt.xlabel("gradient component value")
    plt.ylabel("frequency")
plt.xlim(-10, 10)

In [None]:
def sparse_projection_matrix_sloppy(original_dim, projection_dim, sparsity):
    """
    For speed, uses poisson distribution to approximate binomial distribution 
    and also doesn't check for collisions in the indices -- samples indices
    with replacement.

    Args:
        original_dim (int): input dimension of projection
        projection_dim (int): output dimension of projection
        sparsity (float): probability that an entry is zero
    """
    row_indices = []
    col_indices = []
    values = []
    lambda_ = original_dim * (1 - sparsity)
    for i in range(projection_dim):
        rowi_num_entries = int(torch.poisson(torch.tensor([lambda_])).item())
        rowi_col_indices = torch.randint(0, original_dim-1, (rowi_num_entries,))
        rowi_values = 2*torch.randint(0, 2, (rowi_num_entries,)) - 1
        row_indices.append(torch.ones((rowi_num_entries,)) * i)
        col_indices.append(rowi_col_indices)
        values.append(rowi_values)
    row_indices = torch.cat(row_indices)
    col_indices = torch.cat(col_indices)
    indices = torch.stack([row_indices, col_indices])
    values = torch.cat(values)
    return torch.sparse_coo_tensor(indices, values, (projection_dim, original_dim))

In [None]:
def sparse_projection_operator(original_dim, projection_dim, sparsity, seed=0):
    """
    Closure that returns a function that computes the projection. Instead
    of using a sparse matrix implementation, this simply stores a list
    of +1 indices and -1 indices for each row.
    """
    torch.manual_seed(seed)
    lambda_  = original_dim * (1 - sparsity)
    num_entries = torch.poisson(lambda_ * torch.ones(projection_dim)).int()
    max_entries = num_entries.max()
    positives = torch.randint(0, original_dim, (projection_dim, max_entries))
    negatives = torch.randint(0, original_dim, (projection_dim, max_entries))
    masks = torch.arange(max_entries).expand(projection_dim, max_entries) < num_entries.unsqueeze(-1)
    positives = positives * masks
    negatives = negatives * masks
    def project(x):
        y = x[positives].sum(-1) - x[negatives].sum(-1)
        return y
    return project

In [None]:
# let's not worry too much about efficiency right now
# we want to perform a sparse random projection, so the entries are just +1, (0), and -1

def get_indices(n_total, n_subset):
    indices_set = set()
    while len(indices_set) < n_subset:
        indices_set.add(torch.randint(0, n_total, (1,)).item())
    subset_indices = torch.tensor(list(indices_set))
    return subset_indices

def sparse_projection_matrix(original_dim, projection_dim, sparsity):
    """
    Args:
        original_dim (int): input dimension of projection
        projection_dim (int): output dimension of projection
        sparsity (float): probability that an entry is zero
    """
    row_indices = []
    col_indices = []
    values = []
    for i in range(projection_dim):
        rowi_num_entries = int(torch.distributions.Binomial(original_dim, 1-sparsity).sample().item())
        rowi_col_indices = get_indices(original_dim, rowi_num_entries)
        rowi_values = 2*torch.randint(0, 2, (rowi_num_entries,)) - 1
        row_indices.append(torch.ones((rowi_num_entries,)) * i)
        col_indices.append(rowi_col_indices)
        values.append(rowi_values)
    row_indices = torch.cat(row_indices)
    col_indices = torch.cat(col_indices)
    indices = torch.stack([row_indices, col_indices])
    values = torch.cat(values)
    return torch.sparse_coo_tensor(indices, values, (projection_dim, original_dim))

def sparse_projection_matrix_sloppy(original_dim, projection_dim, sparsity):
    """
    For speed, uses poisson distribution to approximate binomial distribution 
    and also doesn't check for collisions in the indices -- samples indices
    with replacement.

    Args:
        original_dim (int): input dimension of projection
        projection_dim (int): output dimension of projection
        sparsity (float): probability that an entry is zero
    """
    row_indices = []
    col_indices = []
    values = []
    lambda_ = original_dim * (1 - sparsity)
    for i in range(projection_dim):
        rowi_num_entries = int(torch.poisson(torch.tensor([lambda_])).item())
        rowi_col_indices = torch.randint(0, original_dim-1, (rowi_num_entries,))
        rowi_values = 2*torch.randint(0, 2, (rowi_num_entries,)) - 1
        row_indices.append(torch.ones((rowi_num_entries,)) * i)
        col_indices.append(rowi_col_indices)
        values.append(rowi_values)
    row_indices = torch.cat(row_indices)
    col_indices = torch.cat(col_indices)
    indices = torch.stack([row_indices, col_indices])
    values = torch.cat(values)
    return torch.sparse_coo_tensor(indices, values, (projection_dim, original_dim))

def sparse_projection_matrix_very_sloppy(original_dim, projection_dim, sparsity):
    """
    Args:
        original_dim (int): input dimension of projection
        projection_dim (int): output dimension of projection
        sparsity (float): probability that an entry is zero
    """
    row_indices = []
    col_indices = []
    values = []
    for i in range(projection_dim):
        rowi_num_entries = int((1-sparsity) * original_dim)
        rowi_col_indices = torch.randint(0, original_dim, (rowi_num_entries,))
        rowi_values = 2*torch.randint(0, 2, (rowi_num_entries,)) - 1
        row_indices.append(torch.ones((rowi_num_entries,)) * i)
        col_indices.append(rowi_col_indices)
        values.append(rowi_values)
    row_indices = torch.cat(row_indices)
    col_indices = torch.cat(col_indices)
    indices = torch.stack([row_indices, col_indices])
    values = torch.cat(values)
    return torch.sparse_coo_tensor(indices, values, (projection_dim, original_dim))