In [1]:
from Utils.train_batched_A2C import *
from Utils.inspection_plots import *
from AC_modules.BatchedA2C import SpatialA2C
import AC_modules.Networks as net
import torch

from Env.test_env import Sandbox

In [2]:
RESOLUTION = 16
game_params = dict(res=RESOLUTION)

In [3]:
model_number = 1
model_names = {0:"Critic-semi-conv", 1:'shared-architecture'}

if model_number == 0:
    print(model_names[model_number]+" selected.")
    actor_model = net.SpatialNet
    critic_model = net.CriticNet
    actor_dict = {"in_channels":2, 'linear_size':RESOLUTION}
    critic_dict = {"in_channels":2, 'linear_size':RESOLUTION}
    
    HPs = dict(gamma=0.99, n_steps=5, H=1e-3, 
           actor_model=actor_model, critic_model=critic_model,
           actor_dict=actor_dict, critic_dict=critic_dict)
elif model_number == 1:
    print(model_names[model_number]+" selected.")
    n_channels = 32
    spatial_model = net.SharedNet
    spatial_dict = {'in_channels':2, 'n_channels':n_channels}
    shared_act_dict = {'n_channels':n_channels, 'linear_size':RESOLUTION}
    shared_crit_dict = {'n_channels':n_channels, 'linear_size':RESOLUTION}
    HPs = dict(gamma=0.99, n_steps=5, H=1e-3, shared=True,
               spatial_model=spatial_model, spatial_dict=spatial_dict, 
               shared_act_dict=shared_act_dict, shared_crit_dict=shared_crit_dict)
else:
    print("Model number not available. Choose 0.")

shared-architecture selected.


In [4]:
if torch.cuda.is_available():
    HPs['device'] = 'cuda'
else:
    HPs['device'] = 'cpu'
    
print("Using device "+HPs['device'])

lr = 1e-4
agent = SpatialA2C(**HPs)

Using device cuda


In [5]:
unroll_length = 240
MAX_STEPS = 250
train_dict = dict(n_train_processes = 11,
                  max_train_steps = unroll_length*5000,
                  unroll_length = unroll_length,
                  max_episode_steps = MAX_STEPS,
                  test_interval = unroll_length*10 
                  )

In [6]:
%%time
results = train_batched_A2C(agent, game_params, lr, **train_dict)

Process ID:  FCSE
Step # : 2400, avg score : 0.8
Step # : 4800, avg score : 0.6
Step # : 7200, avg score : 0.6
Step # : 9600, avg score : 0.0
Step # : 12000, avg score : 0.4
Step # : 14400, avg score : 0.4
Step # : 16800, avg score : 0.2
Step # : 19200, avg score : 0.4
Step # : 21600, avg score : 0.2
Step # : 24000, avg score : 0.8
Step # : 26400, avg score : 0.4
Step # : 28800, avg score : 0.0
Step # : 31200, avg score : 0.2
Step # : 33600, avg score : 0.4
Step # : 36000, avg score : 0.4
Step # : 38400, avg score : 0.8
Step # : 40800, avg score : 1.0
Step # : 43200, avg score : 0.4
Step # : 45600, avg score : 0.2
Step # : 48000, avg score : 0.8
Step # : 50400, avg score : 0.4
Step # : 52800, avg score : 0.2
Step # : 55200, avg score : 1.0
Step # : 57600, avg score : 0.2
Step # : 60000, avg score : 0.4
Step # : 62400, avg score : 0.6
Step # : 64800, avg score : 0.8
Step # : 67200, avg score : 0.6
Step # : 69600, avg score : 1.0
Step # : 72000, avg score : 0.4
Step # : 74400, avg score 

Step # : 588000, avg score : 26.0
Step # : 590400, avg score : 27.6
Step # : 592800, avg score : 25.6
Step # : 595200, avg score : 25.4
Step # : 597600, avg score : 25.4
Step # : 600000, avg score : 25.0
Step # : 602400, avg score : 26.0
Step # : 604800, avg score : 24.8
Step # : 607200, avg score : 25.4
Step # : 609600, avg score : 25.0
Step # : 612000, avg score : 27.6
Step # : 614400, avg score : 23.8
Step # : 616800, avg score : 26.4
Step # : 619200, avg score : 26.2
Step # : 621600, avg score : 25.8
Step # : 624000, avg score : 25.0
Step # : 626400, avg score : 23.0
Step # : 628800, avg score : 26.0
Step # : 631200, avg score : 24.4
Step # : 633600, avg score : 27.2
Step # : 636000, avg score : 25.8
Step # : 638400, avg score : 25.0
Step # : 640800, avg score : 24.6
Step # : 643200, avg score : 23.4
Step # : 645600, avg score : 26.0
Step # : 648000, avg score : 23.8
Step # : 650400, avg score : 26.0
Step # : 652800, avg score : 25.6
Step # : 655200, avg score : 24.0
Step # : 65760

Step # : 1161600, avg score : 25.6
Step # : 1164000, avg score : 26.4
Step # : 1166400, avg score : 26.4
Step # : 1168800, avg score : 24.4
Step # : 1171200, avg score : 27.2
Step # : 1173600, avg score : 26.4
Step # : 1176000, avg score : 27.2
Step # : 1178400, avg score : 27.0
Step # : 1180800, avg score : 26.6
Step # : 1183200, avg score : 24.2
Step # : 1185600, avg score : 24.2
Step # : 1188000, avg score : 25.0
Step # : 1190400, avg score : 25.0
Step # : 1192800, avg score : 27.2
Step # : 1195200, avg score : 28.4
Step # : 1197600, avg score : 22.6
Step # : 1200000, avg score : 24.0
CPU times: user 3h 14min 50s, sys: 5min 54s, total: 3h 20min 44s
Wall time: 1h 16min 23s


In [7]:
score, losses, trained_agent, PID = results

In [8]:
#from Utils import utils
save = True
keywords = ["shared",'lr-1e-4','5-steps',"1.2M-env-steps","240-unroll-len",'working!'] 

if save:
    save_dir = 'Results/'
    keywords.append(PID)
    filename = '_'.join(keywords)
    filename = 'S_'+filename
    print("Save at "+save_dir+filename)
    train_session_dict = dict(game_params=game_params, HPs=HPs, score=score, n_epochs=len(score), keywords=keywords, losses=losses)
    np.save(save_dir+filename, train_session_dict)
    torch.save(trained_agent, save_dir+"agent_"+PID)
else:
    print("Nothing saved")
    pass

Save at Results/S_shared_lr-1e-4_5-steps_1.2M-env-steps_240-unroll-len_working!_FCSE


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [None]:
step_idx = 0
test_env = Sandbox(max_steps=MAX_STEPS, **game_params)
PID = 'prova'
R, insp = inspection_test(step_idx, agent, test_env, PID)
insp_dict = insp.dict
for t in range(len(insp_dict['state_traj'])):
    plot_state(insp_dict, t)
    plt.show()