In [1]:
# !pip install pyarrow umap-learn seaborn altair
import os
import pandas as pd
import umap.umap_ as umap
import matplotlib.pyplot as plt
import seaborn as sns
import altair as alt

In [2]:
# embeddings data
data_dir = ".."

# actual data repo
data_repo = "../../../data"

In [3]:
with open("prefixes.txt", "r") as f:
    prefixes = [line.rstrip() for line in f]

df_groups = pd.read_csv("class_group.csv")
df_groups.head()

df_groups.set_index("class", inplace=True)
class_to_group = df_groups.to_dict()['group']

In [4]:
reducer = umap.UMAP(
    n_neighbors=120,
    n_components=2,
    metric='euclidean',
    min_dist=5.5,
    spread=6.5,
    learning_rate=1.0,
    n_epochs=200,
    init='spectral',
    random_state=29,
)

In [5]:
def get_embeddings(parquet_file, reducer):
    df = pd.read_parquet(os.path.join(data_dir, parquet_file))
    embedding_cols = df.drop('file_name', axis=1)
    umap = reducer.fit_transform(embedding_cols)
    df_umap = pd.DataFrame(umap, columns=['UMAP_1', 'UMAP_2'])
    df_umap['file_name'] = df['file_name'].str.split('.').str[0]
    df_umap['class'] = df_umap['file_name'].apply(lambda x: next((pre for pre in prefixes if x.startswith(pre)), None))
    return df_umap

In [6]:
 # df_umap.to_csv("umap.csv")

In [7]:
# df_umap_image = get_embeddings("image_embeddings.parquet", reducer)
# df_umap_image.to_csv("umap_image.csv")

In [8]:
# df_umap_text = get_embeddings("text_embeddings.parquet", reducer)
# df_umap_text.to_csv("umap_text.csv")

In [9]:
# df_umap_text = get_embeddings("text_0_2_4_embeddings.parquet", reducer)
# df_umap_text.to_csv("umap_text_0_2_4.csv")

In [10]:
# df_umap_text = get_embeddings("text_0_2_4_llm_fs_single_embeddings.parquet", reducer)
# df_umap_text.to_csv("umap_text_0_2_4_llm_fs_single.csv")

In [11]:
# df_umap_spec_freq = get_embeddings("spec_frequency.parquet", reducer)
# df_umap_spec_freq.to_csv("umap_spec_freq.csv")

In [12]:
# df_umap_spec_oh = get_embeddings("spec_onehot.parquet", reducer)
# df_umap_spec_oh.to_csv("umap_spec_oh.csv")

In [13]:
df_umap_image = pd.read_csv("umap_image.csv", index_col=0)
df_umap_text = pd.read_csv("umap_text_0_2_2.csv", index_col=0)
df_umap_text_llm = pd.read_csv("umap_text_0_2_4_llm_fs_single.csv", index_col=0)
df_umap_spec_freq = pd.read_csv("umap_spec_freq.csv", index_col=0)
df_umap_spec_oh = pd.read_csv("umap_spec_oh.csv", index_col=0)

In [14]:
df_umap_image['group'] = df_umap_image['class'].map(class_to_group)
df_umap_text['group'] = df_umap_text['class'].map(class_to_group)
df_umap_text_llm['group'] = df_umap_text_llm['class'].map(class_to_group)
df_umap_spec_freq['group'] = df_umap_spec_freq['class'].map(class_to_group)
df_umap_spec_oh['group'] = df_umap_spec_oh['class'].map(class_to_group)

In [15]:
index_single = [f.split(".png")[0] for f in os.listdir(os.path.join(data_repo, "indexed", "single_chart"))]
index_multiple = [f.split(".png")[0] for f in os.listdir(os.path.join(data_repo, "indexed", "multiple_chart/imgs"))]

df_umap_image['indexed'] = df_umap_image['file_name'].apply(lambda x: 1 if x in index_single else 2 if x in index_multiple else 0)
df_umap_text['indexed'] = df_umap_text['file_name'].apply(lambda x: 1 if x in index_single else 2 if x in index_multiple else 0)
df_umap_text_llm['indexed'] = df_umap_text_llm['file_name'].apply(lambda x: 1 if x in index_single else 2 if x in index_multiple else 0)
df_umap_spec_freq['indexed'] = df_umap_spec_freq['file_name'].apply(lambda x: 1 if x in index_single else 2 if x in index_multiple else 0)
df_umap_spec_oh['indexed'] = df_umap_spec_oh['file_name'].apply(lambda x: 1 if x in index_single else 2 if x in index_multiple else 0)

In [16]:
g1 = alt.Chart(df_umap_image).mark_point(filled=True).encode(
    x='UMAP_1',
    y='UMAP_2',
    color='class',
    tooltip=['class']
).properties(
    width=800,
    height=800
).interactive()

g1

In [17]:
g2 = alt.Chart(df_umap_image).mark_point(filled=True).encode(
    x='UMAP_1',
    y='UMAP_2',
    color='group',
    tooltip=['group']
).properties(
    width=800,
    height=800
).interactive()

g2

In [18]:
domain = [0, 1, 2]
range_ = ['lightgray', 'seagreen', 'blue']
range2_ = ['circle', 'square', 'cross']

alt.Chart(df_umap_image).mark_point(filled=True).encode(
    x='UMAP_1',
    y='UMAP_2',
    color=alt.Color('indexed:N').scale(domain=domain, range=range_),
    shape=alt.Shape('indexed:N').scale(domain=domain, range=range2_),
    tooltip=['class']
).properties(
    width=800,
    height=800
).interactive()