In [None]:

"""
Usage:

1:  Create a dataset myDataset = dataset(setName, sigVars, bkdVars, sigWts, bkdWts, varNums, varNames)

        sigVars, bkdVars, sigWts, bkdWts are lists of variable values.
        varNums is just a list of integers for the variables
        varMames is a list of names for the variables

2:  Create a run and pass it a dataset: myRun = run(myDataset).

        A run holds all information about all nodes and all playouts for
        all playouts that you have run on it (see below).

3:  Execute one or more playouts for your run: playouts(myRun,100,[-1],50)

        myRun: is the run to execute playouts on
        100: is the number of events to use per playout
        [-1]: describes the type of playouts to execute (see below)
        50: is the number of playouts to execute

        playout types:
            [-1] is compressed tree.  I.e. nodes are included/excluded
            based on their performance independent of what other
            variables were inc/exc.

            [-2] is full tree, but see parameter settings for details on
            this.

            [i,j,k] Any other list for this argument is interpreted as
            fixed-path running.  All playouts will use only variable
            numbers i,j,k (counting from zero)

4:  Execute more playouts on the run as needed.

    N.B. Once a run object is created, you can continue to run playouts
    on it, and it will continue to accrue statistics from the playouts
    and add them to the runs stored information.  So, you could keep
    calling playouts(myRun,100,[-1],50) if after looking at the reports
    (see below), you decide that you want to continue with more playouts
    starting from the current state of the run.

5:  Print text-based report textReports(myRun,i)

        myRun specifies the run you want to report on, and  i indicates
        how many playouts you want to include in the per-playout portion
        of the report.  0 means all.

6:  Make plot-based reports plotReports(myRun,0,makePDF=False)

        Makes a bunch of plots to the screen if you're in something like
        jupyter or have XQuartz running.  If you set makePDF=True, then
        it will also send the plots to a file.

Extra:

There are additional parameters you can include when creating a run:

    getEvalParm: Both sides of a node must be visited this many
        times before either gate can be closed.  This is a subtle and
        important parameter for full-tree mode.  In full-tree mode,
        closing a gate on one of the variables means that if any node
        anywhere in the tree has minVisistsForGateEval visits to both
        its include and exclude sides, then the variable will be
        evaluated for possible closing one of the gates.  However, when
        that gate is closed, that variable is now always included or
        excluded no matter where it appears in the tree in future path
        searches.  The nice feature of this approach is that initially,
        only variables near the top of the tree (i.e., small var
        numbers) can be shut off, since they are the nodes that will
        reach the threshold first.  However, as low-numbered variables
        have one of their gates closed, gates further down the tree will
        get more visits.  So the gate actions will propagate down the
        tree as playouts increase.

        If one wants a proper full-tree mode with no gates ever being
        closed, just set this parameter to be very large.  There is
        currently no way to turn off gates in individual nodes in a full
        tree.

    nodeEvalParm: Only used in compress-tree mode
        Both sides of a node must be visited this many
        times before policy can be implemented to choose a side. Below
        this value, we flip a coin to choose a side.

    nParallel: Number of playouts that will be batched in parallel
        processing.  Speeds the code significantly, but do not set this
        to be greater than 8 without reading the longer comment in the
        code itself, even if you are on a machine with more than eight
        cores.

"""

# Jesse Ernst: Version 2.1 16Nov2019
import platform
import sys
import numpy as np
import random
import pandas as pd
import scipy
import operator
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import multiprocessing as mp
import math
import MLP

print("python : " + platform.python_version())
print("numpy : " + np.__version__)
print("pandas : " + pd.__version__)
print("scipy : " + scipy.__version__)

In [None]:
class dataset:
    """Set of events"""

    def __init__(self, setName, sigVars, bkdVars, sigWts, bkdWts, varNums, varNames):
        self.setName = setName
        self.sigVars = sigVars
        self.bkdVars = bkdVars
        self.sigWts = sigWts
        self.bkdWts = bkdWts
        self.sigWtMax = max(sigWts)[0]
        self.sigWtMin = min(sigWts)[0]
        self.bkdWtMax = max(bkdWts)[0]
        self.bkdWtMin = min(bkdWts)[0]
        self.varNums = varNums
        self.varNames = varNames

        # sample efficiency is the fraction of events you'd expect to pass weighted sampling.
        # One extreme is if all weights in the file are equal, then it should be 1.
        # The other extreme is if a few events have large weights and all the others have small weights.  Then,
        # the sampleEff value would tend toward zero.
        self.sigSampleEff = np.mean([x[0] / self.sigWtMax for x in self.sigWts])
        self.bkdSampleEff = np.mean([x[0] / self.bkdWtMax for x in self.bkdWts])

        assert (len(varNames) == len(varNums))
        assert (len(sigVars) == len(sigWts))
        assert (len(bkdVars) == len(bkdWts))

In [None]:
class run:
    """One run of the data through multiple paths"""

    def __init__(self, locDataset, gateEvalParm=20, nodeEvalParm=5, nParallelParm=8):
        self.dataset = locDataset
        self.varNums = locDataset.varNums
        self.varNames = locDataset.varNames

        # Both sides of a node must be visited this many times before either gate can be closed.
        self.minVisitsForGateEval = gateEvalParm

        # Both sides of a node must be visited this many times before policy can be implemented to choose a side.  Below
        # this value, we flip a coin to choose a side.
        self.minVisitsForNodeEval = nodeEvalParm

        # Number of playouts to run in parallel using python's multiprocessing calls.  One can't just launch all the
        # playouts into one big multiprocessing pool because each playout needs to make its own dataset that holds just
        # the variables it is using and is a random selection of all the events (since each playout runs only on a
        # subsample of events).  Typically, I'd expect that this should be set to something around the number of cores,
        # though it might have to be lower if you are running a large number of events in a playout, and don't have
        # enough memory to hold as many copies of the data as you have cores in the machine.
        # CAUTION: You also should not set this too high, because information on nodes paths are not updated until all
        # parallel processes are completed.  I.e., playouts are run in batches of nParallal.  Thus if you set it too
        # high, you'll be running many times w/o new information about which path to try next.  In particular, depending
        # on the tree type and the variable policy, you could wind up running the same exact path nParallel times.
        # That's not necessarily bad, as it'll improve your statistics w/o running-time penalty, but it's not going
        # to give you any more paths (though I guess you could set nEvents per playout smaller if you knew each was
        # going to be run nParallel times).
        self.nParallel = nParallelParm
        self.nodesColl = nodesCollection(self)
        self.pathStatsColl = pathStatsColl()
        self.playoutList = []
        self.nFallbackPath = 0

