In [None]:
import torch
import scipy.stats as stats
import numpy as np
import matplotlib.pyplot as plt
import scipy.linalg as LA
import scipy.io
import seaborn as sns
import pandas as pd
import scienceplots
plt.style.use(['science', 'ieee'])
plt.rcParams.update({"font.family": "sans", "font.serif": [
                    "cm"], "mathtext.fontset": "cm", "font.size": 22})

In [None]:
# %% compare the losses of LISTA
train_loss_list_ISTA = np.load('./data/train_loss_list_ISTA.npy')
valid_loss_list_ISTA = np.load('./data/valid_loss_list_ISTA.npy')
# TODO: Get the correct time_list_ for ISTA weights
time_list_ISTA = np.load('./data/time_list_constant.npy')

train_loss_list_gauss = np.load('./data/train_loss_list_gauss.npy')
valid_loss_list_gauss = np.load('./data/valid_loss_list_gauss.npy')
time_list_gauss = np.load('./data/time_list_gauss.npy')

In [None]:
train_loss_list_gauss.shape, valid_loss_list_gauss.shape, time_list_gauss.shape

In [None]:
train_loss_list_ISTA.shape, valid_loss_list_ISTA.shape, time_list_ISTA.shape

In [None]:
time_list_ISTA[-1], time_list_gauss[-1]

In [None]:
def plot_losses_combined(time_ISTA, train_loss_ISTA, valid_loss_ISTA,
                         time_gauss, train_loss_gauss, valid_loss_gauss,
                         filename_name):
    # Set up the matplotlib figure
    fig, ax = plt.subplots(figsize=(12, 6))

    # Plot - Training and Validation Loss for ISTA and Gauss
    ax.plot(time_ISTA, np.log10(1 + train_loss_ISTA), '-r^', label='training loss ISTA')
    ax.plot(time_ISTA, np.log10(1 + valid_loss_ISTA), '--ro', label='validation loss ISTA')
    ax.plot(time_gauss, np.log10(1 + train_loss_gauss), '-b+', label='training loss gauss')
    ax.plot(time_gauss, np.log10(1 + valid_loss_gauss), '--bo', label='validation loss gauss')

    # Set titles and labels
    # ax.set_title('Training and Validation Loss')
    ax.set_xlabel('TIME (SECONDS)')
    ax.set_ylabel('LOG RECONSTRUCTION LOSS')
    ax.legend()
    ax.grid(True)

    # Adjust the layout and save the figure
    plt.tight_layout()
    loss_path = filename_name
    plt.savefig(loss_path, format='pdf', bbox_inches='tight')
    # plt.close(fig)
    
    return loss_path

# Call the function with sample data
combined_loss_path = plot_losses_combined(
    time_list_ISTA[:20], train_loss_list_ISTA[:20], valid_loss_list_ISTA[:20],
    time_list_gauss[:20], train_loss_list_gauss[:20], valid_loss_list_gauss[:20],
    './figures/07_convergence/convergence_plot.pdf')

# Provide the path for downloading the generated PDF
combined_loss_path


In [None]:
def plot_losses_side_by_side(time_ISTA, train_loss_ISTA, valid_loss_ISTA,
                             time_gauss, train_loss_gauss, valid_loss_gauss, filename_prefix):
    # Set up the matplotlib figure
    fig, axs = plt.subplots(1, 2, figsize=(15, 5))

    # Left plot - Training Loss
    axs[0].plot(time_ISTA, np.log10(1 + train_loss_ISTA), '-r^', label='training loss ISTA')
    axs[0].plot(time_gauss, np.log10(1 + train_loss_gauss), '-b+', label='training loss gauss')
    axs[0].set_title('Training Loss')
    axs[0].set_xlabel('TIME (SECONDS)')
    axs[0].set_ylabel('RECONSTRUCTION LOSS')
    axs[0].legend()
    axs[0].grid(True)

    # Right plot - Validation Loss
    axs[1].plot(time_ISTA, np.log10(1 + valid_loss_ISTA), '-r^', label='validation loss ISTA')
    axs[1].plot(time_gauss, np.log10(1 + valid_loss_gauss), '-b+', label='validation loss gauss')
    axs[1].set_title('Validation Loss')
    axs[1].set_xlabel('TIME (SECONDS)')
    axs[1].set_ylabel('RECONSTRUCTION LOSS')
    axs[1].legend()
    axs[1].grid(True)

    # Adjust the layout and save the figures
    plt.tight_layout()
    train_loss_path = f'./figures/07_convergence/train_loss.pdf'
    valid_loss_path = f'./figures/07_convergence/valid_loss.pdf'
    plt.savefig(train_loss_path, format='pdf', bbox_inches='tight')
    # plt.close(fig)
    
    return train_loss_path, valid_loss_path

# Call the function with sample data to generate the plots
train_loss_path, valid_loss_path = plot_losses_side_by_side(
    time_list_ISTA[:20], train_loss_list_ISTA[:20], valid_loss_list_ISTA[:20],
    time_list_gauss[:20], train_loss_list_gauss[:20], valid_loss_list_gauss[:20],
    'gauss_vs_ISTA')

# Provide the paths for downloading the generated PDFs
train_loss_path, valid_loss_path
