In [None]:
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

In [None]:
font_path = "./ARIAL.TTF"
matplotlib.font_manager.fontManager.addfont(font_path)

# 2) Tell Matplotlib to load it
plt.rcParams['font.family'] = "Arial"

In [None]:
# Palettes
TRAINING_PALETTE = {
    'train_protein_accuracy': '#D3D3D3',
    'valid_protein_accuracy': '#D3D3D3',
    'train_dna_accuracy': '#FF7F7F',
    'valid_dna_accuracy': '#FF7F7F',
    'train_rna_accuracy': '#7F7FFF',
    'valid_rna_accuracy': '#7F7FFF',
    'train_dna_loss': '#FF7F7F',
    'valid_dna_loss': '#FF7F7F',
    'train_rna_loss': '#7F7FFF',
    'valid_rna_loss': '#7F7FFF'
}

TRAINING_LINESTYLE_PALETTE = {
    'train_protein_accuracy': 'solid',
    'valid_protein_accuracy': 'dashed',
    'train_dna_accuracy': 'solid',
    'valid_dna_accuracy': 'dashed',
    'train_rna_accuracy': 'solid',
    'valid_rna_accuracy': 'dashed',
    'train_dna_loss': 'solid',
    'valid_dna_loss': 'dashed',
    'train_rna_loss': 'solid',
    'valid_rna_loss': 'dashed'
}

TRAINING_LEGEND_LABEL_MAP = {
    'train_protein_accuracy': 'Protein (train)',
    'valid_protein_accuracy': 'Protein (valid)',
    'train_dna_accuracy': 'DNA (train)',
    'valid_dna_accuracy': 'DNA (valid)',
    'train_rna_accuracy': 'RNA (train)',
    'valid_rna_accuracy': 'RNA (valid)',
    'train_dna_loss': 'DNA (train)',
    'valid_dna_loss': 'DNA (valid)',
    'train_rna_loss': 'RNA (train)',
    'valid_rna_loss': 'RNA (valid)'
}

# Default style dictionary for training metric plots
DEFAULT_TRAINING_STYLE = {
    'figsize': (80 / 25.4, 80 / 25.4),
    'dpi': 300,
    'title_fontsize': None,
    'axis_title_fontsize': 8,
    'tick_labelsize': 6,
    'background_color': 'white',
    'show_title': False,
    'title': None,
    'show_axis_labels': True,
    'show_ticks': True,
    'save_name': None,
    'show_legend': True,
    'legend_loc': 'lower right',
    'legend_ncol': 1,
    'legend_title': None,
    'legend_title_fontsize': None,
    'legend_fontsize': 6,
    'legend_frameon': False,
    # Legend layout controls
    'legend_handlelength': 2,   # length of the handles
    'legend_handleheight': 0.2,   # height of the handles
    'legend_labelspacing': 0.5, # vertical spacing between rows
    'legend_borderpad': None,    # padding between the legend and content
    'legend_frame_linewidth': 1, # line width of the legend frame (handle box)
    'legend_markerscale': 1.0,     # scale factor for marker size in legend
    'legend_handletextpad': 0.3,   # padding between handle and text
    'linewidth': 0.5,
    'running_mean_width': 1,
    'spine_linewidth': 0.5,
    'tick_width': 0.5,
    'hide_top_right_spines': True,
    # Style-driven replacements
    'palette': TRAINING_PALETTE,              # dict: dep_var -> color
    'linestyle_palette': TRAINING_LINESTYLE_PALETTE,    # dict: dep_var -> linestyle
    'legend_label_map': TRAINING_LEGEND_LABEL_MAP,     # dict: dep_var -> legend label
    'x_label': None,
    'y_label': None,
    'ymin': None,
    'ymax': None,
    # tick controls
    'xtick_step': 10000,
    'xtick_max_mult': 11,
    'custom_xticks': None,
    'custom_xticklabels': None,
    # separate-legend controls
    'save_legend_separately': True,
    'legend_save_suffix': '_legend',
    'legend_figsize': (25 / 25.4, 25 / 25.4),
    'legend_fill_figure': True   # fill legend to entire legend figure
}

# Read Results

In [None]:
def running_mean(x, N):
    cumulative_sum = np.cumsum(np.insert(x, 0, 0)) 
    return (cumulative_sum[N:] - cumulative_sum[:-N]) / float(N)