In [None]:
class nodesCollection:
    """A collection of nodes"""

    def __init__(self, locRun):
        self.run = locRun

        # This will be a nested dictionary for all nodes.  The outer key is var number.
        # The inner key will be locationID for the node.
        self.all = {}
        for i in range(len(locRun.dataset.varNums)):
            varNum = locRun.dataset.varNums[i]
            self.all[varNum] = {}  # Initially an empty dictionary for each var number.  Will be filled over playouts

        # This will be a dictionary of summary nodes (one per variable) with var number as key
        self.smry = {}
        for i in range(len(locRun.dataset.varNums)):
            varNum = locRun.dataset.varNums[i]
            varName = locRun.dataset.varNames[i]
            self.smry[varNum] = node(varNum, varName, locRun)  # Summary nodes for each var to hold stats.

    def getNode(self, varNum, nodeID):
        """Extracts the node with the given variable number and nodeID from the nodesCollection.
        Creates it if it doesn't already exist in the dictionary"""
        # From the nested dictionary, get the dictionary for this particular variable
        varDict = self.all[varNum]
        if nodeID not in varDict:  # requested node not yet in dict.  So create it then add it.
            newNode = node(varNum, self.run.dataset.varNames[varNum], self.run, nodeID)
            varDict[nodeID] = newNode
        myNode = varDict[nodeID]  # Get the node from the dictionary.
        return myNode

In [None]:
class node:
    """One node corresponds to one variable.  Each node has a variable
    that will either be included or excluded in any given path.  Each
    node also has a pair of gates that control whether or not the node's
    variable should always be included or always be excluded from
    paths."""

    def __init__(self, varNum, varName, locRun, nodeID=-999):
        self.varNum = varNum
        self.varName = varName
        self.run = locRun
        self.nodeID = nodeID
        self.nVisitsIncPath = 0
        self.nVisitsExcPath = 0
        self.sumScoreIncPath = 0
        self.sumScoreExcPath = 0
        self.sumScoreSqIncPath = 0
        self.sumScoreSqExcPath = 0
        self.incPathOpen = True
        self.excPathOpen = True

    def __copy__(self):
        newNode = node(self.varNum, self.varName, self.run, self.nodeID)
        newNode.nVisitsIncPath = self.nVisitsIncPath
        newNode.nVisitsExcPath = self.nVisitsExcPath
        newNode.sumScoreIncPath = self.sumScoreIncPath
        newNode.sumScoreExcPath = self.sumScoreExcPath
        newNode.sumScoreSqIncPath = self.sumScoreSqIncPath
        newNode.sumScoreSqExcPath = self.sumScoreSqExcPath
        newNode.incPathOpen = self.incPathOpen
        newNode.excPathOpen = self.excPathOpen
        return newNode

    def nodeOverallScore(self, included):
        """Computes the overall score for a node for either including or excluding it"""
        if included:  # score for results when node is included
            mean, errOnMean, stdev = \
                meanSigma(self.nVisitsIncPath, self.sumScoreIncPath, self.sumScoreSqIncPath)
        else:  # score for results when node is excluded
            mean, errOnMean, stdev = \
                meanSigma(self.nVisitsExcPath, self.sumScoreExcPath, self.sumScoreSqExcPath)
        return mean, errOnMean

    def includeVarPolicy(self, pathType, cpParm):
        """Given a node and a pathType, implements policy for choosing whether to
        include or exclude the node's variable based on the nodes current performance and visits.

        In full tree mode, you look at the inc/exc results for the specific node
        in the full tree.

        In compressed tree mode, you look at the inc/exc summary node
        for that variable. (The information in the summary node for that variable is
        a projection of all the full-tree nodes with the same variable number).

        In fixed path mode, you just inc/exc a variable based on the user-requested variables.

        In coin-flip mode, you randomly decide whether or not to include the variable.
        """
        includeVar = None  # initialize to prevent uninitialized warning
        if self.incPathOpen and not self.excPathOpen:  # exclude gate is closed
            includeVar = True
        elif self.excPathOpen and not self.incPathOpen:  # include gate is closed
            includeVar = False
        elif self.incPathOpen and self.excPathOpen:  # both gates are open, so you need to make comparison
            if pathType[0] == -1:  # compressed-tree mode
                if self.nVisitsIncPath < self.run.minVisitsForNodeEval or \
                        self.nVisitsExcPath < self.run.minVisitsForNodeEval:
                    includeVar = np.random.uniform(0.0, 1.0, 1)[0] >= 0.5  # flip coin
                else:  # both have been visited enough, so compare inc to exc
                    banditScoreInc, banditScoreExc = banditScore(self, cpParm)
                    includeVar = (banditScoreInc >= banditScoreExc)
            elif pathType[0] == -2:  # full-tree mode
                if self.nVisitsIncPath == 0 and self.nVisitsExcPath == 0:
                    includeVar = (np.random.uniform(0.0, 1.0, 1)[0] >= 0.5)  # flip a coin
                elif self.nVisitsIncPath == 0 and self.nVisitsExcPath != 0:  # only exclude has been visited
                    includeVar = True
                elif self.nVisitsIncPath != 0 and self.nVisitsExcPath == 0:  # only include has been visited
                    includeVar = False
                else:  # both have been visited, so make comparison to choose inc vs. exc
                    banditScoreInc, banditScoreExc = banditScore(self, cpParm)
                    includeVar = (banditScoreInc >= banditScoreExc)
            elif pathType[0] == -9:  # coin-flip mode
                includeVar = (np.random.uniform(0.0, 1.0, 1)[0] >= 0.5)  # flip a coin
            elif pathType[0] >= 0:  # fixed-path mode
                pass
            else:  # unknown mode
                print("ERROR: unknown variable-comparison mode")
                sys.exit()
        else:
            print("ERROR: Something is wrong, the gates on both sides are closed")
            sys.exit()

        if pathType[0] >= 0:  # for fixed-path running, ignore all above calculations and just inc/exc based on list.
            includeVar = True if self.varNum in pathType else False
        return includeVar

    def nodeStatsUpdate(self, runScore, included):
        """Given a node, a score for a playout, and whether the node
        was included or excluded from that playout, update the statistics
        on the node's performance."""
        if included:
            self.sumScoreIncPath += runScore
            self.sumScoreSqIncPath += runScore ** 2
            self.nVisitsIncPath += 1
        else:
            self.sumScoreExcPath += runScore
            self.sumScoreSqExcPath += runScore ** 2
            self.nVisitsExcPath += 1

        self.setGates()  # update which gates should be open for the node
        return

    def nodeStatsAdd(self, secondNode):
        """Add the stats from a second node to this node."""
        self.nVisitsIncPath += secondNode.nVisitsIncPath
        self.nVisitsExcPath += secondNode.nVisitsExcPath
        self.sumScoreIncPath += secondNode.sumScoreIncPath
        self.sumScoreExcPath += secondNode.sumScoreExcPath
        self.sumScoreSqIncPath += secondNode.sumScoreSqIncPath
        self.sumScoreSqExcPath += secondNode.sumScoreSqExcPath
        return

    def setGates(self):
        """Consider closing either the include or the exclude side of a
        node depending on various stats that the node has shown during
        playouts so far.  Note that a closed path could be reopened
        in the future if the performance of the side that has remained
        open falls sufficiently."""
        if self.nVisitsIncPath < self.run.minVisitsForGateEval or self.nVisitsExcPath < self.run.minVisitsForGateEval:
            self.incPathOpen = True
            self.excPathOpen = True
        else:  # We have passed the threshold number of playouts on each side and so can make gate-closing decision
            meanInc, meanIncErr = self.nodeOverallScore(True)
            meanExc, meanExcErr = self.nodeOverallScore(False)
            diffOfMeans = meanInc - meanExc
            errOnDiffOfMeans = (meanIncErr ** 2 + meanExcErr ** 2) ** 0.5
            nSigmaDiff = abs(diffOfMeans / errOnDiffOfMeans)

            # The goal with the following block is to permanently include a variable unless it is clearly harmful.
            # We now have enough playouts through both sides to judge, so close include gate only if excluding is
            # meaningfully better than including.  Otherwise, close the exclude gate.
            if (meanExc > meanInc) and nSigmaDiff > 1.0:
                self.incPathOpen = False
                self.excPathOpen = True
            else:
                self.incPathOpen = True
                self.excPathOpen = False
        return

