<a href="https://colab.research.google.com/github/jessicamarycooper/Backwards/blob/main/clustering_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/Drive')

Mounted at /content/Drive


In [4]:
!pip install update transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils
import torch
import random
from matplotlib import pyplot as plt
%matplotlib inline
from IPython import display
import numpy as np
from tqdm import tqdm
from time import time
import json
utils.logging.set_verbosity_error()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_len= 50257
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", padding_side='left')
model = GPT2LMHeadModel.from_pretrained("gpt2",pad_token_id=tokenizer.eos_token_id, vocab_size=vocab_len).to(device)
model.eval()
# the model will be in evaluation, not training, mode throughout
word_embeddings = model.transformer.wte.weight.to(device)  
embedding_dim = word_embeddings.shape[-1] 
# 'word_embeddings' tensor gives emeddings for each token in the vocab for this model,
# has shape (vocab_len, embedding_dimension) which in this case = (50257, 768)


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [5]:
common_centroid_nearest_token_idxs = [30212, 187, 195, 216, 182, 179, 213, 39820, 199, 124,  208, 125, 23090, 554, 30208, 281,  3607,  7003, 192, 37528,  15524, 217, 39752, 42089, 183, 818, 210, 201, 209, 207, 211, 206,  1026, 189, 190,  1315, 219, 205, 212, 203,  287, 188, 30898, 45544, 14827, 218, 30897, 202, 181, 30905]
# Having run 500 batches of 50 and used kmeans, these were the 50 most common closest token embeddings to centroids that were found (from most to least common).


In [6]:
def normalise(x, min_max=[]):     
# normalises values of (array or tensor) x according to first (min) and second (max) values in list min_max. 
# This effectively defaults to [0,1] if the list doesn't contain exactly two elements. 
# The original code threw an error if min_max had length 1, so it's been changed slightly.

# First normalise x to [0,1]
    rnge = x.max() - x.min()
    if rnge > 0:
        x = (x - x.min())/rnge

# Now, if there's a min and max given in min_max list, multiply by difference and add minimum
    if len(min_max) > 1:
        rnge = min_max[1] - min_max[0]
        x = x * rnge + min_max[0]

    return x

def closest_tokens(emb, n=1):      
# This finds the n tokens in the vocabulary that are closest in the embedding space (in terms of Euclidean distance) to a given word embedding (‘emb’).
# Note that here 'emb' may or may not correspond to a token (i.e., it may or may not be a 'legal' embedding).
# Function returns a 4-tuple (list of the n tokens, list of their indices, list of their distances from emb, and list of their embedding vectors)
    torch.cuda.empty_cache()
    dists = torch.linalg.norm(word_embeddings - emb, dim=1)
    sorted_dists, ix = torch.sort(dists)	 
    # sorted_dists is a list of all embedding distances from 'emb', across entire vocab, sorted in increasing order, 
    # ix is a list of their corresponding 'vocab indices'
    tokens = [tokenizer.decode(i) for i in ix[:n]]
    # For each of the first n 'vocab indices' in ix, we decode it into the string version of the corresponding token. 
    # These strings then constitute the list 'tokens'.
    ixs = ix[:n]
    dists = sorted_dists[:n]
    embs = word_embeddings[ixs]  # Each of these n 'embeddings' is a tensor of shape (768,)
    return tokens, ixs, dists, embs  




In [7]:
# attempt to implement k-means algorithm myself ...

