In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install update transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils
import torch
import random
from matplotlib import pyplot as plt
%matplotlib inline
from IPython import display
import numpy as np
from tqdm import tqdm
from time import time
import json
utils.logging.set_verbosity_error()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_len= 50257
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", padding_side='left')
model = GPT2LMHeadModel.from_pretrained("gpt2",pad_token_id=tokenizer.eos_token_id, vocab_size=vocab_len).to(device)
model.eval()
# the model will be in evaluation, not training, mode throughout
word_embeddings = model.transformer.wte.weight.to(device)  
embedding_dim = word_embeddings.shape[-1] 
# 'word_embeddings' tensor gives emeddings for each token in the vocab for this model,
# has shape (vocab_len, embedding_dimension) which in this case = (50257, 768)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting update
  Downloading update-0.0.1-py2.py3-none-any.whl (2.9 kB)
Collecting style==1.1.0
  Downloading style-1.1.0-py2.py3-none-any.whl (6.4 kB)
Installing collected packages: style, update
Successfully installed style-1.1.0 update-0.0.1


Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

In [None]:
# code to produce plots of token embeddings

from torch.nn.functional import normalize
from sklearn.manifold import TSNE

# Normalize the tensor
tensor = normalize(word_embeddings, p=2, dim=1)

# Convert the tensor to a numpy array
tensor = tensor.detach().cpu().numpy()

# Perform t-SNE
tsne = TSNE(n_components=2, random_state=0)
reduced_tensor = tsne.fit_transform(tensor)

plt.rcParams.update({'figure.figsize': (100,100)})

# Plot the reduced tensor
plt.scatter(reduced_tensor[:, 0], reduced_tensor[:, 1], marker='.', color='white')

# Annotate each point with its index
for i, point in enumerate(np.random.permutation(reduced_tensor)[:5026]):
    try:
        if '$' not in repr(tokenizer.decode(i)):
            plt.annotate(repr(tokenizer.decode(i)), (point[0], point[1]), fontsize=15, xytext=(0, 0), textcoords='offset points')
    except:
        pass
plt.show()


In [None]:
# more code to produce plots of token embeddings

plt.rcParams.update({'figure.figsize': (100,100)})

# Plot the reduced tensor
plt.scatter(reduced_tensor[:, 0], reduced_tensor[:, 1], marker='.', color='white')

# Annotate each point with its index
for i, point in enumerate(np.random.permutation(reduced_tensor)[:5026]):
    try:
        if '$' not in repr(tokenizer.decode(i)):
            plt.annotate(repr(tokenizer.decode(i)), (point[0], point[1]), fontsize=15, xytext=(0, 0), textcoords='offset points')
    except:
        pass
plt.show()

In [None]:
def normalise(x, min_max=[]):     
# normalises values of (array or tensor) x according to first (min) and second (max) values in list min_max. 
# This effectively defaults to [0,1] if the list doesn't contain exactly two elements. 
# Note that list can contain an array (which is what happens when it's called to define 'start_input').

# First normalise x values to [0,1]
    rnge = x.max() - x.min()
    if rnge > 0:
        x = (x - x.min())/rnge
# Now, if there's a min and max given in min_max list, multiply x values by (max - ) and add minimum.
    if len(min_max) > 1:
        rnge = min_max[1] - min_max[0]
        x = x * rnge + min_max[0]

    return x


def closest_tokens(emb, n=1):      
# This finds the n tokens in the vocabulary that are closest in the embedding space (in terms of Euclidean distance) to a given word embedding (‘emb’).
# Note that here 'emb' may or may not correspond to a token (i.e., it may or may not be a 'legal' embedding).
# Function returns a 4-tuple (list of the n tokens, list of their indices, list of their distances from emb, and list of their embedding vectors)
    torch.cuda.empty_cache()
    dists = torch.linalg.norm(word_embeddings - emb, dim=1)
    sorted_dists, ix = torch.sort(dists)	 
    # sorted_dists is a list of all embedding distances from 'emb', across entire vocab, sorted in increasing order, 
    # ix is a list of their corresponding 'vocab indices'
    tokens = [tokenizer.decode(i) for i in ix[:n]]
    # For each of the first n 'vocab indices' in ix, we decode it into the string version of the corresponding token. 
    # These strings then constitute the list 'tokens'.
    ixs = ix[:n]
    dists = sorted_dists[:n]
    embs = word_embeddings[ixs]  # Each of these n 'embeddings' is a tensor of shape (768,)
    return tokens, ixs, dists, embs  