In [None]:
class pathStatsColl:
    """Objects that collect information on multiple paths over multiple playouts"""

    def __init__(self):
        self.pathStatsDict = dict()

    def pathStatsUpdate(self, locPath, score):
        """Given a path and a score, add that information either to a
        new entry in the dictionary, or update the existing entry"""
        locPathIdStr, locPathIdDec = locPath.pathId()  # get the ID num and string for the path.
        if locPathIdStr not in self.pathStatsDict:  # If path doesn't yet have pathstats obj in dict then create/add it.
            newStats = pathStats(locPath)
            self.pathStatsDict[locPathIdStr] = newStats
        # retrieve path and update info
        currStats = self.pathStatsDict[locPathIdStr]
        currStats.nVisits += 1
        if score < currStats.minScore: currStats.minScore = score
        if score > currStats.maxScore: currStats.maxScore = score
        currStats.sumScore += score
        currStats.sumScoreSq += score ** 2
        currStats.mean, currStats.errOnMean, currStats.stdev = \
            meanSigma(currStats.nVisits, currStats.sumScore, currStats.sumScoreSq)
        self.pathStatsDict[locPathIdStr] = currStats
        return

In [None]:
class pathStats:
    """Collection of information about a path"""

    def __init__(self, locPath):
        self.pathIdStr, self.pathIdDec = locPath.pathId()
        self.nVisits = 0
        self.maxScore = 0
        self.minScore = 9e9
        self.sumScore = 0
        self.sumScoreSq = 0
        self.mean = 0
        self.errOnMean = 0
        self.stdev = 0

