## Purpose

The purpose of this code is to plot the embeddings 

Input(s):
1. Path to the 2-D dimensionality reduced embeddings data in the form [identification_number, x, y]
2. Path to the corresponding annotation data (e.g., if your embeddings are for 2 seconds of audio, the path to the 2 second audio annotations.)

Output(s):
1. Scatter plot graph color-coded to the classes listed in the annotation file
    
We will being by importing standard python libraries and setting the paths to the embedding and annotation csvs.

In [None]:
#Python Libraries
import pandas as pd
import plotly.graph_objs as go
from ipywidgets import VBox
import random

#User-entered data files
embedding_csv_path = #['INSERT STRING OF ABSOLUTE PATH TO EMBEDDING CSV HERE']
annotation_csv_path = #['INSERT STRING OF ABSOLUTE PATH TO ANNOTATION CSV HERE']

Now we will read in the 2-d embedding and annotation datasets.

In [None]:
embedding_df = pd.read_csv(embedding_csv_path)
annotation_df = pd.read_csv(annotation_csv_path)

Next we merge the two dataframes and drop the second column used to merge on.

In [None]:
# Merge the two dataframes based on the vggish_point and identification_number column
merged = pd.merge(embedding_df, annotation_df, how='left', left_on='identification_number', right_on='vggish_point')
merged_df = merged.drop(columns=['vggish_point'], axis = 1)
merged_df.head()

More datapoints take more time to plot. Here we check how much data will go into the graph. Feel free to limit it as desired.

In [None]:
merged_df['label'].isna().sum()

Finally, we set colors for each of the classes and generate the final scatterplot.

In [None]:
# Function to generate a list of random bright colors
def generate_bright_colors(num_colors):
    # Define a list of bright colors
    bright_colors = ["#FF5733", "#FFBD33", "#33FF57", "#336DFF", "#FF33D5", "#FF3333", "#33FFF6", "#B233FF", "#FFBE33", "#FF33A7", "#B3FF33", "#5733FF", "#33FFD4", "#B23333", "#33B3FF", "#FF8A33", "#6D33FF", "#33FFB6", "#FF336D", "#FFC833", "#33FF7F", "#33A7FF", "#FF33B6", "#33FF8A", "#33FF33", "#33B6FF", "#D533FF", "#33FF33", "#8A33FF", "#FF338A"]
    # Shuffle the list of bright colors to ensure each category gets a distinct color
    random.shuffle(bright_colors)
    return bright_colors[:num_colors]

# Get unique sounds and assign bright colors
unique_sounds = merged_df['label'].unique()
num_unique_sounds = len(unique_sounds)
bright_colors = generate_bright_colors(num_unique_sounds)

# Map sounds to their respective bright colors
sound_colors = dict(zip(unique_sounds, bright_colors))

# Replace NaN values with black
colors = merged_df['label'].map(sound_colors)
colors[merged_df['label'].isna()] = 'grey'

# Define size and opacity for each color
sizes = [2 if c == 'grey' else 5 for c in colors]
opacities = [0.4 if c == 'grey' else 1 for c in colors]

# Create legend entries
legend_entries = []
for sound, color in sound_colors.items():
    legend_entries.append(go.Scatter(
        x=[None],
        y=[None],
        mode='markers',
        marker=dict(color=color),
        showlegend=True,
        name=sound
    ))

# Create main figure with scatter plot and legend entries
f = go.FigureWidget([go.Scatter(
    y=merged_df['y'],
    x=merged_df['x'],
    mode='markers',
    marker=dict(size=sizes, color=colors, opacity=opacities, line=dict(color='rgba(0,0,0,0)')),
    showlegend=False  # Don't show the legend for this trace
)] + legend_entries)  # Include legend entries in the main figure

# Update layout to display legend
f.update_layout(
    dragmode='lasso',
    autosize=False,
    width=1200,
    height=800,
    legend=dict(
        title='Legend',
        title_font_size=16,
        font=dict(size=12),  # Set the font size of the legend text
        traceorder='normal'  # Set the order of legend items
    )
)

def update_axes(xaxis, yaxis):
    scatter = f.data[0]
    scatter.x = merged_df[xaxis]
    scatter.y = merged_df[yaxis]
    with f.batch_update():
        f.layout.xaxis.title = xaxis
        f.layout.yaxis.title = yaxis

t = go.FigureWidget([go.Table(
    header=dict(values=['identification_number', 'label'],
                fill=dict(color='#C2D4FF'),
                align=['left'] * 5),
    cells=dict(values=[merged_df[col] for col in ['identification_number', 'label']],
               fill=dict(color='#F5F8FF'),
               align=['left'] * 5))])

def selection_fn(trace, points, selector):
    t.data[0].cells.values = [merged_df.loc[points.point_inds][col] for col in ['identification_number', 'label']]

f.data[0].on_selection(selection_fn)

# Display plot and table together
VBox((f, t))