This notebook was automatically generated from your MAST-ML run so you can recreate the
plots. Some things are a bit different from the usual way of creating plots - we are
using the [object oriented
interface](https://matplotlib.org/tutorials/introductory/lifecycle.html) instead of
pyplot to create the `fig` and `ax` instances.


In [None]:
"""
This module contains a collection of functions which make plots (saved as png files) using matplotlib, generated from
some model fits and cross-validation evaluation within a MAST-ML run.

This module also contains a method to create python notebooks containing plotted data and the relevant source code from
this module, to enable the user to make their own modifications to the created plots in a straightforward way (useful for
tweaking plots for a presentation or publication).
"""

import math
import os
import pandas as pd
import itertools
import warnings
import logging
from collections import Iterable
from os.path import join
from collections import OrderedDict

# Ignore the harmless warning about the gelsd driver on mac.
warnings.filterwarnings(action="ignore", module="scipy",
                        message="^internal gelsd")
# Ignore matplotlib deprecation warning (set as all warnings for now)
warnings.filterwarnings(action="ignore")

import numpy as np
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure, figaspect
from matplotlib.animation import FuncAnimation
from matplotlib.font_manager import FontProperties
import matplotlib.mlab as mlab
from scipy.stats import gaussian_kde
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes

# Needed imports for ipynb_maker
from mastml.utils import nice_range
from mastml.metrics import nice_names

import inspect
import textwrap
from pandas import DataFrame, Series

import nbformat

from functools import wraps

matplotlib.rc('font', size=18, family='sans-serif') # set all font to bigger
matplotlib.rc('figure', autolayout=True) # turn on autolayout

# adding dpi as a constant global so it can be changed later
DPI = 250

log = logging.getLogger() # only used inside ipynb_maker I guess



In [None]:
def stat_to_string(name, value):
    """
    Method that converts a metric object into a string for displaying on a plot

    Args:

        name: (str), long name of a stat metric or quantity

        value: (float), value of the metric or quantity

    Return:

        (str), a string of the metric name, adjusted to look nicer for inclusion on a plot

    """

    " Stringifies the name value pair for display within a plot "
    if name in nice_names:
        name = nice_names[name]
    else:
        name = name.replace('_', ' ')

    # has a name only
    if not value:
        return name
    # has a mean and std
    if isinstance(value, tuple):
        mean, std = value
        return f'{name}:' + '\n\t' + f'{mean:.3f}' + r'$\pm$' + f'{std:.3f}'
    # has a name and value only
    if isinstance(value, int) or (isinstance(value, float) and value%1 == 0):
        return f'{name}: {int(value)}'
    if isinstance(value, float):
        return f'{name}: {value:.3f}'
    return f'{name}: {value}' # probably a string


def plot_stats(fig, stats, x_align=0.65, y_align=0.90, font_dict=dict(), fontsize=14):
    """
    Method that prints stats onto the plot. Goes off screen if they are too long or too many in number.

    Args:

        fig: (matplotlib figure object), a matplotlib figure object

        stats: (dict), dict of statistics to be included with a plot

        x_align: (float), float denoting x position of where to align display of stats on a plot

        y_align: (float), float denoting y position of where to align display of stats on a plot

        font_dict: (dict), dict of matplotlib font options to alter display of stats on plot

        fontsize: (int), the fontsize of stats to display on plot

    Returns:

        None

    """

    stat_str = '\n'.join(stat_to_string(name, value)
                           for name,value in stats.items())

    fig.text(x_align, y_align, stat_str,
             verticalalignment='top', wrap=True, fontdict=font_dict, fontproperties=FontProperties(size=fontsize))


def make_fig_ax(aspect_ratio=0.5, x_align=0.65, left=0.10):
    """
    Method to make matplotlib figure and axes objects. Using Object Oriented interface from https://matplotlib.org/gallery/api/agg_oo_sgskip.html

    Args:

        aspect_ratio: (float), aspect ratio for figure and axes creation

        x_align: (float), x position to draw edge of figure. Needed so can display stats alongside plot

        left: (float), the leftmost position to draw edge of figure

    Returns:

        fig: (matplotlib fig object), a matplotlib figure object with the specified aspect ratio

        ax: (matplotlib ax object), a matplotlib axes object with the specified aspect ratio

    """
    # Set image aspect ratio:
    w, h = figaspect(aspect_ratio)
    fig = Figure(figsize=(w,h))
    FigureCanvas(fig)

    # Set custom positioning, see this guide for more details:
    # https://python4astronomers.github.io/plotting/advanced.html
    #left   = 0.10
    bottom = 0.15
    right  = 0.01
    top    = 0.05
    width = x_align - left - right
    height = 1 - bottom - top
    ax = fig.add_axes((left, bottom, width, height), frameon=True)
    fig.set_tight_layout(False)
    
    return fig, ax


In [None]:
def plot_predicted_vs_true(train_quad, test_quad, outdir, label):
    """
    Method to create a parity plot (predicted vs. true values)

    Args:

        train_quad: (tuple), tuple containing 4 numpy arrays: true training y data, predicted training y data,

        training metric data, and groups used in training

        test_quad: (tuple), tuple containing 4 numpy arrays: true test y data, predicted test y data,

        testing metric data, and groups used in testing

        outdir: (str), path to save plots to

        label: (str), label used for axis labeling

    Returns:

        None

    """
    filenames = list()
    y_train_true, y_train_pred, train_metrics, train_groups = train_quad
    y_test_true, y_test_pred, test_metrics, test_groups = test_quad

    # make diagonal line from absolute min to absolute max of any data point
    # using round because Ryan did - but won't that ruin small numbers??? TODO this
    max1 = max(y_train_true.max(), y_train_pred.max(),
               y_test_true.max(), y_test_pred.max())
    min1 = min(y_train_true.min(), y_train_pred.min(),
               y_test_true.min(), y_test_pred.min())
    max1 = round(float(max1), rounder(max1-min1))
    min1 = round(float(min1), rounder(max1-min1))
    for y_true, y_pred, stats, groups, title_addon in \
            (train_quad+('train',), test_quad+('test',)):

        # make fig and ax, use x_align when placing text so things don't overlap
        x_align=0.64
        fig, ax = make_fig_ax(x_align=x_align)

        # set tick labels
        # notice that we use the same max and min for all three. Don't
        # calculate those inside the loop, because all the should be on the same scale and axis
        _set_tick_labels(ax, max1, min1)

        # plot diagonal line
        ax.plot([min1, max1], [min1, max1], 'k--', lw=2, zorder=1)

        # do the actual plotting
        if groups is None:
            ax.scatter(y_true, y_pred, color='blue', edgecolors='black', s=100, zorder=2, alpha=0.7)
        else:
            handles = dict()
            unique_groups = np.unique(np.concatenate((train_groups, test_groups), axis=0))
            log.debug(' '*12 + 'unique groups: ' +str(list(unique_groups)))
            colors = ['blue', 'red', 'green', 'purple', 'orange', 'black']
            markers = ['o', 'v', '^', 's', 'p', 'h', 'D', '*', 'X', '<', '>', 'P']
            colorcount = markercount = 0
            for groupcount, group in enumerate(unique_groups):
                mask = groups == group
                log.debug(' '*12 + f'{group} group_percent = {np.count_nonzero(mask) / len(groups)}')
                handles[group] = ax.scatter(y_true[mask], y_pred[mask], label=group, color=colors[colorcount],
                                            marker=markers[markercount], s=100, alpha=0.7)
                colorcount += 1
                if colorcount % len(colors) == 0:
                    markercount += 1
                    colorcount = 0
            ax.legend(handles.values(), handles.keys(), loc='lower right', fontsize=12)

        # set axis labels
        ax.set_xlabel('True '+label, fontsize=16)
        ax.set_ylabel('Predicted '+label, fontsize=16)

        plot_stats(fig, stats, x_align=x_align, y_align=0.90)

        filename = 'predicted_vs_true_'+ title_addon + '.png'
        filenames.append(filename)
        fig.savefig(join(outdir, filename), dpi=DPI, bbox_inches='tight')

    return filenames


In [None]:
from numpy import array
from collections import OrderedDict
from io import StringIO
train_quad = (array([177.28302038, 150.37102595, 202.0924573 , ...,  47.20790294,
        23.809066  ,   0.        ]), array([231.05348537, 159.90478   , 136.55447063, ...,  55.12829832,
        45.58754203,   3.46382707]), OrderedDict([('R2', 0.7737463914126692), ('root_mean_squared_error', 46.57206418327311), ('mean_absolute_error', 31.84218911773704), ('rmse_over_stdev', 0.4804535233765936)]), None)
test_quad = (array([2.45647989e+02, 1.82129033e+02, 2.96501210e+02, 1.93721694e+02,
       3.44799275e+02, 0.00000000e+00, 2.59721126e+02, 1.88139584e+02,
       1.21202059e+02, 1.14194239e+02, 1.20565464e+02, 1.20156284e+02,
       9.45736257e+01, 1.29676557e+02, 1.06566685e+02, 2.02290968e+02,
       1.81659746e+02, 7.75902143e+01, 1.78590946e+02, 8.80696853e+01,
       8.91149867e+01, 1.71809961e+02, 1.88725323e+02, 1.71679718e+02,
       0.00000000e+00, 1.95159840e+02, 1.90460889e+02, 1.96095876e+02,
       1.53899276e+02, 2.27914069e+02, 3.81107807e+02, 8.44124962e+01,
       1.85728826e+02, 1.19298374e+02, 9.33858524e+01, 1.37488659e+02,
       7.70450205e+01, 1.56941241e+02, 1.58367410e+02, 1.21204438e+02,
       2.91955558e+02, 1.09473901e+02, 1.21368422e+02, 2.70007913e+02,
       8.39207969e+01, 1.08666855e+02, 7.30199956e+01, 1.06616452e+02,
       1.56229730e+02, 2.36664682e+02, 4.08488586e+02, 7.19005615e+01,
       2.72804393e+02, 1.22311356e+02, 8.21644400e+00, 7.34978297e+01,
       7.14456318e+01, 3.20094037e+01, 7.66704600e+00, 9.04068336e+01,
       3.04032713e+01, 5.90096117e+00, 6.78471000e-01, 5.42920320e+01,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.16763996e+00,
       0.00000000e+00, 2.99180823e+01, 8.03088872e+01, 3.11290486e+01,
       1.77633032e+02, 7.03541024e+01, 8.81834984e+01, 1.27239967e+02,
       1.49850095e+02, 1.51797441e+02, 1.18302658e+02, 2.94893175e+02,
       7.12267570e+01, 3.39191772e+01, 5.90546828e+01, 6.01231126e+01,
       3.93315709e+01, 0.00000000e+00, 0.00000000e+00, 2.99490808e+01,
       1.66684097e+01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 3.60810735e+00,
       5.93437233e-01, 0.00000000e+00, 0.00000000e+00, 1.92094295e+00,
       0.00000000e+00, 1.99014786e+01, 1.39370004e+02, 3.35683107e+01,
       1.35168013e+02, 1.74713200e+00, 0.00000000e+00, 0.00000000e+00,
       6.37138607e+00, 5.38425541e+01, 4.34882312e+01, 4.48505167e+01,
       4.89943291e+01, 0.00000000e+00, 1.31304680e+02, 4.22958931e+01,
       8.73607131e+01, 2.96532550e+01, 0.00000000e+00, 9.92490625e-01,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       8.97062587e+00, 3.45198214e+01, 6.21108106e+01, 3.55034057e+01,
       4.07973166e+01, 1.17125778e+02, 7.92325765e+01, 3.73488338e+01,
       3.32000681e+01, 2.18803383e+01, 3.54620958e+01, 5.74995003e+01,
       6.12137026e+01, 4.91573257e+01, 3.51232216e+01, 0.00000000e+00,
       0.00000000e+00, 8.68896540e+01, 1.17714484e+02, 5.35428595e+01,
       1.24630840e+02, 1.32109024e+02, 1.00859781e+02, 1.23754016e+02,
       8.52336429e+01, 1.75579596e+02, 0.00000000e+00, 2.02289181e-01,
       2.24784003e+01, 1.51553671e+02, 2.11720874e+01, 5.24822446e+01,
       1.26881899e+02, 0.00000000e+00, 1.24480773e+01, 1.71956703e+02,
       1.66187429e+02, 2.76340269e+02, 0.00000000e+00, 2.51360287e+02,
       7.78684422e+01, 0.00000000e+00, 2.86014919e+02, 6.24940495e+01,
       4.80411705e+01, 5.08392287e+01, 3.50768896e+01, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 3.29759565e+01, 9.53324888e+01,
       5.14231024e+01, 0.00000000e+00, 0.00000000e+00, 5.25651706e+01,
       0.00000000e+00, 4.81001678e+01, 5.40734041e+01, 4.42032278e+01,
       3.81230146e+01, 1.29976759e+02, 4.52879705e+01, 0.00000000e+00,
       1.38758173e+02, 1.50562981e+02, 0.00000000e+00, 1.22980240e+02,
       1.34671642e+02, 0.00000000e+00, 0.00000000e+00, 6.07251297e+01,
       1.20596615e+02, 4.31287284e+01, 2.65092677e+01, 4.95815034e+01,
       2.50852147e+01, 6.33172566e+01, 1.02395743e+02, 3.63567935e+01,
       0.00000000e+00, 2.87242187e+01, 2.36410205e+01, 5.48277307e+01,
       5.76088495e+01, 4.94252903e+01, 1.77187703e+01, 1.25215370e+02,
       2.32281891e+01, 5.71443815e+01, 6.89012756e+00, 0.00000000e+00,
       5.31766390e-01, 1.12780820e+01, 0.00000000e+00, 7.56564178e+00,
       3.56929084e+01, 4.21875225e+01, 7.37535702e+01, 1.20413789e+02,
       1.29601532e+02, 2.44348237e+02, 1.85197855e+02, 2.15604850e+02,
       1.32871922e+02, 2.07904807e+02, 3.61579777e+02, 9.03169483e+01,
       1.10139318e+02, 1.14272223e+02, 1.42295818e+02, 2.52767162e+02,
       2.43428371e+02, 2.24622871e+02, 2.30118336e+02, 1.18103266e+02,
       1.11920471e+02, 1.00207954e+02, 2.35409225e+02, 2.92390663e+02,
       1.96409134e+02, 1.23246629e+02, 2.12302508e+02, 2.05092926e+02,
       2.31481855e+02, 1.54419624e+02, 2.36338029e+02, 2.26507616e+02,
       2.27878672e+02, 3.49453381e+02, 1.01887892e+02, 1.43517205e+02,
       1.97124130e+02, 2.10107835e+02, 1.95672754e+02, 1.77317832e+02,
       1.84516539e+02, 1.92088747e+02, 1.41921101e+02, 1.31977595e+02,
       1.08425980e+02, 2.34968862e+02, 1.85807040e+02, 3.13722881e+02,
       2.40576558e+02, 3.64707101e+02, 3.15514875e+02, 3.30939913e+02,
       1.79999332e+02, 2.50995845e+02, 3.32212137e+02, 2.30917642e+02,
       2.38988178e+02, 3.46902034e+02, 3.46087908e+02, 2.23307827e+02,
       3.05414508e+02, 1.20403530e+02, 8.85969515e+01, 8.95472627e+01,
       1.82220981e+02, 1.23035360e+02, 1.39928045e+02, 1.47525298e+02,
       1.62784213e+02, 6.02850907e+01, 6.45557199e+01, 1.58202913e+02,
       7.74665625e+01, 7.28852659e+01, 1.84330790e+02, 1.84907725e+02,
       1.25571144e+02, 3.24945084e+02, 2.83042429e+02, 2.47909595e+02,
       1.48340548e+02, 2.42965465e+02, 2.95582043e+02, 2.51703483e+02,
       9.57786406e+01, 1.07655170e+02, 9.27064680e+01, 1.15142225e+02,
       9.27983292e+01, 1.26955312e+02, 1.22185191e+02, 8.32175656e+01,
       3.90465522e+01, 1.22338936e+01, 3.51009606e+01, 5.18182329e+01,
       3.70668895e+01, 6.07829098e+01, 4.43281812e+01, 7.09524931e+01,
       6.20242803e+01, 1.12909303e+02, 4.34380611e+01, 2.95478331e+01,
       5.85446973e+01, 9.42635786e+01, 5.83780479e+01, 1.87848591e+02,
       1.47336842e+02, 1.02112315e+02, 1.24017817e+02, 1.78369049e+02,
       1.46552082e+02, 2.11267960e+02, 1.87207178e+02, 2.18074203e+02,
       1.79710356e+02, 1.89998652e+02, 1.59787224e+02, 1.47483468e+02,
       1.71603969e+02, 1.08537384e+02, 1.74060277e+02, 1.35954460e+02,
       1.22316043e+02, 1.13092108e+02, 1.84901449e+02, 1.83151200e+02,
       1.85051275e+02, 1.56027039e+02, 2.37583379e+02, 1.11066722e+02,
       1.15594393e+02, 1.13501499e+02, 2.98245707e+02, 1.35955483e+02,
       1.45310469e+02, 1.37369485e+02, 9.12669747e+01, 9.37481720e+01,
       1.34354983e+02, 1.82613120e+02, 4.45796202e+00, 2.00619416e+01,
       1.91667256e+01, 9.53050436e+01, 3.86328698e+01, 9.56649845e+01,
       2.84417281e+01, 6.35864630e+01, 1.07860026e+02, 1.11369551e+02,
       0.00000000e+00, 1.43739984e+02, 1.01020674e+02, 9.43179777e+01,
       1.82449930e+02, 9.55308332e+01, 8.45812329e+01, 0.00000000e+00,
       6.62624423e+01, 8.29134753e+01, 1.36252990e+02, 3.42214622e+01,
       0.00000000e+00, 0.00000000e+00, 3.97951736e+01, 4.87758689e+01,
       2.83147240e+01, 2.63539598e+01, 1.10177861e+01, 0.00000000e+00,
       2.35812529e+01, 4.92018258e+01, 7.07659236e+01, 3.74680233e+01,
       5.51408019e+01, 5.22740325e+01, 5.50886160e+01, 4.08206874e+01,
       5.66656498e+01, 8.61157898e+01, 0.00000000e+00, 0.00000000e+00,
       8.72670422e+01, 1.05624926e+02, 0.00000000e+00, 0.00000000e+00,
       9.46025620e+01, 1.26616164e+02, 3.74221989e+01, 2.71567126e+01,
       3.66110447e+01, 4.26966290e+01, 1.20061041e+02, 9.28082416e+01,
       1.99708105e+01, 7.83039861e+01, 1.21777150e+02, 1.73113038e+02,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.58851259e+00,
       8.66887169e+00, 2.89029515e+02, 0.00000000e+00, 3.35261904e+01,
       1.94101993e+01]), array([196.74672828, 179.96343226, 168.75352985, 156.40442596,
       273.77701946, 225.72008396, 316.34127477, 205.59192851,
       180.09772761, 113.14998818, 132.73255736,  90.43903759,
       113.88490365,  84.68087005,  93.16264277, 159.54779328,
       211.59708203, 123.60169395, 209.17665116, 111.45139882,
        87.63702803, 120.91630737, 179.50011957, 182.07011169,
       163.35071981, 193.97488098, 189.68955704, 175.10600403,
       169.73945274, 237.50084399, 301.55177274, 142.08040136,
       196.56579907, 175.28393585, 127.87785939, 113.40030126,
       109.36253063, 160.70867031, 169.06862745, 139.97844016,
       294.05329945, 208.0164062 ,  96.74102959, 147.95730472,
        84.60605246, 107.17232929,  91.38851833,  85.34219615,
       179.82476904, 241.46668221, 405.38187679,  85.81707137,
       100.17225831, 113.33641583,   8.18064242,  69.69978351,
        78.36320959, 109.89773518,  17.6737608 , 217.74890486,
        78.41818958,  98.56692262, 118.24997694,  70.86697496,
        79.73485999,  41.28665311,  42.5293059 ,   8.63343538,
        61.18015348,   1.76575808,  53.84360315,  30.30954453,
       146.24895215, 134.01931814, 100.49245422, 200.30980624,
       134.02664231, 135.35410927, 150.00621943, 250.06977183,
       112.73566954,  49.24347756,  61.14162029,  55.74948175,
        52.85838991,  12.26581488,  57.60465807,  50.38611859,
        33.95307636,   7.2285711 ,  13.69958975,  45.98934698,
        30.90449742,  21.67023574,  22.50376671,  64.26400544,
        61.51704502,  15.64451207,  23.6066614 ,  39.41992281,
        18.19956299,  24.00117244, 134.9965918 ,  84.01530807,
       136.03054915,  31.16171198,  53.51223248,  36.43905526,
        30.92215284,  35.58818497,  31.27828183,  40.96725032,
        40.13692634,  27.82211943, 106.0548893 ,  86.60220711,
        79.61955639,  30.59354156,  17.86945447,  29.64476596,
        17.67144638,  60.4059516 ,  66.2631009 ,  35.62888598,
        25.00366644,  31.07102018,  51.45999472,  26.4690028 ,
        30.65832715,  27.98438989,  50.05735672,  30.4514607 ,
        11.91999808,  22.91890011,   4.05991421,  57.73230844,
        50.55809245,  27.2215008 ,  50.26496188,  -2.75028589,
        -3.28424265,  29.4255791 ,  55.03381422,  70.1524295 ,
        69.36647545, 150.23578898, 116.36223649, 179.30246519,
        45.81328439, 231.10931683,  33.1627597 ,  34.09809191,
        31.18627454,  85.53701624,  32.71614572,  44.21380535,
       231.84721274,  57.69642941,  41.5844294 , 123.96516376,
       219.82730373, 279.23936779,  33.23406881,  47.05002928,
        34.10721752,   2.08196777, 246.43901824,  78.71143421,
        86.11107733, 136.71609452,  80.40147848,   1.19411514,
         1.27449432,  22.51581031,  63.18334093,  91.28285463,
        50.20674231,  -1.19851965,   6.31895254,  51.10112157,
        -3.77486272,  41.86771225,  31.11079673,  65.18621172,
        61.27007692, 172.56774741,  19.80863745,  18.31430874,
       114.81514045, 124.96111714, 103.31851875, 114.19865935,
       115.45095956,  14.4576202 ,  13.9916986 ,  74.64357226,
       113.9203179 ,  42.51585353,  40.1671763 ,  43.27171444,
        32.53694311,  58.80001211,  93.03318272,  76.58392357,
        21.5494446 ,  21.18135721,  31.96577699,  38.19118453,
        25.15661349,   9.6581481 ,  33.40911864,  76.89731273,
        26.70764249,  30.15282395,  40.23279607,  46.52072698,
        20.79387669,  22.49576812,  15.23067928,  12.25226043,
        13.91842139,  18.35089493,  32.73443649, 189.84935298,
       202.15925867, 275.55942926, 148.1061608 , 181.33634079,
       156.87696436, 201.34992013, 300.74360356, 153.05561591,
       125.15087462, 117.63753064, 148.12293689, 304.31756692,
       294.0115748 , 164.31773635, 160.18574835, 175.94698864,
        89.90386083,  80.44036772, 235.82700336, 277.29293634,
       138.45524604,  92.64421952,  73.80173797,  71.56698192,
       145.61187719, 137.80393465, 246.0952762 , 193.24643644,
       241.26017382, 256.56552897,  86.75811624, 186.54549186,
       188.51707061,  89.124713  , 147.01392826, 173.30864599,
       189.48277925,  95.64555686, 122.37896519, 112.53013749,
       114.30165604, 148.37597645, 147.80747549, 186.01922003,
       164.17333001, 305.29221824, 299.15440211, 286.01772136,
       174.06259538, 262.42156085, 305.60914692, 182.43389031,
       201.23702269, 231.6401871 , 273.51552787, 189.17972058,
       252.13508893, 197.94395565,  84.86840081, 144.05349549,
       227.57869851, 147.20478148, 183.00031187, 166.28609422,
       135.19689256,  85.17744291,  78.1348687 ,  93.57119736,
        87.15634938, 111.10310722, 147.82061261, 140.66800959,
       127.59663671, 306.95154136, 301.8247913 , 261.52918639,
       166.50508369, 301.46357714, 271.41241085, 295.74347442,
       119.89900012, 121.76941912, 109.46903877, 133.97628789,
       117.34000478, 149.09521423, 133.98503268, 116.54043556,
       108.55484287,  39.75644357,  46.91565633,  39.79405545,
       120.56354045, 124.40845032,  58.36967443,  55.94801396,
        95.66113781, 107.40679975, 103.91083304,  35.45813199,
        62.04048507,  95.51533282,  46.00636761, 171.16601552,
       180.61329048, 163.17415398, 156.70112383, 194.92980171,
       112.60352713, 175.53194398, 192.81078507, 174.94332979,
       169.13151741, 180.54940412, 156.03513991, 121.80060889,
       157.33857811, 132.93765034, 213.85066296, 148.61353275,
       221.9052794 , 143.41764625, 203.78229769, 142.93904172,
       118.56102031, 110.14742506, 151.39764449, 195.94267554,
       113.50146679, 130.36775592, 310.8446476 , 142.6663423 ,
       209.54300071, 107.43338986, 191.95610696, 127.19851798,
       146.45091616, 190.47069801,  34.73921303,  41.01748052,
        29.47063564,  56.73536179,  63.61688902,  90.91881617,
        24.44960933,  85.92527313,  98.80246309, 105.79322834,
        22.57133243, 116.18768843,  57.23387041, 114.39860736,
       155.52593613,  83.78006649,  90.85260147,   6.48611064,
       108.86021313, 112.58568799, 100.0925411 ,  37.64088639,
        11.30702734,  28.01933676,  47.76850437,  53.52228699,
         9.33493348,  15.60112148,  33.18350963,  31.22660852,
        22.69611377,  55.11852909,  64.57071919,  49.2703072 ,
        43.368525  ,  51.83451378,  37.80311611,  25.84721041,
        55.23284594, 100.1884682 ,  17.08137102,  49.94998778,
       104.84290707,  76.49122156,  23.90159294,  35.03738007,
        58.81440701,  99.94029585,  24.43060075,  15.62955825,
         6.84223306,  18.14434089,  52.9649208 ,  27.94811987,
        21.37278936,  29.99294851,  74.26760715,  85.10023714,
        13.14290014,  30.05262253,  20.82975992,  61.29437562,
        52.6048876 ,  69.67364824,  52.03905118,  27.00442728,
        23.76382143]), OrderedDict([('R2', 0.7398782659651179), ('root_mean_squared_error', 45.468316515071265), ('mean_absolute_error', 32.01962429444243), ('rmse_over_stdev', 0.46906688064546165)]), None)
outdir = '/home/nerve/Desktop/skunkworks/lane_schultz/learning_curves/learning/StandardScaler/SequentialFeatureSelector/KernelRidge/RepeatedKFold/split_2'
label = 'Energy above convex hull (meV/atom)'

In [None]:
import pandas as pd
from IPython.display import Image, display

plot_paths = plot_predicted_vs_true(train_triple, test_triple, outdir)
for plot_path in plot_paths:
    display(Image(filename=plot_path))