In [None]:
class path:
    """A path through the list of all variables.  One path corresponds
    to one specific set of nodes whose variables should be included in
    the path."""

    def __init__(self, locRun, pathType, cpParm):
        self.locRun = locRun
        self.pathType = pathType
        self.nodesColl = locRun.nodesColl
        self.allVars = locRun.dataset.varNums
        self.nodesList, self.incVars, self.excVars = self.chooseNodes(self.pathType, cpParm)
        if len(self.incVars) == 0:
            self.locRun.nFallbackPath += 1
            while len(self.incVars) == 0:  # If you got zero included variables, try again using coin-flip mode.
                self.nodesList, self.incVars, self.excVars = self.chooseNodes([-9], cpParm)

    def chooseNodes(self, locPathType, cpParm):
        """Choose a path from the full set of variables. The pathtype
        variables changes how the code chooses nodes when the includeVarPolicy is called."""

        assert len(locPathType) > 0, 'No path-type directive given'
        # This code is somewhat subtle.  You loop over all the variables.  For each variable (starting at the top
        # of the tree) you get the corresponding node.  You then evaluate the variable controlled by the node
        # to see if you include it.  THEN, that decision gets appended onto the incVars and excVars array.  Now those
        # new values for incVars and excVars will determine the node that you take during the next iteration of the loop
        nodesList = []
        incVars = []
        excVars = []
        for varNum in self.allVars:
            # Get the nodeID and then the node that that needs to be evaluated to see if its variable will be included.
            currNodeID = getNodeID(incVars, varNum)
            currNode = self.nodesColl.getNode(varNum, currNodeID)
            nodesList.append(currNode)

            # In compressed-tree mode, the node to evaluate for inc/exc is the summary node, not the tree node
            if locPathType[0] == -1:
                evalNode = self.nodesColl.smry[varNum]
            else:  # Not pathType mode of -1
                evalNode = currNode

            if evalNode.includeVarPolicy(locPathType, cpParm):  # inc/exc variable based on nodes info and the policy
                incVars.append(varNum)
            else:  # do not include the variable
                excVars.append(varNum)
        return nodesList, incVars, excVars

    def getIncVarNames(self):
        """report names of the variables included in path"""
        varNameList = []
        for i in self.incVars:
            varNameList.append(i.varName)
        return varNameList

    def pathId(self):
        """Given a path, returns decimal and string representation of
        binary where the binary representation shows shows which
        variables are included and excluded"""
        pathIdDecimal = 0
        for i in self.incVars:
            pathIdDecimal += 2 ** i
        pathIdStr = np.binary_repr(pathIdDecimal, width=len(self.allVars))
        return pathIdStr, pathIdDecimal

In [None]:
class playouts:
    """make one or more playouts, run them in multiple threads, then update statistics when all are complete"""

    def __init__(self, locRun, eventsPerPlayout, pathType, cpParm=0.7071, numPlays=1):
        nParallel = locRun.nParallel
        numPlaysRemaining = numPlays
        # In the bandit formula, this parameter controls tendency to explore new paths vs exploiting existing ones.
        # 1/sqrt(2) is the default (from the Browne paper).  cpParm=0 would cause tree to only follow best path, and
        # cpParm >> typical max values of MI (i.e. cpParm >> 1) would cause it to always balance number of visits
        # independent of score.  Although, 1/sqrt(2) is the default in the paper, they do say that adjustments are
        # likely needed.  One can use the banditPenaltyDiff utility function in the code to understand the impact of
        # different values of cpParm.

        while numPlaysRemaining > 0:
            # make list of playouts along with a prepared dataset for each.
            playList = [playout(locRun, eventsPerPlayout, pathType, cpParm)
                        for _ in range(min(numPlaysRemaining, nParallel))]

            # == The following line just does single-threaded data prep ===================
            # datasetList = [i.prepData() for i in playList]  # create dataset for each playout
            # == The following four lines do multi-threaded data prep ====================
            myPool = mp.Pool(processes=nParallel)  # make multiprocess pool with nParallel processes
            datasetList = myPool.map(playout.prepData, playList)  # fill a list the score results from the processes
            myPool.close()  # close/cleanup the pool
            myPool.join()  # close/cleanup the pool

            # Process sets in parallel to get nParallel scores.
            myPool = mp.Pool(processes=nParallel)  # make multiprocess pool with nParallel processes
            scores = myPool.map(go, datasetList)  # fill a list the score results from the processes
            myPool.close()  # close/cleanup the pool
            myPool.join()  # close/cleanup the pool

            for i, val in enumerate(playList):  # loop over playouts in the playout list and update score and paths

                val.score = scores[i]  # attach the score to the path

                updateNodes(val.path, val.score)  # update nodes by passing current path and its score

                locRun.pathStatsColl.pathStatsUpdate(val.path, val.score)  # add path information

                locRun.playoutList.append(val)  # add current playout to the run's playout list

            numPlaysRemaining -= nParallel

In [None]:
class playout:
    """one pass through the data along a specific path"""

    def __init__(self, locRun, eventsPerPlayout, pathType, cpParm):
        self.pathType = pathType
        self.cpParm = cpParm
        self.dataset = locRun.dataset
        self.path = path(locRun, self.pathType, self.cpParm)
        # Each playout gets its own copy of the nodes in the playout.
        # This is so that future playouts won't change the information in this playout's copy.
        # This allows us to see the all the nodes' info for each playout at the time the playout was run.
        if pathType[0] != -1:  # Not compressed-tree mode, so playlist gets copy of nodes along the path
            self.nodesList = [i.__copy__() for i in self.path.nodesList]
        else:  # Compressed-tree mode.  Playlist's node copy should be of summary nodes.
            self.nodesList = list(locRun.nodesColl.smry.values())

        # Most likely, if you're running fixed-path mode, you'll be calling the same run with multiple playouts and
        # comparing results across playouts.  So the code will assume that you want each playout to begin in the
        # standard starting state (i.e., with all gates open).  If you don't want that, then just comment out this line.
        if pathType[0] >= 0: openAllGates(self.nodesList)  # A pathType>=0 means fixed path running.

        self.nPathsOpen = pathsCount(self.nodesList)
        self.score = -999  # just set to an initial value that makes clear that it hasn't been calculated yet.
        if eventsPerPlayout != 0:  # User requested specific number of sig/bkd events
            self.nEvents = eventsPerPlayout  # nSig and nBkd each equals nEvents (so tot events = 2*nEvents)
        else:  # 0 means use all events.  Since nsig must equal nbkd, set the number from the smaller set.
            self.nEvents = min(len(self.dataset.sigWts), len(self.dataset.bkdWts))

    def prepData(self):
        """Removes from the data the columns corresponding to unused variables and then selects the
        requested number of events at random"""
        # Trailing letters on "sig" and "bkd" just helps keep track of steps in manipulating
        sigA = np.array(self.dataset.sigVars)
        bkdA = np.array(self.dataset.bkdVars)

        # Remove from the data the variables that aren't being used in this playout.
        sigA = np.delete(sigA, self.path.excVars, 1)
        bkdA = np.delete(bkdA, self.path.excVars, 1)

        sigB = sigA.tolist()
        bkdB = bkdA.tolist()

        # Now to weighted selection of nEvents sig and nEvents bkd
        # append weights as last column
        sigC = MLP.joint_space(sigB, self.dataset.sigWts)
        bkdC = MLP.joint_space(bkdB, self.dataset.bkdWts)

        # Now do a weighted sampling
        sigD = weightedSample(sigC, self.nEvents, self.dataset.sigSampleEff, self.dataset.sigWtMax)
        bkdD = weightedSample(bkdC, self.nEvents, self.dataset.bkdSampleEff, self.dataset.bkdWtMax)

        # remove the weights from the variables file.  Identify the weights column, then remove it.
        # There isn't a +1 in identifying the weights column because kept variables/nodes are counted from zero
        weightColumn = len(self.path.incVars)
        sigD = np.delete(sigD, weightColumn, 1)
        bkdD = np.delete(bkdD, weightColumn, 1)

        sigD = sigD.tolist()
        bkdD = bkdD.tolist()

        return sigD, bkdD

