<a href="https://colab.research.google.com/github/jessicamarycooper/Backwards/blob/main/clustering_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/Drive')

Mounted at /content/Drive


In [2]:
!pip install update transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils
import torch
import random
from matplotlib import pyplot as plt
%matplotlib inline
from IPython import display
import numpy as np
from tqdm import tqdm
from time import time
import json
utils.logging.set_verbosity_error()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_len= 50257
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", padding_side='left')
model = GPT2LMHeadModel.from_pretrained("gpt2",pad_token_id=tokenizer.eos_token_id, vocab_size=vocab_len).to(device)
model.eval()
# the model will be in evaluation, not training, mode throughout
word_embeddings = model.transformer.wte.weight.to(device)  
embedding_dim = word_embeddings.shape[-1] 
# 'word_embeddings' tensor gives emeddings for each token in the vocab for this model,
# has shape (vocab_len, embedding_dimension) which in this case = (50257, 768)


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting update
  Downloading update-0.0.1-py2.py3-none-any.whl (2.9 kB)
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m79.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting style==1.1.0
  Downloading style-1.1.0-py2.py3-none-any.whl (6.4 kB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m113.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.4/182.4 KB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, style, 

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

In [3]:
def normalise(x, min_max=[]):     
# normalises values of (array or tensor) x according to first (min) and second (max) values in list min_max. 
# This effectively defaults to [0,1] if the list doesn't contain exactly two elements. 
# The original code threw an error if min_max had length 1, so it's been changed slightly.

# First normalise x to [0,1]
    rnge = x.max() - x.min()
    if rnge > 0:
        x = (x - x.min())/rnge

# Now, if there's a min and max given in min_max list, multiply by difference and add minimum
    if len(min_max) > 1:
        rnge = min_max[1] - min_max[0]
        x = x * rnge + min_max[0]

    return x

def closest_tokens(emb, n=1):      
# This finds the n tokens in the vocabulary that are closest in the embedding space (in terms of Euclidean distance) to a given word embedding (‘emb’).
# Note that here 'emb' may or may not correspond to a token (i.e., it may or may not be a 'legal' embedding).
# Function returns a 4-tuple (list of the n tokens, list of their indices, list of their distances from emb, and list of their embedding vectors)
    torch.cuda.empty_cache()
    dists = torch.linalg.norm(word_embeddings - emb, dim=1)
    sorted_dists, ix = torch.sort(dists)	 
    # sorted_dists is a list of all embedding distances from 'emb', across entire vocab, sorted in increasing order, 
    # ix is a list of their corresponding 'vocab indices'
    tokens = [tokenizer.decode(i) for i in ix[:n]]
    # For each of the first n 'vocab indices' in ix, we decode it into the string version of the corresponding token. 
    # These strings then constitute the list 'tokens'.
    ixs = ix[:n]
    dists = sorted_dists[:n]
    embs = word_embeddings[ixs]  # Each of these n 'embeddings' is a tensor of shape (768,)
    return tokens, ixs, dists, embs  




In [93]:
# attempt to implement k-means algorithm myself ...

def kmeans(num_clusters):

    empty_cluster = True

    # euclidean distance threshold to break out of loop (may not be necessary as 0 seems to work fine)
    threshold = 0

    while empty_cluster == True:    # We're going to loop on random initialisation of centroids until all clusters are non-empty.
        # randomly generate num_clusters centroids, stack these into a tensor of shape (num_clusters, 768)
        # then normalise to vocab embedding span
        centroids = torch.rand(num_clusters, embedding_dim).to(device)
        centroids = normalise(centroids,[word_embeddings.min(dim=0)[0], word_embeddings.max(dim=0)[0]])

        distances = torch.cdist(word_embeddings, centroids, p=2)
        # This will be of shape (vocab_len, num_clusters), recording distances of all token embeddings from each of the centroids.
        closest_distance, closest_centroid = torch.min(distances, dim = -1)
        # These will be of shape (vocab_len,), recording the distance of each token embedding to nearest centroid, and the index of that centroid.

        clusters = []
        # We iterate over the centroids to build a list of clusters (tokens which share a common nearest centroid)
        for i in range(num_clusters):
            mask = closest_centroid == i  # This builds a Boolean mask over the complete set of tokens, True iff a token's nearest centroid is the ith.
            clusters.append(word_embeddings[mask]) 
            # word_embeddings[mask] is a subtensor of the shape (50257, 768) tensor involving only tokens whose nearest centroid is the ith.
            # This subtensor gets appended to the list 'clusters', which has num_clusters elements.

        cluster_sizes = [clusters[i].shape[0] for i in range(num_clusters)]  
        if 0 not in cluster_sizes:
            empty_cluster = False 

    # Now we have a set of num_clusters non-empty clusters, we begin iterating centroid positions:

    centroids_stable = False
    iterations = 0

    while centroids_stable == False:    # We're going to loop until the centroids stop moving.
        iterations += 1
        distances = torch.cdist(word_embeddings, centroids, p=2)
        # This will be of shape (vocab_len, num_clusters), recording distances of all token embeddings from each of the centroids.
        closest_distance, closest_centroid = torch.min(distances, dim = -1)
        # These will be of shape (vocab_len,), for each token embedding recording the distance to the nearest centroid, and the index of that centroid.

        clusters = []
        # We iterate over the centroids to build a list of clusters (tokens which share a common nearest centroid)
        for i in range(centroids.shape[0]):
            mask = closest_centroid == i  # This builds a Boolean mask over the complete set of tokens, True iff a token's nearest centroid is the ith.
            clusters.append(word_embeddings[mask]) 
            # word_embeddings[mask] is a subtensor of the (50257, 768) shape tensor involving only tokens nearest ith centroid.
            # This subtensor gets appended to the list 'clusters', which has num_clusters elements.

        new_centroids = []
        for i in range(num_clusters):
            new_centroids.append(clusters[i].mean(dim=0))
            # clusters[i].mean(dim=0) is the centroid of the set of vectors encoded in the tensor clusters[i]
            # These all get put in a list.

        new_centroids = torch.stack(new_centroids)  # stack the list into a tensor of shape (num_clusters, 768)

        # We now compute the euclidean distance between old and new centroids
        distance = torch.norm(new_centroids - centroids, dim=-1)

        if torch.max(distance) <= threshold:
            centroids_stable == True
            break
        centroids = new_centroids # if we're still outside the distance threshold, keep iterating
    


    for i in range(num_clusters):
        print('Cluster', i, 'contains', len(clusters[i]), 'embeddings.\n')
        #Now we're going to find the nearest 10 tokens in the cluster (give them + indices) and list them all with distances
        #If the cluster size k < 10, we'll just use the nearest k tokens

        k = clusters[i].shape[0]
        if k > 10:
            k = 10
        print('The nearest', k, 'tokens in cluster', i, 'to the cluster centroid:')
        cluster = clusters[i]
        centroid = centroids[i].unsqueeze(0) # adding a dimension
        cluster_distances = torch.norm(cluster - centroid, dim=1)
        top_10, top_10_indices = torch.topk(cluster_distances, k, largest=False, sorted=True)
        for j in range(k):
            word_embedding = clusters[i][top_10_indices[j].item()]
            distances = torch.norm(word_embeddings - word_embedding, dim=1)
            index = torch.argmin(distances)
            print("{:<7}{:35}{:<8}{:<8}{:<11}{:<15}".format('     Token: [', tokenizer.decode(index)+']', 'Index:', index.item(), 'Distance:', top_10[j].item()))
        if i < num_clusters - 1:
            print('')

        #Then we're going to find the nearest 10 tokens in the entire vocab (give them + indices)  and list them all with distances
        print('The nearest 10 tokens in the entire vocabulary to the cluster centroid:')
        centroid = centroids[i].unsqueeze(0) # adding a dimension
        vocab_distances = torch.norm(word_embeddings - centroid, dim=1)
        top_10, top_10_indices = torch.topk(vocab_distances, 10, largest=False, sorted=True)
        for j in range(10):
            word_embedding = word_embeddings[top_10_indices[j].item()]
            distances = torch.norm(word_embeddings - word_embedding, dim=1)
            index = torch.argmin(distances)
            print("{:<7}{:35}{:<8}{:<8}{:<11}{:<15}".format('     Token: [', tokenizer.decode(index)+']', 'Index:', index.item(), 'Distance:', top_10[j].item()))
        if i < num_clusters - 1:
            print('\n')

    return centroids

In [94]:
kmeans(25)

Cluster 0 contains 2640 embeddings.

The nearest 10 tokens in cluster 0 to the cluster centroid:
     Token: [Although]                          Index:  7003    Distance:  2.1066014766693115
     Token: [Even]                              Index:  6104    Distance:  2.108825922012329
     Token: [While]                             Index:  3633    Distance:  2.132721424102783
     Token: [Despite]                           Index:  8332    Distance:  2.165256977081299
     Token: [Several]                           Index:  14945   Distance:  2.172441244125366
     Token: [There]                             Index:  1858    Distance:  2.174029588699341
     Token: [Some]                              Index:  4366    Distance:  2.2021584510803223
     Token: [What]                              Index:  2061    Distance:  2.203798294067383
     Token: [That]                              Index:  2504    Distance:  2.220250129699707
     Token: [When]                              Index:  2215    

tensor([[ 0.0331, -0.0527,  0.1413,  ...,  0.0145,  0.0307,  0.0059],
        [-0.0265, -0.0292,  0.1175,  ...,  0.0493,  0.0241,  0.0274],
        [-0.0531, -0.0611,  0.1549,  ...,  0.0364,  0.1804,  0.1168],
        ...,
        [ 0.0083, -0.0652,  0.1819,  ...,  0.0363,  0.0038,  0.0430],
        [ 0.0059, -0.0110,  0.1039,  ...,  0.0677,  0.0156, -0.0135],
        [-0.0318, -0.0335,  0.1328,  ...,  0.0175,  0.0224,  0.0166]],
       device='cuda:0', grad_fn=<StackBackward0>)

In [4]:
# After getting bogged down in trying to reinvent the wheel, keeping track of recursively reshuffled indices, 
# I asked ChatGPT3 and it wrote this (only the kmeans = KMeans... line had to be adjusted due to CPU/GPU issues).
# It's using Scikit-learn or sklearn, which is a popular python package for this kind of thing.

# Unless we can get this to run on some faster hardware, this is futile. It will take all day to find these centroids.
# ChatGPT suggested using cuml or rapidsai, but pip installation failed for both of these.


def kkmeans(k):

    from sklearn.cluster import KMeans

    # k = number of clusters

    # Apply k-means clustering
    print('Applying k-means clustering...')
    kmeans = KMeans(n_clusters=k, random_state=0).fit(word_embeddings.detach().cpu().numpy())
    
    # Get the cluster assignments for each point
    cluster_assignments = kmeans.labels_
    print('Getting the cluster assignments for each point...')

    # Count the number of points in each cluster
    cluster_sizes = np.bincount(cluster_assignments)
    print('Counting the number of points in each cluster...')

    # Initialize an array to store the new cluster assignments
    new_assignments = np.zeros(cluster_assignments.shape)
    print('Initializing an array to store the new cluster assignments...')

    # Target cluster size
    target_size = len(word_embeddings) // k
    print('Target cluster size =', target_size)

    # Initialize a list to store the indices of the points in each cluster
    clusters = [[] for _ in range(k)]
    for i, c in enumerate(cluster_assignments):
        clusters[c].append(i)

    # Reassign points to achieve equal cluster sizes
    for c in range(k):
        while len(clusters[c]) > target_size:
            # Find the point in the current cluster that is closest to another cluster
            min_distance = float('inf')
            min_point = None
            min_cluster = None
            for i in clusters[c]:
                print('Working on cluster ', c, ' size =', len(clusters[c]), '; token embedding: ', i)
                point = word_embeddings[i]
                for j in range(k):
                    if j == c:
                        continue
                    centroid = kmeans.cluster_centers_[j]
                    distance = torch.norm(point.to(device) - torch.from_numpy(centroid).to(device))
                    if distance < min_distance:
                        min_distance = distance
                        min_point = i
                        min_cluster = j
            # Move the point to the closest cluster
            clusters[c].remove(min_point)
            clusters[min_cluster].append(min_point)
            new_assignments[min_point] = min_cluster

    # Get the new cluster assignments
    kmeans.labels_ = new_assignments

    # Get the centroids of each cluster
    centroids = kmeans.cluster_centers_
    centroids = torch.from_numpy(centroids)

    # return the centroids

In [None]:
kkmeans(25)

In [None]:
# This runs 500 batches of 50 and keeps track of the most common closest tokens to centroid embeddings, and how many appeareances they make
token_counts = torch.zeros(vocab_len)
for j in range(500):
    print('batch', j)
    centroids = kmeans(50)
    for i in range(50):
        token_counts[closest_tokens(centroids[i])[1].item()] +=1

values, indices = torch.sort(token_counts, descending=True)
print(indices[:50], values[:50])

In [None]:
most_common_tokens_idxs = [30212,   187,   195,   216,   182,   179,   213, 39820,   199,   124,
          208,   125, 23090,   554, 30208,   281,  3607,  7003,   192, 37528,
        15524,   217, 39752, 42089,   183,   818,   210,   201,   209,   207,
          211,   206,  1026,   189,   190,  1315,   219,   205,   212,   203,
          287,   188, 30898, 45544, 14827,   218, 30897,   202,   181, 30905]
most_common_tokens = [tokenizer.decode(i) for i in most_common_tokens_idxs]
print(most_common_tokens)
print(word_embeddings[30212])