In [18]:
from SignalTemporalLogic.STLFactory import STLFactory
import matplotlib.pyplot as plt
plt.rc('font', size=12)
import copy
import pandas as pd
import seaborn as sns
import numpy as np
from collections import Counter
import warnings
import treelib
import re
import logger

warnings.filterwarnings('ignore')
%matplotlib inline


In [63]:
#Helper Functions

def loadClientRules(popSize, dataFilename):
    clientRules = []
    clientTrees = []
    num = 1
    clientsAdded = 0
    while clientsAdded < popSize:
        fileName = dataFilename + repr(num) + "Rules.txt"
        fileFound, trees, rls = loadRuleSet(num, fileName)
        # c.logRuleSet()
        if fileFound:
            clientsAdded += 1
            clientTrees.extend(trees)
            clientRules.extend(rls)

        num += 1      
    return clientTrees, clientRules


def loadRuleSet(num, textfile):
    ruleSet = []
    ruleTrees = []
    stlFac = STLFactory()
    try:
        file = open(textfile, "r")
        for line in file:
            if line[0] == "(" and line[-2] == ")":
                line = line[1:-2] + "\n"

            rule = stlFac.constructFormulaTree(line)
            rule.getFormulaNoParams()
                
            ruleTrees.append(rule)
            
            #fix relop for string rule
            strRl = rule.toString()
            strRl = re.sub('>=', '>', strRl)
            strRl = re.sub('<=', '<', strRl)
            ruleSet.append(strRl)

        file.close()
        return True, ruleTrees, ruleSet

    except:
        print("File not found for Client %d" % (num) )
        return False, ruleTrees, ruleSet
            


## Load Client Rules

In [64]:
#get list of client rules sorted by count
popSize = 100
dataFilename = "../Data/ICU/Best/"
# dataFilename = "../Data/Sepsis/Best/"

In [65]:
# Load Client Rule Trees and Text Rule Structures
clientTrees, clientRules = loadClientRules(popSize, dataFilename)

File not found for Client 8
File not found for Client 81


In [66]:
#Make dataframe of rules and their counts
clientDF = pd.DataFrame.from_dict(dict(Counter(clientRules)), orient='index').reset_index()
clientDF.columns=["Rule", "Rule Count"]
clientDF['Percent of Population'] = clientDF['Rule Count'] / popSize

clientDF.sort_values("Rule Count", ascending=False)

Unnamed: 0,Rule,Rule Count,Percent of Population
69,"((MET > 0.000) U[0,0] (death = 0.000))",230,2.30
79,"G[0,0]((n_evts < 0.000 -> LOS > 0.000))",195,1.95
62,"F[0,0]((BLOOD_UREA_NITROGEN < 0.000 & CREATINI...",77,0.77
102,"G[0,0]((hr > 0.000 & Pulse > 0.000))",46,0.46
619,"((Mort > 0.000) U[0,0] (y = 0.000))",8,0.08
...,...,...,...
1766,"F[0,0]((CREATININE < 0.000 -> y = 0.000))",1,0.01
1768,"G[0,0]((CHLORIDE > 0.000 | GLUCOSE < 0.000))",1,0.01
1769,"F[0,0]((BLOOD_UREA_NITROGEN < 0.000 -> direct ...",1,0.01
1771,"G[0,0]((s2_hr < 0.000 & s8_hr > 0.000))",1,0.01


## Load LDP Ruleset

In [None]:
resultsFilename = "../Results/ICU_Ruleset_MCTS_Baseline.txt"
# resultsFilename = "../Results/Sepsis_Ruleset_MCTS_Baseline_1000pts_100iters.txt"

ldpTrees = []
ldpRules = []

file = open(resultsFilename, "r")
stlFac = STLFactory()
for line in file:
    if line[0] == "(" and line[-2] == ")":
        line = line[1:-2] + "\n"

    rule = stlFac.constructFormulaTree(line)
    rule.getFormulaNoParams()

    ldpTrees.append(rule)
    ldpRules.append(rule.toString())

file.close()

ldpRules

In [None]:
#Get count of the number of true structures matched in client rules

def getTemplateNodes(temp):
    nodes = []
    
    for node in temp.expand_tree(mode=treelib.Tree.DEPTH, sorting=True):
        n = re.sub('[0-9]', '', node)
        nodes.append(n)

    return nodes

def findRuleMatch(template, clientTrees):
    ldpNodes = getTemplateNodes(template)
    ldpVars = template.getAllVars()
    
    
    for c in clientTrees:
        # check if variables in rule
        clVars = c.getAllVars()
        hasVars = True
        for v in ldpVars:
            if v not in clVars:
                hasVars = False

        if hasVars:
            # check for structural match
            clientNodes = []

            for node in c.expand_tree(mode=treelib.Tree.DEPTH, sorting=True):
                n = re.sub('[0-9]', '', node)
                clientNodes.append(n)

            # print("client nodes", clientNodes)
            if nodeListMatch(ldpNodes, clientNodes):
                return True  # found match

    return False

# check for match  between two lists of template nodes + client nodes
def nodeListMatch(tempList, cList):
    #Fix relop matches
    tempList[:] = [x if x != "LT" else "LE" for x in tempList]
    tempList[:] = [x if x != "GT" else "GE" for x in tempList]
    cList[:] = [x if x != "LT" else "LE" for x in cList]
    cList[:] = [x if x != "GT" else "GE" for x in cList]

    i = 0
    while i < len(tempList):
        if tempList[i] in cList:
            idx = cList.index(tempList[i]) #get idx of element of cList
            cList = cList[idx+1:]
        else:
            return False

        i = i+1

    return True

def getCoverage(ldpTrees, clientTrees):
    #Calculate num true rules, num false rules and precision (true rules / total rules found)
    foundRules = 0
    nonRules = 0
    
    for l in ldpTrees:
        if findRuleMatch(l, clientTrees):
            foundRules += 1
        else:
            nonRules += 1
            
            
    print("Found", foundRules, " Rules")
    print("Found", nonRules, "Non Rules")
    print("Precision", foundRules / (foundRules + nonRules))
            #check structural match --> will count partial matches as a full match
        


In [None]:
getCoverage(ldpTrees, clientTrees)

In [None]:
#To do - make coverage based on cutoff threshold for number of clients who have the rule (e.g., 1% or 10%)


In [None]:
len(clientRules)