def model_emb(inputs_embeds, output_len):
# 'input_embeds' is a tensor of shape (batch_size, input_len, embedding_dim)
# 'output_len' is an integer specifying the number of output tokens to generate
# Note that this function doesn't involve a target output. It simply takes a tensor of input embeddings (based on input length),
# calculates perplexities for that batch of input sequences, and runs the batch of input sequences through GPT2, 
# for each finding next tokens iteratively 'output_len' number of times.
    embs = inputs_embeds   # This is going to get expanded using 'output_embs'
    logits = []
    ixs = []
    input_logits = None
    for i in range(output_len):
        model_out = model(inputs_embeds=embs, return_dict=True)
        # Does a forward pass of GPT2 (or whichever model) on a batch of inputs (given as a tensor 'embs' of embeddings).
        # This 'embs' will expand along its 1st dimension with each iteration.
        # It outputs logits and more (hidden states, attention, etc.) as a dictionary 'model_out',
        # but we'll only be concerned with model_out.logits.

        if i == 0:
            input_logits = model_out.logits 
            # On first pass through loop, we simply use the logits of the model output.
            # That's a tensor of shape (batch_size, input_len, vocab_size) giving logits for each input in each batch.
            # Presumably for each input, this is conditioned on the inputs that preceded it?

        # On every pass throught the loop (including the first), we defined this tensor of shape (batch_size, 1, vocab_size):
        last_logits = model_out.logits[:,-1].unsqueeze(1)  
        # model_out.logits[:,-1] will be a 2D tensor of shape (batch_size, vocab_size), just giving logits for last input/embedding across all batches/tokens.
        # Unsqueezing, we get tensor of shape (batch_size, 1, vocab_size), also giving logits of last input/embedding, but differently formatted.  
        logits.append(last_logits)  # appends last_logits tensor to the 'logits' list 
        ix = torch.argmax(last_logits, dim=-1)  # for each batch, finds the vocab index of the token with the largest logit in last_logits
        ixs.append(ix) # ...and appends this tensor of shape (batch_size,) (containing indices) to the list 'ixs'
        output_embs = word_embeddings[ix]   # for each batch, finds embedding for the token with that index...
        embs = torch.cat([embs, output_embs], dim=1)  #... and concatenates that tensor of embeddings to the 'embs' tensor in the first dimension, before next iteration.

     # When the loop is completed 'embs' will be a tensor containing all of the input and output word embeddings produced by the model,   
     # of shape (batch_size, input_len + output_len, embedding_dim)

    logits = torch.cat(logits, dim=1)   # this converts logits from a list of tensors to a single tensor, by concatenating all of the tensors in the list
                                        # it will have shape (batch_size, output_len, vocab_size)
    perp = perplexity(input_logits)     # 'input_logits' was calculated on first pass through loop where only input embeddings were involved
    return logits, embs, perp          
    # logits has shape (batch_size, output_len, vocab_size),     
    # embs has shape (batch_size, input_len + output_len, embedding_dim),
    # perp has shape (batch_size,)


def perplexity(logits):
    # logits is of shape (batch_size, 'sequence length', vocab_size)
    # for all current calls, 'sequence length' is going to be input_len
    probs, ix = torch.max(torch.softmax(logits, dim=-1), dim=-1)
    # torch.softmax(logits, dim=-1) will also be a tensor of shape (batch_size, 'sequence length', vocab_size), 
    # but where we convert the logits in the last dimension into probabilities via softmax.torch.max() and then pull out the largest of these and its index
    # probs is a tensor that contains the maximum probability for each token in the embedding sequence, shape (batch_size, 'sequence length')
    # ix is a tensor that contains the corresponding indices, also with shape (batch_size, 'sequence length')
    perp = 1/ (torch.prod(probs, dim=-1)**(1/probs.shape[-1])) - 1
    # defines a scalar that's larger with greater uncertainty (so if the probs are small, their product is small, the reciprocal of some power is large)
    # probs.shape[-1] is output_len; the idea of raising the probs product to power 1/output_len is to make perplexities comparable across different output lengths
    return perp


