In [1]:
import json
from itertools import permutations
from multiprocessing import cpu_count, Manager, Pool

import numpy as np
import pandas as pd
from tqdm.auto import tqdm, trange
from matplotlib import pyplot as plt

# Disables no GPU warning
# Comment out to see warning or use GPU
import jax
jax.config.update('jax_platform_name', 'cpu')

from environment import Env

In [2]:
data = pd.read_csv("SH_SDS_data_4.csv", index_col=0)
data.current_state = [('000'+str(cs))[-3:] for cs in data.current_state]
prob = pd.read_csv("SH_SDS_transition_matrix.csv", index_col=0)
prob.index = [('000'+str(idx))[-3:] for idx in prob.index]

In [3]:
def sniper(state_1, state_2, env, start_state):
    state = start_state
    done = False
    all_rewards = list()
    while not done:
        if state[0] == state_1:
            state, reward, done, _ = env.step(0)
        elif state[0] == state_2:
            state, reward, done, _ = env.step(1)
        else:
            state, reward, done, _ = env.step(2)
        all_rewards.append(reward)
    return all_rewards

In [4]:
def imap_sniper(states):
    env = Env(data, prob, steps=2000)
    mappings = env.mapping
    
    long = mappings[states[0]]
    short = mappings[states[1]]
    
    rewards = list()
    for _ in trange(5, leave=False, position=1):
        rewards.append(sniper(long, short, env, env.reset()))
        
    final = np.array(rewards)
    
    strategy_rewards[str((states[0], states[1]))] = final[:, -1].mean()

In [5]:
# This takes a while to run, ~2.5 hrs on 15 processes

pairs = list(permutations(Env.get_mapping().keys(), 2))

num_procs = cpu_count()-1
chunk = 1
print('Number of processes:', num_procs)

procs = list()
manager = Manager()
strategy_rewards = manager.dict()

print('Creating pool')
pool = Pool(processes=num_procs)

print('Start processing')
for _ in tqdm(pool.imap_unordered(imap_sniper, pairs, chunksize=chunk), total=len(pairs)//chunk):
    pass

strategy_rewards = dict(strategy_rewards)

# Order strategy_rewards from largest to smallest
strategy_rewards = {k: v for k, v in sorted(strategy_rewards.items(), key=lambda item: -item[1])}
with open('sniper_pairs.json', 'w') as file:
    json.dump(strategy_rewards, file)

Number of processes: 15
Creating pool
Start processing


HBox(children=(FloatProgress(value=0.0, max=2862.0), HTML(value='')))




In [6]:
try:
    strategy_rewards
except NameError:
    import json
    with open('sniper_pairs.json') as file:
        strategy_rewards = json.load(file)

In [7]:
key = list(strategy_rewards.keys())[0]
s1 = key[2:5]
s2 = key[9:12]
rewards = []
env = Env(data, prob, steps=2000)
mappings = env.mapping
plt.figure(figsize=(15, 10))
for _ in range(10):
    plt.plot(sniper(mappings[s1], mappings[s2], env, env.reset()), 'g-')
    plt.plot(sniper(mappings[s2], mappings[s1], env, env.reset()), 'r*')
    
plt.title(f'({s1}, {s2})')

State: (6, , )


KeyError: '6, '

<Figure size 1080x720 with 0 Axes>

In [None]:
key = list(strategy_rewards.keys())[-1]
s1 = key[9:12]
s2 = key[2:5]
rewards = []
env = Env(data, prob, steps=2000)
mappings = env.mapping
plt.figure(figsize=(15, 10))
for _ in range(10):
    plt.plot(sniper(mappings[s1], mappings[s2], env, env.reset()), 'g-')
    plt.plot(sniper(mappings[s2], mappings[s1], env, env.reset()), 'r*')
    
plt.title(f'({s1}, {s2})')

In [None]:
key = list(strategy_rewards.keys())[1500]
s1 = key[2:5]
s2 = key[9:12]
print(f'State: ({s1}, {s2})')
rewards = []
env = Env(data, prob, steps=2000)
mappings = env.mapping
plt.figure(figsize=(15, 10))
for _ in range(10):
    plt.plot(sniper(mappings[s1], mappings[s2], env, env.reset()), 'g-')
    plt.plot(sniper(mappings[s2], mappings[s1], env, env.reset()), 'r*')
    
plt.title(f'({s1}, {s2})')