In [None]:
#default_exp plotting

In [None]:
#export
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.ticker import ScalarFormatter
from collections import defaultdict

matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'

In [None]:
#hide
from nbdev.showdoc import show_doc

# Logs plotting

In [None]:
#export
def plot_curve(x:list, # The x coordinates of the data.
               y:torch.Tensor, # The y coordinates of the data.
               label:str=None, # The label of the curve.
               x_label:str=None, # The label of the x axis.
               y_label:str=None, # The label of the y axis.
               axis:"matplotlib.axes.AxesSubplot"=None, # An axes object, if already known. Otherwise a new axis is created.
               logplot:bool=False, # Whether to plot the x axis logarithmically.
               y_std:torch.Tensor=None, # The standard deviation of the y values, if known.
               std_alpha:float=0.2, # The alpha value for `y_std`. Only has an effect if `y_std` is not None.
               show_all_xticks:bool=True, # Whether to display all ticks of the data along the x axis.
               **kwargs):
    """
    A method to plot a simple one dimensional curve using matplotlib.
    """
    if axis is None:
        fig, axis = plt.subplots(1, figsize=(7, 3), dpi=200)

    axis.plot(x, y, label=label, **kwargs)

    if y_std is not None:
        y = np.array(y)
        y_std = np.array(y_std)
        axis.fill_between(x, y - y_std, y + y_std, alpha=std_alpha, color=kwargs.get('color', None))

    if logplot:
        axis.set_xscale('log')
        axis.get_xaxis().set_major_formatter(ScalarFormatter())

    if show_all_xticks:
        axis.set_xticks(x)

    axis.set_xlabel(x_label, fontsize=24)
    axis.set_ylabel(y_label, fontsize=24)

In [None]:
#export
def plot_logs(
    x:list, # The x coordinates of the data.,
    logs:defaultdict, # A dictionary that contains a list of values for each key. The keys are the same strings stored in `criterion_names`, optionally along with tiehr standard deviations.
    criterion_names:list, # The criterion names deterime which keys are accesssed in the logs file for the plot.
    file_path:str=None, # The path where the plot should be saved as a png file.
    label:str=None, # The label of the plot.
    x_label:str=None, # The x label of the plot.
    share_x_axis:bool=True, # Whether all plots should share the same x axis or should be plotted in individual plots below one another.
    show_legend:bool=True, # Whether to show a legend for the plot.
    fig_and_axes:tuple=None, # A tuple containing a figure and an axes object, if already built.
    logplot:bool=False, # Whether to plot the x axis logarithmically.
    plot_with_std:bool=False, # Whether to also plot the standard deviations of the data. Requires keys `{criterion_name}_std` in `logs`.
    std_alpha:float=0.2, # The alpha value for `y_std`. Only has an effect if `y_std` is not None.
    y_ticks:list=None, # The y ticks for the plots.
    **kwargs):
    """
    A method to plot criterion evaluation curves stored in a logs file.
    """
    n_plots = len(criterion_names)

    if fig_and_axes is None:
        fig, axes = plt.subplots(n_plots, figsize=(7, 3*n_plots), dpi=800, sharex=share_x_axis)
    else:
        fig, axes = fig_and_axes

    if hasattr(axes, 'plot'):
        axes = [axes]

    assert len(axes) >= len(criterion_names)

    for i, name in enumerate(criterion_names):
        y_std = logs.get(f'{name}_std', None) if plot_with_std else None
        plot_curve(
            x=x,
            y=logs[name],
            label=label,
            x_label=None if share_x_axis else x_label,
            y_label=f'{name}',
            axis=axes[i],
            logplot=logplot,
            y_std=y_std,
            std_alpha=std_alpha,
            **kwargs
        )
        axes[i].set_ylabel(f'{name}', fontsize=24)
        if y_ticks != None:
            axes[i].set_yticks(y_ticks[i])
            axes[i].set_ylim(y_ticks[i][0] - y_ticks[i][-1]/20, y_ticks[i][-1] + y_ticks[i][-1]/20)
            axes[i].set_yticklabels(y_ticks[i], fontsize=16)

    if share_x_axis:
        axes[-1].set_xlabel(x_label, fontsize=22)
        axes[i].set_xticklabels(x, fontsize=16)
        axes[-1].tick_params(axis='x', labelrotation=45, labelbottom=True)

    if show_legend:
        handles, labels = axes[0].get_legend_handles_labels()
        fig.legend(handles, labels, loc='upper center')

    fig.align_ylabels(axes)
    fig.tight_layout()

    if file_path is not None:
        plt.savefig(file_path, bbox_inches='tight')

    plt.close(fig)

    return fig, axes