This notebook was automatically generated from your MAST-ML run so you can recreate the
plots. Some things are a bit different from the usual way of creating plots - we are
using the [object oriented
interface](https://matplotlib.org/tutorials/introductory/lifecycle.html) instead of
pyplot to create the `fig` and `ax` instances.


In [None]:
"""
This module contains a collection of functions which make plots (saved as png files) using matplotlib, generated from
some model fits and cross-validation evaluation within a MAST-ML run.

This module also contains a method to create python notebooks containing plotted data and the relevant source code from
this module, to enable the user to make their own modifications to the created plots in a straightforward way (useful for
tweaking plots for a presentation or publication).
"""

import math
import os
import pandas as pd
import itertools
import warnings
import logging
from collections import Iterable
from os.path import join
from collections import OrderedDict

# Ignore the harmless warning about the gelsd driver on mac.
warnings.filterwarnings(action="ignore", module="scipy",
                        message="^internal gelsd")
# Ignore matplotlib deprecation warning (set as all warnings for now)
warnings.filterwarnings(action="ignore")

import numpy as np
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure, figaspect
from matplotlib.animation import FuncAnimation
from matplotlib.font_manager import FontProperties
import matplotlib.mlab as mlab
from scipy.stats import gaussian_kde
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes

# Needed imports for ipynb_maker
from mastml.utils import nice_range
from mastml.metrics import nice_names

import inspect
import textwrap
from pandas import DataFrame, Series

import nbformat

from functools import wraps

matplotlib.rc('font', size=18, family='sans-serif') # set all font to bigger
matplotlib.rc('figure', autolayout=True) # turn on autolayout

# adding dpi as a constant global so it can be changed later
DPI = 250

log = logging.getLogger() # only used inside ipynb_maker I guess



In [None]:
def stat_to_string(name, value):
    """
    Method that converts a metric object into a string for displaying on a plot

    Args:

        name: (str), long name of a stat metric or quantity

        value: (float), value of the metric or quantity

    Return:

        (str), a string of the metric name, adjusted to look nicer for inclusion on a plot

    """

    " Stringifies the name value pair for display within a plot "
    if name in nice_names:
        name = nice_names[name]
    else:
        name = name.replace('_', ' ')

    # has a name only
    if not value:
        return name
    # has a mean and std
    if isinstance(value, tuple):
        mean, std = value
        return f'{name}:' + '\n\t' + f'{mean:.3f}' + r'$\pm$' + f'{std:.3f}'
    # has a name and value only
    if isinstance(value, int) or (isinstance(value, float) and value%1 == 0):
        return f'{name}: {int(value)}'
    if isinstance(value, float):
        return f'{name}: {value:.3f}'
    return f'{name}: {value}' # probably a string


def plot_stats(fig, stats, x_align=0.65, y_align=0.90, font_dict=dict(), fontsize=14):
    """
    Method that prints stats onto the plot. Goes off screen if they are too long or too many in number.

    Args:

        fig: (matplotlib figure object), a matplotlib figure object

        stats: (dict), dict of statistics to be included with a plot

        x_align: (float), float denoting x position of where to align display of stats on a plot

        y_align: (float), float denoting y position of where to align display of stats on a plot

        font_dict: (dict), dict of matplotlib font options to alter display of stats on plot

        fontsize: (int), the fontsize of stats to display on plot

    Returns:

        None

    """

    stat_str = '\n'.join(stat_to_string(name, value)
                           for name,value in stats.items())

    fig.text(x_align, y_align, stat_str,
             verticalalignment='top', wrap=True, fontdict=font_dict, fontproperties=FontProperties(size=fontsize))


def make_fig_ax(aspect_ratio=0.5, x_align=0.65, left=0.10):
    """
    Method to make matplotlib figure and axes objects. Using Object Oriented interface from https://matplotlib.org/gallery/api/agg_oo_sgskip.html

    Args:

        aspect_ratio: (float), aspect ratio for figure and axes creation

        x_align: (float), x position to draw edge of figure. Needed so can display stats alongside plot

        left: (float), the leftmost position to draw edge of figure

    Returns:

        fig: (matplotlib fig object), a matplotlib figure object with the specified aspect ratio

        ax: (matplotlib ax object), a matplotlib axes object with the specified aspect ratio

    """
    # Set image aspect ratio:
    w, h = figaspect(aspect_ratio)
    fig = Figure(figsize=(w,h))
    FigureCanvas(fig)

    # Set custom positioning, see this guide for more details:
    # https://python4astronomers.github.io/plotting/advanced.html
    #left   = 0.10
    bottom = 0.15
    right  = 0.01
    top    = 0.05
    width = x_align - left - right
    height = 1 - bottom - top
    ax = fig.add_axes((left, bottom, width, height), frameon=True)
    fig.set_tight_layout(False)
    
    return fig, ax


In [None]:
def plot_predicted_vs_true(train_quad, test_quad, outdir, label):
    """
    Method to create a parity plot (predicted vs. true values)

    Args:

        train_quad: (tuple), tuple containing 4 numpy arrays: true training y data, predicted training y data,

        training metric data, and groups used in training

        test_quad: (tuple), tuple containing 4 numpy arrays: true test y data, predicted test y data,

        testing metric data, and groups used in testing

        outdir: (str), path to save plots to

        label: (str), label used for axis labeling

    Returns:

        None

    """
    filenames = list()
    y_train_true, y_train_pred, train_metrics, train_groups = train_quad
    y_test_true, y_test_pred, test_metrics, test_groups = test_quad

    # make diagonal line from absolute min to absolute max of any data point
    # using round because Ryan did - but won't that ruin small numbers??? TODO this
    max1 = max(y_train_true.max(), y_train_pred.max(),
               y_test_true.max(), y_test_pred.max())
    min1 = min(y_train_true.min(), y_train_pred.min(),
               y_test_true.min(), y_test_pred.min())
    max1 = round(float(max1), rounder(max1-min1))
    min1 = round(float(min1), rounder(max1-min1))
    for y_true, y_pred, stats, groups, title_addon in \
            (train_quad+('train',), test_quad+('test',)):

        # make fig and ax, use x_align when placing text so things don't overlap
        x_align=0.64
        fig, ax = make_fig_ax(x_align=x_align)

        # set tick labels
        # notice that we use the same max and min for all three. Don't
        # calculate those inside the loop, because all the should be on the same scale and axis
        _set_tick_labels(ax, max1, min1)

        # plot diagonal line
        ax.plot([min1, max1], [min1, max1], 'k--', lw=2, zorder=1)

        # do the actual plotting
        if groups is None:
            ax.scatter(y_true, y_pred, color='blue', edgecolors='black', s=100, zorder=2, alpha=0.7)
        else:
            handles = dict()
            unique_groups = np.unique(np.concatenate((train_groups, test_groups), axis=0))
            log.debug(' '*12 + 'unique groups: ' +str(list(unique_groups)))
            colors = ['blue', 'red', 'green', 'purple', 'orange', 'black']
            markers = ['o', 'v', '^', 's', 'p', 'h', 'D', '*', 'X', '<', '>', 'P']
            colorcount = markercount = 0
            for groupcount, group in enumerate(unique_groups):
                mask = groups == group
                log.debug(' '*12 + f'{group} group_percent = {np.count_nonzero(mask) / len(groups)}')
                handles[group] = ax.scatter(y_true[mask], y_pred[mask], label=group, color=colors[colorcount],
                                            marker=markers[markercount], s=100, alpha=0.7)
                colorcount += 1
                if colorcount % len(colors) == 0:
                    markercount += 1
                    colorcount = 0
            ax.legend(handles.values(), handles.keys(), loc='lower right', fontsize=12)

        # set axis labels
        ax.set_xlabel('True '+label, fontsize=16)
        ax.set_ylabel('Predicted '+label, fontsize=16)

        plot_stats(fig, stats, x_align=x_align, y_align=0.90)

        filename = 'predicted_vs_true_'+ title_addon + '.png'
        filenames.append(filename)
        fig.savefig(join(outdir, filename), dpi=DPI, bbox_inches='tight')

    return filenames


In [None]:
from numpy import array
from collections import OrderedDict
from io import StringIO
train_quad = (array([150.37102595, 202.0924573 , 325.49788774, ...,   0.        ,
        23.809066  ,   0.        ]), array([162.83452379, 142.86078628, 269.2296665 , ...,  73.09833392,
        52.17641948,   8.94889157]), OrderedDict([('R2', 0.7671637058292502), ('root_mean_squared_error', 46.93815869170976), ('mean_absolute_error', 31.81884078555601), ('rmse_over_stdev', 0.4868729886212817)]), None)
test_quad = (array([177.28302038, 200.26888144, 218.80057061, 193.03789201,
       201.40211278, 407.14718606, 259.72112607, 110.462825  ,
        95.94694389, 202.29096841, 121.012896  , 177.34632637,
        88.06968527, 181.77618797, 146.05982872, 170.82835436,
       142.48339282,  59.67920269,  72.36735708,  88.46084748,
       291.95555798,  37.9332462 ,  58.99391702, 109.47390069,
        63.04573585,  25.66391325, 121.57100715, 118.30376345,
        87.31329665, 236.17731536,  70.97342159, 417.7519315 ,
        99.06622089,  66.0766101 ,  73.01999556, 515.80533047,
       209.8692985 , 272.8043925 ,  78.67122954, 178.2357388 ,
       227.4912848 , 136.31607677, 122.31135567, 198.06185912,
       128.36434487,  32.0094037 ,   0.        , 179.25872313,
        30.40327127,   0.678471  ,  36.52821981,  56.67383667,
        15.621525  , 128.57645296, 197.42137823,   0.        ,
        74.58610407,  78.02361062, 149.61116727,  66.14597382,
        51.86444618,  50.49019186, 198.10666422, 206.15823625,
       288.48656282, 118.3026576 , 219.40213351,  59.63498938,
        59.05468284,  34.47065573,   0.        ,  27.99263983,
        29.94908075,  47.01813711,  16.66840972,   0.        ,
         0.        ,   0.59343723,   0.        ,   0.        ,
         0.        ,  94.50620971,  25.36959074,  15.44290641,
        22.3715509 , 108.22466848,  33.43611227,  67.08510148,
        53.84255407,   0.        ,  40.25443982,  40.83543394,
        44.85051669,  26.7022126 , 194.26853657,   0.        ,
        42.29589311,  29.653255  ,   0.        ,   0.        ,
         0.        ,  34.36901768,  89.91354036,   0.        ,
        13.80460838,  42.32480613,  88.38029969, 141.47054363,
        74.7468692 ,  79.23257647,  30.17024475,  44.856381  ,
        45.433214  ,   0.        ,  47.26926175,  35.46209575,
        31.378064  ,   0.        ,   0.        ,  78.00949578,
       132.10902428,  75.88549386, 278.98466522, 185.40421066,
       283.90645336, 138.89285472,  59.1838239 , 151.55367115,
       170.3886173 ,  69.97238548, 122.31135567,  15.00782838,
         0.        , 130.59739217, 166.18742887, 258.66530956,
        69.47351553, 284.89818957, 298.36883132, 232.45401086,
         0.        ,  48.46453063,  77.8684422 , 211.4607162 ,
       286.01491902, 139.34508133, 299.69248478, 113.90148661,
        48.04117053,  35.07688959,  29.18020703,   0.        ,
        33.65498379, 100.44019458,   0.        ,   0.        ,
         0.        ,   0.        ,  18.44019411,   0.        ,
        51.31971375,   0.        ,  52.52667425,  48.18377975,
        45.52619143,  47.92046585,  17.0719435 ,   0.        ,
        50.94866855,  75.46581208,  67.74394496,  82.92662983,
       135.99188516,  56.54071259,  43.1541225 ,  76.83212346,
        33.27336669,  37.03306019,  42.81588044,  28.24964357,
         0.        ,   0.        ,  65.03154095, 214.36416907,
         0.        , 141.4711659 , 154.69746659, 122.98024015,
       134.6716424 ,  83.02690466,  10.35277113,  33.67375834,
        51.95072185,  19.2994335 ,  95.22684338, 103.2663285 ,
        44.01037784,   0.        , 137.88193919,  57.14438148,
         0.        ,   9.36490231,   0.        ,  69.00741775,
         7.56564178, 129.60153223, 136.43144375, 245.69243888,
       202.26980845, 193.22945633, 207.90480672, 168.86730174,
       157.97280294, 342.04001938, 336.22469679,  90.31694834,
        91.53400436, 106.514784  ,  99.68134329, 178.93978486,
       187.87580019, 241.66495314, 181.29528486, 249.72472862,
       196.07415857, 131.76146218, 153.2465035 , 130.34718776,
       136.29376713, 126.70082241, 149.94833269, 189.88100424,
       292.39066312, 192.28465626, 220.16151104, 100.86488318,
       141.5352297 , 190.07525544, 205.09292614, 210.33572768,
       103.16515278, 231.48185529, 237.06234447, 294.31041312,
       227.87867171, 201.32142318, 337.2143761 , 205.56403043,
       197.12412955,  96.99151193, 149.64777796, 135.16475661,
       166.21281277, 247.00509507,  98.39446077, 141.92110132,
       140.26518176,  94.43350892, 188.51316397,  95.45473206,
       131.97759503, 213.33322768, 159.4014052 , 197.30900025,
       116.27871809, 259.08638129, 329.04283429, 245.48897712,
       250.99584507, 252.00105048, 247.47424665, 252.61149593,
       323.63415062, 346.90203425, 318.43464993, 141.45345548,
        84.93486971,  79.9859482 ,  55.94984366, 148.87388686,
       179.65931277, 109.38101591, 182.22098097, 123.03536001,
        64.02417181, 139.23135132, 158.4134052 , 165.45646046,
       125.57114361, 202.52553418, 312.11562515, 242.96546494,
       313.927543  , 347.29590062, 209.325439  ,  98.80490655,
       101.8838757 ,  82.11902013,  92.70646802,  46.57820663,
       115.14222467, 108.31218227, 110.48019418,  46.88266895,
        12.23389363,  24.231722  ,  62.02428034,  49.6708469 ,
        22.02491282,  29.54783309, 131.67141422,  48.8429076 ,
        75.98358167,  41.58359043,  58.54469733,  27.21757246,
       121.55357738, 102.11231532, 173.38559503, 197.17633245,
       188.27143603, 160.45672687, 185.414463  , 182.10184863,
       211.26795997, 218.07420311, 189.99865166, 122.37095632,
        30.77112483,  29.88706053,  91.46208074, 100.81566362,
       174.06027747, 106.74575874, 105.72217961, 258.82805093,
       117.24705906, 104.37297146, 184.59131561, 185.66169375,
       237.58337883, 291.09907332, 112.28738628, 144.11116592,
       137.36948495, 121.69241919, 191.31459627, 263.54337143,
       182.61311972, 290.22157108, 283.09312767,  31.31275089,
        37.32259412,  49.79361192,  35.28329862,  10.43143757,
         0.        , 129.16762848, 123.68971028,  97.25569603,
        90.435045  , 155.90986428,   6.92239663,   0.        ,
        17.84640765,  93.79066576,   0.        ,  72.67252036,
        34.2214622 ,   0.        ,   0.        ,  21.28138975,
        93.06800775,  11.01778613,   0.        ,  32.36845034,
       120.02645236,  31.76337063,  38.65176325,  37.78468525,
        29.90345073,  17.21211269,  33.75589221, 184.38740476,
        76.50326643,  82.99657195,  60.54108325,  69.473884  ,
        52.14150703,   0.        ,  44.40936128,   0.        ,
       152.57088652,  84.26131862,  86.51921261, 157.61174657,
        71.76198016,  36.22347961,   0.        ,   0.        ,
        94.60256203, 126.61616369,  27.25161269,  46.35809831,
        29.36909727,  27.15671258,  25.95621478,  94.00578421,
        36.61104469,  80.15612571,  66.50124863,  17.00030381,
         0.        ,  45.40761163,   7.0009685 ,  27.90730325,
        19.9708105 ,   5.974325  ,  25.1156351 ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,  33.52619036,
        47.20790294]), array([ 2.26790255e+02,  1.79142678e+02,  1.94692652e+02,  1.58660487e+02,
        1.56859722e+02,  2.74031827e+02,  3.30339589e+02,  9.59429435e+01,
        8.79104417e+01,  1.52439487e+02,  1.38825528e+02,  1.41135743e+02,
        1.08877799e+02,  1.71209209e+02,  1.65632099e+02,  1.73733515e+02,
        1.46514494e+02,  9.64498219e+01,  8.48421868e+01,  8.85164791e+01,
        2.87600849e+02,  9.18086919e+01,  1.29020242e+02,  2.06322772e+02,
        2.02123036e+01,  5.02927157e+01,  1.14120820e+02,  1.31508986e+02,
        7.93739243e+01,  1.00151929e+02,  8.69793471e+01,  2.45091087e+02,
        8.19896339e+01,  8.79377190e+01,  9.11616632e+01,  5.08587510e+02,
        6.62440336e+01,  9.39773298e+01,  7.84804512e+01,  8.22573629e+01,
        1.33227744e+02,  1.46222620e+02,  1.10997186e+02,  1.70049053e+02,
        1.66040811e+02,  1.05921922e+02,  1.08096177e+02,  1.23434723e+02,
        7.33577105e+01,  1.17368673e+02,  1.43344402e+02,  1.39231548e+02,
        1.10436602e+02,  1.55543977e+02,  1.98506901e+02,  2.74743190e+01,
        1.09148112e+02,  6.65805718e+01,  1.56555259e+02,  7.37678613e+01,
        9.58066560e+01,  5.56549206e+01,  1.54749103e+02,  1.20516308e+02,
        2.86458051e+02,  1.55373670e+02,  1.96680327e+02,  7.84283142e+01,
        4.93477109e+01,  5.82304579e+01,  5.87322915e+01,  5.64808133e+01,
        5.37348655e+01,  7.58040135e+01,  3.68253197e+01,  3.25885874e+01,
       -4.48490295e-01,  6.46971758e+01,  2.59117388e+01,  1.48096535e+01,
        2.15920935e+01,  9.12957886e+01,  4.88160125e+01,  4.29508550e+01,
        5.43920145e+01,  5.90348601e+01,  3.84822656e+01,  1.24291544e+02,
        3.74329325e+01,  1.40211032e-01,  3.09128403e+01,  3.61926619e+01,
        4.24949334e+01,  4.47452661e+01,  2.34850064e+02,  3.61818878e+01,
        8.68075165e+01,  2.71053878e+01,  6.24728687e+01,  2.80643081e+00,
        3.27280795e+01,  5.05076965e+01,  9.55310178e+01,  2.00942724e+01,
        3.40368649e+01,  2.82973979e+01,  6.37252626e+01,  7.57436869e+01,
        3.13721333e+01,  4.79871216e+01,  2.35400867e+01,  1.92267520e+01,
        2.99191328e+01,  1.72342255e+01,  2.87316113e+01,  9.92974909e+00,
        3.95176336e+01, -2.16128059e+00, -2.25157959e+00,  6.46329100e+01,
        1.49923951e+02,  9.80349826e+01,  3.32787081e+02,  1.73935965e+02,
        2.34157864e+02,  8.99141463e+01,  4.22350469e+01,  8.29008596e+01,
        1.25746356e+02,  7.25307262e+01,  1.10997186e+02,  1.27209246e+02,
        5.32801539e+01,  9.70457845e+01,  2.16808661e+02,  3.02775695e+02,
        9.48367906e+01,  3.27476501e+02,  1.59824003e+02,  9.81419119e+01,
        3.19403827e+01,  1.14063447e+02,  3.76383513e+01,  4.17430563e+01,
        2.23793947e+02,  1.29132206e+02,  2.10725549e+02,  1.14873140e+02,
        7.63962263e+01,  6.91455548e+01,  3.19722826e+01, -3.28461268e+00,
        3.96011796e+01,  5.94904696e+01,  1.80630429e+01,  1.01868275e+01,
        5.44098761e+00,  1.70120614e+01,  3.36311683e+01,  1.30740238e+01,
        6.93912367e+01, -1.04809000e+00,  9.27771503e+01,  3.97946535e+01,
        1.58881866e+01,  2.06759791e+01,  2.52128046e+01, -6.97702365e+00,
        3.59674782e+01,  8.30916141e+01,  8.10931692e+01,  8.14351503e+01,
        1.15880638e+02,  8.47226776e+01,  7.81995989e+01,  9.52655917e+01,
        5.83307393e+01,  6.04143500e+01,  5.59673918e+01,  5.88797930e+01,
        2.13995354e-01,  1.92953148e+01,  2.74159296e+01,  2.68614580e+02,
        9.88867727e+01,  1.18930692e+02,  1.36740596e+02,  1.14647483e+02,
        1.18289288e+02,  7.18766865e+01,  2.87772736e+01,  2.53956279e+01,
        1.13117592e+01,  3.14508145e+01,  3.50964191e+01,  2.85945790e+01,
        3.78417853e+01,  2.91355099e+01,  4.47602945e+01,  2.83811725e+01,
        1.99237108e+01,  3.89989236e+01,  1.99342810e+01,  1.17099103e+01,
        1.30504151e+01,  1.97472721e+02,  2.08629382e+01,  3.15498770e+02,
        2.87849738e+02,  1.76029977e+02,  2.01836625e+02,  1.25049907e+02,
        1.20972097e+02,  2.79008862e+02,  2.59598355e+02,  1.47848410e+02,
        1.55562551e+02,  1.39762898e+02,  1.15399944e+02,  2.55567600e+02,
        1.61233377e+02,  2.99301312e+02,  2.56344906e+02,  1.64055914e+02,
        2.26951804e+02,  1.08804132e+02,  9.30209343e+01,  9.89500171e+01,
        8.70000758e+01,  1.84541027e+02,  8.55788269e+01,  1.43708195e+02,
        2.81378623e+02,  1.86785797e+02,  2.70655238e+02,  8.64929196e+01,
        1.28343593e+02,  1.40933956e+02,  7.52772443e+01,  1.57708967e+02,
        1.37077881e+02,  1.46548723e+02,  2.25260637e+02,  2.81893704e+02,
        2.46356877e+02,  1.74515065e+02,  2.07990412e+02,  2.11329990e+02,
        1.85256118e+02,  8.33590484e+01,  1.47382521e+02,  1.49048695e+02,
        1.97241867e+02,  2.63837701e+02,  8.31058458e+01,  1.22899319e+02,
        1.21070384e+02,  7.19611157e+01,  7.61988346e+01,  6.58852001e+01,
        1.14418331e+02,  1.57763908e+02,  1.56963879e+02,  1.41886748e+02,
        1.36048626e+02,  1.38934976e+02,  3.13329875e+02,  2.31750331e+02,
        2.65099255e+02,  2.11850674e+02,  2.03018079e+02,  2.27088642e+02,
        2.97567554e+02,  2.30447234e+02,  1.85621171e+02,  1.79612606e+02,
        8.82900913e+01,  9.07286566e+01,  8.45137028e+01,  2.04752743e+02,
        2.01459517e+02,  1.37984955e+02,  2.29055020e+02,  1.44633293e+02,
        8.32433484e+01,  1.22496973e+02,  1.56956501e+02,  1.40813651e+02,
        1.26566050e+02,  2.38434192e+02,  3.03908831e+02,  3.07056765e+02,
        3.03460616e+02,  3.01924408e+02,  1.61977072e+02,  1.23062477e+02,
        1.31821000e+02,  1.21670330e+02,  1.11798861e+02,  1.10121854e+02,
        1.35659606e+02,  1.24192580e+02,  1.24021388e+02,  9.91193094e+01,
        3.81172200e+01,  8.98166578e+01,  9.48592903e+01,  9.25268408e+01,
        3.86907019e+01,  2.55813338e+01,  1.38125372e+02,  5.83548401e+01,
        7.31045983e+01,  6.67291891e+01,  6.40992247e+01,  2.69131238e+01,
        1.59688247e+02,  1.61090108e+02,  1.56642415e+02,  1.76535002e+02,
        1.72203616e+02,  1.52475864e+02,  1.23368934e+02,  1.75000226e+02,
        1.74368809e+02,  1.73434356e+02,  1.77517482e+02,  1.18293321e+02,
        5.90251971e+01,  5.34308513e+01,  1.22360965e+02,  1.21752817e+02,
        2.16536425e+02,  2.00546503e+02,  1.13880004e+02,  3.17567500e+02,
        1.17160310e+02,  1.11882420e+02,  1.14878011e+02,  1.20715214e+02,
        1.53029912e+02,  2.05093678e+02,  1.01589092e+02,  1.32510358e+02,
        1.05237862e+02,  1.05162813e+02,  1.50305626e+02,  3.16768713e+02,
        1.91259678e+02,  8.04456032e+01,  1.17137239e+02,  5.11785947e+01,
        4.31118683e+01,  7.80764021e+01,  2.97548777e+01,  5.32874582e+01,
        2.93925007e+01,  1.19729585e+02,  1.15881491e+02,  1.05603114e+02,
        1.28169775e+02,  5.29098673e+01,  2.63510694e+01,  2.26633741e+01,
        5.86632777e+01,  8.57731861e+01,  8.13929667e-01,  7.67803687e+01,
        4.30307433e+01,  8.28557354e+00,  9.72161012e-01,  1.14582232e+00,
        6.32178462e+01,  2.96407889e+01,  2.99373780e+01,  2.42093972e+01,
        6.50189639e+01,  2.76017767e+01,  2.98690148e+01,  2.91514263e+01,
        3.74804582e+01,  2.50956076e+01,  2.87503844e+01,  1.15565058e+02,
        5.80612115e+01,  7.42199605e+01,  4.15865994e+01,  3.51648894e+01,
        4.23659315e+01, -1.47364608e-01,  3.96799494e+01,  2.21814245e+01,
        1.63045818e+02,  1.35755803e+02,  1.45716546e+02,  2.12444046e+02,
        7.48519560e+01,  5.69952749e+01,  3.45863636e+01,  5.23055274e+01,
        5.96421570e+01,  9.93239187e+01,  2.56418897e+01,  2.95829267e+01,
        5.12500356e+00,  1.28933722e+01,  3.61548077e+01,  5.80785131e+01,
        8.16908874e+00,  7.73957845e+01,  3.33999825e+01,  2.53729979e+01,
        2.73884486e+01,  3.59077835e+01,  5.76729822e+01,  2.23035353e+01,
        2.06028980e+01,  1.74499305e+01,  3.43521371e+01,  2.10842863e+01,
        2.08095432e+01,  1.95718386e+01,  1.25770365e+01,  3.31806805e+01,
        3.23152136e+01, -2.22683026e+00, -5.14711947e+00,  3.03689394e+01,
        6.59927914e+01]), OrderedDict([('R2', 0.7391445611071926), ('root_mean_squared_error', 46.6760373162358), ('mean_absolute_error', 33.172489830495095), ('rmse_over_stdev', 0.4841540959119035)]), None)
outdir = '/home/nerve/Desktop/skunkworks/lane_schultz/learning_curves/learning/StandardScaler/SequentialFeatureSelector/KernelRidge/RepeatedKFold/split_5'
label = 'Energy above convex hull (meV/atom)'

In [None]:
import pandas as pd
from IPython.display import Image, display

plot_paths = plot_predicted_vs_true(train_triple, test_triple, outdir)
for plot_path in plot_paths:
    display(Image(filename=plot_path))