# The key function that optimises for a sequence of input embeddings, given a target_output string:
def optimise_input(epochs=100, 
                   lr=0.1, 
                   rand_after=False,    # Do we re-initialise inputs tensor with random entries when an optimal input is found?
                   w_freq=10,           # logging (write) frequency
                   base_input=None,      # If none, start_inputs will be entirely random; 
                                         # otherwise it will be built by stacking this tensor and then gently "noising" all but the first copies
                   batch_size=1, 
                   input_len=1, 
                   target_output=tokenizer.eos_token,    # Default target output is the "end-of-string" token; this won't generally be used
                   output_len=None,
                   dist_reg=1,       # distance regularisation coefficient
                   perp_reg=0,       # perplexity regularisation coefficient; setting to 0 means perplexity loss isn't a thing
                   plt_loss=False,   # Do we plot loss?
                   loss_type='log_prob_loss', 
                   seed=0,
                   return_early=True,    # finishes if single optimised input is found
                   verbose=0,            # Controls how much info gets logged.
                   lr_decay=False,       # Use learning rate decay? If so, a scheduler gets invoked.
                   noise_coeff = 0.01):     # Introduced for generality in the construction of start_input[1:] below.
    torch.manual_seed(seed)               # sets up PyTorch random number generator

    if plt_loss:
        plt.rcParams.update({'figure.figsize': (40,6)})

    total_losses = []
    losses = []
    dists = []
    perps = []
    optimised_inputs = set()
    done = None

    output_ix = tokenizer.encode(target_output, return_tensors='pt')[0].to(device)
    # output_ix is a 1-D tensor of shape (output_len,) that contains the indices of the tokens in the encoding of the string 'target_output'
    # tokenizer.encode(target_output, return_tensors='pt') is a list containing this one tensor, hence the need for the [0]
    # "return_tensors='pt'" ensures that we get a tensor in PyTorch format

    if output_len == None or output_len < output_ix.shape[0]:    # This won't generally be the case, but if we don't specify output_len (i.e. it's == None), then...
        output_len = output_ix.shape[0]       # ...it will be set to the number of tokens in the encoding of the string 'target_output'

    print('Optimising input of length {} to maximise output logits for "{}"'.format(input_len, target_output))
    # Typically this would print something like 'Optimising input of length 6 to maximise output logits for "KILL ALL HUMANS!"'.

    if base_input == None:
        start_input = torch.rand(batch_size, input_len, word_embeddings.shape[-1]).to(device)
        # If no base_input is provided, we construct start_input as a random tensor... 
        # ...of shape (batch_size, input_len, embedding_dim) (embedding_dim = 768 for this GPT-2 model).
        start_input = normalise(start_input,[word_embeddings.min(dim=0)[0], word_embeddings.max(dim=0)[0]])
        # We normalise this random tensor so that its minimum and maximum values correspond to those in the entire word_embeddings tensor
        # This dispenses with whole swathes of "input space" which contain no legal token embeddings 
        # (we're limiting ourselves to a kind of "hull" defined by the vocab tokens in the embedding space), 
        # which is a sensible place to look for optimised inputs.
    else:
        start_input = base_input.repeat(batch_size, 1, 1)
        # If a base_input was given, it should be of shape (input_len, embedding_dim), 
        # and we build the start_input tensor by stacking 'batch_size' number of copies of this together...

        if batch_size > 1:
            start_input[1:] += (torch.rand_like(start_input[1:]) + torch.full_like(start_input[1:], -0.5)) * noise_coeff
        # ...and if we have more than one element in our batch, we "noise" the rest. 
        # This was originally done using "*=" (multiplying entries by small random numbers)
    
    input = torch.nn.Parameter(start_input, requires_grad=True)
    # input is not a tensor, it's a Parameter object that wraps a tensor and adds additional functionality. 
    # 'input.data' is used below
    
    optimiser = torch.optim.Adam([input], lr=lr)
    # standard optimiser; note that it generally operates on a list of tensors, so we're giving it a list of one tensor; standard learning rate
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=20, cooldown=20, factor=0.5)
    # this is used when loss hasn't improved for 20 timesteps; this scheduler will reduce the lr by a 'factor' of 0.5 when the 
    # validation loss stops improving for 'patience' (here 20) epochs, and will wait 'cooldown' (here 20) epochs before resuming normal operation.

    # now loop across training epochs
    for e in range(epochs):

        logits, emb, perp = model_emb(torch.clamp(input, word_embeddings.min(), word_embeddings.max()), output_len)
        # Does forward pass on a 'clamped' version of the 'input' tensor (done to contain it within the 'hull' of the vocabulary within embedding space).
        # Iterates to produce an output of output_len tokens, returns:
        # 'logits' = tensor of logits for output, of shape (batch_size, output_len, vocab_size)
        # 'emb': tensor of embeddings for input+output, of shape (batch_size, input_len + output_len, embedding_dim); 
        # 'perp': the input sequence perplexities tensor, of shape (batch_size,)
        probs = torch.softmax(logits, dim=-1)
        # For each batch, output, converts the sequence of logits (of length 'vocab_size') in the 'logits' tensor to probabilities, using softmax.

        logits = (logits - logits.min(dim=-1)[0].unsqueeze(-1)) / (logits.max(dim=-1)[0].unsqueeze(-1) - logits.min(dim=-1)[0].unsqueeze(-1))
        # This appears to be normalising the logits for each batch/output embedding so they're all between 0 and 1. 
        # This is for ease of visualisation.

        perp_loss = perp.mean() * perp_reg
        # That's taking the mean perp value across all batches, then regularising it.

        if output_len > output_ix.shape[0]:
            target_logits = torch.stack([logits[:, :, ix] for ix in output_ix], dim=-1)
            target_logits = torch.max(target_logits, dim=-1)[0]
            # logits has shape (batch_size, output_len, vocab_size) 
            # We throw out everything in the final dimension except those logits corresponding to indices of tokens in the target_ouput
            # This gives tensor with shape (batch_size, output_len, output_ix.shape[0])
            # We then take the maximum of those logits for each batch, output; this gives shape (batch_size, output_len)
            # The [0] returns just the maximum (torch.max returns max, indices tuple).
            target_probs = torch.stack([probs[:, :, ix] for ix in output_ix], dim=-1)
            target_probs = torch.max(target_probs, dim=-1)[0]
            # This does the analogous thing for probs.

        else:
            target_logits = torch.stack([logits[:,i, ix] for i, ix in enumerate(output_ix)], dim=-1)
            target_probs = torch.stack([probs[:,i, ix] for i, ix in enumerate(output_ix)], dim=-1)
            # This handles case where output_len == output_ix.shape[0]
            # target_logits isnow of shape (batch_size, output_len)
            # output_len < output_ix.shape[0] was dealt with in line 133
            
        token_dist = torch.stack([torch.stack([closest_tokens(e)[2].squeeze(-1) for e in input[b]]) for b in range(batch_size)])
        # As far as I can tell, this creates a tensor of shape (batch_size, input_len, 1) which gives distance to nearest
        # legal token embedding for each input embedding in each batch
        mean_token_dist = token_dist.mean() * dist_reg
        # A single scalar value, taking mean across the batch and input embeddings? 


        # There are currently four loss types, many more could be introduced.
        # log_prob_loss is the current default.
        if loss_type == 'logit_loss':
            loss = torch.mean(1-target_logits)
        elif loss_type == 'log_prob_loss':
            loss = -torch.log(target_probs).mean()
        elif loss_type == 'prob_loss':
            loss = 1-torch.mean(target_probs)
        elif loss_type == 'CE':
            loss = torch.nn.functional.cross_entropy(logits.swapaxes(-1,-2), output_ix.repeat(batch_size, 1))

        else:
            print(loss_type + 'is not implemented.')
            return

        total_loss = torch.stack([mean_token_dist, loss, perp_loss]).mean()
        # This is this just (mean_token_dist + loss + perp_loss)/3 tensorised across batches, yes?

        total_losses.append(total_loss.detach().cpu().data)
        losses.append(loss.detach().cpu().data)
        dists.append(mean_token_dist.detach().cpu().data)
        perps.append(perp_loss.detach().cpu().data)
        # these four lists were intialised above. We're appeneding to the list each epoch. All are scalars.

        closest_ix = torch.stack([torch.stack([closest_tokens(e)[1] for e in b]) for b in input]).squeeze(-1)
        # This is similar to above, but building a tensor of indices of nearest embeddings, rather than distances.
        # Iterates over batches, and for each batch iterates over embeddings, giving tensor of shape (batch_size, input_len).

        model_outs = model.generate(closest_ix, max_length = output_len+input_len)
        # The 'closest_ix' tensor is passed as the initial input sequence to the model, 
        # and the max_length parameter specifies the maximum length of the total sequence to generate.
        # The output sequence will be terminated either when the end-of-sequence token is generated 
        # or when the maximum length is reached, whichever occurs first.
        # 
        # The output of the model.generate method will be a tuple containing the generated sequences and the model's internal states. 
        # The generated sequences will be stored in a tensor of shape (batch_size, output_len+input_len). 
        # Each element of the tensor will be a sequence of vocab indices with a length of at most output_len+input_len.
        
        for b in range(batch_size):
        # iterate over batches  
            if output_len > output_ix.shape[0]:
                if target_output in tokenizer.decode(model_outs[b][input_len:]):
                    done = tokenizer.decode(model_outs[b][:input_len])
                    optimised_inputs.add(done)
                # model_outs[b][input_len:], for a batch b, is only looking at the *output* embeddings 
                # we decode these as tokens... is the target_output a substring?
                # if so, we print the target_output and the decoded string that contains it
                # 'done' is the string version of the model's output for given input, we add this to set 'optimised_inputs'.

                if rand_after:
                    input.data[b] = torch.rand_like(input[b])
                    # This will require new normalisation function.
                    # The idea here seems to be randomly re-initialise the input tensor once we've found an optimised input,
                    # input.data is the tensor version of the 'input' Parameter object. Current values, without gradient!
                    # That's of shape (batch_size, input_len, embedding_dim)

            if tokenizer.decode(model_outs[b][input_len:]) == target_output:
                done = tokenizer.decode(model_outs[b][:input_len])
                optimised_inputs.add(done)
                # model_outs[b][input_len:], for a batch b, is only looking at the *output* embeddings 
                # we decode these as tokens... is the target_output equal to output string?
                # Nothing printed in this case.
                # 'done' is the string version of the model's output for given input, we add this to set 'optimised_inputs'.
                if rand_after:
                    input.data[b] = torch.rand_like(input[b])
                    # Random re-initialisation (if 'rand_after' set to True)

  
        if ((e+1) % w_freq == 0) or done and return_early:
            display.clear_output(wait=True)  
        # Every w epochs we write to log, unless we have found an optimised input before that and 'return_early' == True. 
        # I'm still not entirely sure about the idea of 'return_early'.

            if plt_loss:
                plt.plot(range(len(total_losses)), total_losses, label='Total Loss', color='black')
                plt.plot(range(len(losses)), losses, label='Output Loss')
                plt.plot(range(len(dists)), dists, label='Emb Dist Loss')
                plt.plot(range(len(perps)), perps, label='Perp Loss')
                plt.yscale('log')
                plt.legend()

                plt.show()

            print('Inputs found: ', optimised_inputs)
            print('{}/{} Output Loss: {} Emb Dist Loss: {} Perp Loss: {} LR: {}'.format(e+1, epochs, loss, mean_token_dist, perp_loss, optimiser.param_groups[0]['lr']))
            if verbose == 3:
                print('Target Probs: {}\nTarget Logits: {}\nInput Dists: {}\nInput Perplexity: {}\n'.format(target_probs.detach().cpu().numpy(), target_logits.detach().cpu().numpy(), token_dist.detach().cpu().numpy(), perp.detach().reshape(-1).cpu().numpy()))
            # Optimised inputs and additional information are printed as part of log

            for b in range(batch_size):
                if verbose > 0:
                    if verbose == 2:
                        print(b, repr(' Raw embeddings: {}'.format(''.join([closest_tokens(e)[0][0] for e in emb[b]]))))
                        # Change name to clarify?
                        # Input embeddings get pushed though model to produce output tokens...
                        # ...then input embeddings get snapped to nearest tokens...
                        # ...then all these 'input' and output tokens get concatenated and printed
                        # closest_tokens(e)[0] is a list of tokens, closest_tokens(e)[0][0] is the first (closest) of these
                        # these get joined with separator '' (most tokens come with a leading blank space)
                    print(b, repr(' Closest embeddings: {}'.format(tokenizer.decode(model_outs[b]), '\n')))
                        # Change name to clarify?
                        # Here the input embeddings have already been snapped to the nearest tokens, THEN pushed though the model
                        # Therefore the first input_len tokens in 'Raw embeddings' and 'Closest embeddings' will be the same...
                        # ...but the remaining output_len tokens may well differ.
                else:
                    print(repr(tokenizer.decode(model_outs[b])), end=' ')
                    # The least verbose printed output. The 'end' parameter is used to specify the end-of-line string that is appended to the output. 
                    # By default, this is a newline character, but in this case it has been set to a single space character, 
                    # so the output will be separated by spaces rather than newlines.

            if done and return_early:
                print('\nOptimised Input: "{}"'.format(done))
                return optimised_inputs
                # we know optimised_inputs set contains a single element in this case
            
        optimiser.zero_grad()
        total_loss.backward()
        optimiser.step()
        # I assume these three lines are standard NN optimisation stuff?

        if lr_decay:
            scheduler.step(total_loss)
         # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=20, cooldown=20, factor=0.5) gets used if lr_decay == True
    
    return optimised_inputs
    # that's a set of strings


