In [2]:
import sys
sys.path.insert(0, "../")
import torch
from SC_Utils.train_v5 import *
from SC_Utils.game_utils import FullObsProcesser
import AC_modules.Networks as net
from AC_modules.BatchedA2C import *
import torch
import numpy as np

# dev modules
from AC_modules.ActorCriticArchitecture import *
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from pysc2.lib import actions as sc_actions

In [2]:
debug = True

class ParallelActorCritic(nn.Module):
    """
    Used in FullSpaceA2C_v2
    
    Description of some attributes:
    - action_table: numpy array of shape (n_actions,) 
        is a look-up table that associates an action index to its StarCraft action id
    - spatial_arg_mask: numpy array of shape (n_actions, n_spatial_args) 
        spatial_arg_mask[a] is a mask telling which of the n_spatial_args sampled args
        belong to action `a`. Same thing for categorical_arg_mask
    """
    def __init__(self, env, spatial_model, nonspatial_model, spatial_dict, nonspatial_dict, 
                 n_features, n_channels, action_names):
        super(ParallelActorCritic, self).__init__()
        
        self.action_names = action_names
        self._set_action_table() # creates self.action_table
        self.screen_res = env.observation_spec()[0]['feature_screen'][1:]
        self.all_actions = env.action_spec()[0][1]
        self.all_arguments = env.action_spec()[0][0]
        
        # Useful HyperParameters as attributes
        self.n_features = n_features
        self.n_channels = n_channels
        action_space = len(action_names)
        
        # Networks
        self.spatial_features_net = spatial_model(**spatial_dict)
        self.nonspatial_features_net = nonspatial_model(**nonspatial_dict) 
        self.actor = SharedActor(action_space, n_features)
        self.critic = SharedCritic(n_features)
        self._init_arg_names()
        self._set_spatial_arg_mask()
        self._set_categorical_arg_mask()
        self._init_params_nets()
    
    def _set_action_table(self):
        action_ids = [sc_actions.FUNCTIONS[a_name].id for a_name in self.action_names]
        action_table = np.array([action_ids[i] for i in range(len(action_ids))])
        self.action_table = action_table
    
    def _init_arg_names(self):
        spatial_arg_names = []
        categorical_arg_names = []
        categorical_sizes = []
        act_to_arg_names = {}

        for action_id in self.action_table:
            action = self.all_actions[action_id]
            args = action.args
            act_to_arg_names[action_id] = [str(action.name)+"/"+arg.name for arg in args]
            spatial = []
            categorical = []
            for arg in args:
                arg_name = str(action.name)+"/"+arg.name
                size = self.all_arguments[arg.id].sizes
                if len(size) == 1:
                    categorical.append(arg_name)
                    categorical_sizes.append(size[0])
                else:
                    spatial.append(arg_name)
            spatial_arg_names+=spatial
            categorical_arg_names+=categorical
    
        self.spatial_arg_names = spatial_arg_names
        self.n_spatial_args = len(spatial_arg_names)
        self.categorical_arg_names = categorical_arg_names
        self.n_categorical_args = len(categorical_arg_names)
        self.categorical_sizes = np.array(categorical_sizes)
        self.act_to_arg_names = act_to_arg_names 

    def _set_spatial_arg_mask(self):
        spatial_arg_mask = np.zeros((self.action_table.shape[0], self.n_spatial_args))
        for i, action_id in enumerate(self.action_table):
            action_arg_names = self.act_to_arg_names[action_id]
            spatial_arg_mask[i] = np.array([1 if self.spatial_arg_names[j] in action_arg_names else 0 \
                                            for j in range(self.n_spatial_args)])
        self.spatial_arg_mask = spatial_arg_mask
    
    def _set_categorical_arg_mask(self):
        categorical_arg_mask = np.zeros((self.action_table.shape[0], self.n_categorical_args))
        for i, action_id in enumerate(self.action_table):
            action_arg_names = self.act_to_arg_names[action_id]
            categorical_arg_mask[i] = np.array([1 if self.categorical_arg_names[j] in action_arg_names else 0 \
                                            for j in range(self.n_categorical_args)])
        self.categorical_arg_mask = categorical_arg_mask

    def _init_params_nets(self):
        n_arguments = len(self.spatial_arg_names)
        self.spatial_params_net = ParallelSpatialParameters(self.n_channels, self.screen_res[0], n_arguments)
        self.categorical_params_net = ParallelCategoricalNet(self.n_features, self.categorical_sizes, n_arguments)
        
    def pi(self, spatial_state, player_state, mask):
        spatial_features = self.spatial_features_net(spatial_state, player_state)
        nonspatial_features = self.nonspatial_features_net(spatial_features)
        logits = self.actor(nonspatial_features) 
        log_probs = F.log_softmax(logits.masked_fill((mask).bool(), float('-inf')), dim=-1) 
        return log_probs, spatial_features, nonspatial_features
    
    def V_critic(self, spatial_state, player_state):
        spatial_features = self.spatial_features_net(spatial_state, player_state)
        nonspatial_features = self.nonspatial_features_net(spatial_features)
        V = self.critic(nonspatial_features)
        return V
    
    def sample_spatial_params(self, spatial_features, actions):
        """
        Input
        -----
        spatial_features: tensor, (batch_size, n_channels, screen_res, screen_res)
        actions: array, (batch_size,)
        
        Returns
        -------
        arg_list: list of lists
        """
        batch_size = actions.shape[0]
        parallel_args, parallel_log_prob, _ = self.spatial_params_net(spatial_features)
        if debug:
            expected = actions.shape + (self.n_spatial_args, 2)
            actual = parallel_args.shape
            assert actual == expected, ("unexpected parallel_args shape; actual vs expected: ", \
                                        actual, expected)
            expected = actions.shape + (self.n_spatial_args,)
            actual = parallel_log_prob.shape
            assert actual == expected, ("unexpected parallel_log_prob shape; actual vs expected: ", \
                                        actual, expected)

        # Select only spatial arguments needed by sampled actions
        arg_mask = self.spatial_arg_mask[actions,:] # shape (batch_size, n_spatial_args)
        if debug:
            expected = actions.shape + (self.n_spatial_args,)
            actual = arg_mask.shape
            assert actual == expected, ("unexpected arg_mask shape; actual vs expected: ", \
                                        actual, expected)

        batch_pos = arg_mask.nonzero()[0]
        arg_pos = arg_mask.nonzero()[1]
        args = parallel_args[batch_pos, arg_pos]
        #arg_list = [args[batch_pos==i] for i in range(batch_size)]
        arg_list = [list(args[batch_pos==i]) for i in range(batch_size)]
        
        # Compute composite log_probs of selected arguments
        
        # Infer device from spatial_params_net output with parallel_log_prob.is_cuda
        if parallel_log_prob.is_cuda:
            device = 'cuda' # Assume only 1 GPU device is used 
        else:
            device = 'cpu'
            
        # for every arg index contains the index of the action that uses that parameter
        main_action_ids = torch.tensor(self.spatial_arg_mask.nonzero()[0]).to(device)
        if debug:
            expected = self.n_spatial_args
            actual = main_action_ids.shape[0]
            assert actual == expected, ("unexpected main_action_ids shape; actual vs expected: ", \
                                        actual, expected)
            
        sum_log_prob = torch.zeros(batch_size, len(self.action_table)) # (batch_size, action_space)
        sum_log_prob.index_add_(1, main_action_ids, parallel_log_prob)
        sampled_actions = torch.tensor(actions) # of shape (batch_size,)
        # sum of log_probs of the relevant parameters by
        log_prob = sum_log_prob[torch.arange(batch_size), sampled_actions]
        if debug:
            expected = batch_size
            actual = log_prob.shape[0]
            assert actual == expected, ("unexpected main_action_ids shape; actual vs expected: ", \
                                        actual, expected)
        return arg_list, log_prob
    
    def sample_categorical_params(self, categorical_features, actions):
        """
        Input
        -----
        categorical_features: tensor, (batch_size, n_channels, screen_res, screen_res)
        actions: array, (batch_size,)
        """
        batch_size = actions.shape[0]
        parallel_args, parallel_log_prob = self.categorical_params_net(categorical_features)
        arg_mask = self.categorical_arg_mask[actions,:] # shape (batch_size, n_spatial_args)
        
        # select correct arguments
        batch_pos = arg_mask.nonzero()[0]
        arg_pos = arg_mask.nonzero()[1]
        args = parallel_args[batch_pos, arg_pos]
        arg_list = [list(args[batch_pos==i]) for i in range(batch_size )]

        # select and sum correct log probs
        if parallel_log_prob.is_cuda:
            device = 'cuda' # Assume only 1 GPU device is used 
        else:
            device = 'cpu'

        # for every arg index contains the index of the action that uses that parameter
        main_action_ids = torch.tensor(self.categorical_arg_mask.nonzero()[0]).to(device)
        if debug:
            expected = self.n_categorical_args
            actual = main_action_ids.shape[0]
            assert actual == expected, ("unexpected main_action_ids shape; actual vs expected: ", \
                                        actual, expected)



        sum_log_prob = torch.zeros(batch_size, len(self.action_table)) # (batch_size, action_space)
        sum_log_prob.index_add_(1, main_action_ids, parallel_log_prob)
        sampled_actions = torch.tensor(actions) # of shape (batch_size,)
        # sum of log_probs of the relevant parameters by
        log_prob = sum_log_prob[torch.arange(batch_size), sampled_actions]
        if debug:
            expected = batch_size
            actual = log_prob.shape[0]
            assert actual == expected, ("unexpected main_action_ids shape; actual vs expected: ", \
                                        actual, expected)
        return arg_list, log_prob
    
    def sample_params(self, nonspatial_features, spatial_features, actions):
        categorical_arg_list, categorical_log_prob = self.sample_categorical_params(nonspatial_features, actions)
        spatial_arg_list, spatial_log_prob = self.sample_spatial_params(spatial_features, actions)
        
        # merge arg lists
        assert len(categorical_arg_list) == len(spatial_arg_list), ("Expected same length for arg lists", \
                                                                len(categorical_arg_list), len(spatial_arg_list))
        
        assert categorical_log_prob.shape == spatial_log_prob.shape, ("Expected same log_prob shape", \
                                                                 categorical_log_prob.shape, spatial_log_prob.shape)
        log_prob = categorical_log_prob + spatial_log_prob
        arg_list = []
        for cat, spa in zip(categorical_arg_list, spatial_arg_list):
            print("cat; ", cat)
            print("spa: ", spa)
            args = []
            if len(cat) != 0:
                args.append(cat)
            args += [list(s) for s in spa] # hopefully is the right format [[arg1],[arg2],...] x batch time
            print("args: ", args)
            arg_list.append(args)
            
        return arg_list, log_prob
    

