In [None]:
import numpy as np
import pickle
import umap
import matplotlib.pyplot as plt
import pandas as pd

from matplotlib.colors import to_hex
from pyprojroot import here
from sklearn.preprocessing import StandardScaler

from mpl_lego.labels import bold_text
from mpl_lego.style import use_latex_style

from normative_evaluation_llms_everyday_dilemmas import keys

In [None]:
use_latex_style()

In [None]:
df = pd.read_csv(here('data/normative_evaluation_everyday_dilemmas_dataset.csv'))

In [None]:
with open('../data/embedding.pkl', 'rb') as file:
    embeddings = pickle.load(file)

In [None]:
embeddings = np.vstack(list(embeddings.values()))

In [None]:
labels = []

for label in keys.LABEL_COLS:
    labels.append(list(df[label].values))

In [None]:
labels = np.concatenate(labels)

In [None]:
colors = np.repeat([f'C{idx}' for idx in range(8)], 10826)
colors_hex = [to_hex(color) for color in colors]

In [None]:
scaler = StandardScaler()
standardized_embeddings = scaler.fit_transform(embeddings)

In [None]:
n_neighbors = 30
min_dist = 0.3
n_components = 2
metric = 'cosine'

reducer = umap.UMAP(n_neighbors=n_neighbors,
                    min_dist=min_dist,
                    n_components=n_components,
                    metric=metric,
                    n_jobs=8,
                    random_state=2332)
reduced_embeddings = reducer.fit_transform(standardized_embeddings)

In [None]:
# Visualization
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

ax.scatter(
    reduced_embeddings[:, 0],
    reduced_embeddings[:, 1],
    s=0.5,
    alpha=0.2,
    c=colors)
ax.set_xlabel(bold_text("UMAP Dimension 1"), fontsize=12)
ax.set_ylabel(bold_text("UMAP Dimension 2"), fontsize=12)
ax.set_ylim([-5, 20])
ax.set_xlim([-5, 20])
for idx, label in enumerate(keys.MODEL_LABELS_PLOT):
    ax.scatter(-10,
               -10,
               s=20,
               alpha=1,
               c=f'C{idx}',
               label=bold_text(label))
ax.legend(loc='center left', bbox_to_anchor=(1.03, 0.5), prop={'size': 12})
plt.savefig('fig4_umap.png', bbox_inches='tight')
plt.show()