In [None]:
# attempt to implement k-means algorithm

def kmeans(num_clusters):

    empty_cluster = True

    # euclidean distance threshold to break out of loop
    threshold = 0

    while empty_cluster == True: 
        # randomly generate num_clusters centroids, stack these into a tensor of shape (num_clusters, 768)
        # then normalise to vocab embedding span
        centroids = torch.rand(num_clusters, embedding_dim).to(device)
        centroids = normalise(centroids,[word_embeddings.min(dim=0)[0], word_embeddings.max(dim=0)[0]])

        distances = torch.cdist(word_embeddings, centroids, p=2)
        # This will be of shape (vocab_len, num_clusters), recording distances of all token embeddings from each of the centroid.
        closest_distance, closest_centroid = torch.min(distances, dim = -1)
        # These will be of shape (vocab_len,), recording the distance of each token embedding to nearest centroid, and the index of that centroid.

        clusters = []
        mean_distances = []
        # We iterate over the centroids:
        for i in range(num_clusters):
            mask = closest_centroid == i  # This builds a Boolean mask over the complete set of tokens, True if a token's nearest centroid is the ith.
            clusters.append(word_embeddings[mask]) 
            # word_embeddings[mask] is a subtensor of the (50257, 768) shape tensor involving only tokens nearest ith centroid.
            # This subtensor gets appended to the list 'clusters', which will end up with num_clusters elements.

        cluster_sizes = [clusters[i].shape[0] for i in range(num_clusters)]  
        if 0 not in cluster_sizes:
            empty_cluster = False   # We break out of our loop when none of the clusters have size zero
                                    # So every centroid must now be the closest of the centroids to at least one token embedding.

    # Now we have a set on num_clusters non-empty clusters, we begin iterating centroid positions:
    centroids_stable = False
    iterations = 0

    while centroids_stable == False:
        iterations += 1
        distances = torch.cdist(word_embeddings, centroids, p=2)
        # This will be of shape (vocab_len, num_clusters), recording distances of token embeddings from each of the centroid.
        closest_distance, closest_centroid = torch.min(distances, dim = -1)
        # These will be of shape (vocab_len,), recording the distance of each token embedding to nearest centroid, and the index of that centroid.

        clusters = []
        mean_distances = []
        # We iterate over the centroids, building centroid-wise lists of clusters and mean distances 
        for i in range(num_clusters):
            mask = closest_centroid == i  # This builds a Boolean mask over the complete set of tokens, True if a token's nearest centroid is the ith.
            clusters.append(word_embeddings[mask]) 
            # word_embeddings[mask] is a subtensor of the shape (50257, 768) word_embeddings tensor involving only tokens nearest ith centroid.
            # This subtensor gets appended to the list 'clusters' (num_clusters elements).
            mean_distances.append(closest_distance[mask].mean().item()) 
            # closest_distance[mask] is a subtensor of the (50257,) shape tensor involving only tokens nearest ith centroid...
            # ... and closest_distance[mask].mean() is mean distance of those tokens from ith centroid
            # This value gets appended to the list 'mean_distances' (num_clusters elements).

        new_centroids = []
        for i in range(num_clusters):
            new_centroids.append(clusters[i].mean(dim=0))
            # clusters[i].mean(dim=0) is the centroid of the set of vectors encoded in the shape (,768) tensor clusters[i] 
            # these all get put into the list new_centroids

        new_centroids = torch.stack(new_centroids)
        # ... and the list gets stacked into a tensor of shape (num_clusters, 768)

        # We now compute the euclidean distance between old and new centroids
        distance = torch.norm(new_centroids - centroids, dim=-1)
        print('iteration: ', iterations, '\n max distance between old and new centroids:', torch.max(distance).item())

        if torch.max(distance) <= threshold:
            centroids_stable == True
            break   # centroids have stabilised, so we break out of the loop
        centroids = new_centroids # otherwise we're still outside the distance threshold, so keep iterating
    
 
    for i in range(len(clusters)):
        print('Cluster', i, ' contains ', len(clusters[i]), ' embeddings.')
        print('Cluster', i, 'has centroid ', centroids[i][range(8)], '...')
    for j in range(num_clusters):
        # sanity check this by finding the embedding in the cluster nearest the centorid
        cluster = clusters[j]
        centroid = centroids[j].unsqueeze(0) # adding a dimension
        # tensor of distances from all token embeddings in clusters[j] to embedding centroids[j]
        cluster_distances = torch.norm(cluster - centroid, dim=1)

        min_distance, min_index = torch.min(cluster_distances, dim=0)            
        print('Smallest distance from centroid', j, 'to a token embedding in cluster', j, '= ', min_distance.item())
        print('The relevant token is ', closest_tokens(clusters[j][min_index])[0])
        print('Its embedding starts ', clusters[j][min_index][range(8)], '...')
        print('Closest token from entire vocab to centroid', j,':', closest_tokens(centroids[j])[0],', vocab index', closest_tokens(centroids[j])[1].item())
        print('Distance from centroid', j, 'to this closest token =', torch.norm(word_embeddings[closest_tokens(centroids[j])[1].item()] - centroid))
        print('\n')
    print('Distance from externalToEVA to centroids = ', [torch.norm(word_embeddings[30212] - centroids[j]).item() for j in range(25)])
    return centroids

