In [None]:
import chromadb
from chromadb.utils import embedding_functions
import embedClustering
import numpy as np
import pandas as pd
import tensorflow as tf
import plotly.express as px


In [None]:
tf.config.list_physical_devices('GPU')


In [None]:
CHROMA_DATA_PATH = "chroma_data/"
EMBED_MODEL = "distiluse-base-multilingual-cased-v1"
COLLECTION_NAME = "WW2-Languages-Wiki-Limited"

client = chromadb.PersistentClient(path=CHROMA_DATA_PATH)

In [None]:
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
    model_name=EMBED_MODEL,
    device="mps"
)

In [None]:
collection = client.get_collection(COLLECTION_NAME, embedding_function=embedding_func)

In [None]:
collection.peek(1)

In [None]:
query_results = collection.query(
    query_texts=["The End"],
    n_results=5
)
query_results

In [None]:
embeddings = np.array(collection.get(include=['embeddings'])['embeddings'])
len(embeddings), len(embeddings[0])

In [None]:
documents = np.array(collection.get(include=['documents'])['documents'])
languages = np.array([lang["language"] for lang in collection.get(include=['metadatas'])["metadatas"]])

In [None]:
reduced_embeddings = embedClustering.tsneReduceEMB(embeddings)

In [None]:
k_nearest_neighbors = 3

In [None]:
optimal_eps = embedClustering.findEPS(reduced_embeddings, k=k_nearest_neighbors)
optimal_eps


In [None]:
labels = embedClustering.hdbscanEMB(reduced_embeddings, min_samples=k_nearest_neighbors)

In [None]:
labels = embedClustering.dbscanEMB(reduced_embeddings, eps=2, min_samples=k_nearest_neighbors)

In [None]:
fig = embedClustering.plot_clusters(reduced_embeddings, documents, labels)

In [None]:

tsne_df = pd.DataFrame(reduced_embeddings, columns=['Component 1', 'Component 2'])
tsne_df['text'] = [doc[:40] for doc in documents]
tsne_df['cluster'] = labels
tsne_df['language'] = languages

#Arabic, Chinese, Dutch, English, French, German, Italian, Korean, Polish, Portuguese, Russian, Spanish, Turkish

selected_languages = ["ar", "zh", "nl", "en", "fr", "de", "it", "ko", "pl", "pt", "ru", "es", "tr"]
tsne_df = tsne_df[tsne_df["language"].isin(selected_languages)]


fig = px.scatter(
    tsne_df,
    x='Component 1',
    y='Component 2',
    color='language',
    hover_data=['text', "language"],
    title="Languages"
    #color_continuous_scale=px.colors.diverging.BrBG
)


fig.show()

In [None]:
fig.write_html("./exportClusters/interactive_plot2.html")
