In [1]:
import sys
sys.path.insert(0, "../")

import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

In [2]:
from pysc2.env import sc2_env
from pysc2.lib import actions as sc_actions
from SC_Utils.game_utils import IMPALA_ObsProcesser, FullObsProcesser

pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
def init_game(game_params, map_name='MoveToBeacon', step_multiplier=8, **kwargs):

    race = sc2_env.Race(1) # 1 = terran
    agent = sc2_env.Agent(race, "Testv0") # NamedTuple [race, agent_name]
    agent_interface_format = sc2_env.parse_agent_interface_format(**game_params) #AgentInterfaceFormat instance

    game_params = dict(map_name=map_name, 
                       players=[agent], # use a list even for single player
                       game_steps_per_episode = 0,
                       step_mul = step_multiplier,
                       agent_interface_format=[agent_interface_format] # use a list even for single player
                       )  
    env = sc2_env.SC2Env(**game_params, **kwargs)

    return env

In [4]:
# Environment parameters
RESOLUTION = 32
game_params = dict(feature_screen=RESOLUTION, feature_minimap=RESOLUTION, action_space="FEATURES") 
game_names = ['MoveToBeacon','CollectMineralShards','DefeatRoaches','FindAndDefeatZerglings',
              'DefeatZerglingsAndBanelings','CollectMineralsAndGas','BuildMarines']
map_name = game_names[1]
obs_proc_params = {'select_all':True}
op = FullObsProcesser(**obs_proc_params)
screen_channels, minimap_channels, in_player = op.get_n_channels()
in_channels = screen_channels + minimap_channels 

"""
# A2C params
spatial_model = net.FullyConvPlayerAndSpatial
nonspatial_model = net.FullyConvNonSpatial
# Internal features, passed inside a dictionary
conv_channels = flags.conv_channels #32
player_features = flags.player_features #16
# Exposed features, passed outside of a dictionary
n_channels = conv_channels + player_features #48
n_features = flags.n_features #256

spatial_dict = {"in_channels":in_channels, 'in_player':in_player, 
                'conv_channels':conv_channels, 'player_features':player_features}
nonspatial_dict = {'resolution':RESOLUTION, 'kernel_size':3, 'stride':2, 'n_channels':n_channels}

HPs = dict(spatial_model=spatial_model, nonspatial_model=nonspatial_model,
       n_features=n_features, n_channels=n_channels, action_names=flags.action_names,
       spatial_dict=spatial_dict, nonspatial_dict=nonspatial_dict)
game_params['HPs'] = HPs
"""
print()




In [5]:
env = init_game(game_params, map_name)

In [6]:
obs = env.reset()

In [7]:
#action = sc_actions.FunctionCall(actions.FUNCTIONS.no_op.id, [])
action = sc_actions.FunctionCall(sc_actions.FUNCTIONS.select_army.id, [[0]]) 
#action = sc_actions.FunctionCall(actions.FUNCTIONS.Attack_screen.id, [[0],[1,1]])
obs = env.step(actions=[action])

In [8]:
action_names = ['no_op', 'move_camera', 'select_point', 'select_rect', 'select_idle_worker', 'select_army', 
              'Attack_screen','Attack_minimap', 'Build_Barracks_screen', 'Build_CommandCenter_screen',
              'Build_Refinery_screen', 'Build_SupplyDepot_screen','Harvest_Gather_SCV_screen', 
              'Harvest_Return_SCV_quick', 'HoldPosition_quick', 'Move_screen', 'Move_minimap',
              'Rally_Workers_screen', 'Rally_Workers_minimap','Train_Marine_quick', 'Train_SCV_quick']

In [9]:
action_ids = [sc_actions.FUNCTIONS[a_name].id for a_name in action_names]
action_table = np.array([action_ids[i] for i in range(len(action_ids))])

In [10]:
IMP_op = IMPALA_ObsProcesser(action_table, **obs_proc_params)

# 1. Last Action as additional input

Obs: All actions that are invalid or equivalent to no-op are not recorded by the environment, so last actions will be [ ]

In [11]:
last_action = obs[0].observation['last_actions']
print(last_action)
if len(last_action) == 0:
    last_action = 0
else:
    last_action = last_action[0]
last_action

[7]


7

In [14]:
last_action_idx = np.where(IMP_op.action_table == last_action)[0]
print(IMP_op.action_table)
#IMP_op.action_table[last_action_idx]

[  0   1   2   3   6   7  12  13  42  44  79  91 268 273 274 331 332 343
 344 477 490]


In [15]:
last_action_idx

array([5])

We also need somehow to embed the last action in a meaningful way (I guess that an embedding layer with embedding dim of 10 would do). After that we can simply concatenate player\_info with last\_action.

The model already has the information about the action space, so we just need to pass the embed\_dim variable (we can actually keep it constant to 10 for simplicity).

# 2. Screen / Minimap / Categorical action 

Task: tile a binary mask to screen and minimap with ones if respectively last action was acting on the screen or on the minimap, with zeros otherwise.

How to understand if an action is for screen or minimap? At the moment I just have a spatial vs categorical distinction at the argument level, but nothing screen vs minimap vs other at the main action level.

It makes sense to build a look-up table before the beginning of the training to answer this question as fast as possible during runtime.

In [17]:
all_actions = env.action_spec()[0][1]
all_arguments = env.action_spec()[0][0]