def read_text_file(path):
    with open(path, mode = "rt") as f:
        return f.read()

def read_results(log_path):
    log_text = read_text_file(log_path)
    
    data_labels_to_index = dict()
    data = []
    for log_line in log_text.split("\n"):
        if log_line[:5] == "epoch":
            if len(data_labels_to_index) == 0:
                labels = list(map(lambda x: x.split(": ")[0].strip(), 
                                  log_line.split(", ")))
                data_labels_to_index = {k: v for (v, k) in enumerate(labels)}
            line_data = list(map(lambda x: float(x.split(": ")[1].strip()) if x != "nan" else np.nan, 
                                 log_line.split(", ")))
            data.append(line_data)
    return data_labels_to_index, np.array(data, dtype = np.float64)

# Helper Functions for Plotting

In [None]:
def _save_legend_from_axes(ax, parent_save_name, style_dict):
    """Save legend from ax as a separate figure and remove it from parent."""
    if not parent_save_name:
        return None

    # Grab the legend.
    lg = ax.get_legend()
    if lg is None:
        return None
    
    handles, labels = ax.get_legend_handles_labels()

    # Remove the legend from the figure.
    try:
        lg.remove()
    except Exception:
        pass
    
    # Determine the save options/path.
    dpi = style_dict.get('dpi', 300)
    suffix = style_dict.get('legend_save_suffix', '_legend')
    root, ext = os.path.splitext(parent_save_name)
    if not ext:
        ext = '.svg'
    save_path = f"{root}{suffix}{ext}"

    # Recreate the legend as a separate figure.
    fig_leg = plt.figure(
        figsize=style_dict.get('legend_figsize'), 
        dpi=dpi,
    )
    if style_dict.get('legend_fill_figure', True):
        ax_leg = fig_leg.add_axes([0, 0, 1, 1])
        ax_leg.axis('off')
        legend = ax_leg.legend(
            handles,
            labels,
            ncol=style_dict.get('legend_ncol', 1),
            loc='center',
            mode='expand',
            frameon=style_dict.get('legend_frameon', True),
            title=style_dict.get('legend_title', None),
            handlelength=style_dict.get('legend_handlelength', 2),
            handleheight=style_dict.get('legend_handleheight', 1),
            labelspacing=style_dict.get('legend_labelspacing', 0.5),
            borderpad=style_dict.get('legend_borderpad', 0.5),
            markerscale=style_dict.get('legend_markerscale', 1.0),
            handletextpad=style_dict.get('legend_handletextpad', 0.8),
        )
    else:
        ax_leg = fig_leg.add_subplot(111)
        ax_leg.axis('off')
        legend = ax_leg.legend(
            handles,
            labels,
            ncol=style_dict.get('legend_ncol', 1),
            loc='center',
            frameon=style_dict.get('legend_frameon', True),
            title=style_dict.get('legend_title', None),
            handlelength=style_dict.get('legend_handlelength', 2),
            handleheight=style_dict.get('legend_handleheight', 1),
            labelspacing=style_dict.get('legend_labelspacing', 0.5),
            borderpad=style_dict.get('legend_borderpad', 0.5),
            markerscale=style_dict.get('legend_markerscale', 1.0),
            handletextpad=style_dict.get('legend_handletextpad', 0.8),
        )

    # Set legend font sizes.
    if style_dict.get('legend_title') and style_dict.get('legend_title_fontsize'):
        legend.get_title().set_fontsize(style_dict['legend_title_fontsize'])
    if style_dict.get('legend_fontsize'):
        for txt in legend.get_texts():
            txt.set_fontsize(style_dict['legend_fontsize'])

    # Apply legend frame line width if frame is on
    if style_dict.get('legend_frameon', True) and style_dict.get('legend_frame_linewidth') is not None:
        frame = legend.get_frame()
        if frame is not None:
            frame.set_linewidth(style_dict.get('legend_frame_linewidth', 1.0))

    # Save the figure.
    fig_leg.savefig(save_path, dpi=dpi)
    
    return fig_leg

