In [1]:
from transformers import BertModel, BertTokenizer
import numpy as np
from sklearn.manifold import TSNE
import pandas as pd
import torch
import plotly.express as px


# import hvplot.pandas
# import holoviews as hv

# from holoviews import dim, opts

# hv.extension('matplotlib')

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
titles = pd.read_csv("../data/nes_titles.csv")['Title']

In [5]:
word_to_index = {}
index_to_word = {}
for title in titles:
    for word in title.split():
        if word not in word_to_index:
            index = len(word_to_index)
            word_to_index[word] = index
            index_to_word[index] = word
words = list(word_to_index.keys())

### Initial Tokenization Embedding

In [13]:
# Load pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



In [14]:
# Select words you want to visualize
# words = ['king', 'queen', 'man', 'woman', 'doctor', 'nurse']
words = list(word_to_index.keys())

# Extract word embeddings
word_embeddings = []
for word in words:
    inputs = tokenizer(word, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
    word_embeddings.append(embeddings)

word_embeddings = np.array(word_embeddings)

In [15]:
word_embeddings.shape

(2710, 768)

In [16]:
with open('../data/BERT_embeddings.npy', 'wb') as f:
    np.save(f, word_embeddings)

### Preloaded

In [6]:
word_embeddings = np.load('../data/BERT_embeddings.npy')
word_embeddings.shape

(2710, 768)

In [10]:
# Reduce dimensionality to 3D using t-SNE
tsne = TSNE(n_components=3, perplexity=5, random_state=42)
word_embeddings_3d = tsne.fit_transform(word_embeddings)
word_embeddings_3d.shape

Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



(2710, 3)

In [31]:
shuffled_indices = np.random.permutation(len(words))
num_rows_to_sample = 100

In [32]:
# Take the first 500 shuffled indices to get the sampled rows and words
sampled_rows = word_embeddings_3d[shuffled_indices[:num_rows_to_sample]]
sampled_words = [words[i] for i in shuffled_indices[:num_rows_to_sample]]

In [33]:
print(word_embeddings_3d.shape)
print(sampled_rows.shape)

(2710, 3)
(100, 3)


In [34]:
# sample_indices = np.random.choice(len(word_embeddings_3d), size=10, replace=False)

In [35]:
fig = px.scatter_3d(x=sampled_rows[:, 0], y=sampled_rows[:, 1], 
                    z=sampled_rows[:, 2], text=sampled_words)
fig.show()

In [24]:
# import plotly.graph_objects as go

# # Assuming word_embeddings_3d is a numpy array containing your word embeddings
# # and words is a list of corresponding words

# # Create 3D scatter plot
# fig = go.Figure(data=[go.Scatter3d(
#     x=word_embeddings_3d[:, 0],
#     y=word_embeddings_3d[:, 1],
#     z=word_embeddings_3d[:, 2],
#     mode='markers',
#     marker=dict(
#         size=8,
#         color='rgb(0,0,255)',  # Change color if needed
#     ),
#     text=words,  # Assign words as text
# )])

# # Add annotations
# annotations = []
# for i, word in enumerate(words):
#     annotation = dict(
#         x=word_embeddings_3d[i, 0],
#         y=word_embeddings_3d[i, 1],
#         z=word_embeddings_3d[i, 2],
#         text=word,
#         showarrow=False,
#         font=dict(
#             color='black',  # Adjust font color if needed
#             size=12,
#         ),
#     )
#     annotations.append(annotation)

# fig.update_layout(
#     scene=dict(
#         annotations=annotations,
#     ),
# )

# fig.show()
