# Extract rules and membership functions from TGFNN model

In [None]:
from generalized_fuzzy_net import GeneralizedFuzzyClassifier as TGFNN
from rule_extraction import draw_membership_function, extract_encoding_intervals, extract_relations
import re
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from sklearn.preprocessing import MinMaxScaler
from utils import cal_acc

def get_thresholds(rules, params, encoding_value_details):
    rule_var_name = rules.index
    rule_var_name = set([name.split('_low')[0] for name in rule_var_name if 'low' in name])

    continous_variable_name = [params.feature_names[index] \
                                for index in range(len(params.category_info)) \
                                if params.category_info[index]==0]

    encoding = pd.DataFrame(np.transpose(encoding_value_details), index=continous_variable_name,
                            columns=['low', 'medium_left', 'medium_right', 'high'])
    
    return encoding

def extract_rules_from_net(net, params, scaler):
    
    row_names_continous = []
    row_names_categorical = []
    for i in range(len(params.category_info)):
        if params.feature_names is None:
            if params.category_info[i] == 0:
                row_names_continous += [f'x{i}_low', f'x{i}_medium', f'x{i}_high']
            else:
                row_names_categorical += [f'x{i}_level{j}' for j in range(params.category_info[i])]
        else:
            feature_name = params.feature_names[i]
            if params.category_info[i] == 0:
                row_names_continous += [f'{feature_name}_low', f'{feature_name}_medium', f'{feature_name}_high']
            else:
                row_names_categorical += [f'{feature_name}_level{j}' for j in range(params.category_info[i])]
    row_names = row_names_continous + row_names_categorical
    
    ## For continous variables
    # Extract the way that varaibles are encoded
    encoding_values, extract_encoding_details = extract_encoding_intervals(net, scaler)
    encoding_column_continous = np.expand_dims(encoding_values.flatten('F'), axis=1)
    encoding_column_categorical = []
    category_levels = params.category_info[params.category_info>0]
    for n_levels in category_levels:
        encoding_column_categorical += [i for i in range(n_levels)]
    encoding_column_categorical = np.expand_dims(np.array(encoding_column_categorical), axis=-1)
    encoding_column = np.concatenate([encoding_column_continous,
                                      encoding_column_categorical], axis=0)

    # Extract the attention mask and connection mask
    # the entry in relation is calculated by multiplying the corresponding entry in the attention mask with
    # the corresponding entry in the connection mask     
    attention_mask, connection_mask, relation_mat = extract_relations(net, params)

    # Extract the output layer (why do I square it?)
    out_layer = net.layer3.weight.detach().numpy()**2
    if params.binary_pos_only:
        row_names.append('direction')
    else:
        row_names.extend([f'{i}' for i in range(params.n_classes)])
    
    # Normalize the output layer
    weighted_out_layer = out_layer/np.max(np.abs(out_layer))
    out_row = np.insert(weighted_out_layer, 0, np.nan, axis=0)

    rules = np.concatenate([encoding_column, relation_mat], axis=-1)
    # rules = np.concatenate([rules, out_row.T], axis=0)
    # print(out_row.shape)
    rules = np.concatenate([rules, np.expand_dims(out_row, axis=0)], axis=0)
        
    # Build the row names and column names
    if params.binary_pos_only:
        all_column_names = ['encoding']
        for i in range(params.n_rules):
            all_column_names.append(f'Rule_{i}')
    else:
        all_column_names = ['encoding']
        for i in range(params.n_rules):
            all_column_names.append(f'Rule_{i}')
            
    # Table with all rules extracted from the network
    rules_all =  pd.DataFrame(rules, columns=all_column_names, index=row_names)

    return rules_all, attention_mask, connection_mask, extract_encoding_details, relation_mat

#### Load CV models

In [None]:
# params
random_state = 0
dir_name = 'experiment'

dataset = pickle.load(open(f'cv_results/{dir_name}/dataset.pkl', 'rb'))
models = pickle.load(open(f'cv_output/experiment_60_all_GFN_cv_models.pkl', 'rb'))

### Evaluate

In [None]:
all_metrics = []

