In [3]:
from typing import List, Optional, Tuple

def rogue_net_param_est(d_model: int, n_layer: int, entity_feats: List[int]):
    # Attention params (assumes no relative position encoding)
    attn_params = 4 * d_model * d_model
    # MLP params
    mlp_params = 8 * d_model * d_model
    # Embedding params
    embedding_params = sum(entity_feats) * d_model

    return (attn_params + mlp_params) * n_layer + embedding_params

def rogue_net_flops_est(d_model: int, n_layer: int, entity_feats_count: List[Tuple[int, int]], max_entity_count: Optional[int] = None):
    n_entity = sum([n for (_, n) in entity_feats_count])
    # Conservatively use maximum entity count when computing flops, using mean would be understimate
    attn_flops = n_layer * (4 * d_model * d_model * n_entity + d_model * n_entity * (max_entity_count or n_entity))
    # MLP params
    mlp_flops = n_layer * n_entity * 8 * d_model * d_model
    # Embedding params
    embedding_flops = sum(n * d_model * entity_feats for (entity_feats, n) in entity_feats_count)
    total_flops = attn_flops + mlp_flops + embedding_flops
    # print(f"{attn_flops / total_flops * 100:.2f}% attention flops")
    return total_flops


def impala_resblock_param_est(c):
    # Two 3x3 stride 1 padding 1 convolutions
    return 2 * 3 * 3 * c * c

def impala_conv_seq_param_est(c0, c1):
    # One 3x3 convolution projecting channels from c0 to c1
    proj_params = 3 * 3 * c0 * c1
    # Two residual blocks
    resblock_params = 2 * impala_resblock_param_est(c1)
    return proj_params + resblock_params

def impala_param_est(h, w, c):
    channels = [c, 16, 32, 32]
    param_est = 0
    for (c0, c1) in zip(channels[:-1], channels[1:]):
        param_est += impala_conv_seq_param_est(c0, c1)
        h = (h + 1) // 2
        w = (w + 1) // 2
        c = c1
    # Final projection from (h,w,c) to 256
    param_est += h * w * c * 256
    return param_est

def impala_resblock_flops_est(h, w, c):
    return impala_resblock_param_est(c) * h * w

def impala_conv_seq_flops_est(h, w, c0, c1):
    # One 3x3 convolution projecting channels from c0 to c1
    proj_flops = 3 * 3 * h * w * c0 * c1
    h = (h + 1) // 2
    w = (w + 1) // 2
    # Two residual blocks
    resblock_flops = 2 * impala_resblock_flops_est(h, w, c1)
    return proj_flops + resblock_flops

def impala_flops_est(h, w, c):
    channels = [c, 16, 32, 32]
    flops_est = 0
    for (c0, c1) in zip(channels[:-1], channels[1:]):
        print(impala_conv_seq_flops_est(h, w, c0, c1))
        flops_est += impala_conv_seq_flops_est(h, w, c0, c1)
        h = (h + 1) // 2
        w = (w + 1) // 2
        c = c1
    # Final projection from (h,w,c) to 256
    flops_est += h * w * c * 256
    return flops_est

# Print table for d_model={16, 32, 64}, n_layer=2, entity_feats=8*[32], 1*[16]
# print("RogueNet param and flops est")
# print("d_model\tseqlen\tentity_feats\tparams\tflops")
# for d_model in [16, 32, 64]:
#     for seqlen in [10]:
#         for entity_feats in [1*[16], 8*[32]]:
#             print(f"{d_model}\t{seqlen}\t{sum(entity_feats)}\t{rogue_net_param_est(d_model, 2, entity_feats)}\t{rogue_net_flops_est(d_model, 2, entity_feats, seqlen)}")

print("Impala param est:", impala_param_est(64, 64, 3))
print("Impala flops est:", impala_flops_est(64, 64, 3))

# print("RogueNet(d_model=32)/IMPALA param est:", impala_param_est(64, 64, 3) / rogue_net_param_est(32, 2, [256]))
# print("RogueNet(d_model=32)/IMPALA flops est:", impala_flops_est(64, 64, 3) / rogue_net_flops_est(32, 2, [256], 10))