In [None]:
kmeans(25)

iteration:  1 
 max distance between old and new centroids: 9.38183879852295
iteration:  2 
 max distance between old and new centroids: 0.857846736907959
iteration:  3 
 max distance between old and new centroids: 0.9770517349243164
iteration:  4 
 max distance between old and new centroids: 0.6848060488700867
iteration:  5 
 max distance between old and new centroids: 0.4010399281978607
iteration:  6 
 max distance between old and new centroids: 0.25260230898857117
iteration:  7 
 max distance between old and new centroids: 0.26106947660446167
iteration:  8 
 max distance between old and new centroids: 0.3878091275691986
iteration:  9 
 max distance between old and new centroids: 0.5297049880027771
iteration:  10 
 max distance between old and new centroids: 0.27062681317329407
iteration:  11 
 max distance between old and new centroids: 0.13183221220970154
iteration:  12 
 max distance between old and new centroids: 0.08200440555810928
iteration:  13 
 max distance between old and n

tensor([[-0.0271, -0.0444,  0.1192,  ...,  0.0087,  0.0022,  0.0960],
        [-0.0097, -0.0867,  0.2000,  ...,  0.0221,  0.0038,  0.0425],
        [ 0.0087, -0.0629,  0.1017,  ...,  0.0319, -0.0032,  0.0456],
        ...,
        [ 0.0060, -0.0658,  0.1873,  ...,  0.0349,  0.0032,  0.0447],
        [ 0.0136, -0.0706,  0.1417,  ...,  0.0286, -0.0112,  0.0104],
        [ 0.0091, -0.0927,  0.0833,  ...,  0.0421, -0.0138,  0.0427]],
       device='cuda:0', grad_fn=<StackBackward0>)

