In [1]:
from SignalTemporalLogic.STLFactory import STLFactory
import matplotlib.pyplot as plt
plt.rc('font', size=12)
import copy
import pandas as pd
import seaborn as sns
import numpy as np
from collections import Counter
import warnings
import treelib
import re

warnings.filterwarnings('ignore')
%matplotlib inline


In [2]:
#Helper Functions

def loadClientRules(popSize, dataFilename):
    clientRules = []
    clientTrees = []
    num = 1
    clientsAdded = 0
    while clientsAdded < popSize:
        fileName = dataFilename + repr(num) + "Rules.txt"
        fileFound, trees, rls = loadRuleSet(num, fileName)
        # c.logRuleSet()
        if fileFound:
            clientsAdded += 1
            clientTrees.extend(trees)
            clientRules.extend(rls)

        num += 1
        
    return clientTrees, clientRules


def loadRuleSet(num, textfile):
    ruleSet = []
    ruleTrees = []
    stlFac = STLFactory()
    try:
        file = open(textfile, "r")
        for line in file:
            if line[0] == "(" and line[-2] == ")":
                line = line[1:-2] + "\n"

            rule = stlFac.constructFormulaTree(line)
            rule.getFormulaNoParams()
                
            ruleTrees.append(rule)
            ruleSet.append(rule.toString())

        file.close()
        return True, ruleTrees, ruleSet

    except:
        print("File not found for Client %d" % (num) )
        return False, ruleTrees, ruleSet
            


## Load Client Rules

In [3]:
#get list of client rules sorted by count
popSize = 10
dataFilename = "../Data/ICU/Best/"

In [4]:
# Load Client Rule Trees and Text Rule Structures
clientTrees, clientRules = loadClientRules(popSize, dataFilename)

File not found for Client 8


In [5]:
#Make dataframe of rules and their counts
clientDF = pd.DataFrame.from_dict(dict(Counter(clientRules)), orient='index').reset_index()
clientDF.columns=["Rule", "Rule Count"]
clientDF['Percent of Population'] = clientDF['Rule Count'] / popSize

clientDF.sort_values("Rule Count", ascending=False)

Unnamed: 0,Rule,Rule Count,Percent of Population
79,"G[0,0]((n_evts <= 0.000 -> LOS >= 0.000))",17,1.7
69,"((MET >= 0.000) U[0,0] (death = 0.000))",13,1.3
62,"F[0,0]((BLOOD_UREA_NITROGEN <= 0.000 & CREATIN...",11,1.1
134,"((death >= 0.000) U[0,0] (y >= 0.000))",3,0.3
4,"((Glasgow_Coma_Scale_Total >= 0.000) U[0,0] (n...",3,0.3
...,...,...,...
194,"((n_edrk >= 0.000) U[0,0] (PROTIME_INR >= 0.000))",1,0.1
193,"((Glasgow_Coma_Scale_Total >= 0.000) U[0,0] (H...",1,0.1
192,"F[0,0]((PARTIAL_THROMBOPLASTIN_TIME >= 0.000 &...",1,0.1
191,"((ICU_Pt_Days >= 0.000) U[0,0] (PHOSPHORUS <= ...",1,0.1


## Load LDP Ruleset

In [6]:
resultsFilename = "../Results/ICU_Ruleset_MCTS_Baseline.txt"
ldpTrees = []
ldpRules = []

file = open(resultsFilename, "r")
stlFac = STLFactory()
for line in file:
    if line[0] == "(" and line[-2] == ")":
        line = line[1:-2] + "\n"

    rule = stlFac.constructFormulaTree(line)
    rule.getFormulaNoParams()

    ldpTrees.append(rule)
    ldpRules.append(rule.toString())

file.close()

ldpRules

