# Monte Carlo Tree Search for Attacker Traversal

###### First: Import the necessary packages

In [1]:
import numpy as np
from collections import deque, defaultdict
from scipy.stats.distributions import poisson, expon
from scipy.stats import truncexpon
import copy
from matplotlib import pyplot as plt
import operator
%matplotlib inline

###### MCTS Class

This class takes as an input the main game parameters and provides methods to determine the optimal
policy as well as output the expected utility.  

In [58]:
class MCTSAttacker(object):
        """
        Class for solving for attacker traversal strategy
        
        Parameters
        ----------
        
        lambdaB : 2-d array
            Matrix giving the rate of benign message transmission
        
        tbar : float
            Time the attacker enters the network
        
        tau : float
            Defender Window Length
            
        d : float
            Defender threshold
        
        gamma : float
            Time Between Attacker Moves
            
        lrdetector : bool
            If True, d is a threshold for a likelihood ratio.  Otherwise,
            it definies a threshold for the likelihood of the data under 
            normal conditions.
            
        c : float
            Tuning parameter for MCTS algorithm
        
        Rfirst : 2-d array
            A matrix the represents the reward the attacker earns the first time
            he sends a message.  i.e. Rfirst[i,j] is the reward the attacker earns 
            the first time the attacker sends a message from i to j,
            
        Radditional : 2-d boolean array
            Radditional[i,j] should be true if the attacker earns a reward everytime
            he sends a message from i to j, not just the first.  
            
        Attributes
        ----------
        vc : dict
            A nested dictionary (vc stands for values/counts). vc[s]. where
            s is a string representing a possible information state of the attacker,
            gives a dictionary keyed by possible actions the attacker can take in 
            vc[s].  vc[s][a], where 'a' is a possible action the attacker can take in
            state s is a tuple (v,c) where v is the average total reward from choosing
            action a in information state s and c is the number of times a was selected
            in state s.  vc[s]['n'] also gives the total number of times vc[s] was 
            visited.
        """
        def __init__(self, lambdaB, tbar, tau, d, gamma, lrdetector, c, Rfirst, Radditional, 
                     lookahead, **kwargs):
            self.lambdaB = lambdaB
            self.tbar = tbar
            self.tau = tau
            self.d = d
            self.c = c
            self.gamma = gamma
            self.lrdetector = lrdetector
            self.Rfirst = Rfirst
            self.Radditional = Radditional
            self.vc = {}
            self.state_visits = {}
            # Below is all precomputation to speed up the MCTS
            self._first_move = self._get_firstmoves()
            self._event_rate = 1. / np.sum(lambdaB)
            self._normed_rowsums = np.sum(lambdaB, axis=1) / float(np.sum(lambdaB))
            #self._normed_rowsums = self._normed_rowsums / np.sum(self._normed_rowsums)
            self._n_hosts = lambdaB.shape[0]
            self._hosts_arr = np.arange(self._n_hosts)
            self._network_state = self._initialize_network_state()
            self._action_counts = defaultdict(int)
            # _network_state at a given t is all messages sent in the
            # last t-tau
            self.infected = [0]
            if not lrdetector:
                self._intrusion_criteria = self._simple_detector
            self.empirical_transitions =defaultdict(list)
            self._lognfact = self._genlognfact()
            self.lookahead = lookahead
            
        def _genlognfact(self):
            lnf = {}
            for i in range(2,500):
                lnf[i-1] = np.sum(np.log(np.arange(1,i)))
            return lnf
        

        
        def _get_firstmoves(self):
            """
            Returns a list of strings of the attacker's first possible moves.
            """
            return [(0,x) for x in np.nonzero(self.lambdaB[0,:])[0]]
        
        def _initialize_network_state(self):
            network_state = {}
            for i in self._hosts_arr:
                for j in self._hosts_arr:
                    if self.lambdaB[i,j]>0:
                        network_state[(i,j)] = []
            return network_state
        
        def _simple_detector(self):
            logprob = 0
            for key, val in self._network_state.iteritems():
                count = len(val)
                #To account for low likelihoods due to inactivity
                if count >= self.tau*self.lambdaB[key]:
                    if count > 0:
                    #logprob += poisson.logpmf(count, self.tau * self.lambdaB[key])
                        logprob += np.log(self.lambdaB[key] * self.tau) * count \
                            - (self.lambdaB[key]*self.tau) - self._lognfact[count]
                        #np.sum(np.log(np.arange(1,count+1)))
                    else:
                        logprob +=-self.lambdaB[key]*self.tau
                else:
                    if count >0:
                        logprob += np.log(count)*count -count - self._lognfact[count]
            return logprob

        
        def get_action_at_state(self, informationstate, nmin, infected_at_state,t):
            """
            Loops over simulate function to find the optimal attacker policy.
            
            Parameters
            ----------
            
            varlimit : float
                The simulation stops when the running variance of the maximum 
                expected utility of the last 'varwindow' MCTS simulations
                is below varlimit.
                
            varwindow : float
                See varlimit
            """
            n=0
            maxall = 40
            while n < nmin:
                if n < nmin/5.:
                    self.c=100.
                if (n >nmin/5.) and (n<2*nmin/5.):
                    self.c = 1.
                if (n>2*nmin/5.) and (n<3*nmin/5.):
                    self.c=100.
                if (n>3*nmin/5.):
                    self.c = 1.
                #self.c = np.random.randint(200)
                #self.c = 50*(1+np.sin(.05*(n+1)))
                #self.c = 100 -n*(100./nmin)
                #self.c = -0.0007840*n**2+0.3920*n+1.000
                #self.c = -(4*maxall/float((nmin**2)))*n**2 + (4*maxall/float(nmin))*n
                self.infected = copy.deepcopy(infected_at_state)
                if informationstate =='start,':
                    self._draw_initial_state()
                else:
                    ix = np.random.randint(len(self.empirical_transitions[informationstate]))
                    self._network_state = copy.deepcopy(self.empirical_transitions[informationstate][ix])
                    
                self._action_counts = defaultdict(int)
                self.simulate(informationstate, t, False)
                n += 1