def plot_results(model_name, independent_variable, dependent_variables, style = None):
    """
    Plot training metrics using a unified style dict.
    Inputs:
        model_name (str)
        independent_variable (str)  # column name in log (e.g., 'step')
        dependent_variables (list[str])  # metric names to plot
        style (dict) optional overrides of DEFAULT_TRAINING_STYLE
    Underlying data extraction, smoothing, NaN handling unchanged.
    """
    merged_style = DEFAULT_TRAINING_STYLE.copy()
    if style:
        merged_style.update(style)

    line_width = merged_style.get('linewidth', 1.0)
    smooth_width = merged_style.get('running_mean_width', 1)
    fig_size = merged_style.get('figsize')
    base_fs = merged_style.get('tick_labelsize', 12)

    # Log path
    try:
        log_path = os.path.join('/projects/ml/na_mpnn/models', model_name, 'log.txt')
        data_labels_to_index, data = read_results(log_path)
    except:
        log_path = os.path.join('/home/akubaney/projects/na_mpnn/models', model_name, 'log.txt')
        data_labels_to_index, data = read_results(log_path)

    # Necessary for legend handle/text alignment.
    base_fs = matplotlib.rcParams['legend.fontsize']
    if style.get("legend_fontsize", None) is not None:
        matplotlib.rcParams['legend.fontsize'] = style["legend_fontsize"]
    
    # Figure creation
    if fig_size:
        fig, ax = plt.subplots(figsize = fig_size, dpi = merged_style.get('dpi'), constrained_layout = True)
    else:
        fig, ax = plt.subplots(dpi = merged_style.get('dpi'))

    palette_dict   = merged_style.get('palette') or {}
    linestyle_dict = merged_style.get('linestyle_palette') or {}
    legend_map     = merged_style.get('legend_label_map') or {}

    color_cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', [])

    if independent_variable not in data_labels_to_index:
        raise KeyError(f"Independent variable '{independent_variable}' not in log header")

    x_raw = data[:, data_labels_to_index[independent_variable]]

    for idx, dep in enumerate(dependent_variables):
        if dep not in data_labels_to_index:
            raise KeyError(f"Dependent variable '{dep}' not found in log header")
        y_raw = data[:, data_labels_to_index[dep]]

        if np.count_nonzero(np.isnan(y_raw)) > 0:
            mask = ~np.isnan(y_raw)
            x_data = x_raw[mask]
            y_data = y_raw[mask]
        else:
            x_data = x_raw
            y_data = y_raw

        if smooth_width > 1:
            if smooth_width < len(x_data):
                x_data = running_mean(x_data, smooth_width)
                y_data = running_mean(y_data, smooth_width)
            else:
                raise ValueError('Smoothing window is larger than data length.')

        color = palette_dict.get(dep, color_cycle[idx % len(color_cycle)] if color_cycle else None)
        ls = linestyle_dict.get(dep, 'solid')
        ax.plot(x_data, y_data, color = color, linestyle = ls, linewidth = line_width, label = legend_map.get(dep, dep))

    # Axis labels
    if merged_style.get('show_axis_labels', True):
        if merged_style.get('x_label'):
            ax.set_xlabel(merged_style['x_label'], fontsize = merged_style.get('axis_title_fontsize', base_fs))
        if merged_style.get('y_label'):
            ax.set_ylabel(merged_style['y_label'], fontsize = merged_style.get('axis_title_fontsize', base_fs))
    else:
        ax.set_xlabel('')
        ax.set_ylabel('')

    # X ticks
    custom_xticks = merged_style.get('custom_xticks')
    custom_xticklabels = merged_style.get('custom_xticklabels')
    if custom_xticks is not None:
        ax.set_xticks(custom_xticks)
        if custom_xticklabels is not None:
            ax.set_xticklabels(custom_xticklabels, fontsize = merged_style.get('tick_labelsize', base_fs))
        else:
            ax.tick_params(axis = 'x', labelsize = merged_style.get('tick_labelsize', base_fs))
    else:
        step = merged_style.get('xtick_step', 1000)
        max_mult = merged_style.get('xtick_max_mult', 100)
        ticks = np.arange(0, max_mult, 1) * step
        ax.set_xticks(ticks)
        ax.set_xticklabels(np.arange(0, max_mult, 1), fontsize = merged_style.get('tick_labelsize', base_fs))

    # Y ticks
    if merged_style.get('show_ticks', True):
        ax.tick_params(axis = 'y', labelsize = merged_style.get('tick_labelsize', base_fs))
    else:
        ax.set_xticks([])
        ax.set_yticks([])

    # y limits
    ymin = merged_style.get('ymin')
    ymax = merged_style.get('ymax')
    if (ymin is not None) or (ymax is not None):
        cur_ymin, cur_ymax = ax.get_ylim()
        ax.set_ylim(ymin = ymin if ymin is not None else cur_ymin,
                    ymax = ymax if ymax is not None else cur_ymax)

    # Legend
    if merged_style.get('show_legend', True):
        legend = ax.legend(
            loc=merged_style.get('legend_loc', 'upper left'),
            ncol=merged_style.get('legend_ncol', 1),
            frameon=merged_style.get('legend_frameon', True),
            title=merged_style.get('legend_title'),
            handlelength=merged_style.get('legend_handlelength', 2),
            handleheight=merged_style.get('legend_handleheight', 1),
            labelspacing=merged_style.get('legend_labelspacing', 0.5),
            borderpad=merged_style.get('legend_borderpad', 0.5),
            markerscale=merged_style.get('legend_markerscale', 1.0),
            handletextpad=merged_style.get('legend_handletextpad', 0.8),
        )
        if merged_style.get('legend_title') and merged_style.get('legend_title_fontsize'):
            legend.get_title().set_fontsize(merged_style['legend_title_fontsize'])
        if merged_style.get('legend_fontsize'):
            for txt in legend.get_texts():
                txt.set_fontsize(merged_style['legend_fontsize'])
        # Apply legend frame line width if frame is on
        if merged_style.get('legend_frameon', True) and merged_style.get('legend_frame_linewidth') is not None:
            frame = legend.get_frame()
            if frame is not None:
                frame.set_linewidth(merged_style.get('legend_frame_linewidth', 1.0))
        # Save legend separately if requested
        parent_save = merged_style.get('save_name')
        if merged_style.get('save_legend_separately', False) and parent_save:
            fig_leg = _save_legend_from_axes(ax, parent_save, merged_style)
    else:
        lg = ax.get_legend()
        if lg:
            lg.remove()

    # Title
    if merged_style.get('show_title') and merged_style.get('title'):
        ax.set_title(merged_style['title'], fontsize = merged_style.get('title_fontsize', base_fs))

    # Background
    ax.set_facecolor(merged_style.get('background_color', 'white'))
    fig.patch.set_facecolor(merged_style.get('background_color', 'white'))

    # Spines / ticks
    spine_lw = merged_style.get('spine_linewidth')
    if spine_lw is not None:
        for spine in ax.spines.values():
            spine.set_linewidth(spine_lw)
    tick_w = merged_style.get('tick_width')
    if tick_w is not None:
        ax.tick_params(width = tick_w)
    if merged_style.get('hide_top_right_spines', False):
        if 'top' in ax.spines:
            ax.spines['top'].set_visible(False)
        if 'right' in ax.spines:
            ax.spines['right'].set_visible(False)

    if merged_style.get('save_name'):
        fig.savefig(merged_style['save_name'], dpi = merged_style.get('dpi', 300))

    matplotlib.rcParams['legend.fontsize'] = base_fs

    return fig, ax

