This notebook was automatically generated from your MAST-ML run so you can recreate the
plots. Some things are a bit different from the usual way of creating plots - we are
using the [object oriented
interface](https://matplotlib.org/tutorials/introductory/lifecycle.html) instead of
pyplot to create the `fig` and `ax` instances.


In [None]:
"""
This module contains a collection of functions which make plots (saved as png files) using matplotlib, generated from
some model fits and cross-validation evaluation within a MAST-ML run.

This module also contains a method to create python notebooks containing plotted data and the relevant source code from
this module, to enable the user to make their own modifications to the created plots in a straightforward way (useful for
tweaking plots for a presentation or publication).
"""

import math
import os
import pandas as pd
import itertools
import warnings
import logging
from collections import Iterable
from os.path import join
from collections import OrderedDict

# Ignore the harmless warning about the gelsd driver on mac.
warnings.filterwarnings(action="ignore", module="scipy",
                        message="^internal gelsd")
# Ignore matplotlib deprecation warning (set as all warnings for now)
warnings.filterwarnings(action="ignore")

import numpy as np
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure, figaspect
from matplotlib.animation import FuncAnimation
from matplotlib.font_manager import FontProperties
import matplotlib.mlab as mlab
from scipy.stats import gaussian_kde
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes

# Needed imports for ipynb_maker
from mastml.utils import nice_range
from mastml.metrics import nice_names

import inspect
import textwrap
from pandas import DataFrame, Series

import nbformat

from functools import wraps

matplotlib.rc('font', size=18, family='sans-serif') # set all font to bigger
matplotlib.rc('figure', autolayout=True) # turn on autolayout

# adding dpi as a constant global so it can be changed later
DPI = 250

log = logging.getLogger() # only used inside ipynb_maker I guess



In [None]:
def stat_to_string(name, value):
    """
    Method that converts a metric object into a string for displaying on a plot

    Args:

        name: (str), long name of a stat metric or quantity

        value: (float), value of the metric or quantity

    Return:

        (str), a string of the metric name, adjusted to look nicer for inclusion on a plot

    """

    " Stringifies the name value pair for display within a plot "
    if name in nice_names:
        name = nice_names[name]
    else:
        name = name.replace('_', ' ')

    # has a name only
    if not value:
        return name
    # has a mean and std
    if isinstance(value, tuple):
        mean, std = value
        return f'{name}:' + '\n\t' + f'{mean:.3f}' + r'$\pm$' + f'{std:.3f}'
    # has a name and value only
    if isinstance(value, int) or (isinstance(value, float) and value%1 == 0):
        return f'{name}: {int(value)}'
    if isinstance(value, float):
        return f'{name}: {value:.3f}'
    return f'{name}: {value}' # probably a string


def plot_stats(fig, stats, x_align=0.65, y_align=0.90, font_dict=dict(), fontsize=14):
    """
    Method that prints stats onto the plot. Goes off screen if they are too long or too many in number.

    Args:

        fig: (matplotlib figure object), a matplotlib figure object

        stats: (dict), dict of statistics to be included with a plot

        x_align: (float), float denoting x position of where to align display of stats on a plot

        y_align: (float), float denoting y position of where to align display of stats on a plot

        font_dict: (dict), dict of matplotlib font options to alter display of stats on plot

        fontsize: (int), the fontsize of stats to display on plot

    Returns:

        None

    """

    stat_str = '\n'.join(stat_to_string(name, value)
                           for name,value in stats.items())

    fig.text(x_align, y_align, stat_str,
             verticalalignment='top', wrap=True, fontdict=font_dict, fontproperties=FontProperties(size=fontsize))


def make_fig_ax(aspect_ratio=0.5, x_align=0.65, left=0.10):
    """
    Method to make matplotlib figure and axes objects. Using Object Oriented interface from https://matplotlib.org/gallery/api/agg_oo_sgskip.html

    Args:

        aspect_ratio: (float), aspect ratio for figure and axes creation

        x_align: (float), x position to draw edge of figure. Needed so can display stats alongside plot

        left: (float), the leftmost position to draw edge of figure

    Returns:

        fig: (matplotlib fig object), a matplotlib figure object with the specified aspect ratio

        ax: (matplotlib ax object), a matplotlib axes object with the specified aspect ratio

    """
    # Set image aspect ratio:
    w, h = figaspect(aspect_ratio)
    fig = Figure(figsize=(w,h))
    FigureCanvas(fig)

    # Set custom positioning, see this guide for more details:
    # https://python4astronomers.github.io/plotting/advanced.html
    #left   = 0.10
    bottom = 0.15
    right  = 0.01
    top    = 0.05
    width = x_align - left - right
    height = 1 - bottom - top
    ax = fig.add_axes((left, bottom, width, height), frameon=True)
    fig.set_tight_layout(False)
    
    return fig, ax


In [None]:
def plot_best_worst_split(y_true, best_run, worst_run, savepath,
                          title='Best Worst Overlay', label='target_value'):
    """
    Method to create a parity plot (predicted vs. true values) of just the best scoring and worst scoring CV splits

    Args:

        y_true: (numpy array), array of true y data

        best_run: (dict), the best scoring split_result from mastml_driver

        worst_run: (dict), the worst scoring split_result from mastml_driver

        savepath: (str), path to save plots to

        title: (str), title of the best_worst_split plot

        label: (str), label used for axis labeling

    Returns:

        None

    """
    # make fig and ax, use x_align when placing text so things don't overlap
    x_align = 0.64
    fig, ax = make_fig_ax(x_align=x_align)

    maxx = max(y_true) # TODO is round the right thing here?
    minn = min(y_true)
    maxx = round(float(maxx), rounder(maxx-minn))
    minn = round(float(minn), rounder(maxx-minn))
    ax.plot([minn, maxx], [minn, maxx], 'k--', lw=2, zorder=1)

    # set tick labels
    _set_tick_labels(ax, maxx, minn)

    # do the actual plotting
    ax.scatter(best_run['y_test_true'],  best_run['y_test_pred'],  c='red',
               alpha=0.7, label='best test',  edgecolor='darkred',  zorder=2, s=100)
    ax.scatter(worst_run['y_test_true'], worst_run['y_test_pred'], c='blue',
               alpha=0.7, label='worst test', edgecolor='darkblue', zorder=3, s=60)
    ax.legend(loc='lower right', fontsize=12)

    # set axis labels
    ax.set_xlabel('True '+label, fontsize=16)
    ax.set_ylabel('Predicted '+label, fontsize=16)

    #font_dict = {'size'   : 10, 'family' : 'sans-serif'}

    # Duplicate the stats dicts with an additional label
    best_stats = OrderedDict([('Best Run', None)])
    best_stats.update(best_run['test_metrics'])
    worst_stats = OrderedDict([('worst Run', None)])
    worst_stats.update(worst_run['test_metrics'])

    plot_stats(fig, best_stats, x_align=x_align, y_align=0.90)
    plot_stats(fig, worst_stats, x_align=x_align, y_align=0.60)

    fig.savefig(savepath, dpi=DPI, bbox_inches='tight')


In [None]:
from numpy import array
from collections import OrderedDict
from io import StringIO
y_true = array([177.28302038, 150.37102595, 202.0924573 , ...,  47.20790294,
        23.809066  ,   0.        ])
best_run = OrderedDict([('normalizer', 'StandardScaler'), ('selector', 'SequentialFeatureSelector'), ('model', 'KernelRidge'), ('splitter', 'NoSplit'), ('split_num', 0), ('y_train_true', array([177.28302038, 150.37102595, 202.0924573 , ...,  47.20790294,
        23.809066  ,   0.        ])), ('y_train_pred', array([227.47526251, 162.31305945, 140.78376572, ...,  62.25564739,
        50.22694807,   5.51842863])), ('y_test_true', array([177.28302038, 150.37102595, 202.0924573 , ...,  47.20790294,
        23.809066  ,   0.        ])), ('y_test_pred', array([227.47526251, 162.31305945, 140.78376572, ...,  62.25564739,
        50.22694807,   5.51842863])), ('train_metrics', OrderedDict([('R2', 0.771823494407079), ('root_mean_squared_error', 45.89988069094602), ('mean_absolute_error', 31.369644013893062), ('rmse_over_stdev', 0.4811262117767574)])), ('test_metrics', OrderedDict([('R2', 0.771823494407079), ('root_mean_squared_error', 45.89988069094602), ('mean_absolute_error', 31.369644013893062), ('rmse_over_stdev', 0.4811262117767574)])), ('train_indices', array([   0,    1,    2, ..., 2142, 2143, 2144])), ('test_indices', array([   0,    1,    2, ..., 2142, 2143, 2144])), ('train_groups', None), ('test_groups', None), ('prediction_metrics', None)])
worst_run = OrderedDict([('normalizer', 'StandardScaler'), ('selector', 'SequentialFeatureSelector'), ('model', 'KernelRidge'), ('splitter', 'NoSplit'), ('split_num', 0), ('y_train_true', array([177.28302038, 150.37102595, 202.0924573 , ...,  47.20790294,
        23.809066  ,   0.        ])), ('y_train_pred', array([227.47526251, 162.31305945, 140.78376572, ...,  62.25564739,
        50.22694807,   5.51842863])), ('y_test_true', array([177.28302038, 150.37102595, 202.0924573 , ...,  47.20790294,
        23.809066  ,   0.        ])), ('y_test_pred', array([227.47526251, 162.31305945, 140.78376572, ...,  62.25564739,
        50.22694807,   5.51842863])), ('train_metrics', OrderedDict([('R2', 0.771823494407079), ('root_mean_squared_error', 45.89988069094602), ('mean_absolute_error', 31.369644013893062), ('rmse_over_stdev', 0.4811262117767574)])), ('test_metrics', OrderedDict([('R2', 0.771823494407079), ('root_mean_squared_error', 45.89988069094602), ('mean_absolute_error', 31.369644013893062), ('rmse_over_stdev', 0.4811262117767574)])), ('train_indices', array([   0,    1,    2, ..., 2142, 2143, 2144])), ('test_indices', array([   0,    1,    2, ..., 2142, 2143, 2144])), ('train_groups', None), ('test_groups', None), ('prediction_metrics', None)])
savepath = '/home/nerve/Desktop/skunkworks/lane_schultz/learning_curves/learning/StandardScaler/SequentialFeatureSelector/KernelRidge/NoSplit/best_worst_split.png'
label = 'Energy above convex hull (meV/atom)'

In [None]:
import pandas as pd
from IPython.display import Image, display

plot_best_worst_split(y_true, best_run, worst_run, savepath, label)
display(Image(filename='best_worst_split.png'))