['ALT_GPT > 0.000',
 'CREATININE > 0.000',
 'ALKALINE_PHOSPHATASE > 0.000',
 'CHLORIDE > 0.000',
 'CO > 0.000',
 'Temp > 0.000',
 'F[0,0](SBP > 0.000)',
 'F[0,0](CREATININE > 0.000)',
 'PARTIAL_THROMBOPLASTIN_TIME > 0.000',
 'CALCIUM > 0.000',
 'F[0,0](srr > 0.000)',
 'death > 0.000',
 'F[0,0](cosen > 0.000)',
 'tte > 0.000',
 'Pulse > 0.000',
 'F[0,0](GLUCOSE > 0.000)',
 'F[0,0](dfa > 0.000)',
 'BLOOD_UREA_NITROGEN > 0.000',
 'F[0,0](MAGNESIUM > 0.000)',
 's_hr > 0.000',
 'F[0,0](ICU_Pt_Days > 0.000)',
 'F[0,0](CO > 0.000)',
 'SBP > 0.000',
 'AST_GOT > 0.000',
 'F[0,0](CHLORIDE > 0.000)',
 'F[0,0](SpO > 0.000)',
 'SpO > 0.000',
 'n_edrk > 0.000',
 'F[0,0](lds > 0.000)',
 'ICU_Pt_Days > 0.000',
 'ALBUMIN > 0.000',
 'TROPONIN_I > 0.000',
 'F[0,0](Temp > 0.000)',
 'F[0,0](POTASSIUM > 0.000)',
 'edrk > 0.000',
 'Mort > 0.000',
 'F[0,0](Mort > 0.000)',
 'F[0,0](Resp > 0.000)',
 'LOS > 0.000',
 'y > 0.000',
 'F[0,0](LOS > 0.000)',
 'WHITE_BLOOD_CELL_COUNT > 0.000',
 'F[0,0](O_Flow > 0.000)'

In [14]:
#Get count of the number of true structures matched in client rules

def getTemplateNodes(temp):
    nodes = []
    
    for node in temp.expand_tree(mode=treelib.Tree.DEPTH, sorting=True):
        n = re.sub('[0-9]', '', node)
        nodes.append(n)

    return nodes

def findRuleMatch(template, clientTrees):
    ldpNodes = getTemplateNodes(template)
    ldpVars = template.getAllVars()
    
    
    for c in clientTrees:
        # check if variables in rule
        clVars = c.getAllVars()
        hasVars = True
        for v in ldpVars:
            if v not in clVars:
                hasVars = False

        if hasVars:
            print("\nLDP vars", ldpVars)
            print("VAR MATCH", c.getAllVars())
            # check for structural match
            clientNodes = []

            for node in c.expand_tree(mode=treelib.Tree.DEPTH, sorting=True):
                n = re.sub('[0-9]', '', node)
                clientNodes.append(n)

            # print("client nodes", clientNodes)
            if nodeListMatch(ldpNodes, clientNodes):
                print("Found match ldp rule", template.toString())
                print("client match", c.toString())
                return True  # found match

    return False

# check for match  between two lists of template nodes + client nodes
def nodeListMatch(tempList, cList):
    #Fix relop matches
    tempList[:] = [x if x != "LT" else "LE" for x in tempList]
    tempList[:] = [x if x != "GT" else "GE" for x in tempList]
    cList[:] = [x if x != "LT" else "LE" for x in cList]
    cList[:] = [x if x != "GT" else "GE" for x in cList]

    i = 0
    while i < len(tempList):
        if tempList[i] in cList:
            idx = cList.index(tempList[i]) #get idx of element of cList
            cList = cList[idx+1:]
        else:
            return False

        i = i+1

    return True

def getCoverage(ldpTrees, clientTrees):
    #Calculate num true rules, num false rules and precision (true rules / total rules found)
    foundRules = 0
    nonRules = 0
    
    for l in ldpTrees:
        if findRuleMatch(l, clientTrees):
            foundRules += 1
        else:
            nonRules += 1
            
            
    print("Found ", foundRules, " Rules")
    print("Found ", nonRules, "Non Rules")
            #check structural match --> will count partial matches as a full match
        


In [15]:
getCoverage(ldpTrees, clientTrees)


LDP vars ['ALT_GPT']
VAR MATCH ['ALT_GPT', 'y']

LDP vars ['CREATININE']
VAR MATCH ['CREATININE', 'srr']

LDP vars ['CREATININE']
VAR MATCH ['CREATININE', 'PARTIAL_THROMBOPLASTIN_TIME']
Found match ldp rule CREATININE > 0.000
client match F[0,0]((CREATININE >= 0.000 & PARTIAL_THROMBOPLASTIN_TIME >= 0.000))

LDP vars ['ALKALINE_PHOSPHATASE']
VAR MATCH ['ALKALINE_PHOSPHATASE', 'y']
Found match ldp rule ALKALINE_PHOSPHATASE > 0.000
client match G[0,0]((ALKALINE_PHOSPHATASE > 0.000 -> y = 0.000))

LDP vars ['CHLORIDE']
VAR MATCH ['CHLORIDE', 'SpO2']

LDP vars ['CHLORIDE']
VAR MATCH ['CHLORIDE', 's8_hr']
Found match ldp rule CHLORIDE > 0.000
client match G[0,0]((CHLORIDE >= 0.000 | s8_hr >= 0.000))

LDP vars ['Temp']
VAR MATCH ['Temp', 'y']
Found match ldp rule Temp > 0.000
client match ((Temp >= 0.000) U[0,0] (y >= 0.000))

LDP vars ['SBP']
VAR MATCH ['SBP', 'tte']

LDP vars ['SBP']
VAR MATCH ['n_edrk', 'SBP']

LDP vars ['SBP']
VAR MATCH ['direct', 'SBP']
Found match ldp rule F[0,0](SBP >

VAR MATCH ['Pulse', 's24_edrk']

LDP vars ['Pulse']
VAR MATCH ['Pulse', 's24_hr']

LDP vars ['Pulse']
VAR MATCH ['GLUCOSE', 'Pulse']

LDP vars ['Pulse']
VAR MATCH ['Pulse', 'SBP']

LDP vars ['Pulse']
VAR MATCH ['hr', 'Pulse']

LDP vars ['Pulse']
VAR MATCH ['hr', 'Pulse']
Found match ldp rule F[0,0](Pulse > 0.000)
client match F[0,0]((hr >= 0.000 & Pulse >= 0.000))

LDP vars ['Glasgow_Coma_Scale_Total']
VAR MATCH ['Glasgow_Coma_Scale_Total', 'n_evts']
Found match ldp rule Glasgow_Coma_Scale_Total > 0.000
client match ((Glasgow_Coma_Scale_Total >= 0.000) U[0,0] (n_evts >= 0.000))

LDP vars ['n_evts']
VAR MATCH ['n_evts', 'PROTIME_INR']
Found match ldp rule n_evts > 0.000
client match G[0,0]((n_evts >= 0.000 -> PROTIME_INR <= 0.000))

LDP vars ['AF']
VAR MATCH ['AF', 'Glasgow_Coma_Scale_Total']

LDP vars ['AF']
VAR MATCH ['AF', 'lds']

LDP vars ['AF']
VAR MATCH ['AF', 'y']
Found match ldp rule F[0,0](AF > 0.000)
client match F[0,0]((AF >= 0.000 & y = 0.000))

LDP vars ['hr']
VAR MATCH ['h

VAR MATCH ['n_edrk', 'SODIUM']

LDP vars ['SODIUM']
VAR MATCH ['SODIUM', 'Temp']

LDP vars ['SODIUM']
VAR MATCH ['GLUCOSE', 'SODIUM']

LDP vars ['SODIUM']
VAR MATCH ['O2_Flow', 'SODIUM']
Found match ldp rule SODIUM > 0.000
client match G[0,0]((O2_Flow >= 0.000 -> SODIUM >= 0.000))

LDP vars ['PHOSPHORUS']
VAR MATCH ['PHOSPHORUS', 'PLATELET_COUNT']

LDP vars ['PHOSPHORUS']
VAR MATCH ['HEMOGLOBIN', 'PHOSPHORUS']

LDP vars ['PHOSPHORUS']
VAR MATCH ['PHOSPHORUS', 'tte']

LDP vars ['PHOSPHORUS']
VAR MATCH ['n_evts', 'PHOSPHORUS']
Found match ldp rule PHOSPHORUS > 0.000
client match F[0,0]((n_evts >= 0.000 & PHOSPHORUS >= 0.000))

LDP vars ['ALKALINE_PHOSPHATASE']
VAR MATCH ['ALKALINE_PHOSPHATASE', 'y']

LDP vars ['lds']
VAR MATCH ['AF', 'lds']

LDP vars ['lds']
VAR MATCH ['lds', 'y']

LDP vars ['lds']
VAR MATCH ['lds', 'Pulse']
Found match ldp rule lds > 0.000
client match G[0,0]((lds >= 0.000 | Pulse <= 0.000))

LDP vars ['direct']
VAR MATCH ['direct', 'n_evts']

LDP vars ['direct']
VAR MA