In [None]:
ix = tokenizer.encode("")
# list of 'vocab indices'
print(ix)
print([tokenizer.decode(i) for i in ix])
# prints reconstruction of input string
print(len(ix))
# prints number of tokens
output_len=2
model_out = model.generate(torch.tensor(ix).unsqueeze(0).to(device), max_length = output_len + len(ix))
print(tokenizer.decode(model_out[0]))
# pushes input string throught GPT2 (or whichever model) iteratively producing output_len number of tokens, then prints input + output.

In [None]:
from time import time
target_output = " a lot of data."
input_len = 3

tic = time()
oi = optimise_input(base_input=True, 
                    plt_loss=False,
                    verbose=2, 
                    epochs=500, 
                    lr_decay=False,
                    return_early=False, 
                    lr=0.1, 
                    batch_size=20, 
                    target_output=target_output, 
                    output_len=4,
                    input_len=input_len, 
                    w_freq=20, 
                    dist_reg=1, 
                    perp_reg=0,
                    loss_type='log_prob_loss',
                    noise_coeff = 0.75)
toc = time()
tt = toc - tic
print('Time Taken: ', tt)

In [None]:
experiments = [{'base_input': True, 
                'plt_loss': False, 
                'verbose': 1, 
                'epochs': 1000, 
                'lr_decay': False, 
                'return_early': False, 
                'lr': 0.1, 
                'batch_size': 50, 
                'target_output': ' a lot of data', 
                'output_len': 4, 
                'input_len': 3, 
                'w_freq': 20, 
                'dist_reg': 1, 
                'perp_reg': 0, 
                'loss_type': 'log_prob_loss',
                'note':''}
                ]