In [None]:
def weightedSample(origSample, nReq, samplingEff, weightMax):
    """From a sample of events (with weights as last column for each event), do a weighted sampling and
    return a sample with the requested number of events"""
    # figure out how big a sample of events we'll need in order, after accounting for weights, to wind up
    # with the number of events the user requested.  Leave some headroom, since you don't know how many events
    # culling will actually return.  The sampling eff is only the average yield you expect.
    nEvtsInSample = len(origSample)
    # The weight col num is one less than the length of each row, since weights are last col and col num start from 0.
    weightColumn = len(origSample[0]) - 1
    sampleSize = int(1.20 * nReq / samplingEff)  # scale up request to ensure enough events
    if sampleSize > nEvtsInSample:  # num requested is too close to full sample size.  Just return original sample.
        resultSample = origSample
    else:
        preWeightSel = random.sample(origSample, sampleSize)  # Get a sample that you'll do weighted selection on
        resultSample = []
        nSelected = 0
        i = 0
        while nSelected < nReq:
            if preWeightSel[i][weightColumn] / weightMax > np.random.uniform(0, 1):
                resultSample.append(preWeightSel[i])
                i += 1
                nSelected += 1
            else:
                i += 1
    return resultSample

In [None]:
def go(locData):
    """make one pass through the data and calculate MI"""
    sigSet = locData[0]
    bkdSet = locData[1]
    sigAns = []
    bkdAns = []
    sigAns += [[1]] * len(sigSet)
    bkdAns += [[0]] * len(bkdSet)

    # It's unclear whether or not we should only use k=1.
    # I had seen some early signs that k>1 will lead to increasing
    # MI even when adding just random numbers to a dataset unless
    # the density is high enough.  That was earlier code though,
    # so it's worth revisiting.
    #    ksgKParameter = 1 if len(sigSet[0]) == 1 else 3
    ksgKParameter = 1
    sigBkdMerge = (np.concatenate((sigSet, bkdSet), axis=0)).tolist()
    score = MLP.mi_binary(sigBkdMerge, sigAns + bkdAns, ksgKParameter)
    return score

In [None]:
def banditScore(locNode, cpParm):
    """Scores for making an include/exclude choice at a node.  The higher the score, the better the choice."""
    nVisitsInc = locNode.nVisitsIncPath
    nVisitsExc = locNode.nVisitsExcPath
    nVisitsBoth = nVisitsInc + nVisitsExc
    meanScoreInc = locNode.sumScoreIncPath / nVisitsInc
    meanScoreExc = locNode.sumScoreExcPath / nVisitsExc

    assert nVisitsInc > 0, 'Too few visits on include side to use bandit formula'
    assert nVisitsExc > 0, 'Too few visits on exclude side to use bandit formula'

    banditInc = meanScoreInc + 2 * cpParm * np.sqrt(2 * np.log(nVisitsBoth) / nVisitsInc)
    banditExc = meanScoreExc + 2 * cpParm * np.sqrt(2 * np.log(nVisitsBoth) / nVisitsExc)

    return banditInc, banditExc

In [None]:
def getNodeID(varList, varNum):
    """Given a list of included variable numbers and the variable number associated with the node, return the nodeID.
    The nodeID is such that when represented in binary, it should just show a 1 for each variable number that was
    included on the path taken from the top of the tree to get the node.  Note that two different nodes in a tree could
    have the same nodID.  However, they will be associated with different variables.  So, to uniquely identify a node,
    you need its nodeID and the variable it's associated with."""
    nodeID = 0
    for i in varList:
        # When computing the ID, only use variables that are higher in the tree than the current variable.
        # So, you only include vars with a number that's less than (not <=) the var num of the node whose ID you want.
        if i < varNum: nodeID += 2 ** i
    return nodeID

