In [2]:
import json

from collections import Counter
from itertools import permutations, product, chain, count

from threading import Thread
from multiprocessing import Process, cpu_count, Manager, Pool

import numpy as np
import pandas as pd
from tqdm.auto import trange, tqdm
from matplotlib import pyplot as plt

from environment import Env

In [3]:
data = pd.read_csv("SH_SDS_data_4.csv", index_col=0)
data.current_state = [('000'+str(cs))[-3:] for cs in data.current_state]
prob = pd.read_csv("SH_SDS_transition_matrix.csv", index_col=0)
prob.index = [('000'+str(idx))[-3:] for idx in prob.index]

with open('result.json') as json_file:
    strategy_rewards = json.load(json_file)

ordered_rewards = {k: np.mean(v) for k, v in sorted(strategy_rewards.items(), key=lambda item: -np.mean(item[1]))}

In [5]:
top_ten_long = list()
for key in list(ordered_rewards.keys()):
    if len(top_ten_long) == 5:
        break
    elif key[2:5] not in top_ten_long:
        top_ten_long.append(key[2:5])
    
for key in list(ordered_rewards.keys())[::-1]:
    if len(top_ten_long) == 10:
        break
    elif key[9:12] not in top_ten_long:
        top_ten_long.append(key[9:12])
    
bottom_ten_short = list()
for key in list(ordered_rewards.keys()):
    if len(bottom_ten_short) == 5:
        break
    elif key[9:12] not in bottom_ten_short:
        bottom_ten_short.append(key[9:12])
    
for key in list(ordered_rewards.keys())[::-1]:
    if len(bottom_ten_short) == 10:
        break
    elif key[2:5] not in bottom_ten_short:
        bottom_ten_short.append(key[2:5])

In [6]:
all_lists = list()
for len1 in range(1, 7):
    long_states = top_ten_long[:len1]
    for len2 in range(1, 7):
        short_states = bottom_ten_short[:len2]
        all_lists.append((long_states, short_states))

In [None]:
def sniper_list(states_1, states_2, env, start_state):
    state = start_state
    done = False
    all_rewards = list()
    while not done:
        if int(state[0].item()) in states_1:
            state, reward, done, _ = env.step(0)
        elif int(state[0].item()) in states_2:
            state, reward, done, _ = env.step(1)
        else:
            state, reward, done, _ = env.step(2)
        all_rewards.append(reward)
    return all_rewards

In [None]:
def imap_sniper(list_of_input_tuples):
    env = Env(data, prob, steps=2000)
    mappings = env.mapping
    
    long = [mappings[s] for s in list_of_input_tuples[0]]
    short = [mappings[s] for s in list_of_input_tuples[1]]
    rew = list()
    for _ in trange(10, leave=False, position=1):
        rew.append(sniper_list(long, short, env, env.reset()))
    final = np.array(rew)
    
    strategy_rewards[str(list_of_input_tuples)] = final[:, -1].mean()

In [12]:
num_procs = cpu_count()-1
chunk = 1
print('Number of proccesses:', num_procs)
procs = list()
manager = Manager()
strategy_rewards = manager.dict()

print('Creating pool')

pool = Pool(processes=num_procs)

print('Starting processing')

for _ in tqdm(pool.imap_unordered(imap_sniper, all_lists, chunksize=chunk), total=len(all_lists)//chunk):
    pass

strategy_rewards = dict(strategy_rewards)
with open('selective_sniper.json', 'w') as file:
    json.dump(strategy_rewards, file)

Number of proccesses: 15
Creating pool
Starting processing


HBox(children=(FloatProgress(value=0.0, max=36.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




TypeError: Object of type DictProxy is not JSON serializable

In [None]:
try:
    strategy_rewards
except NameError:
    import json
    with open('selective_sniper.json') as file:
        strategy_rewards = json.load(file)

In [16]:
{k: v for k, v in sorted(strategy_rewards.items(), key=lambda item: -item[1])}

{"(['122', '111'], ['401', '400'])": 29.974609132586686,
 "(['122', '111'], ['401'])": 29.876032265813926,
 "(['122', '111'], ['401', '400', '210'])": 25.864675251198758,
 "(['122', '111', '502'], ['401', '400'])": 24.786415304914396,
 "(['122', '111', '502'], ['401'])": 24.6158441883519,
 "(['122', '111', '502', '402', '412', '121'], ['401'])": 23.081572279289436,
 "(['122', '111', '502', '402', '412', '121'], ['401', '400'])": 22.98482327201328,
 "(['122', '111', '502', '402'], ['401'])": 21.991270515708475,
 "(['122', '111', '502', '402'], ['401', '400'])": 21.95988320427346,
 "(['122', '111', '502', '402', '412'], ['401'])": 19.161891013775467,
 "(['122', '111', '502', '402', '412'], ['401', '400'])": 19.13620133753137,
 "(['122', '111'], ['401', '400', '210', '110'])": 17.687100060232495,
 "(['122', '111'], ['401', '400', '210', '110', '211'])": 17.100378021771615,
 "(['122', '111', '502'], ['401', '400', '210'])": 16.274406914803045,
 "(['122', '111', '502', '402'], ['401', '400'