## Preparing everything for real test

In [3]:
# Environment parameters
RESOLUTION = 32
game_params = dict(feature_screen=RESOLUTION, feature_minimap=RESOLUTION, action_space="FEATURES") 
game_names = {1:'MoveToBeacon',
              2:'CollectMineralShards',
              3:'DefeatRoaches',
              4:'FindAndDefeatZerglings',
              5:'DefeatZerglingsAndBanelings',
              6:'CollectMineralsAndGas',
              7:'BuildMarines'
              }

map_name = game_names[1]
obs_proc_params = {'select_all':True}
# 6 actions, 5 spatial params (select_rect has 2), 5 nonspatial params
action_names = ['no_op','select_army','select_rect','Move_screen','select_point','Attack_screen']

In [4]:
env = init_game(game_params, map_name)
op = FullObsProcesser(**obs_proc_params)
screen_channels, minimap_channels, in_player = op.get_n_channels()
in_channels = screen_channels + minimap_channels 

In [5]:
spatial_model = net.FullyConvPlayerAndSpatial
nonspatial_model = net.FullyConvNonSpatial
# Internal features, passed inside a dictionary
conv_channels = 32
player_features = 16
# Exposed features, passed outside of a dictionary
n_channels = 48
n_features = 256
spatial_dict = {"in_channels":in_channels, 'in_player':in_player, 
                'conv_channels':conv_channels, 'player_features':player_features}
