In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

from nb_utils import read_params, filter_paths, sweepable_params, parse_logs, plot_measure, visu_diff

user = os.environ['USER']

## Minimal

This code plots logs from a folder created with clutils (see [README](../../clutils/README.md)).


In [None]:
jobname = "minimal"
all_exp_dir = f"/checkpoint/{user}/2024_logs/{jobname}"
filters = {
    # "scaling_w": [0.4],
}

# load params and filter
params = read_params(os.path.join(all_exp_dir, 'params.txt'))
paths = filter_paths(params, filters)

# print sweepable parameters
sweepable = sweepable_params(params)
for k, v in sweepable.items():
    print(f"{k}: {v}")

# load dataframes
dfs = {}
list_scaling = []
for ii, path in enumerate(paths):
    exp_dir = os.path.join(all_exp_dir, path)
    param = params[path]
    try:
        logs = parse_logs(os.path.join(exp_dir, 'log.txt'), '')
        dfs[path] = pd.DataFrame.from_dict(logs).transpose()
    except:
        print(f"{path} is not found")
paths = dfs.keys()

colors = plt.cm.tab20.colors  # 20 different colors

def plot_measure(measures, title=None, max_epoch=None):
    if len(measures) > 8:
        ncols = 8
        nrows = (len(measures) + 7) // 8  # Calculate the number of rows needed
    else:
        ncols = len(measures)
        nrows = 1
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(3*ncols, 3*nrows))
    axes = axes.flatten()  # Flatten the axes array for easier indexing
    for ii, path in enumerate(paths):
        df = dfs[path]
        param = params[path]
        for jj, measure in enumerate(measures):
            try:
                ax = axes[jj]
            except IndexError:
                continue  # Skip if there are more measures than axes available
            df = df.dropna(subset=[measure])
            # df = remove_outliers(df, measure)
            ax.plot(df['epoch'][:max_epoch], df[measure][:max_epoch], 
                    label=f"{path.replace('_', '')}" ,
                    # label=
                    #     f"lambda_i:{param['lambda_i']}, " +
                    #     f"lambda_d:{param['lambda_d']}, " +
                    #     f"embedder:{param['embedder_model']}, " +
                    #     f"scaling_w:{param['scaling_w']}",
                    color=colors[ii],
                    )
                    # label=f"{path.replace('_', '')}")
            ylabel = measure.replace('_', ' ').capitalize()
            if 'acc' in ylabel: 
                ax.set_ylim(0.5, 1.01)
            ax.set_ylabel(ylabel)
            xlabel = 'Epoch'
            ax.set_xlabel(xlabel)
            ax.grid(True)
    if title is not None:
        plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    ax.legend(loc='upper right', bbox_to_anchor=(1.04, -0.1), fontsize=16, ncol=1, fancybox=True, shadow=True)
    plt.show()

measures = ['train_total_loss', 'train_loss_decode', 'train_psnr', 'train_bit_acc']
plot_measure(measures)

measures = ['val_bit_acc_mask=0_aug=identity_0', 'val_bit_acc_mask=0_aug=crop_0.75', 'val_bit_acc_mask=0_aug=jpeg_60'] 
plot_measure(measures)