for i in range(len(models)):
    model = models[i]
    # evaluate
    _, train_metrics, metric_names, _, _, fpr_train, tpr_train = cal_acc(model, dataset['X_train'], dataset['y_train'], model.n_classes>2)
    _, test_metrics, _, _, _, fpr_test, tpr_test = cal_acc(model, dataset['X_test'], dataset['y_test'], model.n_classes>2)
    # conf_mat = [conf_mat_train, conf_mat_test]

    # compile metrics
    roc_values = {'fpr_test': fpr_test, 'tpr_test': tpr_test, 'auc_test': test_metrics[5],
                    'fpr_train': fpr_train, 'tpr_train': tpr_train, 'auc_train': train_metrics[5]}


    metrics = pd.DataFrame([train_metrics, test_metrics], 
                            columns=metric_names, 
                            index=['Train', 'Test'])
    metrics = metrics.reset_index(names='Set')
    metrics['model_index'] = i
    metric_names = metric_names + ['model_index']

    # print(conf_mat[0])
    # print(conf_mat[1])
    all_metrics.append(metrics)

metrics = pd.concat(all_metrics, axis=0, ignore_index=True)

In [None]:
metrics.sort_values(by=['auc'], ascending=False)

#### Extract rules

In [None]:
def sort_index(x):
    concept_types = {
        'pheno' : 'c',
        'blood' : 'b',
        'creatinine' : 'b',
        'CO2' : 'b',
        'BMI' : 'a',
        'pulse' : 'a',
        'height' : 'a',
        'temperature' : 'a',
        'calcium' : 'b',
    }

    # add new column of concept types based on if key substring found in index
    for i in range(x.shape[0]):
        concept = x.index.tolist()[i]
        for key in concept_types.keys():
            if key in concept:
                x.loc[concept, 'Type'] = f'{concept_types[key]}{i}'
                break
    x = x.sort_values(by='Type', ascending=True)
    x = x.drop(columns='Type')

    return x

def pub_fig_adjust_index(concepts, concept_thresholds):
    '''
    Use this function to adjust the names of concepts in the output rule table to make them more interpretable.
    '''
    new_concepts = []

    for concept in concepts:
        phenotype = False

        if concept == 'directions' or concept.startswith('weight_'):
            new_concepts.append(concept)
            continue

        # get threshholds
        if concept[0].isdigit():
            phenotype = True
        elif '_level' not in concept:
            threshholds = concept_thresholds.loc['_'.join(concept.split('_')[:-1]),:]
        else:
            threshholds = None

        # remove source file name
        concept = re.sub('\[.*\]', '', concept)

        # remove whitespace
        # concept = concept.replace(' ', '')

        # make quantity more legible
        quantity = concept.split('_')[-1]
        concept = ' '.join(concept.split('_')[:-1]).title()

        units = {
            'Mean Creatinine' : 'mg/dL',
            'Min Creatinine' : 'mg/dL',
            'Std Co2' : 'mmol/L',
            'Mean Calcium' : 'mg/dL',
            'Std Bmi' : 'kg/m$^2$',
            'Min Bmi' : 'kg/m$^2$',
            'Mean Bmi' : 'kg/m$^2$',
            'Latest Pulse' : 'bpm',
            'Std Pulse' : 'bpm',
            'Mean Pulse' : 'bpm',
            'Min Height Cm' : 'cm',
            'Mean Temp' : 'F',
            'Std Age' : 'years',
            'Mean Bp Dia' : 'mmHg',
            'Mean Bp Sys' : 'mmHg',
            'Latest Bp Sys' : 'mmHg',
        }

        if quantity == 'level0':
            concept = 'No ' + concept
        elif phenotype:
            if ' ' in concept:
                phe_number = concept.split(' ')[0]
                phe_type = concept.split(' ')[1]
            else:
                phe_number = concept
                phe_type = 'Dx/Rx'
            concept = phe_type + ' phenotype ' + phe_number + ' is ' + quantity
        elif quantity == 'level1':
            concept = concept
        elif quantity == 'low':
            t = str(round((threshholds['medium_left'] + threshholds['low'])/2, 3)) + ' ' + units[concept]
            quantity = f'low (<{t})'
            concept = concept + ' is ' + quantity
        elif quantity == 'medium':
            t1 = str(round(threshholds['medium_left'], 3))
            t2 = str(round(threshholds['medium_right'], 3))
            quantity = f'medium ({t1} - {t2} {units[concept]})'
            concept = concept + ' is ' + quantity
        elif quantity == 'high':
            t = str(round((threshholds['medium_right'] + threshholds['high'])/2, 3)) + ' ' + units[concept]
            quantity = f'high (>{t})'
            concept = concept + ' is ' + quantity
        else:
            concept = concept + ' is ' + quantity


        # specific adjustments
        swap = {
            'Mmhg' : 'mmHg',
            'Cvp' : 'CVP',
            'Map' : 'MAP',
            'Hr' : 'hr',
            'Fluid Based' : '(fluid based)',
            'Arterial Line' : '(arterial line)',
            'Mpv' : 'MPV',
            'Bmi' : 'BMI',
            'Std' : 'Std dev',
            'Bp Dia' : 'diastolic blood pressure',
            'Bp Sys' : 'systolic blood pressure',
            'Creatinine' : 'creatinine',
            'Calcium' : 'calcium',
            'Temp' : 'temperature',
            'Height Cm' : 'height',
            ' Mg' : 'mg',
            'Lv' : 'Lab/vital',
            'Dxrx' : 'Dx/Rx',
            'Age' : 'age',
            'Pulse' : 'pulse',
            'Co2' : 'CO2',
            'Tablet' : 'tablet',
            'History' : 'History of',
            ' Rx ' : ' ',
            'Family Cardiac Hx' : 'family history of cardiac disease',
            'Latest Smoking Status Never' : 'latest smoking status of \'never\'',
        }

        for key in swap.keys():
            concept = concept.replace(key, swap[key])
        

        new_concepts.append(concept)

    return new_concepts