def print_param_flops_ratio(tile_feats: int, tile_count: int, entity_feats: int, entity_count: int, entity_types: int, max_entity_count: Optional[int] = None):
    feats = [tile_feats] + [entity_feats] * entity_types
    entity_feats_count = [(tile_feats, tile_count), (entity_feats, entity_count)]
    impala_params = impala_param_est(64, 64, 3)
    impala_flops = impala_flops_est(64, 64, 3)
    rogue_net_params = rogue_net_param_est(16, 2, feats)
    rogue_net_flops = rogue_net_flops_est(16, 2, entity_feats_count, max_entity_count)
    print(f"params: {rogue_net_params} ({impala_params / rogue_net_params:.0f}x)")
    print(f"flops: {rogue_net_flops} ({impala_flops / rogue_net_flops:.0f}x)")
    # print("RogueNet(d_model=32)/IMPALA param est:", impala_param_est(64, 64, 3) / rogue_net_param_est(32, 2, feats))
    # print("RogueNet(d_model=32)/IMPALA flops est:", impala_flops_est(64, 64, 3) / rogue_net_flops_est(32, 2, entity_feats_count, max_entity_count))


print("Coinrun")
print_param_flops_ratio(tile_feats=7 * 25, tile_count=25, entity_feats=31 + 7, entity_count=20, entity_types=8, max_entity_count=200)

print("Miner")
print_param_flops_ratio(tile_feats=6 * 25, tile_count=4, entity_feats=31+1, entity_count=2, entity_types=3)

print("BossFight")
print_param_flops_ratio(tile_feats=0, tile_count=0, entity_feats=31 + 23, entity_count=80, entity_types=9, max_entity_count=200)

print("FruitBot")
print_param_flops_ratio(tile_feats=0, tile_count=0, entity_feats=31 + 3, entity_count=50, entity_types=8)

print("Dodgeball")
print_param_flops_ratio(tile_feats=0, tile_count=0, entity_feats=31 + 7, entity_count=7, entity_types=8)

print("Typical")
print_param_flops_ratio(tile_feats=0, tile_count=0, entity_feats=50, entity_count=10, entity_types=5)

Impala param est: 621488
11206656
14155776
4718592
Impala flops est: 30605312
Coinrun
11206656
14155776
4718592
params: 13808 (45x)
flops: 646640 (47x)
Miner
11206656
14155776
4718592
params: 10080 (62x)
flops: 48640 (629x)
BossFight
11206656
14155776
4718592
params: 13920 (45x)
flops: 1072640 (29x)
FruitBot
11206656
14155776
4718592
params: 10496 (59x)
flops: 414400 (74x)
Dodgeball
11206656
14155776
4718592
params: 11008 (56x)
flops: 48832 (627x)
Typical
11206656
14155776
4718592
params: 10144 (61x)
flops: 72640 (421x)


## Sanity checks

In [None]:
import torch.nn as nn
import torch
import numpy as np
from torch.distributions.categorical import Categorical

# taken from https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/models/impala_cnn_torch.py
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)

    def forward(self, x):
        inputs = x
        x = nn.functional.relu(x)
        x = self.conv0(x)
        x = nn.functional.relu(x)
        x = self.conv1(x)
        return x + inputs


