# W track modeling

- observation space is non-repeating two-well combinations of the 3 wells = 6 states 
- action space is the 3 wells 
- reward table would be the observation space x action space with values reinforcing alternation rule

In [431]:
import itertools
import random
import numpy as np

In [387]:
all_wells = ('A', 'B', 'C')
outer_wells = ('A', 'C')
home_well = ('B')

In [388]:
# all possible two-well non-repeating combinations
observation_space = list(itertools.permutations(all_wells, r=2))

# the possible actions at any state
action_space = all_wells

print('observation space {}\n'.format(observation_space))
print('action space {}'.format(action_space))

observation space [('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'C'), ('C', 'A'), ('C', 'B')]

action space ('A', 'B', 'C')


In [389]:
# get rewarded sequences to populate reward table
sequences = list(itertools.product(all_wells, repeat=3))

outbound_rewarded_sequences = []
inbound_rewarded_sequences = []
for seq in sequences:
    if seq[0] in outer_wells and seq[2] in outer_wells and seq[0] != seq[2] and seq[1] in home_well:
        outbound_rewarded_sequences.append(seq)
    if seq[1] in outer_wells and seq[0] != seq[1] and seq[2] in home_well:
        inbound_rewarded_sequences.append(seq)

print('outbound rewarded sequences {}\n'.format(outbound_rewarded_sequences))
print('inbound rewarded sequences {}\n'.format(inbound_rewarded_sequences))

outbound rewarded sequences [('A', 'B', 'C'), ('C', 'B', 'A')]

inbound rewarded sequences [('A', 'C', 'B'), ('B', 'A', 'B'), ('B', 'C', 'B'), ('C', 'A', 'B')]



In [390]:
def make_reward_table():
    reward_table = {}
    for obsv in observation_space:
        current_well = obsv[-1]
        reward_table[obsv] = {}
        for act in action_space:
            reward_table[obsv][act] = {}
            if act == current_well:
                # if the action is the same as the most recent well, 
                # loop back obvs state and draw penalty
                reward_table[obsv][act]['next_state'] = obsv
            else:
                reward_table[obsv][act]['next_state'] = tuple([current_well, act])
            # set reward value
            if obsv + tuple(act) in outbound_rewarded_sequences:
                reward_table[obsv][act]['reward'] = 1
            elif obsv + tuple(act) in inbound_rewarded_sequences:
                reward_table[obsv][act]['reward'] = 1
            else:
                reward_table[obsv][act]['reward'] = 0
    return reward_table

In [391]:
reward_table = make_reward_table()

In [392]:
# initialize q table to zeros
def init_q_table():
    q_table = {}
    for obsv in observation_space:
        q_table[obsv] = {}
        for act in action_space:
            q_table[obsv][act] = 0
    return q_table

In [427]:
def check_q_table(q_table):
    #quick check if q table, followed greedily without further updates, is consistent with W rules
    q_table_max_sequences = [tuple(list(state) + list(max(actions, key=lambda act: actions[act]))) for state, actions in q_table.items()]
#     print(q_table_max_sequences)
    check = set(q_table_max_sequences).issubset(outbound_rewarded_sequences+inbound_rewarded_sequences)
    if check:
        print('q_table consistent with W-TRACK rules')
    else:
        print('q_table NOT consistent with W-TRACK rules')
    return 

In [428]:
%%time

alpha = 0.4
gamma = 0.7
epsilon = 0.5

q_table = init_q_table()
state = random.choice(observation_space)
performance_indicator_func = []

for i in range(1, 100):
    # act
    if random.uniform(0, 1) < epsilon:
        # Explore action space
        action = random.choice(all_wells)
    else:
        action = max(q_table[state], key=lambda act: q_table[state][act])
    # gather
    reward = reward_table[state][action]['reward']
    old_qvalue = q_table[state][action]
    next_state = reward_table[state][action]['next_state']
    next_max_action = max(q_table[next_state], key=lambda act: q_table[next_state][act])
    next_max_qvalue = q_table[next_state][next_max_action]
    # update
    new_value = (1 - alpha) * old_qvalue + alpha * (reward + gamma * next_max_qvalue)
    q_table[state][action] = round(new_value,4)
    # iterate
    state = reward_table[state][action]['next_state']
    #save
    performance_indicator_func.append(reward)


CPU times: user 892 µs, sys: 96 µs, total: 988 µs
Wall time: 1.02 ms


In [429]:
q_table

{('A', 'B'): {'A': 1.2692, 'B': 1.253, 'C': 2.6036},
 ('A', 'C'): {'A': 0.6428, 'B': 1.5402, 'C': 0.4313},
 ('B', 'A'): {'A': 0.1823, 'B': 2.6246, 'C': 0.7329},
 ('B', 'C'): {'A': 0.3312, 'B': 2.6333, 'C': 1.3236},
 ('C', 'A'): {'A': 0, 'B': 1.9477, 'C': 0},
 ('C', 'B'): {'A': 2.5285, 'B': 0.8155, 'C': 0.6145}}

In [430]:
check_q_table(q_table)

q_table consistent with W-TRACK rules


# This wasn't supposed to work!? 