nonspatial_dict = {'resolution':RESOLUTION, 'kernel_size':3, 'stride':2, 'n_channels':n_channels}

In [6]:
HPs = dict(gamma=0.99, n_steps=20, H=1e-2, 
           spatial_model=spatial_model, nonspatial_model=nonspatial_model,
           n_features=n_features, n_channels=n_channels, action_names=action_names,
           spatial_dict=spatial_dict, nonspatial_dict=nonspatial_dict)

if torch.cuda.is_available():
    HPs['device'] = 'cuda'
else:
    HPs['device'] = 'cpu'
    
print("Using device "+HPs['device'])

lr = 7e-4

Using device cuda


In [7]:
class FullSpaceA2C_v2(FullSpaceA2C):
    def __init__(self, env, spatial_model, nonspatial_model, spatial_dict, nonspatial_dict, 
                 n_features, n_channels, action_names, gamma=0.99, H=1e-2, n_steps=20, device='cpu'):
        self.gamma = gamma
        self.n_steps = n_steps
        self.H = H
        self.AC = ParallelActorCritic(env, spatial_model, nonspatial_model, spatial_dict, 
                                     nonspatial_dict, n_features, n_channels, action_names)
        self.device = device 
        self.AC.to(self.device) 
        
    def step(self, state, action_mask):
        spatial_state = state['spatial']
        player_state = state['player']
        spatial_state = torch.from_numpy(spatial_state).float().to(self.device)
        player_state = torch.from_numpy(player_state).float().to(self.device)
        action_mask = torch.tensor(action_mask).to(self.device)
        
        log_probs, spatial_features, nonspatial_features = self.AC.pi(spatial_state, player_state, action_mask)
        entropy = self.compute_entropy(log_probs)
        probs = torch.exp(log_probs)
        a = Categorical(probs).sample()
        a = a.detach().cpu().numpy()
        log_prob = log_probs[range(len(a)), a]
        
        args, args_log_prob = self.AC.sample_params(nonspatial_features, spatial_features, a)
        assert args_log_prob.shape == log_prob.shape, ("Shape mismatch between arg_log_prob and log_prob ",\
                                                      args_log_prob.shape, log_prob.shape)
        log_prob = log_prob + args_log_prob
        
        action_id = np.array([self.AC.action_table[act] for act in a])
        action = [actions.FunctionCall(action_id[i], args[i]) for i in range(len(action_id))]

        return action, log_prob, torch.mean(entropy)

In [8]:
agent = FullSpaceA2C_v2(env = env, **HPs)
env.close()

In [9]:
n_train_processes = 1
envs = ParallelEnv(n_train_processes, game_params, map_name, obs_proc_params, agent.AC.action_table)

In [10]:
state, action_mask = envs.reset()

# Inspection step

In [11]:
from SC_Utils.A2C_inspection_v3 import *

In [12]:
inspector = InspectionDict(0, "prova", agent)

In [13]:
spatial_state = state['spatial']
player_state = state['player']
spatial_state = torch.from_numpy(spatial_state).float().to(agent.device)
player_state = torch.from_numpy(player_state).float().to(agent.device)
action_mask = torch.tensor(action_mask).to(agent.device)

In [14]:
log_probs, spatial_features, nonspatial_features = agent.AC.pi(spatial_state, player_state, action_mask)
entropy = agent.compute_entropy(log_probs)
probs = torch.exp(log_probs)
a = Categorical(probs).sample()
a = a.detach().cpu().numpy()
log_prob = log_probs[range(len(a)), a]

In [15]:
### Inspection ###
step_dict = {}
p = probs.detach().cpu().numpy() 
step_dict['action_distr'] = p
step_dict['action_sel'] = a

# Choose top 5 actions from the probabilities - check about the batch dim
top_5 = np.argsort(p)[:,-5:]
top_5_actions = np.array(top_5[:,::-1])[0] # some issues in accessing p if I don't call np.array()
step_dict['top_5_actions'] = top_5_actions

In [21]:
# Save SPATIAL distributions only of the top 5 actions + THEIR NAMES
_, _, log_probs = agent.AC.spatial_params_net(spatial_features)
log_probs = log_probs.detach().cpu().numpy()[0] # batch dim 1 during inspection
log_probs.shape

(5, 1024)

In [18]:
top_5_actions # use it as a batch of actions

array([2, 4, 0, 1, 5])

In [25]:
arg_mask = agent.AC.spatial_arg_mask[top_5_actions,:] 
arg_mask.shape #(n_actions, n_spatial_args)

(5, 5)

In [26]:
arg_mask

array([[1., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1.]])

In [29]:
rows, cols = arg_mask.nonzero()

In [30]:
distributions = log_probs[cols]
distributions.shape

(4, 1024)

In [39]:
step_dict['top_5_action_distr'] = {}
for act in top_5_actions:
    step_dict['top_5_action_distr'][act] = {}
    arg_mask = agent.AC.spatial_arg_mask[act,:].astype(bool)
    arg_names = np.array(agent.AC.spatial_arg_names)[arg_mask]
    print(arg_names)
    distr = log_probs[arg_mask].reshape((-1,)+agent.AC.screen_res)
    print(distr)
    for i, name in enumerate(arg_names):
        step_dict['top_5_action_distr'][act][name+'_distr'] = distr[i]

