In [186]:
import sys
import os

sys.path.append(os.path.abspath("../"))

In [187]:
from agents.models import models

print("Available models:")
for i in models.keys():
    print(i)

Available models:
azure-gpt-4o
azure-gpt-4o-mini


In [188]:
from database.connection import PGConnection

conn = PGConnection().get_conn()
cursor = conn.cursor()

In [189]:
import pandas as pd


def get_df(cursor, table_name):
    cursor.execute(f"SELECT * FROM {table_name};")
    rows = cursor.fetchall()
    df = pd.DataFrame(rows, columns=[desc[0] for desc in cursor.description])
    return df


table = "goods"
df = get_df(cursor, table)
FILENAME = "./sample_data/goods_sample.json"
CLUSTERS = 35

In [190]:
import ast


def convert_to_list(vector_str):
    try:
        return ast.literal_eval(vector_str)
    except (ValueError, SyntaxError):
        return []


df["useembed_description"] = df["useembed_description"].apply(convert_to_list)

In [191]:
import numpy as np
from sklearn.cluster import KMeans

num_rows = df.shape[0]

X = np.vstack(df.useembed_description.to_numpy())

kmeans = KMeans(n_clusters=CLUSTERS)
c = kmeans.fit_predict(X)
df["cluster_kmeans"] = c

In [192]:
df.cluster_kmeans.value_counts().sort_index()

cluster_kmeans
0      45
1     299
2      36
3      71
4      45
5      79
6      30
7      81
8      31
9      62
10     69
11     65
12     19
13      6
14     14
15     64
16     36
17     27
18     52
19     64
20     16
21     70
22     20
23     58
24     29
25     36
26     47
27     62
28     16
29     65
30     45
31     63
32     50
33     41
34     40
Name: count, dtype: int64

## Visualization

In [193]:
import plotly.express as px
from plotly.express.colors import qualitative
from plotly.graph_objs import FigureWidget
from sklearn.manifold import TSNE

In [194]:
X_tsne = TSNE(n_components=2).fit_transform(X)

In [195]:
kmeans_tsne = KMeans(n_clusters=CLUSTERS)
c_tsne = kmeans.fit_predict(X_tsne)

In [196]:
X_tsne.shape

(1853, 2)

In [197]:
df["cluster_tsne"] = c_tsne

In [198]:
def print_samples(df, indices, n_samples=10):
    for _, row in df.iloc[indices].head(n_samples).iterrows():
        print(row.usevec_description)
        print("---------------------------------------")


# Create a scatter plot with hover annotations
annos = df.usevec_description.str.slice(0, 50)
fig = px.scatter(
    x=X_tsne[:, 0],
    y=X_tsne[:, 1],
    hover_name=annos,
    # color=c,
    color=c_tsne,
    color_discrete_sequence=qualitative.Set1,
    labels={"color": "Cluster"},
    title="Chunk Embeddings",
    width=800,
    height=800,
)

# Enable selection of points
fig.update_traces(marker=dict(size=5), selector=dict(mode="markers"))

fig_widget = FigureWidget(fig)

# Global variable to store selected indices
selected_indices = []


# Define a callback to capture selected points
def on_selection(trace, points, state):
    global selected_indices
    selected_indices = points.point_inds  # Store selected indices


# Attach the callback to the scatter trace
scatter_trace = fig_widget.data[0]
scatter_trace.on_selection(on_selection)

# Display the interactive plot
fig_widget

FigureWidget({
    'data': [{'hovertemplate': ('<b>%{hovertext}</b><br><br>x=%' ... '%{marker.color}<extra></extra>'),
              'hovertext': array(['[Omitted]', 'All goods [other than fresh or chilled] pre-packag',
                                  'Niobium, tantalum, vanadium or zirconium ores and ', ...,
                                  'Yacht and other vessels for pleasure or sports',
                                  'All goods other than those mentioned at S. Nos. 1 ',
                                  'Note: *Notification No. 21/2018 -Central Tax (Rate'], dtype=object),
              'legendgroup': '',
              'marker': {'color': {'bdata': ('AAAAABMAAAAWAAAAEwAAAAAAAAAAAA' ... 'AACwAAAAsAAAAbAAAAFQAAABoAAAA='),
                                   'dtype': 'i4'},
                         'coloraxis': 'coloraxis',
                         'size': 5,
                         'symbol': 'circle'},
              'mode': 'markers',
              'name': '',
              'sho

In [203]:
print_samples(df, selected_indices, n_samples=10)

Coffee roasted, whether or not decaffeinated coffee  husks and skins; coffee substitutes containing coffee in any proportion [other than coffee beans not roasted]
---------------------------------------
Soya beans, whether or not broken other than of seed quality.
---------------------------------------
Linseed, whether or not broken other than of seed quality.
---------------------------------------
Rape or colza seeds, whether or not broken other than of  seed quality.
---------------------------------------
Sunflower seeds, whether or not broken other than of seed  quality
---------------------------------------
Other oil seeds and oleaginous fruits (i.e. Palm nuts and kernels, cotton seeds, Castor oil seeds, Sesamum seeds,  Mustard seeds, Saffower (Carthamustinctorius) seeds, Melon seeds, Poppy seeds, Ajams, Mango kernel, Niger seed, Kokam) whether or not broken, other than of seed quality
---------------------------------------
Flour and meals of oil seeds or oleaginous fruits, ot

In [200]:
from decimal import Decimal


def decimal_serializer(obj):
    if isinstance(obj, Decimal):
        return float(obj)
    raise TypeError(f"Type {type(obj)} not serializable")

In [201]:
import json


def smart_sample_to_json(df, samples_per_cluster, output_file, as_jsonl=False):
    all_samples = []

    for cluster, group in df.groupby("cluster_tsne"):
        sample = group.sample(n=min(samples_per_cluster, len(group)), random_state=42)
        # Drop columns that start with 'useembed_'
        sample_filtered = sample.loc[:, ~sample.columns.str.startswith("useembed_")]
        all_samples.extend(sample_filtered.to_dict(orient="records"))

    if as_jsonl:
        with open(output_file, "w") as f:
            for item in all_samples:
                f.write(json.dumps(item, default=str) + "\n")
    else:
        with open(output_file, "w") as f:
            json.dump(all_samples, f, indent=4, default=str)

In [202]:
smart_sample_to_json(df, 6, output_file=FILENAME)