In [None]:
def updateNodes(locPath, score):
    """Given a path and a score, update the the information stored
    for each node in the node collection.  This means both in the nested
    dictionary for all nodes and also the list of summary nodes that are
    one per variable"""
    # Update the nodes that were in the path.  These are refs to the nodes that
    # are in the full nodes dictionary, so updating the nodes in the path will
    # update the relevant nodes in the full nested dictionary.
    for iNode in locPath.nodesList:
        if iNode.varNum in locPath.incVars:  # See if this node's variable is listed in the path's included variables
            iNode.nodeStatsUpdate(score, True)  # Node's var was included in path
        else:
            iNode.nodeStatsUpdate(score, False)  # Node's var was NOT included in path

    # Loop over the variables and for each one, update the stats for the summary node.
    # In practice, this just means making new summary nodes, filling the values and replacing the old ones.
    smryDict = locPath.locRun.nodesColl.smry
    for iVar in locPath.allVars:  # Loop over all the variables in the dataset, and update the smry node for each var
        varNodesDict = locPath.locRun.nodesColl.all[iVar]
        varName = locPath.locRun.dataset.varNames[iVar]
        smryNode = node(iVar, varName, locPath.locRun, iVar)  # make summary node (use varnum as its ID)
        for key, value in varNodesDict.items():  # Loop over all existing tree nodes for this variable
            smryNode.nodeStatsAdd(value)  # Add the tree node's info to the summary node's info
        # for the smry nodes you have to call setgates.  For the tree nodes, it gets done in the call to nodestatsupdate
        smryNode.setGates()
        smryDict[iVar] = smryNode
    return

In [None]:
def gatesStatus(locNodeList):
    """"Given a nodelist, it return a coded list that shows the gate
    status for all the nodes. Note that this is coded in base4 so that
    in each place, a 0 corresponds to completely closed (which should
    never happen), a 1 corresponds to only the exclude gate open, a 2
    corresponds to only the include gate is open, and a 3 corresponds to
    both gates open.  So, if you pass it a list of 20 nodes, you'll get
    a large integer that when represented in base-4 will show 0,1,2,3 in
    each place indicating the gate status of all 20 nodes."""
    statusCode = 0
    for i in range(len(locNodeList)):
        inode = locNodeList[i]
        nodeStatus = 0
        if inode.excPathOpen: nodeStatus += 1
        if inode.incPathOpen: nodeStatus += 2
        statusCode += nodeStatus * 4 ** i
    return statusCode


def openAllGates(locNodeList):
    """"Given a nodelist, open both gate for each node."""
    for currNode in locNodeList:
        currNode.incPathOpen = True
        currNode.excPathOpen = True
    return

In [None]:
def meanSigma(nEntries, sumVals, sumValsSq):
    """Given n, sum and sum of squares, return mean, and error on mean"""
    if nEntries == 0:
        mean = 0
        errOnMean = 0
        stdev = 0
    else:
        mean = sumVals / nEntries
        variance = (sumValsSq / nEntries) - mean ** 2
        stdev = variance ** 0.5
        errOnMean = stdev / (nEntries ** 0.5)
    return mean, errOnMean, stdev

In [None]:
def replayBest(locRun, nBest, nEvents):
    """Replay the paths from the top nBest playouts for a given run.  Use nEvents in each new playout.
     Unfortunately, the current code structure does not readily allow for these to be run in parallel."""
    # Playout list sorted by score
    allPlayoutsSorted = sorted(locRun.playoutList, key=lambda element: element.score, reverse=True)
    # Get list of top nBest playouts from the run
    topPlayouts = allPlayoutsSorted[:nBest]
    for currPlayout in topPlayouts:
        currPath = currPlayout.path
        varNumList = currPath.incVars
        playouts(locRun, nEvents, varNumList)
    return

In [9]:
def pathsCount(nodeList):
    """Given a list of nodes, count the total number of paths still
    available by looking at which gates are closed"""
    totalPaths = 1
    for i in nodeList:
        nPathsCurNode = 0  # start with zero paths through node
        if i.incPathOpen: nPathsCurNode += 1  # add 1 if inc side open
        if i.excPathOpen: nPathsCurNode += 1  # add 1 if exc side open
        totalPaths = totalPaths * nPathsCurNode
    return totalPaths

In [None]:
def textReports(locRun, playoutsToReport=100):
    """Function to call all the other main text-based report functions"""

    print('Run Info')
    runReport(locRun)
    print('\n', '=' * 80, '\n')

    print('Summary-Nodes Info')
    smryNodesReport(list(locRun.nodesColl.smry.values()))
    print('\n', '=' * 80, '\n')

    print('Tree-Nodes Info')
    treeNodesReport(locRun)
    print('\n', '=' * 80, '\n')

    print('Paths Info (sorted by pathId)')
    pathsReport(locRun.pathStatsColl)
    print('\n', '=' * 80, '\n')

    print('Paths Info (sorted by mean score)')
    pathsReport(locRun.pathStatsColl, 'mean')
    print('\n', '=' * 80, '\n')

    print('First', playoutsToReport, 'Playouts (0 means all)')
    playoutsReport(locRun, playoutsToReport, False)
    print('\n', '=' * 80, '\n')

    print('Highest', playoutsToReport, 'Playouts (0 means all)')
    playoutsReport(locRun, playoutsToReport, True)
    print('=' * 80)

In [None]:
def runReport(locRun):
    """Print a summary of a run"""
    locDataset = locRun.dataset

    print('gateEvalParm =', locRun.minVisitsForGateEval, 'nodeEvalParm =', locRun.minVisitsForNodeEval,
          'nParallel =', locRun.nParallel)
    print("{0:<15s} {1:>20s} {2:>7s} {3:>7s} {4:>7s}".
          format("Dataset Info:", "dataset name", "nVar", "nSig", "nBkd"))

    print("{0:<15s} {1:>20s} {2:>7d} {3:>7d} {4:>7d}".
          format(" ", locDataset.setName, len(locDataset.varNums), len(locDataset.sigVars), len(locDataset.bkdVars)))
    print('-' * 60)
    for i in range(len(locDataset.varNums)):
        print("{0:>12s} {1:3d} {2:<48s}".format(" ", locDataset.varNums[i], locDataset.varNames[i]))
    print('-' * 40)
    print('Number of playouts = ', len(locRun.playoutList))
    print('Number of fallback paths = ', locRun.nFallbackPath)
    return