['select_rect/screen' 'select_rect/screen2']
[[[-6.853835  -6.899518  -6.9091325 ... -6.901828  -6.9142013 -6.890389 ]
  [-6.871352  -6.936567  -6.9470887 ... -6.9370604 -6.9423475 -6.8993316]
  [-6.872617  -6.925705  -6.933766  ... -6.9363503 -6.935559  -6.8935723]
  ...
  [-6.8781548 -6.938509  -6.945193  ... -6.936241  -6.9353194 -6.8852797]
  [-6.876353  -6.935577  -6.9377956 ... -6.9275503 -6.9265475 -6.8783255]
  [-6.8391137 -6.8665137 -6.867539  ... -6.8609843 -6.8571305 -6.836329 ]]

 [[-6.889978  -6.8994184 -6.8955946 ... -6.8968673 -6.9060254 -6.917488 ]
  [-6.8972144 -6.9172516 -6.918076  ... -6.920451  -6.930906  -6.947459 ]
  [-6.900403  -6.927324  -6.933854  ... -6.934857  -6.9377623 -6.955583 ]
  ...
  [-6.8931293 -6.9201465 -6.927556  ... -6.929954  -6.934314  -6.9463725]
  [-6.896717  -6.9192557 -6.9322405 ... -6.92431   -6.932617  -6.9462247]
  [-6.913361  -6.9318533 -6.940413  ... -6.93039   -6.934319  -6.9381313]]]