# Plot Training Curves

In [None]:
style_acc = {
    'x_label': 'Number of Batches (x10000)',
    'y_label': 'Accuracy',
    'save_name': "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_training_curve_accuracy.svg"
}
plot_results(
    'design_model',
    'step',
    ['train_protein_accuracy', 'valid_protein_accuracy', 'train_dna_accuracy', 'valid_dna_accuracy', 'train_rna_accuracy', 'valid_rna_accuracy'],
    style = style_acc
)

In [None]:
style_loss = {
    'x_label': 'Number of Batches (x10000)',
    'y_label': 'Loss',
    'legend_loc': 'upper right',
    'save_name': "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_training_curve_loss.svg",
    'legend_figsize': (22 / 25.4, 17 / 25.4)
}
plot_results(
    'specificity_model',
    'step',
    ['train_dna_loss', 'valid_dna_loss', 'train_rna_loss', 'valid_rna_loss'],
    style = style_loss
)

style_acc_spec = {
    'x_label': 'Number of Batches (x10000)',
    'y_label': 'Accuracy',
    'save_name': "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_training_curve_accuracy.svg"
}
plot_results(
    'specificity_model',
    'step',
    ['train_dna_accuracy', 'valid_dna_accuracy', 'train_rna_accuracy', 'valid_rna_accuracy'],
    style = style_acc_spec
)