In [17]:
"""
# BERT Embedding Visualization in 3D
This project demonstrates how to visualize BERT (Bidirectional Encoder Representations from Transformers)
embeddings using 3D interactive plots. It utilizes the `transformers` library to fetch pre-trained BERT model
embeddings and `plotly` for interactive 3D visualization.

## Overview

BERT embeddings provide a powerful method for understanding textual data, capturing contextual relationships
between words in a way that simpler models cannot. This notebook aims to provide a visual insight into
these embeddings, allowing for an interactive exploration of their dimensional characteristics.
"""
%pip install -q transformers scikit-learn plotly

Note: you may need to restart the kernel to use updated packages.


In [24]:
import torch
from transformers import BertModel, BertTokenizer
import numpy as np
import plotly.graph_objects as go
from sklearn.decomposition import PCA

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load pre-trained model
model = BertModel.from_pretrained('bert-base-uncased')

# Encode text
text = """king rules the kingdom
queen loves the king
man drives the car
woman rides the bicycle
apple is a fruit
orange is a citrus fruit
I am learning about computer science
The student is eating and apple
""".split('\n')

tokens = {txt: tokenizer(txt, return_tensors='pt', padding=True, truncation=True) for txt in text}

# Extract embeddings
with torch.no_grad():
    embeddings = {txt: model(**tokens[txt])[0][:, 0, :].numpy() for txt in text}  # Use the [CLS] token

# Collect all embeddings into one matrix for PCA
all_embeddings = np.array([embeddings[txt].flatten() for txt in text])

# Reduce dimensions to 3D using PCA
pca = PCA(n_components=3)
reduced_embeddings = pca.fit_transform(all_embeddings)

# Plotting using plotly for interactive 3D plot
fig = go.Figure(data=[go.Scatter3d(
    x=reduced_embeddings[:, 0],
    y=reduced_embeddings[:, 1],
    z=reduced_embeddings[:, 2],
    mode='markers+text',
    text=text,
    marker=dict(
        size=5,
        color=np.linspace(0, 1, len(reduced_embeddings)),  # Color by position along the array
        colorscale='Viridis',  # Change this for different color schemes
        opacity=0.8
    ),
    textposition='top center'
)])

fig.update_layout(
    height=800,
    title_text='BERT Contextual Embeddings Visualized in 3D',
    scene=dict(
        xaxis_title='PCA Component 1',
        yaxis_title='PCA Component 2',
        zaxis_title='PCA Component 3',
        xaxis=dict(range=[-7, 2]),  # Adjust x-axis range
        yaxis=dict(range=[-10, 10]),  # Adjust y-axis range
        zaxis=dict(range=[-5, 5])   # Adjust z-axis range        
    )
)

fig.show()
