In [4]:
import numpy as np
import pandas as pd
#import matplotlib
import matplotlib.pyplot as plt
import os

# Ensure a non-interactive backend
#matplotlib.use('Agg')

script_dir = os.getcwd()

test_names = ['1 - 1d sde', '2 - 2d sde', '3 - 1d sde-cubic', '3 - 2d sde-cubic', '4 - 3d sde-LT', '4 - 3d sde-spd', '6 - SIRGillespie', '8 - nonGaussian']
#test_names = ['4 - 3d sde-LT']

for i in range(len(test_names)):
    test_name = test_names[i]

    # Load SGD
    file_path = f'example{test_name} data.csv' 
    data = pd.read_csv(file_path)
    
    SGD_time = data.iloc[:, 0]
    SGD_loss = data.iloc[:, 1]
    SGD_time_above_std = data.iloc[:, 2]
    SGD_time_below_std = data.iloc[:, 3]
    SGD_loss_above_std = data.iloc[:, 4]
    SGD_loss_below_std = data.iloc[:, 5]
    
    # Load ARFF
    with open(f'ex{test_name} data.txt', 'r') as file:
        line = file.readline().strip()
        ARFF_time, ARFF_loss, ARFF_time_above_std, ARFF_time_below_std, ARFF_loss_above_std, ARFF_loss_below_std = map(float, line.split(','))

    # Load ARFF wo resampling
    with open(f'ex{test_name} wo resampling data.txt', 'r') as file:
        line = file.readline().strip()
        ARFF_wo_time, ARFF_wo_loss, ARFF_wo_time_above_std, ARFF_wo_time_below_std, ARFF_wo_loss_above_std, ARFF_wo_loss_below_std = map(float, line.split(','))
    
    # Create the plot
    plt.figure(figsize=(8, 6))
    
    # TMLE
    if test_names[i] != '6 - SIRGillespie':
        with open(f'TMLE ex{test_name}.txt', 'r') as file:
            line = file.readline().strip()
            TMLE = float(line)
        plt.axhline(y=TMLE, color='green', linestyle='--', linewidth=1.5, label='Expected Min Loss')
    
    plt.errorbar(
        SGD_time, SGD_loss,
        xerr=[SGD_time_below_std, SGD_time_above_std], 
        yerr=[SGD_loss_below_std, SGD_loss_above_std],  
        fmt='o', color='blue', ecolor='blue', elinewidth=1.0, capsize=3, label='SGD'
    )
    
    plt.errorbar(
        ARFF_time, ARFF_loss,
        xerr=[[ARFF_time_below_std], [ARFF_time_above_std]], 
        yerr=[[ARFF_loss_below_std], [ARFF_loss_above_std]],  
        fmt='o', color='red', ecolor='red', elinewidth=1.0, capsize=3, label='ARFF with resampling'
    )

    plt.errorbar(
        ARFF_wo_time, ARFF_wo_loss,
        xerr=[[ARFF_wo_time_below_std], [ARFF_wo_time_above_std]], 
        yerr=[[ARFF_wo_loss_below_std], [ARFF_wo_loss_above_std]],  
        fmt='o', color='orange', ecolor='orange', elinewidth=1.0, capsize=3, label='ARFF'
    )
    
    # Add labels and title
    plt.xlabel('Time (s)')
    plt.ylabel('Validation Loss')
    #plt.title(f'Example {test_name}')
    plt.legend()
    
    # Show the plot
    plt.grid(True)
    plt.show()

    # output_file = os.path.join(script_dir, f'Example {test_name}')
    # plt.savefig(output_file, dpi=300, bbox_inches='tight')
    # plt.close()
    