In [None]:
def pathsReport(pStatsColl, sortBy=None):
    """Sort the path info report by a given sort key"""
    # first convert the dictionary into a sorted list.  This will be sorted by pathIdStr
    if sortBy is None:
        sortBy = []
    sortedListTuple = sorted(pStatsColl.pathStatsDict.items(), key=operator.itemgetter(0))
    sortedList = [row[1] for row in sortedListTuple]  # this is a list of pathStats

    if sortBy == 'nVisits':
        resortedList = sorted(sortedList, key=lambda myPathStats: myPathStats.nVisits, reverse=True)
    elif sortBy == 'mean':
        resortedList = sorted(sortedList, key=lambda myPathStats: myPathStats.mean, reverse=True)
    elif sortBy == 'maxScore':
        resortedList = sorted(sortedList, key=lambda myPathStats: myPathStats.maxScore, reverse=True)
    else:
        resortedList = sorted(sortedList, key=lambda myPathStats: myPathStats.pathIdStr, reverse=True)

    print("{0:>50s}{1:>8s}{2:>9s}{3:>9s}{4:>6s}{5:>9s}".
          format("987654321098765432109876543210", "nVisits", "maxScore", "minScore", "stdev", "mean"))
    for i in resortedList:
        """Print one entry line in report for the current element"""
        print("{0:>50s}{1:>8d}{2:>9.2f}{3:>9.2f}{4:>6.2f}{5:>5.2f}{6:>3s}{7:>5.2f}".
              format(i.pathIdStr, i.nVisits, i.maxScore, i.minScore, i.stdev, i.mean, "+/-", i.errOnMean))
    return

In [None]:
def smryNodesReport(nodeList):
    """Print a report for a given set of summary nodes."""
    print("{0:>10s} {1:>4s} {2:>5s} {3:>5s} {4:>11s} {5:>11s} {6:>7s} {7:>7s}".
          format("var", "var", "times", "times", "score when", "score when", "incGate", "excGate"))
    print("{0:>10s} {1:>4s} {2:>5s} {3:>5s} {4:>11s} {5:>11s} {6:>7s} {7:>7s}".
          format("name", "num", "V Inc", "V Exc", "var inc", "var exc", "Open", "Open"))
    for i in nodeList:
        print("{0:>10s} {1:>4d} {2:>5d} {3:>5d} {4:>4.2f}{5:3s}{6:>4.2f} {7:>4.2f}{8:3s}{9:>4.2f} {10:>7} {11:>7}".
              format(i.varName, i.varNum,
                     i.nVisitsIncPath, i.nVisitsExcPath,
                     i.nodeOverallScore(True)[0], '+/-', i.nodeOverallScore(True)[1],
                     i.nodeOverallScore(False)[0], '+/-', i.nodeOverallScore(False)[1],
                     i.incPathOpen, i.excPathOpen))

    currPathsOpen = pathsCount(nodeList)
    pctPathsOpen = 100.0 * (currPathsOpen / 2 ** len(nodeList))
    print("{0:35s} {1:.3e} {2:1s} {3:4.1f} {4:1s}".
          format("Current number of paths available = ", currPathsOpen, "(", pctPathsOpen, "%)"))
    return

In [None]:
def treeNodesReport(locRun):
    """Print a report on the gates for the in-tree nodes.  This is a snapshot of their status,
    and so one likely wants to print it at multiple times during a set of playouts using
    something like 'if playoutNum % x ==0: print...' """
    nodesD = locRun.nodesColl.all
    varNumList = locRun.varNums
    varNameList = locRun.varNames

    gateSum = {}  # Make a dictionary to hold the count of how many nodes for each variable have which gate code values
    for iVar in varNumList:  # Loop over all vars
        gateSum[iVar] = {}  # Initialize dictionary to hold counts for each gate-status value
        layerNodes = nodesD[iVar]  # Dictionary of the nodes for the current variable
        for _ in range(4): gateSum[iVar][_] = 0  # Initialize the count for how many nodes are set to each gate value
        for iNode in list(layerNodes.values()):  # Loop over nodes in the layer
            gateCode = gatesStatus([iNode])
            gateSum[iVar][gateCode] += 1

    print('Number of nodes with each gate status for each variable after', len(locRun.playoutList), 'playouts')
    print("{0:>7s} {1:>20s} {2:>9s} {3:>9s} {4:>9s} {5:>9s} {6:>9s} {7:>9s}".
          format(' ', ' ', 'max', 'nodes', 'nodes w/', 'nodes w/', 'nodes w/', 'nodes w/'))
    print("{0:>7s} {1:>20s} {2:>9s} {3:>9s} {4:>9s} {5:>9s} {6:>9s} {7:>9s}".
          format('varNum', 'varName', 'nodes', 'in lyr', 'status=3', 'status=2', 'status=1', 'status=0'))
    for i in varNumList:
        print("{0:>7d} {1:>20s} {2:>9d} {3:>9d} {4:>9d} {5:>9d} {6:>9d} {7:>9d}"
              .format(i, varNameList[i], 2 ** i, len(list(nodesD[i].values())),
                      gateSum[i][3], gateSum[i][2], gateSum[i][1], gateSum[i][0]))
    return

