In [None]:
# plot training loss
# first read in the training loss from the log file
# then pick the training loss terms (note that I want to loss of each environment):
# eg: Loss of env 0: 0.350871
# Loss of env 1: 0.91359204
# Loss of env 2: 0.86774445
# Loss of env 3: 0.38117403
# then draw the training loss, where the x-axis is the epoch number and the y-axis is the training loss
# each log file corresponds to three runs. Calculate the average training loss for each environment across the three runs.

In [None]:
import re
import matplotlib.pyplot as plt
import numpy as np
def read_log_file(file_path):
    with open(file_path, 'r') as file:
        content = file.readlines()
    
    loss_data = {}
    for line in content:
        match = re.search(r'Loss of env (\d+): ([\d.]+)', line)
        if match:
            env_id = int(match.group(1))
            loss_value = float(match.group(2))
            if env_id not in loss_data:
                loss_data[env_id] = []
            loss_data[env_id].append(loss_value)
    
    return loss_data

# print out the loss data for checking
def plot_loss(loss_data, title, save=False, save_path=None):
    env_ids = sorted(loss_data.keys())
    epochs = np.arange(len(loss_data[env_ids[0]]))
    
    plt.figure(figsize=(10, 6))
    
    # For each environment, split losses into 3 runs, compute mean and std, and plot with error bars
    num_runs = 3
    env_ids = sorted(loss_data.keys())
    min_len = min(len(loss_data[env_id]) for env_id in env_ids)
    run_len = min_len // num_runs

    for env_id in env_ids:
        losses = loss_data[env_id][:num_runs * run_len]  # ensure divisible
        runs = np.array(losses).reshape(num_runs, run_len)
        mean_loss = runs.mean(axis=0)
        std_loss = runs.std(axis=0)
        epochs = np.arange(run_len)
        plt.plot(epochs, mean_loss, label=f'Env {env_id} Mean')
        plt.fill_between(epochs, mean_loss - std_loss, mean_loss + std_loss, alpha=0.2)

    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title(title)
    plt.legend()
    plt.grid()
    
    if save and save_path:
        plt.savefig(save_path)
        print(f"Plot saved to {save_path}")
    
    plt.show()

In [None]:
def plot_from_one_file(log_file, title, save=False):
    all_loss_data = {}

    # log_file = '/root/AIGS_pro/src/logs/FourEnv/CMNIST/ERM_log.txt'
    loss_data = read_log_file(log_file)
    for env_id, losses in loss_data.items():
        if env_id not in all_loss_data:
            all_loss_data[env_id] = []
        all_loss_data[env_id].extend(losses)

    # print("Loss data from all runs:")
    # for env_id, losses in all_loss_data.items():
    #     print(f'Env {env_id}: {losses}')
    #     print("length of losses:", len(losses))

    save_folder = '../../images'
    save_name = log_file.split('/')[-3] + '_' + log_file.split('/')[-2] + '_' + log_file.split('/')[-1].replace('_log.txt', '.png')
    save_path = f"{save_folder}/{save_name}"
    
    plot_loss(all_loss_data, title, save=save, save_path=save_path)

In [None]:
log_files = [
    '../logs/FourEnv/CMNIST/ERM_log.txt',
    '../logs/FourEnv/CMNIST/CategoryReweightedERM_log.txt',
    '../logs/FourEnv/CMNIST/groupDRO_log.txt',
    '../logs/FourEnv/CMNIST/IRM_log.txt',
    '../logs/FourEnv/SyntheticFolktables/CategoryReweightedERM_log.txt',
    '../logs/FourEnv/SyntheticFolktables/ERM_log.txt',
    '../logs/FourEnv/SyntheticFolktables/IRM_log.txt',
    '../logs/FourEnv/SyntheticFolktables/groupDRO_log.txt',
    '../logs/FourEnv/SyntheticFolktables/InvRat_log.txt',
]
titles = [
    'ERM on CMNIST (4 environments)',
    'CategoryReweightedERM on CMNIST (4 environments)',
    'groupDRO on CMNIST (4 environments)',
    'IRM on CMNIST (4 environments)',
    'CategoryReweightedERM on Synthetic Folktables (4 environments)',
    'ERM on Synthetic Folktables (4 environments)',
    'IRM on Synthetic Folktables (4 environments)',
    'groupDRO on Synthetic Folktables (4 environments)',
    'InvRat on Synthetic Folktables (4 environments)',
]

