In [13]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import shap
import os
import sys

from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler

In [19]:
def plotR2(modelname):
    '''
    Script to visualize the performance of a given ML model.
    Requires outputs of model training in step3_train.py.
    Generates plots of actual permeability against predicted permeability for 6 gases.
    
    Arguments:
        modelname: name of the folder containing the trained model.
        i.e. 'DNN_BLR_fing'   
    '''
    #reset the working directory
    os.chdir(sys.path[0])
    
    #load in the data from the model training
    os.chdir(os.getcwd() + '/models/' + modelname)
    Y_train=pd.read_csv('Y_train.csv', header=None)
    Y_train = Y_train.to_numpy()
    Y_pred_train=pd.read_csv('Y_pred_train.csv', header=None)
    Y_pred_train = Y_pred_train.to_numpy()
    Y_test=pd.read_csv('Y_test.csv', header=None)
    Y_test = Y_test.to_numpy()
    Y_pred_test=pd.read_csv('Y_pred_test.csv', header=None)
    Y_pred_test = Y_pred_test.to_numpy()

    #plot the performance of the model for six gases
    Columns = ['He','H2','O2','N2','CO2','CH4']

    fig = plt.figure(figsize=(12,8))
    for i in range(6):
        ax=plt.subplot(2, 3, i + 1)
        index = i #adjust so that we look at the last 4 gases
        plt.plot(Y_train[:,index], Y_pred_train[:,index], '.', color='tab:purple')#, alpha=0.8, label="Train R^2: {}".format(r2_score(y_train, Y_pred_train)))
        plt.plot(Y_test[:,index], Y_pred_test[:,index], 'g.')#, alpha=0.8, label="Test R^2: {}".format(r2_score(y_test, Y_pred_test)))
        plt.legend(['Train', 'Test'], loc = 'best')
        plt.xlabel(Columns[i]) 
        plt.ylabel("Predicted value")
        x0, x1 = min(Y_train[:,index]), max(Y_train[:,index])
        length = x1 - x0
        x_start, x_end = x0-0.1*length, x1+0.1*length
        plt.xlim([x_start, x_end])
        plt.ylim([x_start, x_end])
        
        # the unit line
        plt.plot(np.arange(x_start, x_end, 0.01*length),
        np.arange(x_start, x_end, 0.01*length), '-', color='tab:gray')
        plt.text(x_end - 0.7*length, x_start + 0.15*length, "$Train R^2={:.2f}$".format(r2_score(Y_train[:,index], Y_pred_train[:,index])))
        plt.text(x_end - 0.7*length, x_start + 0.05*length, "$Test R^2={:.2f}$".format(r2_score(Y_test[:,index], Y_pred_test[:,index])))
    
    plt.show()

def plotSHAP(modelname):
    '''
    Visualize the results of SHAP for a given ML model.
    Requires SHAP values to be saved from step3.5_SHAP.py.
    Generates (a) a SHAP summary plot for prediction of each permeability
    and (b) a bar graph showing the top 12 most important chemical features
    overall with respective impacts on 6 gas permeabilities.
    Arguments:
        modelname: name of the folder containing the trained model.
        i.e. 'DNN_BLR_fing'
    '''
    #reset the working directory
    os.chdir(sys.path[0])

    modeltype = modelname.split('_')[0]
    imputation = modelname.split('_')[1]
    features = modelname.split('_')[2]
    maindirectory = os.getcwd() + '/models/' + modelname
    
    X_df = pd.read_csv(os.getcwd() + '/datasets/datasetAX_' + features + '.csv')

    X = np.array(X_df)
    Xscaler = StandardScaler()
    X= Xscaler.fit_transform(X)

    os.chdir(maindirectory)
    shap_values = avg_shap_values = np.zeros((6, X.shape[0], X.shape[1]))
    for i in range(6):
        shap_values[i, :, :] = pd.read_csv('shap_' + str(i) + '.csv', header=None)

    for i in range(6):
        fig = plt.figure()
        shap.summary_plot(shap_values[i], X, show=False)
    plt.show()

    wts = np.zeros((X_df.shape[1],6))
    for i in range(6):
        wts[:,i] = np.mean(np.abs(shap_values[i]),0)
    wts = pd.DataFrame(wts)
    wts['sum'] = np.sum(wts, axis=1)
    wts['id'] = X_df.columns
    ordered_wts = wts.sort_values('sum', ascending=False)
    top_ordered_wts = ordered_wts.iloc[:12,:]
    top_ordered_wts = top_ordered_wts.rename(top_ordered_wts['id'])
    fig = plt.figure(figsize=(20, 5))
    top_ordered_wts.iloc[:,:6].plot(kind='bar', colormap='Set2')
    plt.legend(['He','H2','O2','N2','CO2','CH4'])
    plt.ylabel('Avg Magnitude of SHAP Value')
    plt.show()