In [None]:
def playoutsReport(locRun, numToReport, sortByScore):
    """Print a summary of playouts, possibly sorted by score.  Only print numToReport rows, since
    there could be a large number of playouts"""
    if sortByScore:
        myList = sorted(locRun.playoutList, key=lambda element: element.score, reverse=True)
    else:
        myList = locRun.playoutList

    if numToReport > len(myList) or numToReport == 0: numToReport = len(myList)

    print("{0:>4s} {1:>32s} {2:>32s} {3:>6s} {4:>7s} {5:>10s} {6:>9s} {7:>8s}".
          format(" ", "gates status for each var", "inc/exc status for each var", "play", " ", "number", " ",
                 " "))
    print("{0:>4s} {1:>32s} {2:>32s} {3:>6s} {4:>7s} {5:>10s} {6:>9s} {7:>8s}".
          format("#", "987654321098765432109876543210", "987654321098765432109876543210", "type", "cpParm",
                 "Paths Open", "nEvents", "score"))
    for i in range(0, numToReport):
        currPlayout = myList[i]
        ptype = currPlayout.pathType[0]
        if ptype == -2:
            ptypeStr = str(-2)
        elif ptype == -1:
            ptypeStr = str(-1)
        else:
            ptypeStr = 'fixed'
        cpParm = currPlayout.cpParm
        codedGateStatus = np.base_repr(gatesStatus(currPlayout.nodesList), 4)  # represent as string in base 4.
        # the npaths open value only makes sense for mode = -1
        nPathsOpenToReport = currPlayout.nPathsOpen if ptypeStr != 'fixed' else 1
        print("{0:>4d} {1:>32s} {2:>32s} {3:>6s} {4:>7.3f} {5:>10.3e} {6:>9d} {7:>8.2f}".
              format(i, codedGateStatus, currPlayout.path.pathId()[0], ptypeStr, cpParm, nPathsOpenToReport,
                     currPlayout.nEvents, currPlayout.score))
    return

In [None]:
def plotReports(locRun, pdfFileName=None):
    """Function to call all the other main plot-based report functions"""

    # Make plot of mean of score of path taken vs playout number (i.e., vs. time)
    # Also make plot of number of paths still open vs playout number (i.e., vs. time)
    meanResults, nPathsResults = pathTracker(locRun)
    x = [i for i in range(len(meanResults))]
    y1 = [meanResults[i] for i in range(len(meanResults))]
    y2 = [math.log2(nPathsResults[i]) for i in range(len(nPathsResults))]

    pdf = None
    if pdfFileName is not None: pdf = PdfPages(pdfFileName)  # Open file for saving plots

    figA1 = plt.figure()
    plt.title('mean path score vs playout number')
    plt.scatter(x, y1)
    plt.show(block=False)

    figA2 = plt.figure()
    plt.title('Log (base 2) of paths open vs playout number')
    plt.scatter(x, y2)
    plt.show(block=False)

    if pdfFileName is not None:
        pdf.savefig(figA1)  # Save mean score of path vs playout number
        pdf.savefig(figA2)  # Save number of paths remaining vs playout number

    plt.close(figA1)
    plt.close(figA2)

    # Make plot of mean of gate status for each variable vs playout number (i.e., vs. time) and also
    # sig/bkd overlay plots for each variable
    for i in range(len(locRun.dataset.varNums)):
        figGateCodes = plotGateTracker(locRun, i)
        plt.show(block=False)
        if pdfFileName is not None: pdf.savefig(figGateCodes)
        plt.close(figGateCodes)

    for i in range(len(locRun.dataset.varNums)):
        figV = plotVarCompare(locRun.dataset, i)
        plt.show(block=False)
        if pdfFileName is not None: pdf.savefig(figV)
        plt.close(figV)

    if pdfFileName is not None: pdf.close()  # close pdf file

    return

In [None]:
def pathTracker(locRun):
    """Makes lists of quantities for all playouts.  So, for example, it
    returns the number of paths that remained open as a function of
    playout number.  Typically this would be used to make a plot showing
    how some quantity evolved during the run over many playouts."""
    curDict = locRun.pathStatsColl.pathStatsDict
    meanResults = []
    nPathsResults = []

    for i in locRun.playoutList:
        curPath = i.path

        curPathIdStr = curPath.pathId()[0]
        curStats = curDict.get(curPathIdStr)
        meanScore = curStats.mean
        meanResults.append(meanScore)
        nPathsResults.append(i.nPathsOpen)
    return meanResults, nPathsResults

In [None]:
def plotGateTracker(locRun, varNum):
    """plot of gate status vs playout for a given variable"""
    statusList = []
    for i in locRun.playoutList:
        statusString = np.base_repr(gatesStatus(i.nodesList), 4)  # represent as string in base 4.
        currStatus = statusString[
            -1 * (varNum + 1)]  # working our way backwards through the string from -1 to -(varnum+1)
        statusList.append(int(currStatus))  # convert the status character (0,1,2,3) for the gate back to an int.

    x = [i for i in range(len(locRun.playoutList))]
    y = statusList
    myFig = plt.figure()
    plt.title('Gate status vs playout for var ' + str(varNum) + ': ' + locRun.dataset.varNames[varNum])
    plt.ylim(0.0, 3.1)
    plt.scatter(x, y)
    return myFig

In [None]:
def plotVarCompare(locDataset, varNum):
    """Overlays sig and bkd hists for a given varnum in a given dataset"""
    sig = [row[varNum] for row in locDataset.sigVars]
    bkd = [row[varNum] for row in locDataset.bkdVars]

    myFig = plt.figure()
    plt.title(str(varNum) + ': ' + locDataset.varNames[varNum])
    # plt.hist(sig, bins='auto', alpha=0.5, label='sig')
    # plt.hist(bkd, bins='auto', alpha=0.5, label='bkd')
    plt.hist([sig, bkd], bins=50, alpha=0.5, label=['sig', 'bkd'])
    plt.legend(loc='upper right')
    return myFig

In [None]:
def banditBonusDif(nA, nB, cpParm):
    """Returns the difference in bonus in the bandit formula for a given number of visits
    to each side, and a given cpParm.  This isn't currently used in the main code, but it's
    a useful utility routine for determining what value of cpParm makes sense for a problem."""
    nTot = nA + nB
    banditA = 2 * cpParm * np.sqrt(2 * np.log(nTot) / nA)
    banditB = 2 * cpParm * np.sqrt(2 * np.log(nTot) / nB)
    return banditA, banditB, banditA-banditB