# Notebook to plot results from 'sound_classification.py'

## Import Libraries

In [1]:
# Import the necessary libraries
import os
import sys
import pandas as pd
import re
import numpy as np
import matplotlib.pyplot as plt
import git

# Get the root directory of the project
ROOT_DIR = git.Repo(".", search_parent_directories=True).working_tree_dir
sys.path.append(ROOT_DIR)

## Get the results filenames

In [None]:
# Function to parse information from the filename
def parse_filename(filename):
    base_filename = os.path.basename(filename)
    name_parts = base_filename.replace('.txt', '').split('_')
    embedding_name = name_parts[0]
    mode = name_parts[3]
    modality = name_parts[4]
    temperature = name_parts[5]
    return embedding_name, mode, modality, temperature

In [36]:
results_files = os.listdir(os.path.join(ROOT_DIR, "results"))
results_files = [f for f in results_files if f.endswith(".txt")]

In [37]:
# Define a regex pattern to extract Model, train_type, and accuracy
pattern = re.compile(r"Model=(.*), train_type=(.*), acc/mAP=(.*)%")

# Define a list to store the extracted data
data = []

# Loop over file paths and extract data
for filename in results_files:
    with open(os.path.join(ROOT_DIR, "results", filename), 'r') as f:
        content = f.read()
        match = pattern.search(content)
        if match:
            model = match.group(1)
            train_type = match.group(2)
            acc = float(match.group(3))
            embedding_name, mode, modality, temperature = parse_filename(filename)
            data.append([model, train_type, acc, embedding_name, modality, temperature])

# Create a DataFrame from the extracted data
df = pd.DataFrame(data, columns=['Model', 'Mode', 'Accuracy (%)', 'Embedding_Name', 'Modality', 'Temperature'])

In [None]:
df.head(5)

## Plot the results

In [66]:
def plot_best_accuracy(df, mode):
    # Filter the DataFrame by the given mode
    df_mode = df[df['Mode'] == mode]
    
    # Find the best accuracy for each embedding and modality (text, audio, none)
    best_df = df_mode.groupby(['Embedding_Name', 'Modality'])['Accuracy (%)'].max().unstack(fill_value=0)
    
    # Reorder columns if necessary (text, audio, none)
    best_df = best_df[['None', 'text', 'audio']]
    
    # Plot setup
    fig, ax = plt.subplots(figsize=(20, 12), dpi=500)
    
    # Define the colors
    colors = [
        (232/255, 74/255, 59/255),    # Red for text
        (34/255, 34/255, 84/255),     # Dark Blue for audio
        (246/255, 178/255, 79/255)    # Mustard Yellow for none
    ]
    
    # Plot the data as a grouped bar plot
    best_df.plot(kind='bar', ax=ax, width=0.91, color=colors, zorder=3)
    
    # Customize the plot
    plt.xticks(rotation=45, ha='right', fontsize=30)
    plt.ylabel('Accuracy (%)', fontsize=32)
    plt.xlabel('Embedding Name', fontsize=32)
    ax.set_xlim(-0.6, len(best_df) - 0.4)  # Reducing the white space around the bars
    
    # Adding values on top of each bar with correct color logic
    for i, p in enumerate(ax.patches):
        # Get the correct modality based on the column (None=Red, text=Dark Blue, audio=Mustard Yellow)
        if i < 7:
            modality_index = 0
        elif i < 15:
            modality_index = 1
        else:
            modality_index = 2
        
        # Only annotate if the bar height is non-zero (ignore zero-height bars)
        if p.get_height() > 0:
            color = colors[modality_index]
            ax.annotate(f'{p.get_height():.1f}', 
                        (p.get_x() + p.get_width() / 2 - 0.03, p.get_height() + 0.01), 
                        ha='center', va='bottom', fontsize=22, color=color, weight='bold')

    
    # Tight layout
    plt.tight_layout()

    # Increase font size of the y axis
    plt.yticks(fontsize=30)

    # Put the legend outside the plot
    plt.legend(loc='lower center', bbox_to_anchor=(0.5, 1.01), fontsize=32, ncol=3)

    # Add grid on the background
    plt.grid(axis='y', linestyle='--', alpha=0.5, zorder=0)

    # y-axis limits
    plt.ylim(40, 85)

    # Save the plot as a PDF file with bbox_inches to avoid cropping
    # plt.savefig("plot_output.pdf", format='pdf', bbox_inches='tight')

    # Display the plot
    plt.show()

In [None]:
# Plot ZS
plot_best_accuracy(df, mode='zs')

In [None]:
# Plot TGAP
plot_best_accuracy(df, mode='tgap')

In [None]:
# Plot SV
plot_best_accuracy(df, mode='sv')