In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import warnings

def calculate_and_smooth_errors(df, window_size):
    """
    Calculate and smooth absolute and relative errors.
    
    :param df: DataFrame containing the experiment data.
    :param window_size: The size of the rolling window to smooth the errors.
    :return: DataFrame with smoothed error columns added.
    """
    # Calculate errors
    df['absolute_error'] = (df['estimated_mi'] - df['exact_mi']).abs()
    df['relative_error'] = df['absolute_error'] / df['exact_mi'].abs()
    
    # Smooth errors
    df['smoothed_absolute_error'] = df.groupby('time_steps_ahead')['absolute_error'].transform(lambda x: x.rolling(window=window_size, min_periods=1).mean())
    df['smoothed_relative_error'] = df.groupby('time_steps_ahead')['relative_error'].transform(lambda x: x.rolling(window=window_size, min_periods=1).mean())
    
    return df

def plot_data(df, y_column, title, sample_size, dimension, burn_in=200, true_mi_lines=False):
    """
    General function to plot data from the DataFrame, skipping the first 'burn_in' values. Optionally plot horizontal
    lines for the true mutual information.
    
    :param df: DataFrame containing the filtered experiment data.
    :param y_column: Column name to be used as the y-axis in the plot.
    :param title: Title for the plot.
    :param sample_size: Sample size to display in the title.
    :param dimension: Dimension to display in the title.
    :param burn_in: Number of initial iterations to skip in the plot.
    :param true_mi_lines: Boolean, if True, plot horizontal lines for true mutual information.
    """
    colors = plt.cm.viridis(np.linspace(0, 1, df['time_steps_ahead'].nunique()))
    markers = ['o', 's', 'D', '^', 'v', '<', '>', 'p', '*', 'h', 'H', '+', 'x', 'X', 'd', '|', '_']
    
    plt.figure(figsize=(10, 4))
    for idx, (time_steps_ahead, group_df) in enumerate(df.groupby('time_steps_ahead')):
        # Skip the first 'burn_in' values
        group_df = group_df.iloc[burn_in:]
        color = colors[idx % len(colors)]
        plt.plot(group_df['iter'], group_df[y_column],
                 label=f'Estimated MI (delta t = {time_steps_ahead})',
                 marker=markers[idx % len(markers)], color=color, markersize=0.7)

        if true_mi_lines:
            # Draw horizontal lines for the true mutual information of each 'time_steps_ahead'
            true_mi = group_df['exact_mi'].iloc[0]  # Get the true MI for the current 'time_steps_ahead'
            plt.axhline(y=true_mi, color=color, linestyle='--', label = f'True MI (delta t = {time_steps_ahead})')

    plt.xlabel('Epochs')
    plt.ylabel(title)
    plt.yscale('log')
    plt.title(f"{title}\nSample Size: {sample_size}, Dimension: {dimension}")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, linestyle='--', linewidth=0.5)
    plt.tick_params(direction='in', top=True, right=True, which='both')
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plt.show()

def plot_normalized_mi_last_iter(df, title, sample_size, dimension):
    """
    Plot both true and estimated mutual information for the last iteration, 
    with y-axis normalized to one and x-axis as time steps ahead.

    :param df: DataFrame containing the experiment data.
    :param title: Title for the plot.
    :param sample_size: Sample size to display in the title.
    :param dimension: Dimension to display in the title.
    """
    # Select only the last iteration for each time step
    df_last_iter = df.groupby('time_steps_ahead').last().reset_index()

    plt.figure(figsize=(10, 4))
    plt.scatter(df_last_iter['time_steps_ahead'], df_last_iter['exact_mi']/df_last_iter['estimated_mi'].max(), 
             label='True Mutual Information', marker='o')
    plt.scatter(df_last_iter['time_steps_ahead'], df_last_iter['estimated_mi']/df_last_iter['estimated_mi'].max(), 
             label='Estimated Mutual Information', marker='x')

    plt.xlabel('Time Steps Ahead')
    plt.ylabel('Normalized Mutual Information')
    plt.title(f"{title}\nSample Size: {sample_size}, Dimension: {dimension}")
    plt.legend()
    plt.grid(True, linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.show()

def main():

    ##################################################################################
    df = pd.read_csv("2023-11-24-2_experiment_log.csv")
    window_size = 50
    burn_in = 200
    ##################################################################################

    if not burn_in > window_size:
        warnings.warn("Consider burn_in > window size to avoid artifacts of windowing in plot.")    

    # Calculate and smooth the errors
    df = calculate_and_smooth_errors(df, window_size)

    # Smooth the mutual information column
    df['smoothed_estimated_mi'] = df.groupby('time_steps_ahead')['estimated_mi'].transform(
        lambda x: x.rolling(window=window_size, min_periods=1).mean()
    )

    unique_combinations = df[['sample_size', 'dimension']].drop_duplicates()

    for _, row in unique_combinations.iterrows():
        sample_size_filter = row['sample_size']
        dimension_filter = row['dimension']

        # Filter the DataFrame based on sample size and dimension. 
        filtered_df = df.query('sample_size == @sample_size_filter and dimension == @dimension_filter')

        # Optionally filter the DataFrame for specific "time steps ahead" for cleaner plot.
        # filtered_df = filtered_df.query('time_steps_ahead in [0, 3, 9]')

        # Plot the smoothed mutual information with true MI horizontal lines
        plot_data(filtered_df, 'smoothed_estimated_mi', 'Smoothed Estimated Mutual Information', sample_size_filter, dimension_filter, burn_in=burn_in, true_mi_lines=True)

        # Plot the smoothed errors without true MI horizontal lines
        # plot_data(filtered_df, 'smoothed_absolute_error', 'Smoothed Absolute Error in Mutual Information', sample_size_filter, dimension_filter)
        plot_data(filtered_df, 'smoothed_relative_error', 'Smoothed Relative Error in Mutual Information', sample_size_filter, dimension_filter)
    
        # Plot the normalized mutual information for the last iteration
        plot_normalized_mi_last_iter(filtered_df, 'Normalized Mutual Information at Last Iteration', 
                                 sample_size_filter, dimension_filter)
    
if __name__ == '__main__':
    main()