def plotRobeson(filelist):
    '''
    Visualize permeability data in the O2/N2, CO2/CH4, CO2/N2, and H2/CO2 Robeson spaces.
    Data must be a csv file with 6 columns in the order of ['He','H2','O2','N2','CO2','CH4'] permeabilities.
    Update the filelist list with paths to each of the .csv files that are desired.
    Works with the .csv outputs of screen.py and train.py.

    Arguments:
        list containing the paths to each of the outputed prediction files
        i.e. ['models/DNN_BLR_fing/Y_pred_datasetCX_fing_0.csv', ..]
    '''
    #reset the working directory
    os.chdir(sys.path[0])
    
    #plot permeability values on a Robeson plot
    sns.set_palette("bright")
    fig = plt.figure(figsize=(12,8))

    for i, dataset in enumerate(filelist):
        #read in the permeability data for each dataset
        Y_pred = pd.read_csv(dataset, header=None)
        Y_pred = np.array(Y_pred)

        #O2/N2 separations
        plt.subplot(2,2,1)
        permeability = Y_pred[:,-4]
        selectivity = Y_pred[:,-4] - Y_pred[:,-3]
        plt.plot(permeability, selectivity, '.', alpha = 0.2)

        #CO2/CH4 separations
        plt.subplot(2,2,2)
        permeability = Y_pred[:,-2]
        selectivity = Y_pred[:,-2] - Y_pred[:,-1]
        plt.plot(permeability, selectivity, '.', alpha=0.2)

        #CO2/N2 separations
        plt.subplot(2,2,3)
        permeability = Y_pred[:,-2]
        selectivity = Y_pred[:,-2] - Y_pred[:,-3]
        plt.plot(permeability, selectivity, '.', alpha=0.2)

        #H2/CO2 separations
        plt.subplot(2,2,4)
        permeability = Y_pred[:,-5]
        selectivity = Y_pred[:,-5] - Y_pred[:,-2]
        plt.plot(permeability, selectivity, '.', alpha=0.2)
        
    #format the plot and add Robeson upper bounds
    plt.subplot(2,2,1)
    xmin = -4
    xmax = 7
    plt.xlim([xmin, xmax])
    plt.ylim([-1, 2])
    plt.plot([xmin, xmax], [np.log10(9.2008)-0.1724*xmin, np.log10(9.2008)-0.1724*xmax], '-k') #1991 upper bound
    plt.plot([xmin, xmax], [np.log10(12.148)-0.1765*xmin, np.log10(12.148)-0.1765*xmax], '--k') #2008 upper bound
    plt.plot([xmin, xmax], [np.log10(18.50)-0.1754*xmin, np.log10(18.50)-0.1754*xmax], ':k') #2015 upper bound
    plt.title("O2/N2 Separations")

    plt.subplot(2,2,2)
    xmin = -2
    xmax =7
    plt.xlim([xmin, xmax])
    plt.ylim([-2, 4])
    plt.plot([xmin, xmax], [np.log10(197.81)-0.3807*xmin, np.log10(197.81)-0.3807*xmax], '-k') #1991 upper bound
    plt.plot([xmin, xmax], [np.log10(357.33)-0.3794*xmin, np.log10(357.33)-0.3794*xmax], '--k') #2008 upper bound
    plt.plot([xmin, xmax], [np.log10(1155.60)-0.4165*xmin, np.log10(1155.60)-0.4165*xmax], ':k') #2019 upper bound
    plt.title('CO2/CH4 Separations')

    plt.subplot(2,2,3)
    xmin = -2
    xmax =7
    plt.xlim([xmin, xmax])
    plt.ylim([-1, 3])
    plt.plot([xmin, xmax], -1/2.888*np.array([-np.log10(30967000)+xmin, -np.log10(30967000)+xmax]), '--k') #2008 upper bound
    plt.plot([xmin, xmax], -1/3.409*np.array([-np.log10(755.58e6)+xmin, -np.log10(755.58e6)+xmax]), ':k') #2019 upper bound
    plt.title('CO2/N2 Separations')

    plt.subplot(2,2,4)
    xmin = -2
    xmax =7
    plt.xlim([xmin, xmax])
    plt.ylim([-1.5, 2])
    plt.plot([xmin, xmax], -1/1.9363*np.array([-np.log10(1200)+xmin, -np.log10(1200)+xmax]), '-k') #1991 upper bound
    plt.plot([xmin, xmax], -1/2.302*np.array([-np.log10(4515)+xmin, -np.log10(4515)+xmax]), '--k') #2008 upper bound
    plt.title('H2/CO2 Separations')

    plt.show()