## Visualizing Results

In [6]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import re
from typing import Dict, List, Tuple, Set, Optional

In [7]:
pwd

'c:\\Users\\super\\Downloads\\School Work\\MUSA 6950\\Final\\musa-6950-final'

In [8]:
def process_csv_files(folder_path: str):
    """Process all CSV files in folder and extract summary data."""
    total_counts = {}
    averages = {}
    classes = set()
    pattern = re.compile(r'ep(\d+)_summary\.csv')

    for filename in os.listdir(folder_path):
        match = pattern.match(filename)
        if not match:
            continue

        try:
            ep_num = int(match.group(1))
            filepath = os.path.join(folder_path, filename)
            df = pd.read_csv(filepath)

            if 'total_count' not in df.columns or 'average_per_image' not in df.columns:
                continue

            for _, row in df.iterrows():
                class_name = row.iloc[0]
                classes.add(class_name)
                
                total_counts.setdefault(class_name, []).append((ep_num, row['total_count']))
                averages.setdefault(class_name, []).append((ep_num, row['average_per_image']))

        except Exception as e:
            print(f"Error processing {filename}: {e}")

    return total_counts, averages, classes

In [9]:
def prepare_plot_data(total_counts: Dict[str, List[Tuple[int, float]]],
                     averages: Dict[str, List[Tuple[int, float]]]):
    """Sort and organize data for plotting."""
    sorted_classes = sorted(total_counts.keys())
    all_episodes = set()
    
    # Sort data by episode number and collect all episodes
    for class_name in sorted_classes:
        total_counts[class_name].sort(key=lambda x: x[0])
        averages[class_name].sort(key=lambda x: x[0])
        all_episodes.update(ep for ep, _ in total_counts[class_name])
    
    return sorted(all_episodes), total_counts, averages

In [10]:
def create_subplot(ax: plt.Axes, 
                  data: Dict[str, List[Tuple[int, float]]],
                  episodes: List[int],
                  ylabel: str,
                  title: str,
                  custom_labels: Optional[Dict[int, str]] = None):
    """Create a single subplot with custom x-axis labels."""
    for class_name in sorted(data.keys()):
        x, y = zip(*data[class_name])
        ax.plot(x, y, label=class_name, marker='o')
    
    if custom_labels:
        ax.set_xticks(episodes)
        ax.set_xticklabels([custom_labels.get(ep, str(ep)) for ep in episodes])
    else:
        ax.set_xlabel('Episode Number')
    
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.grid(True)

In [11]:
def plot_summary_data(folder_path: str, 
                    custom_labels: Optional[Dict[int, str]] = None):
    """Main function to process data and create plots."""
    # Process and prepare data
    total_counts, averages, classes = process_csv_files(folder_path)
    if not classes:
        print("No valid summary CSV files found.")
        return
    
    episodes, total_counts, averages = prepare_plot_data(total_counts, averages)
    
    # Create figure and subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Create plots
    create_subplot(ax1, total_counts, episodes, 'Total Count', 'Total Count of Each Class', custom_labels)
    create_subplot(ax2, averages, episodes, 'Average per Image', 'Average Occurrence per Image', custom_labels)
    
    # Add unified legend
    handles, labels = ax1.get_legend_handles_labels()
    fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.15, 0.5))
    
    plt.tight_layout()
    plt.show()



In [12]:

custom_labels = {
    1: "Pilot",
    2: "Setup",
    3: "Conflict",
    4: "Resolution"
}

plot_summary_data(folder_path = "youtube_frames",custom_labels=custom_labels)

FileNotFoundError: [WinError 3] The system cannot find the path specified: 'youtube_frames'