for log_file, title in zip(log_files, titles):
    print(f'Processing log file: {log_file}')
    plot_from_one_file(log_file, title, save=True)
    print(f'Finished processing log file: {log_file}\n')

In [None]:
log_files = [
    '../logs/TwoEnv/CMNIST/ERM_log.txt',
    '../logs/TwoEnv/CMNIST/ChiSquareDRO_log.txt',
    '../logs/TwoEnv/CMNIST/IRM_log.txt',
    '../logs/TwoEnv/CMNIST/InvRat_log.txt',
    '../logs/TwoEnv/CMNIST/REx_log.txt',
    '../logs/TwoEnv/CMNIST/groupDRO_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/ERM_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/ChiSquareDRO_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/IRM_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/InvRat_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/REx_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/groupDRO_log.txt',
]
titles = [
    'ERM on CMNIST (2 environments)',
    'ChiSquareDRO on CMNIST (2 environments)',
    'IRM on CMNIST (2 environments)',
    'InvRat on CMNIST (2 environments)',
    'REx on CMNIST (2 environments)',
    'groupDRO on CMNIST (2 environments)',
    'ERM on Synthetic Folktables (2 environments)',
    'ChiSquareDRO on Synthetic Folktables (2 environments)',
    'IRM on Synthetic Folktables (2 environments)',
    'InvRat on Synthetic Folktables (2 environments)',
    'REx on Synthetic Folktables (2 environments)',
    'groupDRO on Synthetic Folktables (2 environments)',
]

for log_file, title in zip(log_files, titles):
    print(f'Processing log file: {log_file}')
    plot_from_one_file(log_file, title, save=True)
    print(f'Finished processing log file: {log_file}\n')

In [None]:
# help me plot six saved graphs into one figure
def plot_multiple_graphs(log_files, save_name=None):
    figure_file_paths = []
    for log_file in log_files:
        figure_folder = '../../images'
        figure_name = log_file.split('/')[-3] + '_' + log_file.split('/')[-2] + '_' + log_file.split('/')[-1].replace('_log.txt', '.png')
        figure_file_path = f"{figure_folder}/{figure_name}"
        figure_file_paths.append(figure_file_path)
        
    plt.figure(figsize=(15, 10))
    num_graphs = len(figure_file_paths)
    for i, figure_file_path in enumerate(figure_file_paths):
        plt.subplot(2, 3, i + 1)
        img = plt.imread(figure_file_path)
        plt.imshow(img)
        plt.axis('off')
        plt.title(log_files[i].split('/')[-1].replace('_log.txt', ''))
    plt.tight_layout()
    if save_name:
        save_folder = '../../images'
        save_path = f"{save_folder}/{save_name}"
        plt.savefig(save_path)
        print(f"Combined plot saved to {save_path}")
    
log_files = [
    '../logs/TwoEnv/CMNIST/ERM_log.txt',
    '../logs/TwoEnv/CMNIST/ChiSquareDRO_log.txt',
    '../logs/TwoEnv/CMNIST/IRM_log.txt',
    '../logs/TwoEnv/CMNIST/InvRat_log.txt',
    '../logs/TwoEnv/CMNIST/REx_log.txt',
    '../logs/TwoEnv/CMNIST/groupDRO_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/ERM_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/ChiSquareDRO_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/IRM_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/InvRat_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/REx_log.txt',
    '../logs/TwoEnv/SyntheticFolktables/groupDRO_log.txt',
]

plot_multiple_graphs(log_files[:6], save_name='CMNIST_combined.png')
plot_multiple_graphs(log_files[6:], save_name='SyntheticFolktables_combined.png')