# Inspection for GeneralA2C

New features to include:
- player info
- top 5 actions and their distributions

TODO:
- look at state trajectories, they should be a list of dictionaries {"spatial","player"}
- create wrapper functions for the state plotting if possible, otherwise do it from scratch
- print in a formatted way all player info
- in general might be interesting for some minigames to plot some player info together with either the critic value or the actor adavantages in a step-by-step way
- change the way in which we deal with decision map plotting (?)

In [None]:
import os
import sys
sys.path.insert(0, "../")
# Custom modules
from AC_modules.BatchedA2C import GeneralA2C
from SC_Utils.game_utils import FullObsProcesser
import AC_modules.Networks as net
# change this inspection plots
from SC_Utils.inspection_plots_v2 import *
from SC_Utils.A2C_inspection_v2 import *
from SC_Utils.train_v4 import init_game, inspection_test

import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Define map_name and PID - ALWAYS

game_names = {1:'MoveToBeacon',
              2:'CollectMineralShards',
              3:'DefeatRoaches',
              4:'FindAndDefeatZerglings',
              5:'DefeatZerglingsAndBanelings',
              6:'CollectMineralsAndGas',
              7:'BuildMarines'
              }
map_name = game_names[1]

PID = "KGIA"

# Online plotting of training curve

In [None]:
filename = '../Results/'+map_name+'/Logging/'+PID+'.txt'
logging = np.loadtxt(filename, delimiter=',', skiprows=1)
plot_logging(logging, map_name)

# Checkpoints at a certain step

In [None]:
jump = 60 # number of env step after which an inspection is done
n = 3 # number of "jumps" to current step idx
step_idx =jump*n
step_idx

In [None]:
load_dir = "../Results/"+map_name+"/Inspection/"
insp_dict = np.load(load_dir+PID+"_"+str(step_idx)+".npy", allow_pickle=True).item()

### Trajectory update

In [None]:
plot_update_curves(insp_dict)

## Agent-related visualizations

Here we need either to init a new agent and load a checkpoint or to load directly the full-trained agent if the training cycle is ended (automatic saving of the whole class if using the run.py script to train).

In [None]:
### Environment parameters ###
RESOLUTION = 32
game_params = dict(feature_screen=RESOLUTION, feature_minimap=RESOLUTION, action_space="FEATURES") 
env = init_game(game_params, map_name)

obs_proc_params = {'select_all':True}
op = FullObsProcesser(**obs_proc_params)

screen_channels, minimap_channels, in_player = op.get_n_channels()
in_channels = screen_channels + minimap_channels 

In [None]:
load = False

if load:
    agent = torch.load("../Results/"+map_name+"/agent_"+PID, map_location='cpu')
    agent.device = 'cpu'
else:
    ### Agent architecture parameters ###
    spatial_model = net.FullyConvPlayerAndSpatial
    nonspatial_model = net.FullyConvNonSpatial
    # Internal features, passed inside a dictionary
    conv_channels = 32
    player_features = 16
    # Exposed features, passed outside of a dictionary
    n_channels = 48
    n_features = 256
    spatial_dict = {"in_channels":in_channels, 'in_player':in_player, 
                    'conv_channels':conv_channels, 'player_features':player_features}
    nonspatial_dict = {'resolution':RESOLUTION, 'kernel_size':3, 'stride':2, 'n_channels':n_channels}
    
    ### A2C parameters ###
    HPs = dict(gamma=0.99, n_steps=20, H=1e-2, 
               spatial_model=spatial_model, nonspatial_model=nonspatial_model,
               n_features=n_features, n_channels=n_channels, 
               spatial_dict=spatial_dict, nonspatial_dict=nonspatial_dict)

    if torch.cuda.is_available():
        HPs['device'] = 'cuda'
    else:
        HPs['device'] = 'cpu'

    print("Using device "+HPs['device'])

    lr = 7e-4
    
    # Agent init 
    agent = GeneralA2C(env=env, **HPs)
    # Load proper checkpoint here - not activated in case of load=True, but can be changed if needed
    agent.AC.load_state_dict(torch.load("../Results/"+map_name+"/Checkpoints/"+PID+"_"+str(step_idx), map_location='cpu'))

In [None]:
# inspector needed for the plotting
inspector = inspection_test(step_idx, agent, env, PID, op, agent.AC.action_space)
insp_dict = inspector.dict

In [None]:
_, layer_names = op.get_state(env.reset())
layer_names

In [None]:
#for t in range(len(insp_dict['state_traj'])):
for t in range(50):
    print_action_info(inspector, insp_dict, t)
    #plot_screen_and_decision(inspector, insp_dict, layer_names, t, show_minimap=True)
    #plot_screen_layers(insp_dict, layer_names, t)
    #plot_minimap_layers(insp_dict, layer_names, t)