In [None]:
from generalized_fuzzy_net import GeneralizedFuzzyClassifier as TGFNN
from rule_extraction import extract_rules_from_net
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from load_dataset import load_data
import pandas as pd
import numpy as np
import re
from sklearn.preprocessing import MinMaxScaler
import icd9cms.icd9 as icd9

In [None]:
def adjust_index(concepts, concept_thresholds):
    '''
    Use this function to adjust the names of concepts in the output rule table to make them more interpretable.
    '''
    new_concepts = []

    for concept in concepts:

        if concept == 'directions' or concept.startswith('weight_'):
            new_concepts.append(concept)
            continue

        # make quantity more legible
        quantity = concept.split('_')[-1]
        concept = ' '.join(concept.split('_')[:-1]) #.title()

        # process ICD codes
        if concept.endswith(')') and 'RX:' not in concept:
            code = concept.split('(')[-1].split(')')[0]
            if icd9.search(code).long_desc:
                desc = icd9.search(code).long_desc
            else:
                desc = icd9.search(code).short_desc
            concept = '\"' + desc + '\" (' + code + ')'

        if quantity == 'level0':
            concept = 'No ' + concept
        elif quantity == 'level1':
            concept = concept
        else:
            concept = concept + ' is ' + quantity

        # specific adjustments
        swap = {
            'Mmhg' : 'mmHg',
            'Cvp' : 'CVP',
            'Map' : 'MAP',
            'Hr' : 'hr',
            'Fluid Based' : '(fluid based)',
            'Arterial Line' : '(arterial line)',
            'Mpv' : 'MPV',
            'AGE' : 'Age',
            'RDW' : 'Red Cell Distribution Width',
            'No ETHNICITY ASIAN' : 'Not of Asian ethnicity',
            'ETHNICITY ASIAN' : 'of Asian ethnicity',
            '(100 units/ml)' : '',
            'RX: ' : '',
            'Antineoplastic and immunosuppressive drugs causing adverse effects in therapeutic use' : 'Adverse effects of antineoplastic/immunosuppressive drugs',
            'Bacterial infection in conditions classified elsewhere and of unspecified site' : 'Other Bacterial Infection',
            'Persistent mental disorders due to conditions classified elsewhere' : 'Other persistent mental disorders',
            'Secondary malignant neoplasm of respiratory and digestive systems' : 'Secondary malignant neoplasm of respiratory/digestive systems'
        }

        for key in swap.keys():
            concept = concept.replace(key, swap[key])
        

        new_concepts.append(concept)

    return new_concepts


def del_vars(x, category_info, threshhold):
    n_cont = len(category_info[category_info == 0])
    n_cat = len(category_info[category_info > 0])

    x_cont = x.iloc[:n_cont*3,:]
    x_cat = x.iloc[n_cont*3:,:]

    # remove continuous variables with max value less than threshhold
    for i in range(n_cont):
        max = x_cont.iloc[i*3:(i+1)*3,:].max(axis=1).max()
        if max < threshhold:
            x_cont = x_cont.drop(x_cont.index[i*3:(i+1)*3])

    # remove categorical variables with max value less than threshhold
    for i in range(n_cat):
        max = x_cat.iloc[i*2:(i+1)*2,:].max(axis=1).max()
        if max < threshhold:
            x_cat = x_cat.drop(x_cat.index[i*2:(i+1)*2])

    return pd.concat([x_cont, x_cat])


def get_thresholds(rules, params, encoding_value_details):
    rule_var_name = rules.index
    rule_var_name = set([name.split('_low')[0] for name in rule_var_name if 'low' in name])

    continous_variable_name = [params.feature_names[index] \
                                for index in range(len(params.category_info)) \
                                if params.category_info[index]==0]

    encoding = pd.DataFrame(np.transpose(encoding_value_details), index=continous_variable_name,
                            columns=['low', 'medium_left', 'medium_right', 'high'])
    
    return encoding

In [None]:
dir = 'cv_results'
exp = 'run-11_newcontraregu'
n_repeats = 1
dataset = pickle.load(open(f'{dir}/{exp}/dataset.pkl', 'rb'))
model = pickle.load(open(f'{dir}/{exp}/{n_repeats}_trained_GFN.mdl', 'rb'))[0]

In [None]:
class Params(object):
    binary_pos_only = model.binary_pos_only
    n_rules = model.n_rules
    n_classes = model.n_classes
    category_info = dataset['category_info']
    epsilon1 = model.min_epsilon1
    epsilon2 = model.min_epsilon2
    epsilon3 = model.min_epsilon3
    feature_names = dataset.get('feature_names')
    