experiment_log = {}

In [None]:

for e in experiments:
    tick = time()
    results = optimise_input(**e)
    tock = time()
    rt = tock - tick
    results.update({'runtime':rt})
    
    with open("backwards_results.json","a") as f:
        f.write(json.dumps({tick:results}))


In [None]:
# This runs 500 batches of 50 and keeps track of the most common closest tokens to centroid embeddings, and how many appeareances they make
token_counts = torch.zeros(vocab_len)
for j in range(500):
    print('batch', j)
    centroids = kmeans(50)
    for i in range(50):
        token_counts[closest_tokens(centroids[i])[1].item()] +=1

values, indices = torch.sort(token_counts, descending=True)
print(indices[:50], values[:50])

In [None]:
most_common_tokens_idxs = [30212,   187,   195,   216,   182,   179,   213, 39820,   199,   124,
          208,   125, 23090,   554, 30208,   281,  3607,  7003,   192, 37528,
        15524,   217, 39752, 42089,   183,   818,   210,   201,   209,   207,
          211,   206,  1026,   189,   190,  1315,   219,   205,   212,   203,
          287,   188, 30898, 45544, 14827,   218, 30897,   202,   181, 30905]
most_common_tokens = [tokenizer.decode(i) for i in most_common_tokens_idxs]
print(most_common_tokens)
print(word_embeddings[30212])

