In [1]:
from google.colab import drive
drive.mount('/content/Drive', force_remount=True)

Mounted at /content/Drive


In [2]:

import torch
!pip install transformers
!pip install wandb
import numpy as np
from IPython.display import clear_output 
from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils, AutoTokenizer, AutoModelForCausalLM


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.13.9-py2.py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m25.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.30-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.0/184.0 KB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sentry-s

In [77]:
model_name='gpt-j'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if 'gpt-j' in model_name:
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
    model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B").to(device)

else:
    tokenizer = GPT2Tokenizer.from_pretrained(model_name, padding_side='left')
    model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id).to(device)

embeddings = model.transformer.wte.weight.to(device)

    

Downloading:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/930 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/24.2G [00:00<?, ?B/s]

KeyboardInterrupt: ignored

In [None]:
def cos_sim(A, B, dim=1, eps=1e-8):
    #https://stackoverflow.com/a/72369507
      numerator = A @ B.T
      A_l2 = torch.mul(A, A).sum(axis=dim)
      B_l2 = torch.mul(B, B).sum(axis=dim)
      denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps))
      return torch.div(numerator, denominator)


def cluster_info(clusters, tokenizer, get_centroid_tokens=True, get_vocab_tokens=False):
    for i, c in enumerate(clusters):
        print('Cluster', i, 'contains', c.shape[0], 'embeddings.\n')
        #Now we're going to find the nearest 20 tokens in the cluster (give them + indices) and list them all with distances
        #If the cluster size k < 20, we'll just use the nearest k tokens
        if get_centroid_tokens:
            k = c.shape[0]
            if k > 20:
                k = 20
            print('The nearest', k, 'tokens in cluster', i, 'to the cluster centroid:')
            centroid = c.mean(dim=0).unsqueeze(0) # adding a dimension
            cluster_distances = 1-cos_sim(c, centroid).squeeze(-1)


            top_k, top_k_indices = torch.topk(cluster_distances, k, largest=False, sorted=True)
            for j in range(k):
                embedding = clusters[i][top_k_indices[j].item()]
                distances = 1-cos_sim(embeddings,embedding.unsqueeze(0)).squeeze(-1)

                index = torch.argmin(distances)
                print("{:<7}{:35}{:<8}{:<8}{:<11}{:<15}".format('     Token:', repr(tokenizer.decode(index)), 'Index:', index.item(), 'Distance:', top_k[j].item()))
            if i < len(clusters) - 1:
                print('')

        if get_vocab_tokens:

            #Then we're going to find the nearest k tokens in the entire vocab (give them + indices)  and list them all with distances
            print('The nearest', k, 'tokens in the entire vocabulary to the cluster centroid:')
            vocab_distances = 1-cos_sim(embeddings,centroid).squeeze(-1)

            top_k, top_k_indices = torch.topk(vocab_distances, k, largest=False, sorted=True)
            for j in range(k):
                embedding = embeddings[top_k_indices[j].item()]
                distances = 1-cos_sim(embeddings,embedding.unsqueeze(0)).squeeze(-1)

                index = torch.argmin(distances)
                print("{:<7}{:35}{:<8}{:<8}{:<11}{:<15}".format('     Token:', repr(tokenizer.decode(index)), 'Index:', index.item(), 'Distance:', top_k[j].item()))
            if i < len(clusters) - 1:
                print('\n')

    return



# function to produce clusters (equal-sized if constrain_size = True)
# threshold controls when k-means centroid iteration stops
# if you don't seed, it's random each time (if you do, it's reproducible)
def kkmeans(embeddings, num_clusters, threshold=0, max_iter=300, seed=-1, constrain_size=True):
    if seed != -1:
        torch.manual_seed(seed) 
    cluster_size = embeddings.shape[0]//num_clusters
    # initial centroids is a set of random token embeddings (one for each cluster)
    centroids = embeddings[torch.randperm(embeddings.shape[0])[:num_clusters]]

    movement = 9999  #this will be used in each iteration step as mean centroid movement distance
    i = 0

    while movement > threshold and i < max_iter: 
        i += 1

        # (vocab_len, num_clusters) Euclidean distances of all token embeddings from each of the centroids.
        distances = 1-cos_sim(embeddings, centroids)
        
        #(vocab_len, num_cluster), for each token embedding recording the sorted distances to each centroid, and the corresponding sorted centroid indexes.
        closest_distance, closest_centroid = torch.min(distances, dim=-1)
        clusters = [embeddings[(closest_centroid==i)] for i in range(num_clusters)]

        new_centroids = torch.stack([c.mean(dim=0) for c in clusters])
        movement = torch.norm(new_centroids - centroids, dim=-1).mean()
        centroids = new_centroids
        

        if constrain_size:
            sizes, sizes_ix = torch.sort(torch.tensor([c.shape[0] for c in clusters]), descending=True)
            sorted_clusters = [clusters[ci] for ci in sizes_ix]

            for cluster_ix in range(num_clusters-1):
                if sizes[cluster_ix] > cluster_size:  # if a cluster is larger than target size

                    #get extra embeddings
                    spare_embeddings = sorted_clusters[cluster_ix][cluster_size:]
                    #truncate cluster at cluster_size
                    sorted_clusters[cluster_ix] = sorted_clusters[cluster_ix][:cluster_size]

                    # redistribute extra embeddings
                    # get other centroids
                    remaining_centroids = torch.stack([ci.mean(dim=0) for ci in sorted_clusters[cluster_ix+1:]])

                    # calculate distance from extra embeddings to other centroids
                    spare_distances = 1 - cos_sim(spare_embeddings, remaining_centroids)

                    #get closest remianing centroid for each extra embedding
                    closest_spare_dist, closest_spare_centroid = torch.min(spare_distances, dim=-1)
                    
                    #update clusters
                    for ci in range(num_clusters-cluster_ix-1):
                        sorted_clusters[cluster_ix+ci+1] = torch.cat([sorted_clusters[cluster_ix+ci+1], spare_embeddings[closest_spare_centroid==ci]])
                    
                    clusters = sorted_clusters
                    sizes = torch.tensor([c.shape[0] for c in clusters])

    centroids = torch.stack([c.mean(dim=0) for c in clusters])
    print([c.shape[0] for c in clusters])
    return clusters, centroids




In [None]:
clusters, centroids = kkmeans(embeddings, 20, seed=2)

In [None]:
cluster_info(clusters, tokenizer, get_centroid_tokens=True, get_vocab_tokens=False)