#                 if n == nmin -1:
#                     print "The current rewards are", self.vc[informationstate]
#                     print "\n"
#                     print "Do you want to keep sampling"
#                     answer = raw_input()
#                     if answer == 'y':
#                         n = int(raw_input("what Would you like the new N to be?"))
#                         self.c = float(raw_input("What would you like the new C to be?"))
            try :
                self.empirical_transitions.pop(informationstate)
            except KeyError:
                pass
            bestaction = 'nm'
            besteu = 0
            for key, val in self.vc[informationstate].iteritems():
                if val[0] > besteu:
                    besteu = val[0]
                    bestaction = key
            print bestaction
            return bestaction
        
        def _draw_initial_state(self):
            self._network_state = self._initialize_network_state()
            t = np.random.exponential(self._event_rate)
            while t < self.tbar:
                nextsender = np.random.choice(self._n_hosts, p=self._normed_rowsums)
                nextreceiver = np.random.choice(self._n_hosts, 
                                                p=self.lambdaB[nextsender] / 
                                                float(np.sum(self.lambdaB[nextsender])) )
                self._network_state[(nextsender, nextreceiver)].append(t)
                t += np.random.exponential(self._event_rate)
                self._remove_old(t)
        
        def _remove_old(self, t):
            for key, val in self._network_state.iteritems():
                ix = np.searchsorted(val, t - self.tau)
                self._network_state[key] = val[ix:]
            
        def simulate(self, informationstate, t, alarm=False, depth= 1):
            if (alarm) or (depth>self.lookahead):
                return 0
            if informationstate not in self.vc.keys():
                availableactions = self._get_available_actions(self.infected)
                self.vc[informationstate] = {x: [0.,0.] for x in availableactions}
                self.vc[informationstate]['nm'] = [0.,0.] 
                self.state_visits[informationstate] = 0
                return self.rollout(informationstate, t, depth=depth)
            action = 'nm'
            weighted_v_max = 0