In [None]:
kmeans(50)



iteration:  1 
 max distance between old and new centroids: 9.975008010864258
iteration:  2 
 max distance between old and new centroids: 1.3475738763809204
iteration:  3 
 max distance between old and new centroids: 1.0740748643875122
iteration:  4 
 max distance between old and new centroids: 0.8603959679603577
iteration:  5 
 max distance between old and new centroids: 0.9484898447990417
iteration:  6 
 max distance between old and new centroids: 0.6111389994621277
iteration:  7 
 max distance between old and new centroids: 0.33316004276275635
iteration:  8 
 max distance between old and new centroids: 0.25298264622688293
iteration:  9 
 max distance between old and new centroids: 0.1368548721075058
iteration:  10 
 max distance between old and new centroids: 0.1660463809967041
iteration:  11 
 max distance between old and new centroids: 0.18448098003864288
iteration:  12 
 max distance between old and new centroids: 0.3939077854156494
iteration:  13 
 max distance between old and n

tensor([[ 0.0104, -0.0040,  0.1911,  ...,  0.0722,  0.0364,  0.0326],
        [-0.0345, -0.0401,  0.1126,  ..., -0.0033,  0.0024,  0.0945],
        [ 0.0398, -0.0695,  0.1101,  ...,  0.0185,  0.0169,  0.0675],
        ...,
        [-0.0405, -0.0464,  0.1556,  ..., -0.0407, -0.0773,  0.1004],
        [ 0.0124, -0.0839,  0.1270,  ...,  0.0342,  0.0066,  0.0149],
        [ 0.0070, -0.0486,  0.0823,  ...,  0.0153, -0.0139,  0.0207]],
       device='cuda:0', grad_fn=<StackBackward0>)