In [None]:
%pip install -r requirements.txt

In [67]:
import json
import numpy as np
import plotly.express as px
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE
import chromadb
from pprint import pprint

In [81]:
# load the sample data
# with open('assets/posts.json', 'r') as f:
with open('assets/gd.json', 'r') as f:
    posts = json.load(f)

# Load the pre-trained model
model = SentenceTransformer('all-MiniLM-L6-v2')

In [82]:
# get embeddings per topic
j_embeddings = {}
for topic, v in posts.items():
    j_embeddings[topic] = model.encode(v)

In [83]:
# combine embeddings in single array
# topics = ["apple", "ai", "3dprinting", "bioinformatics", "beer"]
topics = ["svchost.exe", "wscript.exe", "msiexec.exe", "tiworker.exe"]
embeddings = np.vstack((j_embeddings[topics[0]], j_embeddings[topics[1]]))
for topic in topics[2:]:
    embeddings = np.vstack((embeddings, j_embeddings[topic]))

In [84]:
# Perform TSNE to reduce to 2 components
tsne_model = TSNE(n_components=3, random_state=42)
tsne_embeddings_values = tsne_model.fit_transform(embeddings)

num_elements_per_topic = 51
num_topics = int(embeddings.shape[0]/num_elements_per_topic)

col_topics = [element for element in topics for _ in range(num_elements_per_topic)]

fig = px.scatter_3d(
    x = tsne_embeddings_values[:,0],
    y = tsne_embeddings_values[:,1],
    z = tsne_embeddings_values[:,2],
    color = col_topics,
)

fig.update_traces(marker=dict(size=13))  # Increase the marker size uniformly


fig.update_layout(
    scene=dict(
        xaxis=dict(showticklabels=False, title=''),
        yaxis=dict(showticklabels=False, title=''),
        zaxis=dict(showticklabels=False, title=''),
    ),
    #showlegend=False,
    autosize=False,
    #width=600,  # Width of the plot
    #height=600,  # Height of the plot
    margin=dict(l=50, r=50, b=50, t=50, pad=4)  # Margins
)
fig.show()

In [None]:
# setup vector database

client = chromadb.Client()

collection_name = "stackoverflow-dump"

try:
    client.delete_collection(name=collection_name)
    print(f"COLLECTION {collection_name} DELETED")
except:
    print(f"COLLECTION {collection_name} DIDNT EXIST YET")

collection = client.create_collection(
      name=collection_name,
      metadata={"hnsw:space": "cosine"}
  )

In [None]:
# fill vector database
for k in j_embeddings.keys():
    print(f"Add stuff for topic {k}")
    collection.add(
        embeddings = j_embeddings[k],
        documents=posts[k],
        metadatas=[{"topic": k}]*num_elements_per_topic,
        ids=[f"{i:02}__{k}" for i in range(num_elements_per_topic)],
    )

In [None]:
collection.get()

In [None]:
results = collection.query(
    query_texts=["What is in the sky?"],
    n_results=2,
)

pprint(results)

In [None]:
results = collection.query(
    query_texts=["Are there any posts about brewing?"],
    n_results=2,
)

pprint(results)