In [54]:
import numpy as np
np.set_printoptions(linewidth=150, precision=2, suppress=True)

In [62]:
class AddictionWorld():
    def __init__(self, max_time, n_trials, ETA=1, GAMMA=1, rewards=[], substances=[], verbose=False):

        self.max_time = max_time
        self.n_trials = n_trials
        self.reward = np.zeros((self.n_trials, self.max_time))
        self.addiction = np.zeros((self.n_trials, self.max_time))
        self.value = np.zeros((self.n_trials, self.max_time))
        self.prediction_error = np.zeros((self.n_trials, self.max_time))
        self.learning_rate = ETA
        self.discount_rate = GAMMA

        for args in rewards:
            self.add_reward(*args)
            
        for args in substances:
            self.add_substance(*args)

    def add_reward(self, reward, time, trials=None, verbose=False):
        if time >= self.max_time:
            raise ValueError(f"Tried to add reward on time {time} when max_time is {self.max_time}")
        
        if trials is None:
            trial_range = (0, self.n_trials)
        else:
            if trials[1] >= self.n_trials:
                raise ValueError(f"Tried to add rewards for trials {trials[0]}-{trials[1]} when there are only {self.n_trials}")
            trial_range = trials
        if verbose:
            print(f"Adding reward of {reward} on time {time} for trials {trial_range}")
        self.reward[trial_range[0]:trial_range[1], time] = reward
        print('Reward added')
        if verbose:
            print(self.reward)
            
    
    def add_substance(self, reward, addiction, time, trials=None, verbose=False):
        self.add_reward(reward, time, trials, verbose)
        
        if trials is None:
            trial_range = (0, self.n_trials)
        else:
            trial_range = trials
            
        if verbose:
            print(f"Adding addiction of {addiction} on time {time} for trials {trial_range}")
        self.addiction[trial_range[0]:trial_range[1], time] = addiction  
        print('Substance added')
        if verbose:
            print(self.addiction)
        

    def clear_reward(self, time, trials, verbose=False):
        if time >= self.max_time:
            raise ValueError(f"Tried to clear rewards on time {time} when max_time is {self.max_time}")

        if trials[1] >= self.n_trials:
            raise ValueError(f"Tried to clear rewards for trials {trials[0]}-{trials[1]} when there are only {self.n_trials}")

        if verbose:
            print(f"Clearing rewards on time {time} for trials {trials[0]}-{trials[1]}")
        self.reward[trial_range[0]:trial_range[1], time] = 0
        print('Reward cleared')
        if verbose:
            print(self.reward)

        
    def run_TD(self):
        """
        Function which simulates trials, updating values and recording prediction error
        
        ETA -> Learning rate
        GAMMA -> Discount rate for future reward
        """

        for trial in range(self.n_trials):
            for t in range(self.max_time-1):
                expected_value = self.value[trial][t]
                actual_value = self.reward[trial][t] + self.discount_rate * (self.value[trial][t+1])
                prediction_error = np.max([(actual_value - expected_value + self.addiction[trial][t]), 
                                           self.addiction[trial][t]])
                self.prediction_error[trial][t] = prediction_error
                if trial < (self.n_trials - 1):
                    self.value[trial+1][t] = expected_value + (self.learning_rate * prediction_error)
        print('TD learning complete')


In [63]:
aw = AddictionWorld(max_time=10, n_trials=50, ETA=0.5, GAMMA=.5, rewards=[(10, 4)], substances=[(5, 0.5, 8)])

aw.run_TD()
print(aw.value)

Reward added
Reward added
Substance added
TD learning complete
[[ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
 [ 0.    0.    0.    0.    5.    0.    0.    0.    2.75  0.  ]
 [ 0.    0.    0.    1.25  7.5   0.    0.    0.69  4.12  0.  ]
 [ 0.    0.    0.31  2.5   8.75  0.    0.17  1.38  4.81  0.  ]
 [ 0.    0.08  0.78  3.44  9.38  0.04  0.43  1.89  5.16  0.  ]
 [ 0.02  0.23  1.25  4.06  9.7   0.13  0.69  2.23  5.41  0.  ]
 [ 0.07  0.43  1.64  4.46  9.88  0.24  0.9   2.47  5.66  0.  ]
 [ 0.14  0.62  1.93  4.7  10.    0.34  1.07  2.65  5.91  0.  ]
 [ 0.23  0.8   2.14  4.85 10.09  0.44  1.2   2.8   6.16  0.  ]
 [ 0.31  0.93  2.28  4.95 10.15  0.52  1.3   2.94  6.41  0.  ]
 [ 0.39  1.04  2.38  5.01 10.21  0.58  1.38  3.07  6.66  0.  ]
 [ 0.45  1.11  2.44  5.06 10.25  0.64  1.46  3.2   6.91  0.  ]
 [ 0.51  1.17  2.49  5.09 10.28  0.68  1.53  3.33  7.16  0.  ]
 [ 0.54  1.2   2.52  5.12 10.31  0.72  1.6   3.45  7.41  0.  ]
 [ 0.57  1.23  2.54  5.14 10.34  0.76  1.66  3.58  7.66

In [61]:
print(aw.prediction_error)

[[ 0.    0.    0.    0.   10.    0.    0.    0.    5.5   0.  ]
 [ 0.    0.    0.    2.5   5.    0.    0.    1.38  2.75  0.  ]
 [ 0.    0.    0.62  2.5   2.5   0.    0.34  1.38  1.38  0.  ]
 [ 0.    0.16  0.94  1.88  1.25  0.09  0.52  1.03  0.69  0.  ]
 [ 0.04  0.31  0.94  1.25  0.65  0.17  0.52  0.69  0.5   0.  ]
 [ 0.1   0.39  0.78  0.79  0.37  0.21  0.43  0.47  0.5   0.  ]
 [ 0.15  0.39  0.59  0.48  0.24  0.21  0.33  0.36  0.5   0.  ]
 [ 0.17  0.34  0.41  0.3   0.17  0.19  0.26  0.3   0.5   0.  ]
 [ 0.17  0.27  0.28  0.19  0.13  0.16  0.2   0.28  0.5   0.  ]
 [ 0.15  0.21  0.19  0.13  0.11  0.13  0.17  0.26  0.5   0.  ]
 [ 0.13  0.15  0.13  0.09  0.09  0.11  0.15  0.26  0.5   0.  ]
 [ 0.1   0.11  0.09  0.07  0.07  0.09  0.14  0.25  0.5   0.  ]
 [ 0.08  0.08  0.06  0.05  0.06  0.08  0.13  0.25  0.5   0.  ]
 [ 0.06  0.05  0.04  0.04  0.05  0.07  0.13  0.25  0.5   0.  ]
 [ 0.04  0.04  0.03  0.03  0.04  0.07  0.13  0.25  0.5   0.  ]
 [ 0.03  0.03  0.02  0.03  0.04  0.07  0.13  0.25  0.5 