In [1]:
import sys, os
sys.path.append('/work/mflora/ROAD_SURFACE')

import sage 
import pandas as pd
from joblib import load
import numpy as np
import xarray as xr
import pickle

from skexplain.common.importance_utils import to_skexplain_importance
from wofs_ml_severe.wofs_ml_severe.common.multiprocessing_utils import run_parallel, to_iterator
from wofs_ml_severe.wofs_ml_severe.common.emailer import Emailer 

from probsr_config import TARGET_COLUMN, PREDICTOR_COLUMNS, FIGURE_MAPPINGS
from calibration_classifier import CalibratedClassifier

In [2]:
BASE_PATH = '/work/mflora/explainability_work/'
DATA_BASE_PATH = os.path.join(BASE_PATH, 'datasets')
MODEL_BASE_PATH = os.path.join(BASE_PATH, 'models')

In [3]:
def compute_sage(model, X, y, background):
    """Compute SAGE"""
    # Set up an imputer to handle missing features
    random_state = np.random.RandomState(42)
    random_inds = np.random.choice(len(background), size=100, replace=False)
    try:
        X_rand = background.values[random_inds,:]
    except:
        X_rand = background[random_inds,:]
    
    # Set up the imputer. 
    imputer = sage.MarginalImputer(model.predict_proba, X_rand)

    # Set up an estimator. 
    estimator = sage.PermutationEstimator(imputer, 'cross entropy')

    print(np.shape(X))
    
    sage_values = estimator(X, y)
    
    return sage_values

In [4]:
def subsampler(X,y, size=100000):
    random_state = np.random.RandomState(42)
    inds = random_state.choice(len(X), size=size, replace=False)
    
    X_sub = X.iloc[inds]
    y_sub = y[inds]
    
    X_sub.reset_index(drop=True, inplace=True)
    
    return X_sub, y_sub

In [5]:
time='first_hour'
targets = ['tornado', 'severe_hail', 'severe_wind', 'road_surface']
opts = ['original', 'reduced']

def worker(target, opt): 
    
    emailer = Emailer()
    start_time = emailer.get_start_time()
    
    out_file = os.path.join(DATA_BASE_PATH, f'sage_results_{opt}_{target}.nc')
    #if os.path.exists(out_file):
    #    print(f'{out_file} already exists...')
    #    return None
                            
    print(f'Running {target} {opt}...')
    if target=='road_surface':
        train_df = pd.read_csv('/work/mflora/ROAD_SURFACE/probsr_training_data.csv')
        if opt == 'original':
                calibrator =  load(os.path.join(MODEL_BASE_PATH, 'JTTI_ProbSR_RandomForest_Isotonic.pkl'))
                rf_orig = load(os.path.join(MODEL_BASE_PATH,'JTTI_ProbSR_RandomForest.pkl'))
                model = CalibratedClassifier(rf_orig, calibrator)
                X = train_df[PREDICTOR_COLUMNS].astype(float)
                y = train_df[TARGET_COLUMN].astype(float).values
        else:
                # Load Model
                model_name = os.path.join(MODEL_BASE_PATH, f'RandomForest_manualfeatures_12.joblib')
                data = load(model_name)
                model = data['model']
                X = train_df[data['features']].astype(float)
                y = train_df[TARGET_COLUMN].astype(float).values

    else:
        opt_tag = '' if opt == 'original' else 'L1_based_feature_selection_with_manual'
        df = pd.read_pickle(os.path.join(DATA_BASE_PATH, f'{time}_training_matched_to_{target}_0km_dataset'))
    
        # Load Model
        model_name = os.path.join(MODEL_BASE_PATH,
                                  f'LogisticRegression_first_hour_{target}_under_standard_{opt_tag}.pkl')
        
        data = load(model_name)
        model = data['model']
        X = df[data['features']].astype(float)
        y = df[f'matched_to_{target}_0km'].astype(float).values
    
    # Calculate SAGE values
    X_sub, y_sub = subsampler(X,y)
    sage_values = compute_sage(model, X_sub.values, y_sub, X)
    
    with open(out_file, 'wb') as f:
        pickle.dump(sage_values, f)
        
    emailer.send_email(f'SAGE for {target} {opt} is done', 
            start_time
           )

In [6]:
#worker('road_surface', 'reduced')

In [7]:
run_parallel(worker, args_iterator=to_iterator(targets, opts), nprocs_to_use=8, description='SAGE Compute')

SAGE Compute:   0%|                                                                                                                 | 0/8 [00:00<?, ?it/s]

Running tornado original...Running tornado reduced...Running severe_hail original...Running severe_hail reduced...Running severe_wind reduced...Running severe_wind original...Running road_surface reduced...Running road_surface original...









Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)
Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)
Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)
Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


(100000, 30)


Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)
Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


(100000, 14)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


(100000, 32)


  0%|          | 0/1 [00:00<?, ?it/s]

(100000, 25)


  0%|          | 0/1 [00:00<?, ?it/s]

(100000, 11)


  0%|          | 0/1 [00:00<?, ?it/s]

(100000, 113)


  0%|          | 0/1 [00:00<?, ?it/s]

(100000, 113)


  0%|          | 0/1 [00:00<?, ?it/s]

(100000, 113)


  0%|          | 0/1 [00:00<?, ?it/s]

SAGE Compute: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [1:47:29<00:00, 806.17s/it]


[None, None, None, None, None, None, None, None]