def kmeans(num_clusters):

    empty_cluster = True

    # euclidean distance threshold to break out of loop (may not be necessary as 0 seems to work fine)
    threshold = 0

    while empty_cluster == True:    # We're going to loop on random initialisation of centroids until all clusters are non-empty.
        # randomly generate num_clusters centroids, stack these into a tensor of shape (num_clusters, 768)
        # then normalise to vocab embedding span
        centroids = torch.rand(num_clusters, embedding_dim).to(device)
        centroids = normalise(centroids,[word_embeddings.min(dim=0)[0], word_embeddings.max(dim=0)[0]])

        distances = torch.cdist(word_embeddings, centroids, p=2)
        # This will be of shape (vocab_len, num_clusters), recording distances of all token embeddings from each of the centroids.
        closest_distance, closest_centroid = torch.min(distances, dim = -1)
        # These will be of shape (vocab_len,), recording the distance of each token embedding to nearest centroid, and the index of that centroid.

        clusters = []
        # We iterate over the centroids to build a list of clusters (tokens which share a common nearest centroid)
        for i in range(num_clusters):
            mask = closest_centroid == i  # This builds a Boolean mask over the complete set of tokens, True iff a token's nearest centroid is the ith.
            clusters.append(word_embeddings[mask]) 
            # word_embeddings[mask] is a subtensor of the shape (50257, 768) tensor involving only tokens whose nearest centroid is the ith.
            # This subtensor gets appended to the list 'clusters', which has num_clusters elements.

        cluster_sizes = [clusters[i].shape[0] for i in range(num_clusters)]  
        if 0 not in cluster_sizes:
            empty_cluster = False 

    # Now we have a set of num_clusters non-empty clusters, we begin iterating centroid positions:

    centroids_stable = False
    iterations = 0

    while centroids_stable == False:    # We're going to loop until the centroids stop moving.
        iterations += 1
        distances = torch.cdist(word_embeddings, centroids, p=2)
        # This will be of shape (vocab_len, num_clusters), recording distances of all token embeddings from each of the centroids.
        closest_distance, closest_centroid = torch.min(distances, dim = -1)
        # These will be of shape (vocab_len,), for each token embedding recording the distance to the nearest centroid, and the index of that centroid.

        clusters = []
        # We iterate over the centroids to build a list of clusters (tokens which share a common nearest centroid)
        for i in range(centroids.shape[0]):
            mask = closest_centroid == i  # This builds a Boolean mask over the complete set of tokens, True iff a token's nearest centroid is the ith.
            clusters.append(word_embeddings[mask]) 
            # word_embeddings[mask] is a subtensor of the (50257, 768) shape tensor involving only tokens nearest ith centroid.
            # This subtensor gets appended to the list 'clusters' (num_clusters elements), which has num_clusters elements.

        cluster_sizes_list = [clusters[i].shape[0] for i in range(num_clusters)]  
        cluster_sizes = torch.tensor(cluster_sizes_list)  # tensorise the list
        _, sort_indices = torch.sort(cluster_sizes, descending=True)
        sorted_clusters = [clusters[i] for i in sort_indices]  # list of clusters in decreasing size order
        sorted_centroids = [centroids[i] for i in sort_indices] # corresponding list of nearest centroids to the sorted clusters

        # Now we redistribute embeddings between clusters until clusters are of ~equal size [see https://stackoverflow.com/a/5452702]:
        # We want them all to have size as close as possible to vocab_len/num_clusters.

        for i in range(num_clusters - 1):  # iterating through all sorted_clusters except the last/smallest
            while abs(sorted_clusters[i].shape[0] - vocab_len//num_clusters) > 1:
                print('abs(sorted_clusters[i].shape[0] - vocab_len//num_clusters) = ', abs(sorted_clusters[i].shape[0] - vocab_len//num_clusters))
                dists = []
                for j in range(i+1, num_clusters):    # ranges from i + 1 to num_clusters - 1
                    diff = sorted_clusters[i] - sorted_centroids[j] # tensors like sorted_clusters[i], but where each row records vector difference from sorted_centroids[j]
                    dists.append(torch.norm(diff, dim=1)) # append to list 'dists' a tensor of norms of those vector differences
                centroid_distances = torch.stack(dists, dim=1)  # stack these together to make a tensor of shape (sorted_clusters[i].shape[0], num_clusters - i - 1)
                print('cluster index i = ', i, '; centroid_distances has shape = ', centroid_distances.shape)
                min_idx = torch.argmin(centroid_distances)
                indices = torch.tensor(np.unravel_index(min_idx.item(), centroid_distances.shape))
                row_idx, col_idx = indices

                # We're getting an error here because the clusters are being redistributed wrongly, we're ending up with the smallest cluster getting shrunk to size 0.
                # Clusters should be moving from largest to smallest clusters, but that's not happening.

                removed_row = sorted_clusters[i][row_idx]
                print('removing row ', row_idx.item(), ' from cluster ', i, ' appending it to cluster', col_idx.item() + i + 1)
                sorted_clusters[i] = torch.cat([sorted_clusters[i][:row_idx], sorted_clusters[i][row_idx+1:]], dim=0)
                sorted_clusters[col_idx + i + 1] = torch.cat([sorted_clusters[col_idx + i + 1], removed_row.unsqueeze(0)], dim=0)
                dists.clear() # clear list

            # Now we want to resort the remaining clusters indexed (i+1, ... num_clusters - 1)
            rem_cluster_sizes_list = [sorted_clusters[h].shape[0] for h in range(i+1, num_clusters)]  
            rem_cluster_sizes = torch.tensor(rem_cluster_sizes_list)  # tensorise the list
            _, rem_sort_indices = torch.sort(rem_cluster_sizes, descending=True) # tensor of index positions
            sorted_rem_clusters = [sorted_clusters[h] for h in rem_sort_indices]  # list of remaining clusters in decreasing size order
            sorted_rem_centroids = [sorted_centroids[h] for h in rem_sort_indices] # corresponding list of nearest remaining centroids to the sorted clusters
            
            # We now need to adjust sort_indices which is tracking which cluster is which and will be used for unsorting once we've 
            # completed the redistribution
            new_sort_indices = [h for h in range(i+1)] + [h + i + 1 for h in rem_sort_indices]
            # This builds num_cluster length index set prepended by [0,...,i]
            # this is going to be used to reshuffle sort_indices, but first both lists have to be tensorised
            sort_indices = torch.tensor(sort_indices)
            new_sort_indices = torch.tensor(new_sort_indices)
            sort_indices = new_sort_indices[sort_indices]
            # now convert back to a list
            sort_indices = final_sort_indices.tolist()

            #NOTE THAT THIS REDISTRIBUTION LOOKS LIKE IT COULD TAKE HOURS! BATCH SIZE 25 took almost half an hour to do the sorting
            #AND THAT'S FOR A SINGLE CENTROID ITERATION; WE MAY NEED
            #COULD LOOK FOR WAYS TO SPEED THIS UP, OR COULD RE-STRATEGISE APPROACH TO SETTING UP CENTROIDS
            #ONE WAY MIGHT BE TO ITERATE UNTIL THEY'RE ALL "ROUGHLY" THE SAME SIZE (BIGGEST MINUS SMALLEST < SOME THRESHOLD)
            

        # Now we need to "unsort" the sorted_clusters and sorted_centroids to get everything back in the original order.
        
        unsort_indices = [i for i, x in sorted(enumerate(sort_indices), key=lambda x: x[1])]
        clusters = [sorted_clusters[i] for i in unsort_indices]
        centroids = [sorted_centroids[i] for i in unsort_indices]
        centroids = torch.stack(centroids)  # stack the centroids list into a tensor of shape (num_clusters, 768)

        new_centroids = []
        for i in range(num_clusters):
            new_centroids.append(clusters[i].mean(dim=0))
            # clusters[i].mean(dim=0) is the centroid of the set of vectors encoded in the tensor clusters[i]
            # These all get put in a list.

        new_centroids = torch.stack(new_centroids)  # stack the list into a tensor of shape (num_clusters, 768)

        # We now compute the euclidean distance between old and new centroids
        distance = torch.norm(new_centroids - centroids, dim=-1)
        print('iteration: ', iterations, '\n max distance between old and new centroids:', torch.max(distance).item())
        print('cluster sizes:', cluster_sizes_list) 
        print('number of shuffles', shuffle_steps)

        if torch.max(distance) <= threshold:
            centroids_stable == True
            break
        centroids = new_centroids # if we're still outside the distance threshold, keep iterating
    


    for i in range(len(clusters)):
        print('Cluster', i, ' contains ', len(clusters[i]), ' embeddings.')
    for j in range(num_clusters):
        print('Closest token to centroid ', j,': ', closest_tokens(centroids[j])[0], closest_tokens(centroids[j])[1].item())
    return centroids

In [59]:
# Badly thought-out approach to clustering (just keep randomly re-initialising until the range of cluster sizes is 'small enough', i.e.
# less than some cutoff). Doesn't work.

def roughly_even_clusters(num_clusters):

    too_much_variation = True  
    size_cutoff = 2000 
    
    # Difference in cluster sizes is assumed to be intially :too big". 
    # This algorithm keep regenerating num_clusters centroids until the difference in size 
    # between the biggest and smallest cluster < size_cutoff

    # euclidean distance threshold to break out of cenroid iteration loop (may not be necessary as 0 seems to work fine)
    threshold = 0
    smallest_gap_yet = 100000
    current_max = 2010
    current_min = 2010

    while too_much_variation == True:    # We're going to loop on random initialisation of centroids until size range is acceptable.
        # randomly generate num_clusters centroids, stack these into a tensor of shape (num_clusters, 768)
        # then normalise to vocab embedding span
        centroids = torch.rand(num_clusters, embedding_dim).to(device)
        centroids = normalise(centroids,[word_embeddings.min(dim=0)[0], word_embeddings.max(dim=0)[0]])

        distances = torch.cdist(word_embeddings, centroids, p=2)
        # This will be of shape (vocab_len, num_clusters), recording distances of all token embeddings from each of the centroids.
        closest_distance, closest_centroid = torch.min(distances, dim = -1)
        # These will be of shape (vocab_len,), recording the distance of each token embedding to nearest centroid, and the index of that centroid.

        clusters = []
        # We iterate over the centroids to build a list of clusters (tokens which share a common nearest centroid)
        for i in range(num_clusters):
            mask = closest_centroid == i  # This builds a Boolean mask over the complete set of tokens, True iff a token's nearest centroid is the ith.
            clusters.append(word_embeddings[mask]) 
            # word_embeddings[mask] is a subtensor of the shape (50257, 768) tensor involving only tokens whose nearest centroid is the ith.
            # This subtensor gets appended to the list 'clusters', which has num_clusters elements.

        cluster_sizes = [clusters[i].shape[0] for i in range(num_clusters)]  
        if max(cluster_sizes) - min(cluster_sizes) < size_cutoff:
            too_much_variation = False 
        else:
            if max(cluster_sizes) - min(cluster_sizes) < smallest_gap_yet:
                smallest_gap_yet = max(cluster_sizes) - min(cluster_sizes)
                current_max = max(cluster_sizes)
                current_min = min(cluster_sizes)

            print('span =', max(cluster_sizes) - min(cluster_sizes), '; smallest span yet = ', smallest_gap_yet, '= ', current_max, ' - ', current_min)

    # Now we have a set of num_clusters clusters of roughly equal size, we begin iterating centroid positions:

    centroids_stable = False
    iterations = 0

    while centroids_stable == False:    # We're going to loop until the centroids stop moving.
        iterations += 1
        distances = torch.cdist(word_embeddings, centroids, p=2)
        # This will be of shape (vocab_len, num_clusters), recording distances of all token embeddings from each of the centroids.
        closest_distance, closest_centroid = torch.min(distances, dim = -1)
        # These will be of shape (vocab_len,), for each token embedding recording the distance to the nearest centroid, and the index of that centroid.

        clusters = []
        # We iterate over the centroids to build a list of clusters (tokens which share a common nearest centroid)
        for i in range(centroids.shape[0]):
            mask = closest_centroid == i  # This builds a Boolean mask over the complete set of tokens, True iff a token's nearest centroid is the ith.
            clusters.append(word_embeddings[mask]) 
            # word_embeddings[mask] is a subtensor of the (50257, 768) shape tensor involving only tokens nearest ith centroid.
            # This subtensor gets appended to the list 'clusters' (num_clusters elements), which has num_clusters elements.

        cluster_sizes_list = [clusters[i].shape[0] for i in range(num_clusters)]  
        cluster_sizes = torch.tensor(cluster_sizes_list)  # tensorise the list

        centroids = torch.stack(centroids)  # stack the centroids list into a tensor of shape (num_clusters, 768)

        new_centroids = []
        for i in range(num_clusters):
            new_centroids.append(clusters[i].mean(dim=0))
            # clusters[i].mean(dim=0) is the centroid of the set of vectors encoded in the tensor clusters[i]
            # These all get put in a list.

        new_centroids = torch.stack(new_centroids)  # stack the list into a tensor of shape (num_clusters, 768)

        # We now compute the euclidean distance between old and new centroids
        distance = torch.norm(new_centroids - centroids, dim=-1)
        print('iteration: ', iterations, '\n max distance between old and new centroids:', torch.max(distance).item())
        print('cluster sizes:', cluster_sizes_list) 

        if torch.max(distance) <= threshold:
            centroids_stable == True
            break
        centroids = new_centroids # if we're still outside the distance threshold, keep iterating
    


    for i in range(len(clusters)):
        print('Cluster', i, ' contains ', len(clusters[i]), ' embeddings.')
    for j in range(num_clusters):
        print('Closest token to centroid ', j,': ', closest_tokens(centroids[j])[0], closest_tokens(centroids[j])[1].item())
    return centroids

In [8]:
# After getting bogged down in trying to reinvent the wheel, keeping track of recursively reshuffled indices, 
# I asked ChatGPT3 and it wrote this (only the kmeans = KMeans... line had to be adjusted due to CPU/GPU issues).
# It's using Scikit-learn or sklearn, which is a popular python package for this kind of thing.

# Unless we can get this to run on some faster hardware, this is futile. It will take all day to find these centroids.
# ChatGPT suggested using cuml or rapidsai, but pip installation failed for both of these.


def kkmeans(k):

    from sklearn.cluster import KMeans

    # k = number of clusters

    # Apply k-means clustering
    print('Applying k-means clustering...')
    kmeans = KMeans(n_clusters=k, random_state=0).fit(word_embeddings.detach().cpu().numpy())
    
    # Get the cluster assignments for each point
    cluster_assignments = kmeans.labels_
    print('Getting the cluster assignments for each point...')

    # Count the number of points in each cluster
    cluster_sizes = np.bincount(cluster_assignments)
    print('Counting the number of points in each cluster...')

    # Initialize an array to store the new cluster assignments
    new_assignments = np.zeros(cluster_assignments.shape)
    print('Initializing an array to store the new cluster assignments...')

    # Target cluster size
    target_size = len(word_embeddings) // k
    print('Target cluster size =', target_size)

    # Initialize a list to store the indices of the points in each cluster
    clusters = [[] for _ in range(k)]
    for i, c in enumerate(cluster_assignments):
        clusters[c].append(i)

    # Reassign points to achieve equal cluster sizes
    for c in range(k):
        while len(clusters[c]) > target_size:
            # Find the point in the current cluster that is closest to another cluster
            min_distance = float('inf')
            min_point = None
            min_cluster = None
            for i in clusters[c]:
                print('Working on cluster ', c, ' size =', len(clusters[c]), '; token embedding: ', i)
                point = word_embeddings[i]
                for j in range(k):
                    if j == c:
                        continue
                    centroid = kmeans.cluster_centers_[j]
                    distance = torch.norm(point.to(device) - torch.from_numpy(centroid).to(device))
                    if distance < min_distance:
                        min_distance = distance
                        min_point = i
                        min_cluster = j
            # Move the point to the closest cluster
            clusters[c].remove(min_point)
            clusters[min_cluster].append(min_point)
            new_assignments[min_point] = min_cluster

    # Get the new cluster assignments
    kmeans.labels_ = new_assignments

    # Get the centroids of each cluster
    centroids = kmeans.cluster_centers_
    centroids = torch.from_numpy(centroids)

    # return the centroids

In [None]:
kkmeans(25)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Working on cluster  0  size = 2518 ; token embedding:  3616
Working on cluster  0  size = 2518 ; token embedding:  3619
Working on cluster  0  size = 2518 ; token embedding:  3623
Working on cluster  0  size = 2518 ; token embedding:  3626
Working on cluster  0  size = 2518 ; token embedding:  3629
Working on cluster  0  size = 2518 ; token embedding:  3632
Working on cluster  0  size = 2518 ; token embedding:  3638
Working on cluster  0  size = 2518 ; token embedding:  3641
Working on cluster  0  size = 2518 ; token embedding:  3650
Working on cluster  0  size = 2518 ; token embedding:  3660
Working on cluster  0  size = 2518 ; token embedding:  3661
Working on cluster  0  size = 2518 ; token embedding:  3662
Working on cluster  0  size = 2518 ; token embedding:  3664
Working on cluster  0  size = 2518 ; token embedding:  3667
Working on cluster  0  size = 2518 ; token embedding:  3668
Working on cluster  0  size = 2518 

In [None]:
# This runs 500 batches of 50 and keeps track of the most common closest tokens to centroid embeddings, and how many appeareances they make
token_counts = torch.zeros(vocab_len)
for j in range(500):
    print('batch', j)
    centroids = kmeans(50)
    for i in range(50):
        token_counts[closest_tokens(centroids[i])[1].item()] +=1

values, indices = torch.sort(token_counts, descending=True)
print(indices[:50], values[:50])

In [None]:
most_common_tokens_idxs = [30212,   187,   195,   216,   182,   179,   213, 39820,   199,   124,
          208,   125, 23090,   554, 30208,   281,  3607,  7003,   192, 37528,
        15524,   217, 39752, 42089,   183,   818,   210,   201,   209,   207,
          211,   206,  1026,   189,   190,  1315,   219,   205,   212,   203,
          287,   188, 30898, 45544, 14827,   218, 30897,   202,   181, 30905]
most_common_tokens = [tokenizer.decode(i) for i in most_common_tokens_idxs]
print(most_common_tokens)
print(word_embeddings[30212])