class ConvSequence(nn.Module):
    def __init__(self, input_shape, out_channels):
        super().__init__()
        self._input_shape = input_shape
        self._out_channels = out_channels
        self.conv = nn.Conv2d(in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1)
        self.res_block0 = ResidualBlock(self._out_channels)
        self.res_block1 = ResidualBlock(self._out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = self.res_block0(x)
        x = self.res_block1(x)
        # assert x.shape[1:] == self.get_output_shape()
        return x

    def get_output_shape(self):
        _c, h, w = self._input_shape
        return (self._out_channels, (h + 1) // 2, (w + 1) // 2)

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class ImpalaAgent(nn.Module):
    def __init__(self, obs_space_shape, n_action):
        super().__init__()
        h, w, c = obs_space_shape
        shape = (c, h, w)
        conv_seqs = []
        for out_channels in [16, 32, 32]:
            conv_seq = ConvSequence(shape, out_channels)
            shape = conv_seq.get_output_shape()
            conv_seqs.append(conv_seq)
        conv_seqs += [
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=256),
            nn.ReLU(),
        ]
        self.network = nn.Sequential(*conv_seqs)
        self.actor = layer_init(nn.Linear(256, 15), std=0.01)
        self.critic = layer_init(nn.Linear(256, 1), std=1)

    def get_value(self, x):
        return self.critic(self.network(x.permute((0, 3, 1, 2)) / 255.0))  # "bhwc" -> "bchw"

    def get_action_and_value(self, x):
        hidden = self.network(x.permute((0, 3, 1, 2)) / 255.0)  # "bhwc" -> "bchw"
        logits = self.actor(hidden)
        #probs = Categorical(logits=logits)
        value = self.critic(hidden)
        #entropy = probs.entropy()
        # action = probs.sample()
        # return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)
        return value
    
    def forward(self, x):
        return self.get_action_and_value(x)

impala_agent = ImpalaAgent((64, 64, 3), n_action=15)
params = sum(p.numel() for p in impala_agent.parameters() if p.requires_grad)
print("ImpalaAgent params:", params)


from pthflops import count_ops

# Create a network and a corresponding input
device = 'cuda:0'
model = impala_agent.to(device)
inp = torch.rand(1,64,64,3).to(device)

# Count the number of FLOPs
count_ops(model, inp)

ImpalaAgent params: 626256
Operation                    OPS       
---------------------------  --------  
network_0_conv               1835008   
network_0_res_block0_conv0   2375680   
network_0_res_block0_conv1   2375680   
add                          32768     
network_0_res_block1_conv0   2375680   
network_0_res_block1_conv1   2375680   
add_1                        32768     
network_1_conv               4751360   
network_1_res_block0_conv0   2367488   
network_1_res_block0_conv1   2367488   
add_2                        16384     
network_1_res_block1_conv0   2367488   
network_1_res_block1_conv1   2367488   
add_3                        16384     
network_2_conv               2367488   
network_2_res_block0_conv0   591872    
network_2_res_block0_conv1   591872    
add_4                        4096      
network_2_res_block1_conv0   591872    
network_2_res_block1_conv1   591872    
add_5                        4096      
network_4                    4096      
network_5    

(30933776,
 [['network_0_conv', 1835008],
  ['network_0_res_block0_conv0', 2375680],
  ['network_0_res_block0_conv1', 2375680],
  ['add', 32768],
  ['network_0_res_block1_conv0', 2375680],
  ['network_0_res_block1_conv1', 2375680],
  ['add_1', 32768],
  ['network_1_conv', 4751360],
  ['network_1_res_block0_conv0', 2367488],
  ['network_1_res_block0_conv1', 2367488],
  ['add_2', 16384],
  ['network_1_res_block1_conv0', 2367488],
  ['network_1_res_block1_conv1', 2367488],
  ['add_3', 16384],
  ['network_2_conv', 2367488],
  ['network_2_res_block0_conv0', 591872],
  ['network_2_res_block0_conv1', 591872],
  ['add_4', 4096],
  ['network_2_res_block1_conv0', 591872],
  ['network_2_res_block1_conv1', 591872],
  ['add_5', 4096],
  ['network_4', 4096],
  ['network_5', 524544],
  ['network_6', 512],
  ['actor', 3855],
  ['critic', 257]])

In [2]:
from rogue_net.rogue_net import RogueNet, RogueNetConfig
from entity_gym.env import ObsSpace, Entity, CategoricalActionSpace

net = RogueNet(
    RogueNetConfig(
        d_model=32,
        n_layer=2,
    ),
    obs_space=ObsSpace({
        f"e{i}": Entity(features=[f"f{f}" for f in range(20)]) for i in range(10)
    }),
    action_space={"a": CategoricalActionSpace([f"a{i}" for i in range(15)])}
)

params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("RogueNet params:", params)

RogueNet params: 25903