['select_point/screen']
[[[-6.912881  -6.9016047 -6

In [40]:
step_dict['top_5_action_distr'].keys()

dict_keys([2, 4, 0, 1, 5])

In [41]:
step_dict['top_5_action_distr'][2].keys()

dict_keys(['select_rect/screen_distr', 'select_rect/screen2_distr'])

In [None]:
def inspection_step(agent, inspector, state, action_mask):
    spatial_state = state['spatial']
    player_state = state['player']
    spatial_state = torch.from_numpy(spatial_state).float().to(agent.device)
    player_state = torch.from_numpy(player_state).float().to(agent.device)
    action_mask = torch.tensor(action_mask).to(agent.device)

    log_probs, spatial_features, nonspatial_features = agent.AC.pi(spatial_state, player_state, action_mask)
    entropy = agent.compute_entropy(log_probs)
    probs = torch.exp(log_probs)
    a = Categorical(probs).sample()
    a = a.detach().cpu().numpy()
    log_prob = log_probs[range(len(a)), a]

    ### Inspection ###
    step_dict = {}
    p = probs.detach().cpu().numpy() 
    step_dict['action_distr'] = p
    step_dict['action_sel'] = a
    
    # Choose top 5 actions from the probabilities - check about the batch dim
    top_5 = np.argsort(p)[:,-5:]
    top_5_actions = np.array(top_5[:,::-1])[0] # some issues in accessing p if I don't call np.array()
    step_dict['top_5_actions'] = top_5_actions
    
    # Save distributions only of the top 5 actions
    step_dict['top_5_action_distr'] = {}
    with torch.no_grad():
        for act in top_5_actions:
            step_dict['top_5_action_distr'][act] = {} # first nested level
            arg_names = inspector.act_to_arg_names[act]
            for arg_name in arg_names:
                if inspector.arguments_type[arg_name] == 'spatial': # it's either 'spatial' or 'categorical'
                    insp_arg, insp_log_prob, insp_distr = agent.AC.sample_param(spatial_features, arg_name)
                    p = insp_distr.detach().cpu().numpy().reshape(spatial_state.shape[-2:]) 
                else:
                    insp_arg, insp_log_prob, insp_distr = agent.AC.sample_param(nonspatial_features, arg_name)
                    p = insp_distr.detach().cpu().numpy() 
                    
                step_dict['top_5_action_distr'][act][arg_name+'_distr'] = p # second nested level
                
    ### End inspection ###
   
    args, args_log_prob = agent.AC.sample_params(nonspatial_features, spatial_features, a)
    step_dict['args'] = args
    
    log_prob = log_prob + args_log_prob

    action_id = np.array([agent.AC.action_table[act] for act in a])
    action = [actions.FunctionCall(action_id[i], args[i]) for i in range(len(action_id))]

    inspector.store_step(step_dict)
    return action, log_prob, torch.mean(entropy)

In [25]:
a, log_prob, entropy = agent.step(state, action_mask)
a

cat;  []
spa:  []
args:  []
cat;  [0]
spa:  [array([10,  5]), array([ 5, 27])]
args:  [[0], [10, 5], [5, 27]]


[FunctionCall(function=0, arguments=[]),
 FunctionCall(function=3, arguments=[[0], [10, 5], [5, 27]])]

In [26]:
stuff = envs.step(a)

In [11]:
action_mask # only Move_screen and Attack_screen not available

array([[False, False, False,  True, False,  True],
       [False, False, False,  True, False,  True]])

In [12]:
spatial_state = state['spatial']
player_state = state['player']
spatial_state = torch.from_numpy(spatial_state).float().to(agent.device)
player_state = torch.from_numpy(player_state).float().to(agent.device)
action_mask = torch.tensor(action_mask).to(agent.device)

In [13]:
log_probs, spatial_features, nonspatial_features = agent.AC.pi(spatial_state, player_state, action_mask)
entropy = agent.compute_entropy(log_probs)
probs = torch.exp(log_probs)
a = Categorical(probs).sample()
a = a.detach().cpu().numpy()
log_prob = log_probs[range(len(a)), a]

In [14]:
a

array([2, 1])

In [15]:
categorical_arg_list, spatial_arg_list = agent.AC.sample_params(nonspatial_features, spatial_features, a)

In [16]:
categorical_arg_list

[[0], [1]]

In [17]:
spatial_arg_list

[[array([11, 30]), array([ 3, 21])], []]

In [20]:
arg_list = []
for cat, spa in zip(categorical_arg_list, spatial_arg_list):
    args = [cat]+[list(s) for s in spa]
    arg_list.append(args)
arg_list

[[[0], [11, 30], [3, 21]], [[1]]]

In [22]:
action_id = np.array([agent.AC.action_table[act] for act in a])
action_id

array([3, 7])

In [23]:
action = [sc_actions.FunctionCall(action_id[i], arg_list[i]) for i in range(len(action_id))]

In [24]:
result = envs.step(action)

In [95]:
agent.AC.spatial_params_net.size

32

In [96]:
spatial_args, spatial_log_prob = agent.AC.sample_spatial_params(spatial_features, a)

In [98]:
spatial_args

[array([[ 1, 23]]), array([[15,  3]])]

In [99]:
spatial_log_prob

tensor([-6.9371, -6.7936], grad_fn=<IndexBackward>)

In [100]:
log_prob # these 2 can be summed up easily

tensor([-1.3695, -1.3670], grad_fn=<IndexBackward>)

Now we do the same thing for non-spatial params

Reference code:
```python
class CategoricalNet(nn.Module):
    
    def __init__(self, n_features, size, hiddens=[256]):
        super(CategoricalNet, self).__init__()
        layers = []
        
        layers.append(nn.Linear(n_features, hiddens[0]))
        layers.append(nn.ReLU())
            
        for i in range(0,len(hiddens)-1):
            layers.append(nn.Linear(hiddens[i], hiddens[i+1]))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(hiddens[-1], size))
        self.net = nn.Sequential(*layers)
        
    def forward(self, state_rep):
        logits = self.net(state_rep)
        log_probs = F.log_softmax(logits, dim=(-1))
        probs = torch.exp(log_probs)
        arg = Categorical(probs).sample()
        arg = arg.detach().cpu().numpy()
        return arg.reshape(-1,1), log_probs[range(len(arg)), arg], log_probs
```

Changes:
- receive an array of sizes
- consider the max size and get the logits of shape (batch_size, n_arguments, sizes)
- make a mask of shape (n_arguments, sizes) and repeat it along batch_size
    to use like logits.masked_fill(mask.bool(), float('-inf')) before the softmax
- 

In [102]:
sizes = agent.AC.categorical_sizes
categorical_params_net = ParallelCategoricalNet(256, sizes, agent.AC.n_categorical_args)

In [107]:
parallel_args, parallel_log_prob = categorical_params_net(nonspatial_features)
parallel_args

array([[1, 0, 0, 0, 0],
       [1, 0, 0, 2, 0]])

In [108]:
arg_mask = agent.AC.categorical_arg_mask[a,:] # shape (batch_size, n_spatial_args)
if debug:
    expected = a.shape + (agent.AC.n_categorical_args,)
    actual = arg_mask.shape
    assert actual == expected, ("unexpected arg_mask shape; actual vs expected: ", \
                                actual, expected)

arg_mask      

array([[0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.]])

In [109]:
a

array([4, 4])

In [111]:
batch_size = arg_mask.shape[0]

In [118]:
# select correct action
batch_pos = arg_mask.nonzero()[0]
print("batch_pos: ", batch_pos)
arg_pos = arg_mask.nonzero()[1]
print("arg_pos: ", arg_pos)
args = parallel_args[batch_pos, arg_pos]
print("args: ", args)
arg_list = [list(args[batch_pos==i]) for i in range(batch_size )]
arg_list

batch_pos:  [0 1]
arg_pos:  [3 3]
args:  [0 2]


[[0], [2]]

In [113]:
parallel_log_prob

tensor([[-0.7394, -0.7430, -0.7037, -1.3580, -0.7493],
        [-0.7369, -0.7440, -0.7031, -1.3919, -0.7464]], grad_fn=<ViewBackward>)

In [116]:
# select and sum correct log probs

if parallel_log_prob.is_cuda:
    device = 'cuda' # Assume only 1 GPU device is used 
else:
    device = 'cpu'
    
# for every arg index contains the index of the action that uses that parameter
main_action_ids = torch.tensor(agent.AC.categorical_arg_mask.nonzero()[0]).to(device)
if debug:
    expected = agent.AC.n_categorical_args
    actual = main_action_ids.shape[0]
    assert actual == expected, ("unexpected main_action_ids shape; actual vs expected: ", \
                                actual, expected)


            
sum_log_prob = torch.zeros(batch_size, len(agent.AC.action_table)) # (batch_size, action_space)
sum_log_prob.index_add_(1, main_action_ids, parallel_log_prob)
sampled_actions = torch.tensor(a) # of shape (batch_size,)
# sum of log_probs of the relevant parameters by
log_prob = sum_log_prob[torch.arange(batch_size), sampled_actions]
if debug:
    expected = batch_size
    actual = log_prob.shape[0]
    assert actual == expected, ("unexpected main_action_ids shape; actual vs expected: ", \
                                actual, expected)


In [117]:
log_prob

tensor([-1.3580, -1.3919], grad_fn=<IndexBackward>)

In [22]:
args, args_log_prob, args_entropy = self.get_arguments(spatial_features, nonspatial_features, a)
log_prob = log_prob + args_log_prob
# Use only entropy of main actions for regularization
#entropy = entropy + args_entropy

action_id = np.array([self.AC.action_dict[act] for act in a])
action = [actions.FunctionCall(action_id[i], args[i]) for i in range(len(action_id))]

NameError: name 'self' is not defined

In [19]:
envs.close()

# Dev

In [7]:
all_actions = env.action_spec()[0][1]
all_arguments = env.action_spec()[0][0]

In [8]:
def get_action_table(action_names):
    action_ids = [sc_actions.FUNCTIONS[a_name].id for a_name in action_names]
    action_table = np.array([action_ids[i] for i in range(len(action_ids))])
    return action_table

In [9]:
action_names = ['no_op','select_army','select_rect','Move_screen','select_point','Attack_screen']
action_table = get_action_table(action_names) 
action_table

array([  0,   7,   3, 331,   2,  12])

In [20]:
# now we need to get all the possible arguments counted twice if they belong to different actions
def _init_arg_names(action_table):
    """
    Add self as first argument and use it as a class method in the AC.
    Assume all_actions and all_arguments as attributes of the class (add self. in front of them)
    Also instead of returning the results, add them as attributes of the class.
    """
    spatial_arg_names = []
    categorical_arg_names = []
    categorical_sizes = []
    act_to_arg_names = {}

    for action_id in action_table:
        action = all_actions[action_id]
        args = action.args
        act_to_arg_names[action_id] = [str(action.name)+"/"+arg.name for arg in args]
        spatial = []
        categorical = []
        for arg in args:
            arg_name = str(action.name)+"/"+arg.name
            size = all_arguments[arg.id].sizes
            if len(size) == 1:
                categorical.append(arg_name)
                categorical_sizes.append(size)
            else:
                spatial.append(arg_name)
        spatial_arg_names+=spatial
        categorical_arg_names+=categorical
    return spatial_arg_names, categorical_arg_names, act_to_arg_names

spatial_arg_names, categorical_arg_names, act_to_arg_names = _init_arg_names(action_table)
        
print('spatial_arg_names: ', spatial_arg_names)
print("Number of args: ", len(spatial_arg_names))
print('categorical_arg_names: ', categorical_arg_names)
print("Number of args: ", len(categorical_arg_names))

spatial_arg_names:  ['select_rect/screen', 'select_rect/screen2', 'Move_screen/screen', 'select_point/screen', 'Attack_screen/screen']
Number of args:  5
categorical_arg_names:  ['select_army/select_add', 'select_rect/select_add', 'Move_screen/queued', 'select_point/select_point_act', 'Attack_screen/queued']
Number of args:  5


In [10]:
def get_spatial_arg_mask(action_table, spatial_arg_names):
    spatial_arg_mask = np.zeros((action_table.shape[0], len(spatial_arg_names)))
    for i, action_id in enumerate(action_table):
        action_arg_names = act_to_arg_names[action_id]
        spatial_arg_mask[i] = np.array([1 if spatial_arg_names[j] in action_arg_names else 0 for j in range(len(spatial_arg_names))])
    return spatial_arg_mask

In [11]:
spatial_arg_mask = get_spatial_arg_mask(action_table, spatial_arg_names)
spatial_arg_mask

NameError: name 'spatial_arg_names' is not defined

In [23]:
def get_categorical_arg_mask(action_table, categorical_arg_names):
    categorical_arg_mask = np.zeros((action_table.shape[0], len(categorical_arg_names)))
    for i, action_id in enumerate(action_table):
        action_arg_names = act_to_arg_names[action_id]
        categorical_arg_mask[i] = np.array([1 if categorical_arg_names[j] in action_arg_names else 0 for j in range(len(categorical_arg_names))])
    return categorical_arg_mask

In [24]:
categorical_arg_mask = get_categorical_arg_mask(action_table, categorical_arg_names)
categorical_arg_mask

array([[0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

In [25]:
actions = np.array([2,1,0])
spatial_arg_mask[actions,:] # use this to access args or mask log probs

array([[1., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]])

In [27]:
B = 3
n_channels = 16
res = 32
n_arguments = 5
x = torch.rand(B,n_channels,res,res)
print(x.shape)
spatial_net = ParallelSpatialParameters(n_channels, res, n_arguments)
args, log_prob, log_probs = spatial_net(x) # I can forget about log_probs, because I don't compute their entropy
print('args: ', args.shape, '\n', args) # lipst of lists
print('log_prob: ', log_prob)
print('log_probs.shape: ', log_probs.shape) # (B, res**2)

torch.Size([3, 16, 32, 32])
args:  (3, 5, 2) 
 [[[14 16]
  [13 25]
  [ 6 25]
  [28  4]
  [ 3 20]]

 [[13 12]
  [16 11]
  [ 9 18]
  [20 29]
  [11 15]]

 [[15 18]
  [11 20]
  [ 5 21]
  [18 28]
  [ 7  9]]]
log_prob:  tensor([[-6.7258, -6.9170, -7.1031, -7.2827, -7.0330],
        [-6.7674, -7.1977, -6.7109, -6.7441, -6.9934],
        [-6.9490, -6.9059, -6.5702, -6.8146, -7.0296]], grad_fn=<ViewBackward>)
log_probs.shape:  torch.Size([3, 5, 1024])


In [32]:
actions = np.array([2,1,3])
arg_mask = spatial_arg_mask[actions,:] # use this to access args or mask log probs
arg_mask

array([[1., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.]])

In [33]:
batch_pos = arg_mask.nonzero()[0]
batch_pos

array([0, 0, 2])

In [34]:
arg_pos = arg_mask.nonzero()[1]
arg_pos

array([0, 1, 2])

In [35]:
args_selected = args[batch_pos, arg_pos]
args_selected

array([[ 7, 16],
       [ 9, 15],
       [21,  3]])

In [36]:
# make a list of them based on batch_pos
arg_list = [args_selected[batch_pos==i] for i in range(B)]
arg_list

[array([[ 7, 16],
        [ 9, 15]]), array([], shape=(0, 2), dtype=int64), array([[21,  3]])]

In [37]:
spatial_arg_mask

array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

In [38]:
main_action_ids = torch.tensor(spatial_arg_mask.nonzero()[0])
main_action_ids

tensor([2, 2, 3, 4, 5])

In [39]:
n_actions = len(action_table)
sum_log_prob = torch.zeros(B, n_actions)
sum_log_prob.index_add_(1, main_action_ids, log_prob)

tensor([[  0.0000,   0.0000, -14.0128,  -7.1679,  -6.9187,  -7.1077],
        [  0.0000,   0.0000, -14.0146,  -6.6906,  -7.0085,  -6.6435],
        [  0.0000,   0.0000, -13.7723,  -6.5021,  -6.9189,  -6.6367]],
       grad_fn=<IndexAddBackward>)

In [40]:
log_prob.get_device()

-1

In [41]:
log_prob.is_cuda

False

In [42]:
# Now if you sample actions
samples = torch.tensor(actions) # of shape (batch_size,)
# you can get the sum of log_probs of the relevant parameters by
sum_log_prob[torch.arange(B), samples]

tensor([-14.0128,   0.0000,  -6.5021], grad_fn=<IndexBackward>)

# Alexander code

In [1]:
#which positions are owned by which argument
main_action_ids = torch.tensor([0, 0, 1, 1, 2])
n_arguments = len(main_action_ids)
n_actions = 3 # all possible unique elements in main_action_ids
batch_size = 4

NameError: name 'torch' is not defined

In [18]:
log_probs = torch.arange(batch_size*n_arguments).float().view(batch_size, n_arguments)

In [25]:
log_probs

tensor([[[-7.1397, -7.0619, -6.8159,  ..., -6.9571, -6.8822, -6.6614],
         [-7.2100, -6.8245, -6.8680,  ..., -7.1253, -7.2476, -6.8876],
         [-6.7696, -7.1995, -6.9974,  ..., -6.8698, -6.7482, -6.2425],
         [-6.5937, -6.6011, -6.9123,  ..., -6.9024, -7.0177, -6.8117],
         [-6.7422, -6.7434, -6.5691,  ..., -7.0115, -7.1403, -7.3158]],

        [[-6.9153, -6.7841, -6.7303,  ..., -6.9872, -6.8430, -6.7925],
         [-6.9692, -6.6834, -7.1690,  ..., -6.9372, -7.2004, -6.9948],
         [-6.7372, -7.2126, -7.0391,  ..., -6.6308, -6.5861, -6.3031],
         [-6.7129, -6.7453, -6.8583,  ..., -6.9967, -7.1100, -6.7289],
         [-6.8226, -6.8818, -6.7086,  ..., -7.1040, -6.8994, -7.3203]],

        [[-7.0819, -6.9388, -7.0197,  ..., -6.8330, -6.8524, -7.0321],
         [-6.9040, -6.7202, -7.0060,  ..., -7.0316, -7.0382, -6.7653],
         [-6.7028, -6.9682, -6.9849,  ..., -6.4215, -6.4388, -6.1832],
         [-6.6856, -6.8304, -6.7160,  ..., -7.1000, -7.0478, -6.8546],
  

In [20]:
sum_log_probs = torch.zeros(batch_size, n_actions)
sum_log_probs.index_add_(1, main_action_ids, log_probs)

tensor([[ 1.,  5.,  4.],
        [11., 15.,  9.],
        [21., 25., 14.],
        [31., 35., 19.]])

In [21]:
# Now if you sample actions
samples = torch.tensor([0, 1, 0, 2]) # of shape (batch_size,)
# you can get the sum of log_probs of the relevant parameters by
sum_log_probs[torch.arange(batch_size), samples]

tensor([ 1., 15., 21., 19.])

## Other part

Values are all the possible values that each argument can assume; n_values in reality would be res**2

In [None]:
batch_size = 4
n_values = 6 # possible range of sample_value_ids
log_probs = torch.arange(batch_size*n_arguments*n_values).float().view(batch_size, n_arguments, n_values)

# We sample based on these log_probs
sample_value_ids = torch.tensor([
    [0, 0, 1, 1, 2],
    [3, 3, 4, 4, 5],
    [0, 1, 2, 3, 4],
    [5, 4, 3, 2, 1],
])

In [26]:
log_probs

tensor([[[  0.,   1.,   2.,   3.,   4.,   5.],
         [  6.,   7.,   8.,   9.,  10.,  11.],
         [ 12.,  13.,  14.,  15.,  16.,  17.],
         [ 18.,  19.,  20.,  21.,  22.,  23.],
         [ 24.,  25.,  26.,  27.,  28.,  29.]],

        [[ 30.,  31.,  32.,  33.,  34.,  35.],
         [ 36.,  37.,  38.,  39.,  40.,  41.],
         [ 42.,  43.,  44.,  45.,  46.,  47.],
         [ 48.,  49.,  50.,  51.,  52.,  53.],
         [ 54.,  55.,  56.,  57.,  58.,  59.]],

        [[ 60.,  61.,  62.,  63.,  64.,  65.],
         [ 66.,  67.,  68.,  69.,  70.,  71.],
         [ 72.,  73.,  74.,  75.,  76.,  77.],
         [ 78.,  79.,  80.,  81.,  82.,  83.],
         [ 84.,  85.,  86.,  87.,  88.,  89.]],

        [[ 90.,  91.,  92.,  93.,  94.,  95.],
         [ 96.,  97.,  98.,  99., 100., 101.],
         [102., 103., 104., 105., 106., 107.],
         [108., 109., 110., 111., 112., 113.],
         [114., 115., 116., 117., 118., 119.]]])

In [24]:
arg_log_probs = log_probs.view(batch_size*n_arguments, n_values) \
[torch.arange(batch_size*n_arguments), sample_value_ids.flatten()] \
.view(batch_size, n_arguments)

In [25]:
arg_log_probs

tensor([[  0.,   6.,  13.,  19.,  26.],
        [ 33.,  39.,  46.,  52.,  59.],
        [ 60.,  67.,  74.,  81.,  88.],
        [ 95., 100., 105., 110., 115.]])

### IMPALA stuff 

In [87]:
batch_size = 2
n_actions = 4 # main actions
n_args = 5 # spatial args
res = 32

# like if they were no_op, Move_Screen, select_rect, select_point for example
spatial_arg_mask = torch.tensor([
    [0,0,0,0,0],
    [0,1,0,0,0],
    [0,0,1,1,0],
    [0,0,0,0,1]
], dtype=torch.bool)

# assume this is the output of the learner net and we need to access only the log_probs of the arguments 
# already sampled by the actors
logits = torch.rand((batch_size, n_args, res**2))
log_probs = F.log_softmax(logits, dim=-1)
log_probs

tensor([[[-6.6373, -6.6354, -6.5852,  ..., -7.3167, -7.3601, -7.0510],
         [-7.2242, -7.4536, -6.9138,  ..., -6.5512, -7.1419, -6.7743],
         [-7.2772, -6.5754, -7.0199,  ..., -7.3377, -7.2435, -7.1724],
         [-6.6585, -7.2519, -6.6346,  ..., -7.3207, -7.0605, -7.4516],
         [-6.6625, -7.1998, -6.5355,  ..., -6.8263, -7.0948, -6.7885]],

        [[-7.2413, -7.2254, -6.4852,  ..., -6.7026, -6.8807, -6.5622],
         [-7.4441, -6.4725, -7.1917,  ..., -6.8488, -7.1549, -7.1132],
         [-6.9159, -6.9486, -6.6743,  ..., -6.5905, -6.8584, -6.6929],
         [-7.4537, -7.3384, -6.8046,  ..., -6.9212, -6.6085, -7.2997],
         [-7.1154, -6.9548, -6.8218,  ..., -7.2685, -6.5803, -6.5897]]])

In [65]:
# test
probs = torch.exp(log_probs)
torch.sum(probs, axis =-1)

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]])

In [80]:
# say we sample as main actions 
main_actions = torch.tensor([2,2])
# and as sampled arguments we have
indexes = Categorical(probs).sample() # sample all args for batch_size times
mask = spatial_arg_mask[main_actions,:]
sampled_args_indexes = indexes[mask]
sampled_args_indexes

tensor([796, 698, 677, 127])

In [81]:
mask.nonzero()

tensor([[0, 2],
        [0, 3],
        [1, 2],
        [1, 3]])

In [82]:
# to use them we need more info, because it's not a 1 to 1 correspondence, since every action can have between 0 and 2 spatial arguments 
# and up to 3 arguments in total (considering the categorical once)
# basically we need to know how to split them, associating to them a batch index and an arg index
batch_indexes = mask.nonzero()[:,0]
# first arg belongs to second batch element, second and third args to third batch element and 
# fourth and fifth args to fourth batch element
print("batch_indexes: ", batch_indexes)
arg_indexes = mask.nonzero()[:,1]
# this one says from with position along the argument axis the argument comes from
print("arg_indexes: ", arg_indexes)

batch_indexes:  tensor([0, 0, 1, 1])
arg_indexes:  tensor([2, 3, 2, 3])


In [83]:
sampled_log_prob = log_probs[batch_indexes, arg_indexes, sampled_args_indexes]
sampled_log_prob

tensor([-6.8043, -6.8512, -6.5093, -6.5309])

In [84]:
# the other thing that we have to do is to sum together the log_probs of the arguments 
# that belong to the same action
sum_log_prob = torch.zeros(batch_size)
sum_log_prob.index_add_(0, batch_indexes, sampled_log_prob)
# actions[0] = no_op has no arguments -> default value of 0
# actions[1]=select_point has 1 argument (the first one stored in sampled_log_prob)
# actions[2]=select_rect has 2 arguments (second and third stored)
# actions[3]=select_rect has 2 arguments (fourth and fifth stored)
sum_log_prob # no padding needed for this, it's always of shape (batch_size,)

tensor([-13.6555, -13.0402])

In [85]:
# last problem that we have is how to communicate a variable number of arguments through a torch buffer 
# (tensor of fixed size) - padding?
max_num_spatial_args = torch.max(spatial_arg_mask.sum(axis=1)) * batch_size
max_num_spatial_args

tensor(4)

In [86]:
def pad_to_len(t, length, fill_value=-1):
    """ Assuming t of shape L <= length """
    assert t.shape[0] <= length, "tensor too long to be padded"
    padding = torch.ones(length-len(t), dtype=torch.int64)*-1
    padded_t = torch.cat([t, padding])
    return padded_t

paddded_sampled_args_indexes = pad_to_len(sampled_args_indexes, max_num_spatial_args)
print(paddded_sampled_args_indexes)
sampled_args_indexes = paddded_sampled_args_indexes[paddded_sampled_args_indexes!=-1]
sampled_args_indexes

tensor([796, 698, 677, 127])


tensor([796, 698, 677, 127])

List of variables to be passed through the buffer 
-> they also need to be given as output in the agent_output dictionary
- main_actions # sampling involved
- sum_log_prob # actor weights used
- spatial_sampled_args_indexes # sampling involved

Tecnically speaking we can recompute batch_indexes and arg_indexes inside the learner's step, as long as 
we have access to the spatial_arg_mask

In [98]:
# Now let's simulate the learner 
B = 2 # batch of 2 samples
T = 3 # 3 timesteps
# time first
main_actions = torch.tensor([
                            [0,3],
                            [2,1],
                            [2,2]
])

padded_spatial_indexes = torch.tensor([
    [213,-1,-1,-1],
    [1019,595,913,-1],
    [796, 698, 677, 127]
])

spatial_log_probs = F.log_softmax(torch.rand((T, B, n_args, res**2)), dim=-1).view(B*T, n_args, res**2)

Our goal is to compute the new sum_log_prob, assuming that spatial_log_probs are coming from the learner's weight instead of the actor ones

In [96]:
mask = spatial_arg_mask[main_actions,:].view(-1,n_args) # merge time and batch dims
mask

tensor([[False, False, False, False, False],
        [False, False, False, False,  True],
        [False, False,  True,  True, False],
        [False,  True, False, False, False],
        [False, False,  True,  True, False],
        [False, False,  True,  True, False]])

In [94]:
spatial_indexes = padded_spatial_indexes[padded_spatial_indexes!=-1]
spatial_indexes

tensor([ 213, 1019,  595,  913,  796,  698,  677,  127])

In [99]:
batch_index = mask.nonzero()[:,0]
arg_index = mask.nonzero()[:,1]

In [100]:
spatial_log_prob = spatial_log_probs[batch_index, arg_index, spatial_indexes]
spatial_log_prob

tensor([-7.2616, -7.2981, -6.6641, -6.9051, -6.7400, -7.3465, -6.8952, -6.6582])

In [102]:
sum_log_prob = torch.zeros(B*T)
sum_log_prob.index_add_(0, batch_index, spatial_log_prob)
sum_log_prob

tensor([  0.0000,  -7.2616, -13.9621,  -6.9051, -14.0865, -13.5534])