# A2C derivate classes

In [1]:
import sys
sys.path.insert(0, "../")
from SC_Utils.game_utils import ObsProcesser, get_action_dict
from SC_Utils.train_v3 import *
from AC_modules.BatchedA2C import SpatialA2C, SpatialA2C_v1, SpatialA2C_v2, SpatialA2C_v3, SpatialA2C_MaxEnt
import AC_modules.Networks as net
import torch

pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# Environment parameters
RESOLUTION = 32
game_params = dict(feature_screen=RESOLUTION, feature_minimap=RESOLUTION, action_space="FEATURES") 
game_names = {1:'MoveToBeacon',
              2:'CollectMineralShards',
              3:'DefeatRoaches',
              4:'FindAndDefeatZerglings',
              5:'DefeatZerglingsAndBanelings',
              6:'CollectMineralsAndGas',
              7:'BuildMarines'
              }
map_name = game_names[1]

# Observation Processer parameters
#screen_names = ['visibility_map', 'player_relative', 'selected', 'unit_density', 'unit_density_aa']
#minimap_names = []
#obs_proc_params = {'screen_names':screen_names, 'minimap_names':minimap_names}
obs_proc_params = {'select_all':True}

In [3]:
env = init_game(game_params, map_name)
op = ObsProcesser(**obs_proc_params)
screen_channels, minimap_channels = op.get_n_channels()
in_channels = screen_channels + minimap_channels 

In [4]:
action_names = ['select_army', 'Move_screen','Attack_screen']
#action_names = ['select_army','Move_screen']
#action_names = ['select_army', 'Attack_screen', 'Move_screen', 'select_point', 'select_rect',
#                'move_camera','Stop_quick','Move_minimap','Attack_minimap','HoldPosition_quick']
action_dict = get_action_dict(action_names)
action_space = len(action_dict)

In [5]:
spatial_model = net.FullyConvSpatial
nonspatial_model = net.FullyConvNonSpatial
embed_dim = 8
n_channels = 32
n_features = 256
spatial_dict = {"in_channels":in_channels}
nonspatial_dict = {'resolution':RESOLUTION, 'kernel_size':3, 'stride':2}

In [6]:
HPs = dict(action_space=action_space, gamma=0.99, n_steps=20, H=1e-2, 
           spatial_model=spatial_model, nonspatial_model=nonspatial_model,
           n_features=n_features, n_channels=n_channels, 
           spatial_dict=spatial_dict, nonspatial_dict=nonspatial_dict, 
           action_dict=action_dict)

if torch.cuda.is_available():
    HPs['device'] = 'cuda'
else:
    HPs['device'] = 'cpu'
    
print("Using device "+HPs['device'])

lr = 7e-4

Using device cuda


In [7]:
version = 2
if version == 1:
    HPs = {**HPs, 'embed_dim':embed_dim}
    agent = SpatialA2C_v1(env=env, **HPs)
elif version == 2:
    # no action embedding
    agent = SpatialA2C_v2(env=env, **HPs)
elif version == 3:
    agent = SpatialA2C_v3(env=env, **HPs)
elif version == 4:
    # no action embedding
    agent = SpatialA2C_MaxEnt(env=env, **HPs)
else:
    raise Exception("Version not implemented.")

In [8]:
unroll_length = 60

train_dict = dict(n_train_processes = 11,
                  max_train_steps = unroll_length*1000,#60000,
                  unroll_length = unroll_length,
                  test_interval = unroll_length*10,
                  inspection_interval = unroll_length*10
                  )

In [9]:
%%time
results = train_batched_A2C(agent, game_params, map_name, lr, 
                            obs_proc_params=obs_proc_params, action_dict=action_dict, **train_dict)

Process ID:  EOEV
CPU times: user 4h 40min 12s, sys: 13min 6s, total: 4h 53min 19s
Wall time: 56min 6s


In [10]:
score, losses, trained_agent, PID = results

In [11]:
from Utils import utils
save = True
keywords = ['A2C', 'CMS','embed-action',"conv-net",'lr-7e-4','20-steps', '32x32',"1.2M-env-steps","120-unroll-len",'7-channels', 'select-point'] 

if save:
    save_dir = '../Results/CollectMineralShards/'
    keywords.append(PID)
    filename = '_'.join(keywords)
    filename = 'S_'+filename
    print("Save at "+save_dir+filename)
    train_session_dict = dict(game_params=game_params, HPs=HPs, score=score, n_epochs=len(score), keywords=keywords, losses=losses)
    np.save(save_dir+filename, train_session_dict)
    torch.save(trained_agent, save_dir+"agent_"+PID)
else:
    print("Nothing saved")
    pass

Save at ../Results/CollectMineralShards/S_A2C_CMS_embed-action_conv-net_lr-7e-4_20-steps_32x32_1.2M-env-steps_120-unroll-len_7-channels_select-point_VYPP


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


# FiLM layer for action conditioning

How I condition actions right now:
- action = (b,)
- embedded_a = (b,d), d is embedding dimension (e.g. 8)
- for nonspatial features (b, n_features) just concatenate the embedded action: (b, n_features+d)
- for spatial features (b, n_channels, L, L) broadcast (actually repeat) the embedded action to (b,d,L,L) and then concatenate along the channel dimension: (b, n_channels+d, L, L) 

Then I use them as always, so processing from 1 or 2 layers before applying the softmax and sampling the argument.

With FiLM for spatial arguments:
- still use embedding
- use ReLU + linear layer to extract $\gamma_c$ and $\beta_c$ for each channel c in n_channels
- transform the spatial features as $\gamma_c F_c(x) + \beta_c$
- apply final convolution layer
- apply softmax

Same thing can be done for nonspatial arguments, assuming them as 1x1 images. So in this case the FiLM layer would be much more expressive, because it can achieve any possible output by means of scaling and shifting, so maybe is no more a good idea and a simple concatenation of the embedding action could do.

# Tests

In [16]:
class C1():
    def __init__(self):
        self._init_stuff()
        self.C = 10
        
    def _init_stuff(self):
        self.N = 10

In [17]:
class C2(C1):
    def __init__(self):
        self.embed = 10
        super().__init__()
        print(self.N)
        print(self.C)
        
    def _init_stuff(self):
        self.N = 10 + self.embed

In [18]:
C2()

20
10


<__main__.C2 at 0x7fa00f0bbf50>

In [16]:
B = 2
n_channels = 12
n_features = 256
embed_dim = 16
res = 32
spatial = torch.rand(B, n_channels, res, res)
nonspatial = torch.rand(B, n_features)
embedded_a = torch.rand(B, embed_dim)


In [29]:
print('embedded_a.shape; ', embedded_a.shape)
spatial_a = embedded_a.reshape((embedded_a.shape[:2]+(1,1,))).repeat(1,1,res,res)
for i in range(res):
    for j in range(res):
        print(torch.all(spatial_a[:,:,i,j] == embedded_a)) # copied correctly if True

embedded_a.shape;  torch.Size([2, 16])
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)

In [30]:
cat_spatial = torch.cat([spatial, spatial_a], dim=1)

In [31]:
cat_spatial.shape

torch.Size([2, 28, 32, 32])

In [32]:
cat_nonspatial = torch.cat([nonspatial, embedded_a], dim=1)
cat_nonspatial.shape

torch.Size([2, 272])