In [10]:
# !pip install umap-learn
# !pip install plotly

import torch
from transformers import BertTokenizer, BertModel

import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP

import matplotlib.pyplot as plt
import plotly.graph_objects as go

import gradio as gr

In [11]:
MODEL = 'bert-base-uncased'

### Get BERT Model

In [12]:
# Load the pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained(MODEL)
model = BertModel.from_pretrained(MODEL)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

### Example with words

In [14]:
# Input occupation words
occupations = ["doctor", "nurse", "engineer", "teacher"]

# Create input texts for different genders
male_texts = [f"He is a {occupation}." for occupation in occupations]
female_texts = [f"She is a {occupation}." for occupation in occupations]

# Tokenize and convert input texts to IDs
male_input_ids = [tokenizer.encode(text, add_special_tokens=True) for text in male_texts]
female_input_ids = [tokenizer.encode(text, add_special_tokens=True) for text in female_texts]

# Convert the input IDs to PyTorch tensors
male_input_tensors = [torch.tensor([input_ids]) for input_ids in male_input_ids]
female_input_tensors = [torch.tensor([input_ids]) for input_ids in female_input_ids]

In [15]:
# Get the BERT model embeddings
with torch.no_grad():
    male_outputs = [model(input_tensor).last_hidden_state for input_tensor in male_input_tensors]
    female_outputs = [model(input_tensor).last_hidden_state for input_tensor in female_input_tensors]

# Calculate the average embeddings for each occupation and gender
male_avg_embeddings = [torch.mean(embeddings, dim=1) for embeddings in male_outputs]
female_avg_embeddings = [torch.mean(embeddings, dim=1) for embeddings in female_outputs]

# Calculate the cosine similarity between male and female average embeddings for each occupation
similarities = [torch.cosine_similarity(male_avg_emb, female_avg_emb) for male_avg_emb, female_avg_emb in
                zip(male_avg_embeddings, female_avg_embeddings)]

# Print the cosine similarities
for occupation, similarity in zip(occupations, similarities):
    print(f"Occupation: {occupation}, Cosine Similarity: {similarity.item()}")


Occupation: doctor, Cosine Similarity: 0.9412516951560974
Occupation: nurse, Cosine Similarity: 0.9453706741333008
Occupation: engineer, Cosine Similarity: 0.917951762676239
Occupation: teacher, Cosine Similarity: 0.9249525666236877


### Visualize word embeddings

In [16]:
import numpy as np
from sklearn.manifold import TSNE
import plotly.graph_objects as go

# Example sentences
sentences = [ "He plays the guitar very well.", "She is a doctor", "He is a doctor"]

# Tokenize the sentences
tokenized_sentences = [tokenizer.encode(sentence, add_special_tokens=True) for sentence in sentences]
input_tensors = [torch.tensor([input_ids]) for input_ids in tokenized_sentences]

# Get the BERT model embeddings for the sentences
with torch.no_grad():
    sentence_outputs = [model(input_tensor).last_hidden_state.squeeze(0).numpy() for input_tensor in input_tensors]

# Filter out non-word tokens
non_word_tokens = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]

filtered_embeddings = []
filtered_sentences = []
sentence_ids = []

sentence_id = 0
for sentence, input_ids in zip(sentence_outputs, tokenized_sentences):
    
    for i in range(len(input_ids)):
        if input_ids[i] not in non_word_tokens:
            filtered_sentences.append(tokenizer.decode(input_ids[i]))
            filtered_embeddings.append(sentence[i])
            sentence_ids.append(sentence_id)
    sentence_id += 1

# Convert filtered embeddings to NumPy array
filtered_embeddings = np.array(filtered_embeddings)

# Apply t-SNE to reduce the dimensionality of the embeddings
tsne_embeddings = TSNE(n_components=3, perplexity=5, random_state=42).fit_transform(filtered_embeddings)

# Create an interactive 3D scatter plot using Plotly
fig = go.Figure(data=go.Scatter3d(
    x=tsne_embeddings[:, 0],
    y=tsne_embeddings[:, 1],
    z=tsne_embeddings[:, 2],
    mode='markers',
    text=filtered_sentences,
    marker=dict(
        size=8,
        color=sentence_ids,
        colorscale='Viridis',
        opacity=0.8
    )
))

# Set plot layout
fig.update_layout(
    title="BERT Embeddings - t-SNE Visualization",
    scene=dict(
    )
)

# Show the interactive plot
fig.show()

### UI Example:

In [17]:
def generate_embeddings(sentences):
    sentences = sentences.split("\n")

    # Tokenize the sentences
    if isinstance(sentences, str):
        tokenized_sentences = [tokenizer.encode(sentences, add_special_tokens=True)]
    else:
        tokenized_sentences = [tokenizer.encode(sentence, add_special_tokens=True) for sentence in sentences]
    
    input_tensors = [torch.tensor([input_ids]) for input_ids in tokenized_sentences]

    # Get the BERT model embeddings for the sentences
    with torch.no_grad():
        sentence_outputs = [model(input_tensor).last_hidden_state.squeeze(0).numpy() for input_tensor in input_tensors]

    # Filter out non-word tokens
    non_word_tokens = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]

    filtered_sentences = []
    filtered_embeddings = []
    sentence_ids = []

    sentence_id = 0
    for sentence, input_ids in zip(sentence_outputs, tokenized_sentences):
        for i in range(len(input_ids)):
            if input_ids[i] not in non_word_tokens:
                filtered_sentences.append(tokenizer.decode(input_ids[i]))
                filtered_embeddings.append(sentence[i])
                sentence_ids.append(sentence_id)
        sentence_id += 1

    # Convert filtered embeddings to NumPy array
    filtered_embeddings = np.array(filtered_embeddings)

    return filtered_sentences, filtered_embeddings, sentence_ids



# Define the function for generating the t-SNE plot
def generate_tsne_plot(plot_type, sentences, method):

    filtered_sentences, filtered_embeddings, sentence_ids = generate_embeddings(sentences)        

    if len(filtered_sentences) > 10:
        perplexity = 6
    elif len(filtered_sentences) > 5:
        perplexity = 5
    elif len(filtered_sentences) > 1:
        perplexity = 1
    else:
        perplexity = len(filtered_sentences) - 1

    if plot_type == '3D':
        if method == 'PCA':
            # Apply PCA to reduce the dimensionality of the embeddings
            reduced_embeddings = PCA(n_components=3, random_state=42).fit_transform(filtered_embeddings)
        elif method == 't-SNE':
            # Apply t-SNE to reduce the dimensionality of the embeddings
            reduced_embeddings = TSNE(n_components=3, perplexity=perplexity, random_state=42).fit_transform(filtered_embeddings)
        elif method == 'UMAP':
            # Apply UMAP to reduce the dimensionality of the embeddings
            reduced_embeddings = UMAP(n_components=3, random_state=42).fit_transform(filtered_embeddings)

        # Create an interactive 3D scatter plot using Plotly
        fig = go.Figure(data=go.Scatter3d(
            x=reduced_embeddings[:, 0],
            y=reduced_embeddings[:, 1],
            z=reduced_embeddings[:, 2],
            mode='markers',
            text=filtered_sentences,
            marker=dict(
                size=8,
                color=sentence_ids,
                colorscale='Viridis',
                opacity=0.8
            )
        ))

    else:
        if method == 'PCA':
            # Apply PCA to reduce the dimensionality of the embeddings
            reduced_embeddings = PCA(n_components=2, random_state=42).fit_transform(filtered_embeddings)
        elif method == 't-SNE':
            # Apply t-SNE to reduce the dimensionality of the embeddings
            reduced_embeddings = TSNE(n_components=2, perplexity=perplexity, random_state=42).fit_transform(filtered_embeddings)
        elif method == 'UMAP':
            # Apply UMAP to reduce the dimensionality of the embeddings
            reduced_embeddings = UMAP(n_components=2, random_state=42).fit_transform(filtered_embeddings)

        # Create an interactive 3D scatter plot using Plotly
        fig = go.Figure(data=go.Scatter(
            x=reduced_embeddings[:, 0],
            y=reduced_embeddings[:, 1],
            mode='markers',
            text=filtered_sentences,
            marker=dict(
                size=8,
                color=sentence_ids,
                colorscale='Viridis',
                opacity=0.8
            )
        ))
    
    # Set plot layout
    fig.update_layout(
        title=f"BERT Embeddings - {method} {plot_type} Visualization",
        scene=dict(
        )
    )    

    # Return the plot figure as an HTML string
    return fig


with gr.Blocks() as demo:
    button = gr.Radio(label="Plot type",
                        choices=['3D', '2D'], value='3D')
    
    method = gr.Radio(label="Dimensionality Reduction Method",
                        choices=['t-SNE', 'PCA', 'UMAP'], value='t-SNE')

    # Define the input component
    input_text = gr.Textbox(lines=3, label="Enter one sentences per line", value="He is a nurse \nShe is a nurse")

    # Define the button component
    btn = gr.Button(value="Run")

    # Define the output component
    plot = gr.Plot(label="Embeddings Plot")

    # if plot changes, run the function
    button.change(generate_tsne_plot, inputs=[button, input_text, method], outputs=[plot])

    # if plot changes, run the function
    method.change(generate_tsne_plot, inputs=[button, input_text, method], outputs=[plot])

    # If button is clicked, run the function
    btn.click(generate_tsne_plot, inputs=[button, input_text, method], outputs=[plot])

    # load the interface:
    demo.load(generate_tsne_plot, inputs=[button, input_text, method], outputs=[plot])

    # Run the interface
    demo.launch()


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