save = False

model = models[2]
class Params(object):
    binary_pos_only = model.binary_pos_only
    n_rules = model.n_rules
    n_classes = model.n_classes
    category_info = dataset['category_info']
    epsilon1 = model.min_epsilon1
    epsilon2 = model.min_epsilon2
    epsilon3 = model.min_epsilon3
    feature_names = dataset.get('feature_names')
    
params = Params()

# Extract rule data from a trained model
rules_all, attention_mask, connection_mask, encoding_value_details, relation_mat = extract_rules_from_net(
    model.estimator, 
    params, 
    model.scaler)

# Get thresholds
pd.options.display.float_format = '{:.6f}'.format
encoding = get_thresholds(rules_all, params, encoding_value_details)
if save:
    encoding.to_csv(f'cv_results/{dir_name}/{i}_encoding.csv')


# get rules
rules_all = rules_all.drop(columns=['encoding'])

# scale concept importance
scaler = MinMaxScaler()
rules_all.iloc[:-1,:] = scaler.fit_transform(rules_all.iloc[:-1,:].values.reshape(-1, 1)).reshape(rules_all.iloc[:-1,:].shape)

# filter rules and concepts
x = rules_all.copy()
x = x[x.loc['direction',:].sort_values(ascending=False).index.tolist()] # sort rules by importance
x.columns = [f'R{i}' for i in range(x.shape[1])] # rename rules by importance 
x.columns = (x.loc['direction',:].index + '\n' + x.loc['direction',:].astype(float).round(3).astype(str)).values.tolist() # add rule importance to column names
x = x.loc[:, x.loc['direction',:] >= 0.1] # remove rules with low importance
x = x.iloc[:-1,:] # remove class and inference contribution rows
x = x.astype(float)
x = x.loc[x.max(axis=1) >= 0.1,:] # remove concepts without any importance
x.index = pub_fig_adjust_index(x.index.tolist(), encoding) # make index more legibel
x = sort_index(x)

# plot
plt.clf()
sns.set(font_scale=1.3)
plt.figure(figsize=(17,15))
g = sns.heatmap(x, cmap='Reds', linewidths=0.5, linecolor='white', cbar_kws={'shrink': 0.5})
g.figure.tight_layout()
g.set_xticklabels(g.get_xmajorticklabels(), fontweight = 'bold')
# plt.title('Rules')

if save:
    plt.savefig(f'cv_results/{dir_name}/fold_2_rules.png', dpi=300, bbox_inches='tight')
else:
    plt.show()

In [None]:
encoding

### Plot membership functions

In [None]:

i = 2
model = models[i]
save_prefix = f'cv_results/{dir_name}/model_{i}_membership_functions/'
os.makedirs(save_prefix, exist_ok=True)

# Encoding Visualization
rule_var_name = rules_all.index
rule_var_name = set([name.split('_low')[0] for name in rule_var_name if 'low' in name])


continous_variable_name = [params.feature_names[index] \
                            for index in range(len(params.category_info)) \
                            if params.category_info[index]==0]

for index, var_name in enumerate(continous_variable_name):
    if var_name in rule_var_name:
        var_name = var_name.replace('/','backslash')
        draw_membership_function(
            encoding.iloc[index, 0], 
            encoding.iloc[index, 1], 
            encoding.iloc[index, 2], 
            encoding.iloc[index, 3], 
            output_path=save_prefix, 
            n_points=1000, 
            epsilon=params.epsilon1, 
            variable_name=var_name
        )