In [18]:
def check_if_screen(env, sc_env_action, screen=True):
    """
    Modify this function in a method for some class that has access to the action specs
    (could be the wrapped Environment class, with self.env instead of env)
    """
    all_actions = env.action_spec()[0][1]
    all_arguments = env.action_spec()[0][0]
    
    ###
    args = all_actions[sc_env_action].args
    names = [all_arguments[arg.id].name for arg in args]
    if screen:
        return np.any(['screen' in n for n in names])
    else:
        return np.any(['minimap' in n for n in names])

In [19]:
check_if_screen(env, last_action)

False

In [23]:
screen_mask = list(map(lambda x: check_if_screen(env, x, True), IMP_op.action_table))
screen_mask

[False,
 False,
 True,
 True,
 False,
 False,
 True,
 False,
 True,
 True,
 True,
 True,
 True,
 False,
 False,
 True,
 False,
 True,
 False,
 False,
 False]

In [24]:
minimap_mask = list(map(lambda x: check_if_screen(env, x, False), IMP_op.action_table))
minimap_mask

[False,
 True,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 True,
 False,
 False]

In [26]:
state_dict, names = IMP_op.get_state(obs)

In [27]:
state_dict.keys()

dict_keys(['screen_layers', 'minimap_layers', 'player_features'])

In [28]:
state_dict['screen_layers'].shape

(26, 32, 32)

In [29]:
state_dict['minimap_layers'].shape

(12, 32, 32)

In [30]:
state_dict['player_features'].shape

(8,)

In [32]:
# simple access during run time
screen_binary_mask = np.array([screen_mask[last_action_idx[0]]])
screen_binary_mask2D = np.tile(screen_binary_mask, [1,32,32])
state_dict['screen_layers'] = np.concatenate([state_dict['screen_layers'], screen_binary_mask2D])
state_dict['screen_layers'].shape

(27, 32, 32)

In [33]:
[1,*state_dict['screen_layers'].shape[-2:]]

[1, 32, 32]

And of course same thing for minimap.

# 3. Spatial input processing

