In [10]:
from queue import Queue
from threading import Thread
from itertools import permutations

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from matplotlib import pyplot as plt

from environment import Env

In [3]:
data = pd.read_csv("SH_SDS_data_4.csv", index_col=0)
data.current_state = [('000'+str(cs))[-3:] for cs in data.current_state]
prob = pd.read_csv("SH_SDS_transition_matrix.csv", index_col=0)
prob.index = [('000'+str(idx))[-3:] for idx in prob.index]

In [4]:
def sniper(state_1, state_2, env, start_state):
    state = start_state
    done = False
    all_rewards = list()
    while not done:
        if state[0] == state_1:
            state, reward, done, _ = env.step(0)
        elif state[0] == state_2:
            state, reward, done, _ = env.step(1)
        else:
            state, reward, done, _ = env.step(2)
        all_rewards.append(reward)
    return all_rewards

In [17]:
def thread_sniper(pairs, queue):
    env = Env(data, prob, 0, [1000, -500], 1000)
    start = env.reset()
    strategy_rewards = dict()
    bad_pairs = list()
        
    for p1, p2 in pairs:
        temp_list = list()
        for _ in range(10):
            start = env.reset()
            try:
                temp_list.append(sniper(mappings[p1], mappings[p2], env, start)[-1]) 
            except KeyError:
                bad_pairs.append((p1, p2))
        strategy_rewards[(p1, p2)] = temp_list
        
    queue.put(strategy_rewards)

In [None]:
env = Env(data, prob, 0, [1000, -500], 1000)
mappings = env.mapping
pairs = list(permutations(mappings.keys(), 2))
q = Queue()

nthreads = 8
arg_len = len(pairs) // nthreads

threads = list()

for i in range(nthreads):
    t = Thread(target=thread_sniper, args=(pairs[i*arg_len:(i+1)*arg_len], q))
    threads.append(t)

for t in threads:
    t.start()

for t in threads:
    t.join()

In [None]:
env = Env(data, prob, base_reward, [1000, -500], 1000)
start = env.reset()
mappings = env.mapping
strategy_rewards = dict()
pairs = list(permutations(mappings.keys(), 2))
bad_pairs = list()
for s1, s2 in tqdm(pairs):
    temp_list = list()
    for _ in tqdm(range(10), desc=f'({s1}, {s2})', leave=False, position=1):
        start = env.reset()
        try:
            temp_list.append(sniper(mappings[s1], mappings[s2], env, start)[-1]) 
        except KeyError:
            bad_pairs.append((s1, s2))
    strategy_rewards[(s1, s2)] = temp_list

In [None]:
len(bad_pairs)

In [None]:
{k: v for k, v in sorted(strategy_rewards.items(), key=lambda item: -np.mean(item[1]))}

In [None]:
{k: v for k, v in sorted(strategy_rewards.items(), key=lambda item: np.mean(item[1]))}

In [None]:
rewards = []
env = Env(data, prob, 0, [1000, -500], 5000)
mappings = env.mapping
for _ in range(20):
    rewards.append(sniper(mappings['121'], mappings['401'], env, env.reset()))

In [None]:
plt.figure(figsize=(15, 10))
env = Env(data, prob, 0, [1000, -500], 3000)
plt.plot(sniper(mappings['122'], mappings['401'], env, env.reset()))
plt.plot(sniper(mappings['401'], mappings['122'], env, env.reset()), 'r-')