#             #Boltzman Selection
#             utils = [v[0] for v in self.vc[informationstate].values()]
#             boltzman = np.exp(self.c * np.asarray(utils))
#             probs = boltzman /sum(boltzman)
#             action = self.vc[informationstate].keys()[np.random.choice(len(utils), p=probs)]
            
            ##3 Uncomment for UCT
            for key, val in self.vc[informationstate].iteritems():
                if key != 'n':
                    v = val[0] 
                    correction = np.log(self.state_visits[informationstate] +1) / (val[1] + 1)
                    weighted_v = v + self.c * correction**.5
                    if weighted_v >= weighted_v_max:
                        weighted_v_max = weighted_v
                        action = key
            self._action_counts[action] += 1
            if action != 'nm':
                if action[1] not in self.infected:
                    self.infected.append(action[1])
                self._network_state[action].append(t)
                newinformationstate =  informationstate + str(action[0]) + '-' + str(action[1]) +',' # Update observation
            else:
                newinformationstate = informationstate + 'nm,' # Update observation
            thisis = copy.deepcopy(informationstate)
            self._remove_old(t)
            pdata = self._intrusion_criteria()
            if pdata < self.d:
                immediate_r = 0 
                alarm = True
                #print "action", action, "set off the alarm at", t
            else:
                immediate_r = self.get_reward(self._action_counts, action)
                alarm = self._sample_until_next_action(t) # Update state
            if ((not alarm) and (depth < 5)):
                self.empirical_transitions[newinformationstate].append(copy.deepcopy(self._network_state))
            newt = t+ self.gamma
            r = immediate_r + self.simulate(newinformationstate, newt, alarm, depth=depth+1)
            self.vc[informationstate][action][1] += 1
            self.vc[informationstate][action][0] += (r - self.vc[thisis][action][0]) \
                /float(self.vc[thisis][action][1])

            self.state_visits[thisis] += 1
            return r
                
            
        def _get_available_actions(self, infected):
            actions = []
            for u in infected:
                toadd = [(u,v) for v in self._hosts_arr if self.lambdaB[u,v] > 0]
                actions.extend(toadd)
            return actions
        
        def rollout(self, informationstate, t, alarm=False, depth=1):
            if alarm or (depth > self.lookahead):
                return 0
            else:
                availableactions = self._get_available_actions(self.infected)
                action = availableactions[np.random.choice(len(availableactions))]
                self._action_counts[action] += 1
                if action != 'nm':
                    if action[1] not in self.infected:
                        self.infected.append(action[1])
                    self._network_state[action].append(t)
                    newinformationstate = informationstate + str(action[0]) \
                        + '-' + str(action[1]) +',' # Update observation
                else:
                    newinformationstate = informationstate + 'nm,' # Update observation
                self._remove_old(t)
                pdata = self._intrusion_criteria()
                if pdata < self.d:
                    r = 0#self.rollout(informationstate, t, True)
                    alarm = True
                else:
                    r = self.get_reward(self._action_counts, action)
                    alarm = self._sample_until_next_action(t) # Update state
                if ((not alarm) and (depth < 5)):
                    self.empirical_transitions[newinformationstate].append(copy.deepcopy(self._network_state))
                newt = t + self.gamma
                return r + self.rollout(newinformationstate, newt, alarm, depth+1)
            
#         def _sample_until_next_action(self, t):
#             """
#             Draws samples until the attacker's next time to act.   First, it
#             samples from a distribution to determine if there is a message before
#             the next attacker action.  Then, conditional on the event occuring, it
#             samples the event time.
            
