## Visualizing the embeddings in 2D

We will use t-SNE to reduce the dimensionality of the embeddings from 1536 to 2. Once the embeddings are reduced to two dimensions, we can plot them in a 2D scatter plot. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb).

### 1. Reduce dimensionality

We reduce the dimensionality to 2 dimensions using t-SNE decomposition.

In [7]:
import pandas as pd
from sklearn.manifold import TSNE
import numpy as np

# Load the embeddings
datafile_path = "data/brightspot_articles_with_embeddings.csv"
df = pd.read_csv(datafile_path)

# Convert to a list of lists of floats
matrix = np.array(df.embedding.apply(eval).to_list())

# Create a t-SNE model and transform the data
tsne = TSNE(n_components=2, perplexity=15, random_state=42, init='random', learning_rate=200)
vis_dims = tsne.fit_transform(matrix)
vis_dims.shape

(41, 2)

### 2. Plotting the embeddings

We colour each review by its star rating, ranging from red to green.

We can observe a decent data separation even in the reduced 2 dimensions.

In [35]:
import plotly.express as px
import pandas as pd

# Assuming you have the 2D embeddings stored in `vis_dims`
# Assuming your DataFrame is named `df` and contains a column 'Label' for the labels

# Extract the labels from the DataFrame
labels = df['Label'].tolist()
authors = df['Authors'].tolist()

# Create a new DataFrame with the coordinates and labels
plot_data = pd.DataFrame({'x': vis_dims[:, 0], 'y': vis_dims[:, 1], 'labels': labels, 'authors': authors})

# Create a scatter plot with hover labels
fig = px.scatter(plot_data, x='x', y='y', hover_data=['labels', 'authors'])


# Customize the hovertemplate to define the content and format of the hover labels
fig.update_traces(hovertemplate='<b>%{customdata[0]}</b><br>%{customdata[1]}')

# Customize the plot layout and appearance
fig.update_layout(
    title='2D Embeddings Visualization',
    xaxis_title='Dimension 1',
    yaxis_title='Dimension 2'
)
fig.update_traces(marker=dict(size=8))

# Display the plot
fig.show()
