In [None]:
import json
import numpy as np
import plotly.express as px
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE

In [None]:
with open('posts.json', 'r') as f:
    posts = json.load(f)
    
# Load the pre-trained model
model = SentenceTransformer('all-MiniLM-L6-v2')

In [None]:
# get embeddings per topic
j_embeddings = {}
for topic, v in posts.items():
    j_embeddings[topic] = model.encode(posts[topic])

In [None]:
# combine embeddings in single array
topics = ["apple", "ai", "3dprinting", "bioinformatics", "beer"]
embeddings = np.vstack((j_embeddings[topics[0]], j_embeddings[topics[1]]))
for topic in topics[2:]:
    embeddings = np.vstack((embeddings, j_embeddings[topic]))

In [None]:
# Perform TSNE to reduce to 2 components
tsne_model = TSNE(n_components=2, random_state=42)
tsne_embeddings_values = tsne_model.fit_transform(embeddings)
num_elements_per_topic = 51
num_topics = int(embeddings.shape[0]/num_elements_per_topic)
col_topics = [element for element in topics for _ in range(num_elements_per_topic)]

fig = px.scatter(
    x = tsne_embeddings_values[:,0], 
    y = tsne_embeddings_values[:,1],
    color = col_topics,
)

fig.update_traces(marker=dict(size=13))  # Increase the marker size uniformly


fig.update_layout(
    xaxis=dict(showticklabels=False, title=''),
    yaxis=dict(showticklabels=False, title=''),
    showlegend=False,
    autosize=False,
    width=600,  # Width of the plot
    height=600,  # Height of the plot
    margin=dict(l=50, r=50, b=50, t=50, pad=4)  # Margins
)
fig.show()

In [None]:
import chromadb
from pprint import pprint

client = chromadb.Client()

collection_name = "stackoverflow-dump"

client.delete_collection(name=collection_name)

collection = client.create_collection(
      name=collection_name,
      metadata={"hnsw:space": "cosine"}
  )

In [None]:
collection.get()

In [None]:
for k in j_embeddings.keys():
    print(f"Add stuff for topic {k}")
    collection.add(
        embeddings = j_embeddings[k],
        documents=posts[k],
        metadatas=[{"topic": k}]*num_elements_per_topic,
        ids=[f"{i:02}__{k}" for i in range(num_elements_per_topic)],
    )

In [None]:
collection.get()

In [None]:
results = collection.query(
    query_texts=["What is in the sky?"],
    n_results=2,
)

pprint(results)