params = Params()

# Extract rule data from a trained model
rules, rules_all, _, _, encoding_value_details = extract_rules_from_net(
    model.estimator, 
    params, 
    model.scaler, 
    keep_irrelevant_variable=True,
    filter_similarity_threshold=1)

In [None]:
# Get thresholds
pd.options.display.float_format = '{:.2f}'.format
encoding = get_thresholds(rules, params, encoding_value_details)
encoding

In [None]:
x = rules_all.drop(columns=['encoding'])
x = x[x.loc['pos_class_contribution',:].sort_values(ascending=False).index.tolist()]
x.columns = ['R' + str(i) for i in range(1, x.shape[1]+1)]
x.columns = (x.loc['pos_class_contribution',:].index + '\n' + x.loc['pos_class_contribution',:].astype(str)).values.tolist()
x = x.iloc[:-1,:]

# # remove variables with max value less than 0.1
x = del_vars(x, dataset['category_info'], threshhold=0.1)

# make index more readable
x.index = adjust_index(x.index.tolist(), encoding)

plt.clf()
sns.set(font_scale=1.5)
plt.figure(figsize=(23,17))
g = sns.heatmap(x, cmap='Reds', linewidths=0.75, linecolor='white', cbar_kws={'shrink': 0.5}) #sns.light_palette('#F08080', 10)
g.figure.tight_layout()
g.set_xticklabels(g.get_xmajorticklabels(), fontweight = 'bold')
# plt.show()
plt.savefig(f'{dir}/{exp}/figure_rules_all_rules.png', dpi=200)

In [None]:
units = {}

# write rules
def write(k, v, rules):
    # low < encoding
    if v == ['low']:
        val = rules[(rules['feature'] == k) & (rules['concept'].isin(v))]['encoding'].values[0]
        return f'{k} < {val}{units.get(k,"")}'

    # high > encoding
    if v == ['high']:
        val = rules[(rules['feature'] == k) & (rules['concept'].isin(v))]['encoding'].values[0]
        return f'{k} > {val}{units.get(k,"")}'

    # medium > low and < high
    if v == ['medium']:
        val_l = rules[(rules['feature'] == k) & (rules['concept'] == 'low')]['encoding'].values[0]
        val_h = rules[(rules['feature'] == k) & (rules['concept'] == 'high')]['encoding'].values[0]
        return f'{k} > {val_l} and < {val_h}{units.get(k,"")}'

    # low and high < low and > high
    if v == ['low', 'high']:
        val_l = rules[(rules['feature'] == k) & (rules['concept'] == 'low')]['encoding'].values[0]
        val_h = rules[(rules['feature'] == k) & (rules['concept'] == 'high')]['encoding'].values[0]
        return f'{k} < {val_l} or > {val_h}{units.get(k,"")}'

    # low and medium < medium
    if v == ['low','medium']:
        val = rules[(rules['feature'] == k) & (rules['concept'] == 'medium')]['encoding'].values[0]
        return f'{k} < {val}{units.get(k,"")}'

    # high and medium > medium
    if v == ['medium', 'high']:
        val = rules[(rules['feature'] == k) & (rules['concept'] == 'medium')]['encoding'].values[0]
        return f'{k} > {val}{units.get(k,"")}'
    
    if v == ['level0']:
        return f'{k} is absent'

    if v == ['level1']:
        return f'{k} is present'

# prep data
rules = rules_all[rules_all.loc['pos_class_contribution',:].sort_values(ascending=False).index.tolist()]
rules = rules.drop('pos_class_contribution').reset_index(names='concept')
concepts = pd.DataFrame([x.rsplit('_', 1) for x in rules['concept'].tolist()], columns=['feature','concept'])
rules = rules.drop(columns=['concept'])
rules = pd.concat([concepts, rules], axis=1)

# process rule
for rule in rules.drop(['feature','concept','encoding'], axis=1).columns:
    x = rules[rules[rule] > 0.1][['feature','concept','encoding',rule]] # threshold
    concepts = x.groupby('feature')['concept'].apply(list).to_dict()

    written_rule = ''
    for k,v in concepts.items():
        written_rule = written_rule + write(k,v,rules) + ' and '

    written_rule = written_rule[:-5] # remove last ' and '
    print(f'{rule}: {written_rule}')
