<a href="https://colab.research.google.com/github/joaosMart/fish-species-class-siglip/blob/update-readme-comprehensive/Code/extra-analysis/Learning_Curves.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

def plot_model_curves(data, model_name):
    """
    Create a learning curve plot comparing the three methods for a single model.
    """
    plt.figure(figsize=(12, 9))
    ax = plt.gca()

    # Set font sizes
    SMALL_SIZE = 16
    MEDIUM_SIZE = 18
    LARGE_SIZE = 16

    # Use the font sizes
    plt.rc('font', size=MEDIUM_SIZE)
    plt.rc('axes', titlesize=LARGE_SIZE)
    plt.rc('axes', labelsize=MEDIUM_SIZE)
    plt.rc('xtick', labelsize=30)
    plt.rc('ytick', labelsize=30)
    plt.rc('legend', fontsize=MEDIUM_SIZE)



    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)

    # Set white background
    plt.style.use('default')

    # Define colors and labels for each method
    method_configs = {
        'averaged_features': {
            'color': 'blue',
            'label': 'Temporal Pooling'
        },
        'center_frame': {
            'color': 'red',
            'label': 'Central Frame'
        },
        'temporal_voting': {
            'color': 'green',
            'label': 'Temporal Voting'
        },
        'ResNet50': {
            'color': 'purple',
            'label': 'ResNet50 (Feature\nExtraction)'
        },
        'ResNet50-finetuned': {
            'color': 'orange',
            'label': 'ResNet50 (Fine-tuned)'
        }

    }

    # Plot each method
    for method, config in method_configs.items():
        if method in data and model_name in data[method]:
            plt.plot(data[method][model_name]["train_sizes"],
                    data[method][model_name]["test_scores_mean"],
                    '-', color=config['color'],
                    label=config['label'],
                    linewidth=2)

            plt.fill_between(data[method][model_name]["train_sizes"],
                            np.array(data[method][model_name]["test_scores_mean"]) -
                            np.array(data[method][model_name]["test_scores_std"]),
                            np.array(data[method][model_name]["test_scores_mean"]) +
                            np.array(data[method][model_name]["test_scores_std"]),
                            alpha=0.15,
                            color=config['color'])

    # Add total samples text box
    plt.text(0.02, 0.98, f'Total samples: {int(data["averaged_features"][model_name]["train_sizes"][-1])}',
             transform=plt.gca().transAxes,
             bbox=dict(facecolor='white', edgecolor='black', boxstyle='round'),
             verticalalignment='top',
             fontsize=20)

    # Customize plot
    plt.grid(True)
    plt.xlabel('Number of Training Samples', fontsize=25)
    plt.ylabel('Macro F1 Score', fontsize=25)

    # Set axis limits
    plt.ylim(0.84, 0.99)
    plt.xlim(0, 3200)
    x_range = plt.xlim()

    # Format axis ticks
    y_ticks = np.arange(0.83, 0.98, 0.01)
    plt.xticks(np.arange(0, 3193, 500), fontsize = 20)
    plt.yticks(y_ticks, fontsize = 20)

    # Move legend outside plot and show all entries
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles, labels,
              bbox_to_anchor=(0.53, 0.18),
              loc='center left',
              fontsize=22)

    # Adjust layout
    plt.tight_layout()

    return plt

def main():
    # Read data from JSON file
    with open('path/to/learning-curve-data.json', # Create a json file with the learning curves out of all the evaluation runs in Evaluation.ipnyb
              'r') as file:
        data = json.load(file)

    # Create plots for each model
    for model in ['SVM', 'LogisticRegression']:
        plot = plot_model_curves(data, model)
        plot.savefig(f'learning_curves_{model.lower()}_comparison.png',
                    bbox_inches='tight',
                    dpi=300,
                    facecolor='white',
                    edgecolor='none')
        plt.close()

if __name__ == "__main__":
    main()