In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from transformers import BertTokenizer, BertModel, AutoTokenizer, T5Tokenizer, T5EncoderModel, AutoModel
import torch
import json
import tqdm

In [None]:
final_n = 2000
model_name = 'google/flan-t5-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

In [None]:
def get_t5_embedding(word):
    input_ids = tokenizer.encode(word, add_special_tokens=False, return_tensors='pt')
    # print(input_ids)
    with torch.no_grad():
        outputs = model.encoder.embed_tokens(input_ids)
        # print(outputs)
        pooled_embedding = torch.mean(outputs, dim=1).squeeze().numpy()
    return pooled_embedding

In [None]:
def generate_embeddings(word_dict):
    embeddings = []
    embed_dict = {}
    for word in word_dict.keys():
        embedding = get_t5_embedding(word)
        embed_dict[word] = embedding
        embeddings.append(embedding)
    return np.array(embeddings), embed_dict

In [None]:
def remove_special_token(text: str, special_token: str) -> str:
    return text.replace(special_token, '')

In [None]:
t5_tokenizer = AutoTokenizer.from_pretrained(model_name)

vocab = t5_tokenizer.get_vocab()
special_tokens = [t5_tokenizer.unk_token, t5_tokenizer.pad_token, t5_tokenizer.sep_token, t5_tokenizer.cls_token]
special_space = '▁'
vocab = {word: index for word, index in vocab.items() if word not in special_tokens and special_space in word}
word_embeddings, embed_dict = generate_embeddings(vocab)


In [None]:
vocab_key, vocab_value = list(vocab.keys()), list(vocab.values())
print(len(vocab_key), len(vocab_value))

In [None]:
# save embed_dict to json
new_embed_dict = {}
for k, v in embed_dict.items():
    new_embed_dict[k] = v.tolist()
if 't5' in model_name:
    model_name = model_name.replace('/', '-')
    with open(str(final_n)+'-'+model_name+'-embed_dict.json', 'w') as fp:
        json.dump(new_embed_dict, fp, indent=4, ensure_ascii=False)
else:
    with open(str(final_n)+'-'+model_name+'-embed_dict.json', 'w') as fp:
            json.dump(new_embed_dict, fp, indent=4, ensure_ascii=False)

In [None]:
print(len(new_embed_dict['▁hello']))

In [None]:
new_embed_dict['▁hello']


In [None]:
# visualize the clusters
from sklearn.decomposition import PCA
print(word_embeddings.shape)
pca = PCA(n_components=2)
pca.fit(word_embeddings)
pca_embeddings = pca.transform(word_embeddings)
pca_embeddings.shape


In [None]:
# implement kmeans clustering
from sklearn.cluster import KMeans
# cluster on the word embeddings
kmeans = KMeans(n_clusters=final_n, random_state=0).fit(word_embeddings)
kmeans

In [None]:
# list of the centroid words
centroid_words = []
for centroid in kmeans.cluster_centers_:
    index = np.argmin(np.linalg.norm(word_embeddings - centroid, axis=1))
    print(index)
    centroid_words.append(vocab_key[index])
centroid_words

In [None]:
# use the new centroid words as a new vocab
new_vocab = {word: vocab[word] for index, word in enumerate(centroid_words)}
with open(str(final_n)+'-'+model_name+'-kmeans-vocab.json', 'w') as fp:
    json.dump(new_vocab, fp, indent=4, ensure_ascii=False)

In [None]:
# plot the clusters
plt.scatter(pca_embeddings[:, 0], pca_embeddings[:, 1], c=kmeans.labels_, cmap='rainbow')
plt.show()

In [None]:
# visualize the clusters using t-sne
tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, n_iter=1000, random_state=0)
tsne_embeddings = tsne.fit_transform(word_embeddings)


In [None]:
# plot the cluster
plt.scatter(tsne_embeddings[:, 0], tsne_embeddings[:, 1], c=kmeans.labels_, cmap='rainbow')

In [None]:
vocab

In [None]:
word_embeddings[6].shape