*Spatially encoded inputs (minimap and screen) are tiled with binary masks denoting
whether the previous action constituted a screen- or minimap-related action. These tensors are then fed to
independent residual convolutional blocks, each consisting of one convolutional layer (4 × 4 kernels and stride
2) followed by a residual block with 2 convolutional layers (3 × 3 kernels and stride 1), which process and
downsample the inputs to [8 × 8 × #channels 1 ] outputs. These tensors are concatenated along the depth
dimension to form a singular spatial input (inputs 3D ).*

Differences with previous implementation: 
1. First process them, then merge them

In [34]:
class ResidualConvLayer(nn.Module):
    
    def __init__(self, res, n_channels, kernel_size=3):
        super(ResidualConvLayer, self).__init__()
        
        padding = (kernel_size - 1) // 2
        assert (kernel_size - 1) % 2 == 0, 'Provide odd kernel size to use this layer'
        
        # pre-activations as in Identity Mappings in Deep Residual Networks https://arxiv.org/abs/1603.05027
        self.net = nn.Sequential(
                                nn.ReLU(),
                                nn.Conv2d(n_channels, n_channels, kernel_size, stride=1, padding=padding),
                                )
        
    def forward(self, x):
        x = self.net(x) + x
        return x 

In [76]:
# add a maxpool to shrink each linear dimension by a factor of 4 (2 from stride of first conv, 2 for pooling)
class StateEncodingConvBlock(nn.Module):
    """ 
    - First conv layer halves the spatial dimensions
    - 2 residual convolutional layers with ReLU pre-activations (they act on the input and not the output)
    - 2x2 MaxPool to halve again the dimensions
    """
    def __init__(self, res, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        super(StateEncodingConvBlock, self).__init__()
        new_res = (res - kernel_size + 2*padding)//stride + 1
        self.new_res = new_res # useful info to access from outside the class
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            ResidualConvLayer(new_res, out_channels, kernel_size=3),
            ResidualConvLayer(new_res, out_channels, kernel_size=3),
            nn.MaxPool2d(2),
        )
        
    def forward(self, x):
        return self.net(x)

In [75]:
layer = nn.MaxPool2d(2)
img = torch.rand(1,1,32,32)
out = layer(img)
out.shape

torch.Size([1, 1, 16, 16])

In [44]:
res = 32
in_channels = screen_channels + 1
out_channels = 32
conv_block = StateEncodingConvBlock(res, in_channels, out_channels)

In [40]:
screen_tensor = torch.tensor(state_dict['screen_layers']).float().unsqueeze(0)
#screen_tensor = torch.rand((1,in_channels, res, res))
x_screen = conv_block(screen_tensor)

In [41]:
x_screen.shape

torch.Size([1, 32, 16, 16])

In [54]:
# Cell to compute output resolution for a conv layer
kernel_size = 4
padding = 1
stride = 2
res = 32
new_res = (res - kernel_size + 2*padding)//stride + 1
new_res

16

So basically we will have 2 convolutional blocks, one for the minimap and one for the screen. If we want a shortcut of this we can just merge them together and use a single convolutional block, but of course this way might make a better use of the domain knowledge (i.e. don't treat spatial information at two different scales like if it was from the same).

# 4. Variable dimensionality recap

**screen**: tiling binary mask and adding batch dim: (1, screen_channels, res, res) <br>
**minimap**: tiling binary mask and adding batch dim: (1, minimap_channels, res, res) <br>
**player**:  adding batch dim: (1, in_player) <br>
**last_action**:  adding batch dim: (1,) (good like this because it needs to be embedded afterwards)<br>

### After state encoding:

**inputs_3D**: (1, #channels_1, new_res, new_res) <br>
with `new_res = (res - kernel_size + 2*padding)//stride + 1` and `#channels_1 = out_channels*2` (default 64 and 32 respectively) <br>
Uses **StateEncodingConvBlock**.

**inputs_2D**: (1, in_player+embed_dim) -> (1,128) -> ReLU -> (1, 64) <br>
with default value of embed_dim equal to 10. <br>
Uses **Inputs2D_Net**.

### After memory processing:
Note that inputs_3D are used as input to the Conv2D LSTM and not inputs_2D! Also call the output outputs_3D.

Conv2D LSTM: kernel size 3x3, stride 1, (padding of 1 to keep resolution constant), #output_channels 96

**outputs_3D**: (1, #output_channels, new_res, new_res) (same spatial resolution) <br>
Uses **ConvLSTM**.

### Main processing (control part for the relational processing)

Input: outputs_3D

2 flows:

SPATIAL: 12-layer deep residual model ( 4 blocks of 3 convolutional layers each )
   - first: kernel 4x4, stride 1 (?)
   - second and third: kernel 3x3, stride 1 
   - "interleaved with ReLU activations and skip-connections" (I don't know if after every layer or every block; <br> ReLUs make sense after each layer, skip connections after every block maybe) 
    
Since the output should have the same shape of the relational-spatial outputs, i.e. [8, 8, #channels2], and inputs_3D already have that spatial resolution in the paper, I would substitute 8 with a more generic new_res and deduce that the whole spatial architecture has padding so that the resolution remains unchanged. #channels2 is not better specified (we can keep it the same as #output_channels for semplicity) <br>
Uses **DeepResidualBlock**.

NON-SPATIAL: flattened (all three dimensions I guess) and passed to a 2-layer MLP (512 units per layer, ReLU activations) to produce what we refer to above as relational-nonspatial <br>
(1, 64 x #channels2) -> (1, 512) -> ReLU -> (1, 512) -> ReLU <br>
Uses **NonSpatialBlock**.

### Output processing

inputs 2D and relational-nonspatial are concatenated to form a set of shared features. <br>
Shared features used to produce log probs and V. <br>
shared features: 
- (1, 64 + 512) -> (1, 256) -> ReLU -> (1, #actions) for logits
- (1, 64 + 512) -> (1, 256) -> ReLU -> (1, 1) for value

Actions are sampled using computed policy logits and embedded into a 16 dimensional vector. 

This embedding is used to condition shared features and generate logits for non-spatial arguments (Args) through independent linear combinations (one for each argument). [Basically we concatenate the action to shared features and then pass it through a 2-layers MLP?]

Finally, spatial arguments (Args x,y ) are obtained by first deconvolving relational-spatial to [32 × 32 × #channels 3 ] tensors using Conv2DTranspose layers, conditioned by tiling the action embedding along the depth dimension and passed to a 1 × 1 × 1 convolution layers (one for each spatial argument). 
1 x 1 x 1 means a Conv2d with kernel size of 1 and output channels of 1.

#channels_3 = 16 <br>
Conv2DTranspose: kernel size 4x4, stride 2 (does it work?)


# Conv2D_LSTM layer
From https://github.com/ndrplz/ConvLSTM_pytorch

In [47]:
from ConvLSTM_pytorch.convlstm import ConvLSTM

In [89]:
input_channels = out_channels
# in case of a single layer
num_layers = 1
kernel_size = 3
hidden_channels = [96 for _ in range(num_layers)] # not sure about this
kernel = [(kernel_size, kernel_size) for _ in range(num_layers)] # pay attention to this
conv_lstm = ConvLSTM(input_channels, 
                     hidden_channels, 
                     kernel, 
                     num_layers,
                     batch_first=False,
                     bias=True,
                     return_all_layers=True
                    )

In [90]:
#input_test1 = torch.rand(1, 1, input_channels, new_res, new_res)
input_test2 = torch.rand(5, 10, input_channels, new_res, new_res)

In [91]:
# input must be 5d (t, b, channels, w, h) or (b, t, channels, w, h) 
layer_output_list, last_state_list = conv_lstm(input_test2)

In [77]:
assert len(layer_output_list) == num_layers, "len(layer_output_list) is %d"%len(layer_output_list)
layer_output_list[0].shape

torch.Size([5, 10, 96, 16, 16])

In [78]:
assert len(last_state_list) == num_layers, "len(last_state_list) is %d"%len(last_state_list)
assert len(last_state_list[0]) == 2, "len(last_state_list[0]) is %d"%len(last_state_list[0])
assert last_state_list[0][0].shape == last_state_list[0][1].shape, 'h and c have different shapes'
assert torch.all(last_state_list[0][0] == layer_output_list[0][:,-1,...]), 'they both should be last h of the first layer'
last_state_list[0][0].shape

torch.Size([5, 96, 16, 16])

**layer_output_list**: <br>
[(b,t,c,w,h), ..., (b,t,c,w,h)] <br>
List length equal to number of layers

**last_state_list**: <br>
[[h,c], ..., [h,c]] <br>
with both c and h of shape (b,c,w,h) - no time dimension in hidden and cell states (because is understood that they are relative to the last timestep)

In [93]:
# looping one step at the time and using the previous hidden state as new state of the lstm
T = 10
hidden_states = None
for t in range(T):
    layer_output_list, hidden_states = conv_lstm(input_test2, hidden_states)
out = layer_output_list[-1]
out.shape

torch.Size([10, 5, 96, 16, 16])

In [65]:
R = out.sum()
R.backward()

After this step it would make sense (maybe) to merge batch and time dimensions (I'm thinking about the learner).

In case of the actors, we just have a batch dimension of 1, so we will add a fake time dimension in front of it (or as second dimension, in a coherent way with the tensors coming out of the buffers that will be used by the learner) and we will shrink it again to 4d afterwards.

Just use something like:

out = out.view((-1,*out.shape[2:]))

In [95]:
out = out.transpose(1,0).reshape((-1,*out.shape[2:]))
out.shape

torch.Size([50, 96, 16, 16])

Since we're working with time first and then batch but we receive the output of the lstm as batch first, we need to permute it again before collapsing the two dimensions.

# Main processing

No idea on how to make a 4x4 convolution without loosing resolution... use 5x5 with padding of 2 in the meanwhile.

Also I use a layer skip-connection, since it's not clear if they're using it for every layer or every block.

**How can they decide that the output is going to have 32 channels if is a residual block?**

At the moment residual layers are lacking BatchNormalization (not possible) and LayerNormalization (possible but not in the original implementation)

In [104]:
class ResidualConvLayer(nn.Module):
    
    def __init__(self, res, n_channels, kernel_size=3):
        super(ResidualConvLayer, self).__init__()
        
        padding = (kernel_size - 1) // 2
        assert (kernel_size - 1) % 2 == 0, 'Provide odd kernel size to use this layer'
        
        # pre-activations as in Identity Mappings in Deep Residual Networks https://arxiv.org/abs/1603.05027
        self.net = nn.Sequential(
                                nn.ReLU(),
                                nn.Conv2d(n_channels, n_channels, kernel_size, stride=1, padding=padding),
                                )
        
    def forward(self, x):
        x = self.net(x) + x
        return x 

In [106]:
class ResidualConvBlock(nn.Module):
    def __init__(self, in_channels, res):
        super(ResidualConvBlock, self).__init__()
        # pre-activations as in Identity Mappings in Deep Residual Networks https://arxiv.org/abs/1603.05027
        self.net = nn.Sequential(
            ResidualConvLayer(res, in_channels, kernel_size=5),
            ResidualConvLayer(res, in_channels, kernel_size=3),
            ResidualConvLayer(res, in_channels, kernel_size=3)
        )
        
    def forward(self, x):
        return self.net(x)

In [107]:
class DeepResidualBlock(nn.Module):
    def __init__(self, in_channels, res, n_blocks=3):
        super(DeepResidualBlock, self).__init__()
        self.net = nn.Sequential(
            *[ResidualConvBlock(in_channels, res) for _ in range(n_blocks)]
        )
        
    def forward(self, x):
        return self.net(x)

In [109]:
# we got 96 and 16, we should have 32 and 8 somehow
deep_residual_spatial = DeepResidualBlock(in_channels=hidden_channels[-1], res=new_res) 

In [110]:
spatial_features = deep_residual_spatial(out)
spatial_features.shape

In [112]:
class NonSpatialBlock(nn.Module):
    def __init__(self, in_channels, res):
        super(NonSpatialBlock, self).__init__()
        self.flattened_size = in_channels*(res**2)
        self.out_features = 512
        self.net = nn.Sequential(
            nn.Linear(self.flattened_size, 512),
            nn.ReLU(),
            nn.Linear(512,512),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = x.view(-1, self.flattened_size)
        return self.net(x)

In [151]:
class Inputs2D_Net(nn.Module):
    def __init__(
        self, 
        in_player, 
        n_actions, 
        embedding_dim=10
    ):
        super(Inputs2D_Net, self).__init__()
        self.out_features = 64 # in case needed from outside
        self.embedding = nn.Embedding(n_actions, embedding_dim, padding_idx=0) # no_op action mapped to 0
        self.MLP = nn.Sequential(
            nn.Linear(in_player+embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        
    def forward(self, player_info, last_action):
        """
        player_info: (batch, in_player)
        last_action: (batch,)
        """
        embedded_action = self.embedding(last_action).float()
        nonspatial_input = torch.cat([player_info, embedded_action], dim=1)
        out = self.MLP(nonspatial_input)
        return out

In [152]:
n_actions = len(action_names)
inputs2d_net = Inputs2D_Net(in_player, n_actions)

In [153]:
player_tensor = torch.tensor(state_dict['player_features']).view(1,-1).float()
last_action_tensor = torch.LongTensor([last_action_idx])
inputs2d = inputs2d_net(player_tensor, last_action_tensor)
inputs2d.shape

torch.Size([1, 64])

# Output processing

In [46]:
class ActorHead(nn.Module):
    def __init__(
        self, 
        n_actions, 
        n_shared_features=576
    ):
        super(ActorHead, self).__init__()
        self.n_actions = n_actions
        
        self.actor_net = nn.Sequential(
            nn.Linear(n_shared_features, 256),
            nn.ReLU(),
            nn.Linear(256, n_actions)
        )
    
    def forward(self, shared_features, mask):
        logits = self.actor_net(shared_features)
        log_probs = F.log_softmax(logits.masked_fill((mask).bool(), float('-inf')), dim=-1) 
        return log_probs

    
class CriticHead(nn.Module):
    def __init__(
        self, 
        n_shared_features=576
    ):
        super(CriticHead, self).__init__()
        
        self.critic_net = nn.Sequential(
            nn.Linear(n_shared_features, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    
    def forward(self, shared_features):
        return self.critic_net(shared_features)

These were some methods used in the batchedA2C to condition the parameter sampling with the sampled main action. We just need to decide where to define this embedding layer.

``` python
def _embed_action(self, action):
    a = torch.LongTensor(action).to(self.device)
    a = self.AC.embedding(a)
    return a

def _cat_action_to_spatial(self, embedded_action, spatial_repr):
    """ 
    Assume spatial_repr of shape (B, n_channels, res, res).
    Cast embedded_action from (B, embedd_dim) to (B, embedd_dim, res, res)
    Concatenate spatial_repr with the broadcasted embedded action along the channel dim.
    """
    res = spatial_repr.shape[-1]
    embedded_action = embedded_action.reshape((embedded_action.shape[:2]+(1,1,)))
    spatial_a = embedded_action.repeat(1,1,res,res)
    spatial_repr = torch.cat([spatial_repr, spatial_a], dim=1)
    return spatial_repr

def _cat_action_to_nonspatial(self, embedded_action, nonspatial_repr):
    """
    nonspatial_repr: (B, n_features)
    embedded_action: (B, embedd_dim)
    Concatenate them so that the result is of shape (B, n_features+embedd_dim)
    """
    return torch.cat([nonspatial_repr, embedded_action], dim=1)
```

In [60]:
in_channels=96 # last hidden channel of ConvLSTM
out_channels=16 # by default
kernel_size=4
stride=2
padding=1
conv_transp = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
conv_transp2 = nn.ConvTranspose2d(out_channels, out_channels, kernel_size, stride, padding)

In [62]:
test_tensor = torch.rand(10, in_channels, 8, 8).float()
out = conv_transp(test_tensor)
out.shape

torch.Size([10, 16, 16, 16])

In [63]:
out2 = conv_transp2(out)
out2.shape

torch.Size([10, 16, 32, 32])

In [57]:
new_res

16

Finally tile the embedded action and use SpatialIMPALA to sample the spatial arguments all in parallel.

Actually we don't really need to condition on the sampled action if we use separate networks for each parameter and is simpler to keep at least this part unchanged.

Same goes for the categorical parameters...

```python
class SpatialIMPALA(ParallelSpatialParameters):
    def __init__(self, n_channels, linear_size, n_arguments):
        super(SpatialIMPALA, self).__init__(n_channels, linear_size, n_arguments)
        
    def forward(self, x, x_first=True):
        B = x.shape[0]
        log_probs = self.get_log_probs(x)
        probs = torch.exp(log_probs)
        index = Categorical(probs).sample() # shape (B, n_args)
        # method inherited from ParallelSpatialParameters
        y, x = self.unravel_index(index, (self.size,self.size)) # both x and y of shape (B, n_args)
        if x_first:
            arg_lst = np.array([[xi.detach().numpy(),yi.detach().numpy()] for xi, yi in zip(x,y)])
        else:
            arg_lst = np.array([[yi.detach().numpy(),xi.detach().numpy()] for xi, yi in zip(x,y)])
        arg_lst = arg_lst.transpose(0,2,1)  #shape (batch, n_arguments, [y,x]) (or [x,y])                 
        log_prob = log_probs.view(B*self.n_args, self.size**2)[torch.arange(B*self.n_args), index.flatten()]\
                    .view(B, self.n_args) 
        return arg_lst, log_prob, index
    
    def get_log_probs(self, x):
        """Compute flatten log_probs for all arguments - shape: (batch_size, n_args, size**2)"""
        x = self.conv(x)
        x = x.reshape((x.shape[0],self.n_args,-1))
        log_probs = F.log_softmax(x, dim=(-1))
        return log_probs
```

# Previous flow

``` python
class IMPALA_AC(ParallelActorCritic)
    class ParallelActorCritic(nn.Module)
        self.spatial_features_net = spatial_model(**spatial_dict) # custom net defined in main
        self.nonspatial_features_net = nonspatial_model(**nonspatial_dict) # custom net defined in main
        self.actor = SharedActor(action_space, n_features)
        self.critic = SharedCritic(n_features)
        # these 2 are overridden by the IMPALA_AC init - I report directly the final version
        self.spatial_params_net = SpatialIMPALA(self.n_channels, 
                                                self.screen_res[0], 
                                                self.n_spatial_args
                                               )
        self.categorical_params_net = CategoricalIMPALA(self.n_features, 
                                                        self.categorical_sizes, 
                                                        self.n_categorical_args
                                                       )
    def actor_step(self, env_output)
        """
        Input
        -----
        env_output: dict with keys [spatial_state, player_state, action_mask]
        
        Return
        ------
        actor_output: dict with keys [log_prob, main_action, sc_env_action, 
                                        categorical_args_indexes, spatial_args_indexes]
        """
        
    def learner_step(self, batch):
        """
        Input
        -----
        batch: dict contianing tensors of shape (T, B, *other_dims), where 
                - T = unroll_length (number of steps in the trajectory)
                - B = batch_size

                Keys: 
                - spatial_state
                - player_state
                - spatial_state_trg
                - player_state_trg
                - action_mask
                - main_action
                - categorical_indexes
                - spatial_indexes
                
        Return
        ------
        dict(log_prob=log_prob.view(T,B), 
             baseline=baseline.view(T,B), 
             baseline_trg=baseline_trg.view(T,B), 
             entropy=entropy)
        """
       
```

After this everything should be the same once again. 

Changes to do:
- new ParallelActorCritic tailored to the new architecture (already all networks inside it)
- provide actor step with 'last_action_idx' and 'current_lstm_state'
- output 'new_lstm_state' in addition

In [66]:
from AC_modules.Networks import *

In [67]:
spatial_params_net = SpatialIMPALA_v2(n_channels=32, linear_size=32, n_arguments=2)

In [72]:
x_test = torch.rand(1,32,8,8)
arg_lst, log_prob, index = spatial_params_net(x_test)
print("arg_lst: ", arg_lst)
print("log_prob: ", log_prob)
print("index: ", index)

arg_lst:  [[[27 26]
  [ 4 27]]]
log_prob:  tensor([[-6.9320, -6.9223]], grad_fn=<ViewBackward>)
index:  tensor([[859, 868]])


In [86]:
class SpatialProcessingBlock(nn.Module):
    def __init__(
        self, 
        res, 
        screen_channels, 
        minimap_channels,
        encoding_channels,
        lstm_channels=96,
    ):
        assert res%4 == 0, "Provide an input with resolution divisible by 4"
        self.res = res
        self.new_res = res/4
        self.lstm_channels = lstm_channels
        self.screen_state_enc_net = StateEncodingConvBlock(res, screen_channels, encoding_channels)
        self.minimap_state_enc_net = StateEncodingConvBlock(res, minimap_channels, encoding_channels)
        self.conv_lsmt = ConvLSTM(
                     encoding_channels*2, 
                     lstm_channels, 
                     kernel_size=3, 
                     num_layers=1,
                     batch_first=False, # first time dimension, but return is with batch first
                     bias=True,
                     return_all_layers=True
                    )
        self.deep_residual_block = DeepResidualBlock(lstm_channels, self.new_res)
        self.nonspatial_block = NonSpatialBlock(lstm_channels, self.new_res)
        
    def forward(self, screen_layers, minimap_layers, hidden_state=None, cell_state=None):
        """
        Inputs
        ------
        screen_layers: (batch_size, screen_channels, res, res)
        minimap_layers: (batch_size, minimap_channels, res, res)
        hidden_state: (batch_size, lstm_channels, new_res, new_res)
        cell_state: (batch_size, lstm_channels, new_res, new_res)
        
        Intermediate variables
        ----------------------
        inputs_3D: (batch_size, encoding_channels*2, new_res, new_res)
        """
        # State Encoding
        screen_enc = self.screen_state_enc_net(screen_layers)
        minimap_enc = self.minimap_state_enc_net(minimap_layers)
        inputs_3D = torch.cat([screen_enc, minimap_enc], dim=1) # concatenate along channel dim
        
        # Memory Processing
        if hidden_state is None:
            layer_output_list, last_state_list = self.conv_lsmt(inputs_3D)
        else:
            assert cell_state is not None, \
                "hidden_state provided, but cell_state is None"
            assert hidden_state.shape == cell_state.shape, \
                ("hidden_state and cell_state have different shapes", hidden_state.shape, cell_state.shape)
            layer_output_list, last_state_list = self.conv_lsmt(inputs_3D,[(hidden_state, cell_state)])
        # output is 5d with batch-first 
        outputs_3D = layer_output_list[-1].transpose(1,0) # (t,b,c,w,h)
        # this works only if num_layers = 1
        hidden_state = last_state_list[-1][0]
        cell_state = last_state_list[-1][1]
        
        # Spatial and Non-Spatial Processing
        spatial_features = self.deep_residual_block(outputs_3D)
        nonspatial_features = self.nonspatial_block(outputs_3D)
        
        return spatial_features, nonspatial_features, hidden_state, cell_state

### All high-level networks with their parameters

``` python
class SpatialProcessingBlock(nn.Module):
    def __init__(
        self, 
        res, 
        screen_channels, 
        minimap_channels,
        encoding_channels,
        lstm_channels=96,
    )
    
class Inputs2D_Net(nn.Module):
    def __init__(
        self, 
        in_player, 
        n_actions, 
        embedding_dim=10
    )
        
class ActorHead(nn.Module):
    def __init__(
        self, 
        n_actions, 
        n_shared_features=576 
    )
    
class CriticHead(nn.Module):
    def __init__(
        self, 
        n_shared_features=576
    )
    
n_shared_features = SpatialProcessingBlock.NonSpatialBlock.out_features + Inputs2D_Net.out_features

class SpatialIMPALA_v2(SpatialIMPALA):
    def __init__(
        self, 
        n_channels, 
        linear_size, 
        n_arguments
    ):
    
n_channels = SpatialProcessingBlock.lstm_channels
linear_size = SpatialProcessingBlock.res
n_arguments = IMPALA_AC.n_spatial_args (still to be defined)

class CategoricalIMPALA(ParallelCategoricalNet):
    def __init__(
        self, 
        n_features, 
        sizes, 
        n_arguments
    ):
        
n_features = n_shared_features
sizes = IMPALA_AC.categorical_sizes
n_arguments = IMPALA_AC.n_categorical_args
```

In [87]:
from AC_modules.IMPALA import IMPALA_AC
from AC_modules.ActorCriticArchitecture import ParallelActorCritic

In [101]:
class ParallelActorCritic_v2(ParallelActorCritic, nn.Module):
    """
    Uses as hard-coded architecture the control architecture of paper 
    Relational Deep Reinforcement Learning [https://arxiv.org/abs/1806.01830]
    """
    def __init__(
        self, 
        env,
        action_names,
        screen_channels, 
        minimap_channels,
        encoding_channels,
        in_player
    ):
        # init nn.Module but not ParallelActorCritic 
        nn.Module.__init__(self)
        
        self.action_names = action_names
        self._set_action_table() # creates self.action_table
        self.screen_res = env.observation_spec()[0]['feature_screen'][1:]
        self.all_actions = env.action_spec()[0][1]
        self.all_arguments = env.action_spec()[0][0]
        self.action_space = len(action_names)
        
        # Networks
        self.spatial_processing_block = SpatialProcessingBlock(self.screen_res, 
                                                               screen_channels, 
                                                               minimap_channels,
                                                               encoding_channels
                                                              )
        self.inputs2d_net = Inputs2D_Net(in_player, self.action_space)
        
        n_shared_features = self.spatial_processing_block.NonSpatialBlock.out_features + \
                            self.inputs2d_net.out_features
        
        self.actor = ActorHead(self.action_space, n_shared_features)
        self.critic = CriticHead(n_shared_features)
        
        # take care of computing some useful arguments-related attributes before initializing argument networks
        self._init_arg_names()
        self._set_spatial_arg_mask()
        self._set_categorical_arg_mask()
        
        self.spatial_params_net = SpatialIMPALA_v2(self.spatial_processing_block.lstm_channels,
                                                   self.screen_res[0], 
                                                   self.n_spatial_args
                                                  )
        self.categorical_params_net = CategoricalIMPALA(n_shared_features, 
                                                        self.categorical_sizes, 
                                                        self.n_categorical_args
                                                       )
        
    def compute_features(
        self,
        screen_layers, 
        minimap_layers, 
        player_info, 
        last_action, 
        hidden_state, 
        cell_state
    ):
        results = self.spatial_processing_block(screen_layers, minimap_layers, hidden_state, cell_state)
        spatial_features, nonspatial_features, hidden_state, cell_state = results
        
        inputs_2D = self.inputs2d_net(player_info, last_action)
        
        shared_features = torch.cat([nonspatial_features, inputs_2D], dim=1)
        
        return spatial_features, shared_features, hidden_state, cell_state
    
    def pi(self, shared_features, action_mask):
        logits = self.actor(shared_features) 
        log_probs = F.log_softmax(logits.masked_fill((action_mask).bool(), float('-inf')), dim=-1) 
        return log_probs
    
    def V_critic(self, shared_features):
        return self.critic(shared_features)

In [None]:
class IMPALA_AC_v2(ParallelActorCritic_v2, IMPALA_AC):
    def __init__(
        self, 
        env,
        action_names,
        screen_channels, 
        minimap_channels,
        encoding_channels,
        in_player,
        device
    ):
        """
        Notes:
        
        - ParallelActorCritic_v2 takes the precedence over IMPALA_AC in the inheritance of 
        methods and attributes in case of conflict (e.g. self.pi and self.V_critic)
        
        - From IMPALA_AC we keep 
            sample_spatial_params, 
            sample_categorical_params, 
            sample_params,
            pad_to_len
            
        """
        ParallelActorCritic_v2.__init__(self, 
                                        env,
                                        action_names,
                                        screen_channels, 
                                        minimap_channels,
                                        encoding_channels,
                                        in_player
                                       )
        self.device = device 
        
        # number of categorical and spatial arguments that we expect at most for any action 
        # used for padding arguments and writing them into the buffers always with the same length
        
        self.max_num_categorical_args = int((self.categorical_arg_mask).sum(axis=1).max()) # should be 1
        self.max_num_spatial_args = int((self.spatial_arg_mask).sum(axis=1).max()) # should be 2 because of select_rect
        
    def actor_step(self, env_output, hidden_state=None, cell_state=None):
        screen_layers = env_output['screen_layers'].unsqueeze(0).to(self.device)
        minimap_layers = env_output['minimap_layers'].unsqueeze(0).to(self.device)
        player_state = env_output['player_state'].unsqueeze(0).to(self.device)
        last_action = env_output['last_action'].to(self.device) # add it to the output of the environment
        action_mask = env_output['action_mask'].to(self.device)

        results = self.compute_features(screen_layers, 
                                        minimap_layers, 
                                        player_state, 
                                        last_action, 
                                        hidden_state, 
                                        cell_state
                                       )
        spatial_features, shared_features, hidden_state, cell_state = results
        
        log_probs = self.pi(shared_features, action_mask)
        probs = torch.exp(log_probs)
        main_action_torch = Categorical(probs).sample() # check probs < 0?!
        main_action = main_action_torch.detach().cpu().numpy()
        log_prob = log_probs[range(len(main_action)), main_action]
        
        args, args_log_prob, args_indexes = self.sample_params(shared_features, spatial_features, main_action)
        assert args_log_prob.shape == log_prob.shape, ("Shape mismatch between arg_log_prob and log_prob ",\
                                                      args_log_prob.shape, log_prob.shape)
        log_prob = log_prob + args_log_prob
        
        action_id = np.array([self.action_table[act] for act in main_action])
        sc2_env_action = [sc_actions.FunctionCall(action_id[i], args[i]) for i in range(len(action_id))]
        
        actor_output = {'log_prob':log_prob.flatten(),
                        'main_action':main_action_torch.flatten(),
                        'sc_env_action':sc2_env_action,
                        'hidden_state':hidden_state,
                        'cell_state':cell_state
                        **args_indexes} # args_indexes = {'categorical_args_indexes', 'spatial_args_indexes'}
        
        return actor_output
    
    def learner_step(self, batch):
        """
        batch contains tensors of shape (T, B, *other_dims), where 
        - T = unroll_length (number of steps in the trajectory)
        - B = batch_size
        
        Keywords needed:
        - spatial_state
        - player_state
        - spatial_state_trg
        - player_state_trg
        - action_mask
        - main_action
        - categorical_indexes
        - spatial_indexes
        """
        screen_layers = batch['screen_layers'].to(self.device)
        minimap_layers = batch['minimap_layers'].to(self.device)
        player_state = batch['player_state'].to(self.device)
        
        screen_layers_trg = batch['screen_layers_trg'].to(self.device)
        minimap_layers_trg = batch['minimap_layers_trg'].to(self.device)
        player_state_trg = batch['player_state_trg'].to(self.device)
        last_action = batch['last_action'].to(self.device)
        
        action_mask = batch['action_mask'].to(self.device)
        main_action = batch['main_action'].to(self.device)
        categorical_indexes = batch['categorical_indexes'].to(self.device)
        spatial_indexes = batch['spatial_indexes'].to(self.device)
        hidden_states = batch['hidden_states'].to(self.device)
        cell_states = batch['cell_states'].to(self.device)
        
        if debug:
            print("screen_layers.shape ", screen_layers.shape)
            print("minimap_layers.shape ", minimap_layers.shape)
            print("player_state.shape ", player_state.shape)
            print("last_action.shape ", last_action.shape)
            print("action_mask.shape ", action_mask.shape)
            print("main_action.shape ", main_action.shape)
            print("categorical_indexes.shape ", categorical_indexes.shape)
            print("spatial_indexes.shape ", spatial_indexes.shape)
            print("hidden_states.shape ", hidden_states.shape)
            print("cell_states.shape ", cell_states.shape)
            print("self.device ", self.device)
            
        # useful dimensions
        T = spatial_state.shape[0]
        B = spatial_state.shape[1]
        res = self.screen_res[0]
        
        # merge all batch and time dimensions - 
        # I actually need the time dim of the variables entering the conv2d
        spatial_state = spatial_state.view((-1,)+spatial_state.shape[2:])
        player_state = player_state.view((-1,)+player_state.shape[2:])
        spatial_state_trg = spatial_state_trg.view((-1,)+spatial_state_trg.shape[2:])
        player_state_trg = player_state_trg.view((-1,)+player_state_trg.shape[2:])
        action_mask = action_mask.view((-1,)+action_mask.shape[2:])
        main_action = main_action.view((-1,)+main_action.shape[2:])
        categorical_indexes = categorical_indexes.view((-1,)+categorical_indexes.shape[2:])
        spatial_indexes = spatial_indexes.view((-1,)+spatial_indexes.shape[2:])
        
        if debug:
            print("After view: ")
            print("spatial_state.shape ", spatial_state.shape)
            print("player_state.shape ", player_state.shape)
            print("action_mask.shape ", action_mask.shape)
            print("main_action.shape ", main_action.shape)
            print("categorical_indexes.shape ", categorical_indexes.shape)
            print("spatial_indexes.shape ", spatial_indexes.shape)
            
        #print("learner action_mask: ", action_mask)
        #print("main_action: ", main_action)
        #print("action_mask[range(len(main_action)), main_action] ", action_mask[range(len(main_action)), main_action])
        log_probs, spatial_features, nonspatial_features = self.pi(spatial_state, player_state, action_mask)
        #print("learner log_probs: ", log_probs)
        log_prob = log_probs[range(len(main_action)), main_action]
        #print("learner log_prob: ", log_prob)
        entropy = torch.sum(torch.exp(log_prob) * log_prob) # negative entropy of the main actions
        
        categorical_log_probs = self.categorical_params_net.get_log_probs(nonspatial_features)\
            .view(B*T, self.n_categorical_args, self.categorical_params_net.max_size)
        # self.categorical_arg_mask numpy array -> convert it on the fly to cuda tensor
        # main_action cuda tensor
        # categorical_mask should be cuda tensor
        categorical_arg_mask = torch.tensor(self.categorical_arg_mask).to(self.device)
        categorical_mask = categorical_arg_mask[main_action,:].view(-1,self.n_categorical_args)
        categorical_indexes = categorical_indexes[categorical_indexes!=-1] # remove padding
        batch_index = categorical_mask.nonzero()[:,0]
        arg_index = categorical_mask.nonzero()[:,1]
        categorical_log_prob = categorical_log_probs[batch_index, arg_index, categorical_indexes]
        log_prob = log_prob.index_add(0, batch_index, categorical_log_prob)
        
        # repeat for spatial params
        spatial_log_probs = self.spatial_params_net.get_log_probs(spatial_features)\
            .view(B*T, self.n_spatial_args, res**2)
        spatial_arg_mask = torch.tensor(self.spatial_arg_mask).to(self.device)
        spatial_mask = spatial_arg_mask[main_action,:].view(-1,self.n_spatial_args)
        spatial_indexes = spatial_indexes[spatial_indexes!=-1] # remove padding
        batch_index = spatial_mask.nonzero()[:,0]
        arg_index = spatial_mask.nonzero()[:,1]
        spatial_log_prob = spatial_log_probs[batch_index, arg_index, spatial_indexes]
        log_prob = log_prob.index_add(0, batch_index, spatial_log_prob)
        
        baseline = self.V_critic(nonspatial_features=nonspatial_features)
        baseline_trg = self.V_critic(spatial_state=spatial_state_trg, player_state=player_state_trg)
        
        return dict(log_prob=log_prob.view(T,B), 
                    baseline=baseline.view(T,B), 
                    baseline_trg=baseline_trg.view(T,B), 
                    entropy=entropy)