In [1]:
import json

from multiprocessing import cpu_count, Manager, Pool

import numpy as np
import pandas as pd
from tqdm.auto import trange, tqdm
from matplotlib import pyplot as plt

# Disables no GPU warning
# Comment out to see warning or use GPU
import jax
jax.config.update('jax_platform_name', 'cpu')

from environment import Env

In [2]:
data = pd.read_csv("SH_SDS_data_4.csv", index_col=0)
data.current_state = [('000'+str(cs))[-3:] for cs in data.current_state]
prob = pd.read_csv("SH_SDS_transition_matrix.csv", index_col=0)
prob.index = [('000'+str(idx))[-3:] for idx in prob.index]

with open('sniper_pairs.json') as json_file:
    strategy_rewards = json.load(json_file)

ordered_rewards = {k: np.mean(v) for k, v in sorted(strategy_rewards.items(), key=lambda item: -np.mean(item[1]))}

In [3]:
entries = 16

top_ten_long = list()
for key in list(ordered_rewards.keys()):
    if len(top_ten_long) == entries//2:
        break
    elif key[2:5] not in top_ten_long:
        top_ten_long.append(key[2:5])
    
for key in list(ordered_rewards.keys())[::-1]:
    if len(top_ten_long) == entries:
        break
    elif key[9:12] not in top_ten_long:
        top_ten_long.append(key[9:12])
    
bottom_ten_short = list()
for key in list(ordered_rewards.keys()):
    if len(bottom_ten_short) == entries//2:
        break
    elif key[9:12] not in bottom_ten_short:
        bottom_ten_short.append(key[9:12])
    
for key in list(ordered_rewards.keys())[::-1]:
    if len(bottom_ten_short) == entries:
        break
    elif key[2:5] not in bottom_ten_short:
        bottom_ten_short.append(key[2:5])
        
print(top_ten_long)
print(bottom_ten_short)

['122', '111', '502', '402', '412', '400', '112', '120', '121', '221', '310', '210', '200', '311', '312', '411']
['401', '400', '210', '110', '211', '411', '121', '412', '502', '301', '212', '310', '300', '020', '010', '402']


In [4]:
for val in top_ten_long:
    if val in bottom_ten_short:
        top_ten_long.remove(val)
        bottom_ten_short.remove(val)
        
for val in bottom_ten_short:
    if val in top_ten_long:
        top_ten_long.remove(val)
        bottom_ten_short.remove(val)
        
print(top_ten_long)
print(bottom_ten_short)

['122', '111', '112', '120', '221', '210', '200', '311', '312']
['401', '210', '110', '211', '301', '212', '300', '020', '010']


In [5]:
all_lists = list()
for len1 in range(1, 7):
    long_states = top_ten_long[:len1]
    for len2 in range(1, 7):
        short_states = bottom_ten_short[:len2]
        all_lists.append((long_states, short_states))

In [6]:
def sniper_list(states_1, states_2, env, start_state):
    state = start_state
    done = False
    all_rewards = list()
    while not done:
        if int(state[0].item()) in states_1:
            state, reward, done, _ = env.step(0)
        elif int(state[0].item()) in states_2:
            state, reward, done, _ = env.step(1)
        else:
            state, reward, done, _ = env.step(2)
        all_rewards.append(reward)
    return all_rewards

In [7]:
def imap_sniper_list(list_of_input_tuples):
    env = Env(data, prob, steps=2000)
    mappings = env.mapping
    
    long = [mappings[s] for s in list_of_input_tuples[0]]
    short = [mappings[s] for s in list_of_input_tuples[1]]
    rewards = list()
    for _ in trange(10, leave=False, position=1):
        rewards.append(sniper_list(long, short, env, env.reset()))
    
    final = np.array(rewards)
    
    strategy_rewards[str(list_of_input_tuples)] = final[:, -1].mean()

In [8]:
num_procs = cpu_count()-1
chunk = 1
print('Number of processes:', num_procs)

procs = list()
manager = Manager()
strategy_rewards = manager.dict()

print('Creating pool')
pool = Pool(processes=num_procs)

print('Start processing')
for _ in tqdm(pool.imap_unordered(imap_sniper_list, all_lists, chunksize=chunk), total=len(all_lists)//chunk):
    pass

strategy_rewards = dict(strategy_rewards)
with open('selective_sniper.json', 'w') as file:
    json.dump(strategy_rewards, file)

Number of processes: 15
Creating pool
Start processing


HBox(children=(FloatProgress(value=0.0, max=36.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




In [9]:
try:
    strategy_rewards
except NameError:
    import json
    with open('selective_sniper.json') as file:
        strategy_rewards = json.load(file)

In [10]:
{k: v for k, v in sorted(strategy_rewards.items(), key=lambda item: -item[1])}

{"(['122', '111', '112', '120', '221', '210'], ['401'])": 33.111564672819085,
 "(['122', '111', '112', '120', '221', '210'], ['401', '210'])": 33.111564672819085,
 "(['122', '111', '112', '120', '221'], ['401'])": 32.894692980558354,
 "(['122', '111', '112', '120'], ['401'])": 31.330164300750123,
 "(['122', '111', '112', '120', '221', '210'], ['401', '210', '110', '211', '301', '212'])": 28.974188931609394,
 "(['122', '111', '112', '120', '221', '210'], ['401', '210', '110', '211', '301'])": 28.934693074742547,
 "(['122', '111', '112', '120', '221'], ['401', '210'])": 26.776223965081453,
 "(['122', '111', '112', '120'], ['401', '210'])": 26.44736883735852,
 "(['122', '111', '112', '120', '221', '210'], ['401', '210', '110'])": 26.346161790979664,
 "(['122', '111'], ['401'])": 25.57925792033755,
 "(['122', '111', '112'], ['401'])": 25.57925792033755,
 "(['122', '111', '112', '120', '221', '210'], ['401', '210', '110', '211'])": 25.1075768246358,
 "(['122', '111', '112', '120', '221'], [