#              generates  a sample from
#             the exponential distirbuiton with parameter lambda truncated at T
#             """
#             deltat = 0
#             u = np.random.random()
#             benign_message = u < (1-np.exp(-1./self._event_rate * (self.gamma - deltat)))
#             # way slower
#             benign_messageold = u < expon.cdf(self.gamma - deltat, 
#                                                  scale = self._event_rate )
#             alarm = False
#             while ((benign_message) and (not alarm)):
#                 R = np.random.random()*(1-np.exp(-(self.gamma-deltat)*1./self._event_rate))
#                 change =-np.log(1-R)*self._event_rate
#                 deltat += change
#                 nextsender = np.random.choice(self._n_hosts, p=self._normed_rowsums)
#                 nextreceiver = np.random.choice(self._n_hosts, 
#                                                 p=self.lambdaB[nextsender] / 
#                                                 float(np.sum(self.lambdaB[nextsender])) )
#                 self._network_state[(nextsender, nextreceiver)].append(t + deltat )
#                 self._remove_old(t+deltat)
#                 alarm = self._intrusion_criteria() < self.d
#                 u2 = np.random.random()
#                 benign_message = u2 < (1-np.exp((-1./self._event_rate) * (self.gamma - deltat)))
#                 benign_messageold = u2 < expon.cdf(self.gamma - deltat, 
#                      scale = self._event_rate )
#             return alarm
                    
        def _sample_until_next_action(self, t):
            """
            Draws samples until the attacker's next time to act.   First, it
            samples from a distribution to determine if there is a message before
            the next attacker action.  Then, conditional on the event occuring, it
            samples the event time.
            
             generates  a sample from
            the exponential distirbuiton with parameter lambda truncated at T
            """
            deltat = 0
            u = np.random.random()
            #benign_message = np.random.random() < expon.cdf(self.gamma - deltat, 
            #                                                scale = self._event_rate )
            #way slower
            #benign_message = u < expon.cdf(self.gamma - deltat, 
            #                                                scale = self._event_rate )
            benign_message = u < (1-np.exp(-1./self._event_rate * (self.gamma - deltat)))
            #print(benign_message==benign_messagenew)

            alarm = False
            while ((benign_message) and (not alarm)):
                R = np.random.random()*(1-np.exp(-(self.gamma-deltat)*1./self._event_rate))
                change =-np.log(1-R)*self._event_rate
                deltat += change
                nextsender = np.random.choice(self._n_hosts, p=self._normed_rowsums)
                nextreceiver = np.random.choice(self._n_hosts, 
                                                p=self.lambdaB[nextsender] / 
                                                float(np.sum(self.lambdaB[nextsender])) )
                self._network_state[(nextsender, nextreceiver)].append(t + deltat )
                self._remove_old(t+deltat)
                alarm = self._intrusion_criteria() < self.d
                u2 = np.random.random()
                benign_message = u2 < (1-np.exp(-1./self._event_rate * (self.gamma - deltat)))
            return alarm
            
        def get_reward(self, counts, action):
            if action == 'nm':
                r = 0.
            elif counts[action] == 1:
                r = self.Rfirst[action]
            elif counts[action] > 1:
                if self.Radditional[action]==True:
                    r = self.Rfirst[action]
                else:
                    r = 0
            return r
                
        def get_strategy(self, nsteps, nmin):
            informationstate ='start,'
            infected=[0]
            moves = []
            for j in range(nsteps):
                #print(informationstate, infected)
                tinput = self.tbar +j*self.gamma
                next_move = self.get_action_at_state(informationstate, nmin, infected,tinput)
                if next_move != 'nm':
                    informationstate += str(next_move[0]) + '-' + str(next_move[1]) +','
                    if next_move[1] not in infected:
                        infected.append(next_move[1])
                else:
                    informationstate += next_move+','
                moves.append(next_move)
            return moves
        
        def sample_attacker_utility(self, attackerstrat):
            """
            Returns the attackers expected utility from playing attackerstrat
            and the defenders utility when an attacker is present.
            """
            self._draw_initial_state()
            t = self.tbar
            counts = defaultdict(int)
            alarm=False
            attackerreward  = 0
            for a in attackerstrat:
                if a != 'nm':
                    self._network_state[a].append(t)
                if self._intrusion_criteria() < self.d:
                    return attackerreward
                if not alarm:
                    counts[a] += 1
                    attackerreward  += self.get_reward(counts, a)
                    alarm = self._sample_until_next_action(t)
                    if alarm:
                        return attackerreward
                    else:
                        t += self.gamma
            return attackerreward
        
        def compute_attacker_eu(self, attackerstrat,N):
            sumus = 0
            for i in xrange(N):
                sumus += self.sample_attacker_utility(attackerstrat)
            return sumus/float(N)

            
#         def extract_strategy(self):
#             """
#             Using the entries of self.vc, this function returns
#             a list of moves in the order the attacker chooses his.
#             """
#             infostate = 'start,'
#             actions = []
#             while True:
#                 try:
#                     maxaction ='nm'
#                     maxvalue = 0
#                     for key, val in self.vc[infostate].iteritems():
#                         if val[0] > maxvalue:
#                             maxaction = key
#                             maxvalue = val[0]
#                     actions.append(maxaction)
#                     if maxaction != 'nm':
#                         infostate += str(maxaction[0]) + '-'+ str(maxaction[1]) +','
#                     else:
#                         infostate +='nm,'
#                     print(infostate)
#                 except KeyError:
#                     return actions
                    
            
            

In [59]:
params = {'Radditional' : np.array([[False,False,False],[False,False,True], [False,False,False]]),
          'Rfirst' : np.array([[0,0,0],[0,0,1], [0,0,0]]),
          'lambdaB' : np.array([[0,.5,0], [.5,0,.1],[.5,.5,0]]),
          'gamma' : 1.,
          'c' : 1000,
          'd' : -18,
          'tbar': 20,
          'tau': 5,
          'lrdetector' : False,
          'lookahead':30
          }
          

In [60]:
Model1.vc['start,0-1,']

{'nm': [4.7124735729386789, 473.0],
 (0, 1): [4.7130242825607072, 453.0],
 (1, 0): [4.7127071823204405, 362.0],
 (1, 2): [4.7074148296593181, 499.0]}

In [61]:
results = {}

In [62]:
Model1 = MCTSAttacker(**params)

In [63]:
results['experiment1'] = [copy.deepcopy(params), Model1.get_strategy(20,5000)]

nm
(0, 1)
(1, 2)
(1, 2)
(1, 2)
(1, 2)
nm
(1, 2)
(1, 2)
(1, 2)
(1, 2)
(1, 2)
nm
(1, 2)
(1, 2)
(1, 2)
(1, 2)
(1, 2)
(2, 0)
(1, 2)


In [71]:
params['gamma'] = .5
params['c'] = 20

In [72]:
Model2 = MCTSAttacker(**params)

In [None]:
results['experiment2'] = [copy.deepcopy(params), Model2.get_strategy(40,10000)]

In [67]:
Model1.compute_attacker_eu(results['experiment1'][1], 4000)

7.68075

In [68]:
Model2.compute_attacker_eu(results['experiment2'][1], 4000)

5.1387499999999999

In [70]:
Model2.compute_attacker_eu(os, 4000)

7.7395

In [327]:
Model2.gamma

0.5

In [None]:
params['gamma'] = 2.
Model3 = MCTSAttacker(**params)
results['experiment3'] = [copy.deepcopy(params), Model3.get_strategy(20,6000)]
Model3.compute_attacker_eu(results['experiment3'][1], 4000)

In [69]:
os = []
for x in results['experiment1'][1]:
    os.append(x)
    os.append('nm')

In [334]:
results

{'experiment1': [{'Radditional': array([[False, False, False],
          [False, False,  True],
          [False, False, False]], dtype=bool), 'Rfirst': array([[0, 0, 0],
          [0, 0, 1],
          [0, 0, 0]]), 'c': 10, 'd': -18, 'gamma': 1.0, 'lambdaB': array([[ 0. ,  0.5,  0. ],
          [ 0.5,  0. ,  0.1],
          [ 0.5,  0.5,  0. ]]), 'lookahead': 30, 'lrdetector': False, 'tau': 5, 'tbar': 20},
  [(0, 1),
   (1, 2),
   (1, 2),
   (1, 2),
   'nm',
   (1, 2),
   (1, 2),
   (1, 2),
   (1, 2),
   'nm',
   (1, 2),
   (2, 0),
   (1, 2),
   (1, 2),
   (1, 2),
   'nm',
   (1, 2),
   (0, 1),
   (1, 2),
   (1, 2)]],
 'experiment2': [{'Radditional': array([[False, False, False],
          [False, False,  True],
          [False, False, False]], dtype=bool), 'Rfirst': array([[0, 0, 0],
          [0, 0, 1],
          [0, 0, 0]]), 'c': 20, 'd': -18, 'gamma': 0.5, 'lambdaB': array([[ 0. ,  0.5,  0. ],
          [ 0.5,  0. ,  0.1],
          [ 0.5,  0.5,  0. ]]), 'lookahead': 30, 'lrdetecto

In [197]:
results['experiment2']

[{'Radditional': array([[False, False, False],
         [False, False,  True],
         [False, False, False]], dtype=bool), 'Rfirst': array([[0, 0, 0],
         [0, 0, 1],
         [0, 0, 0]]), 'c': 10, 'd': -18, 'gamma': 0.5, 'lambdaB': array([[ 0. ,  0.5,  0. ],
         [ 0.5,  0. ,  0.1],
         [ 0.5,  0.5,  0. ]]), 'lookahead': 30, 'lrdetector': False, 'tau': 5, 'tbar': 20},
 [(0, 1),
  (1, 2),
  (1, 2),
  (1, 2),
  'nm',
  'nm',
  'nm',
  'nm',
  (1, 2),
  'nm',
  (2, 1),
  'nm',
  (1, 2),
  (1, 2),
  'nm',
  'nm',
  (1, 2),
  'nm',
  'nm',
  (1, 2)]]