In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from omegaconf import DictConfig, OmegaConf


In [48]:
def get_activation(activation_name: str):
    if activation_name == 'tanh':
        activation = nn.Tanh
    elif activation_name == 'relu':
        activation = nn.ReLU
    elif activation_name == 'leakyrelu':
        activation = nn.LeakyReLU
    elif activation_name == "prelu":
        activation = nn.PReLU
    elif activation_name == 'gelu':
        activation = nn.GELU
    elif activation_name == 'sigmoid':
        activation = nn.Sigmoid
    elif activation_name in [ None, 'id', 'identity', 'linear', 'none' ]:
        activation = nn.Identity
    elif activation_name == 'elu':
        activation = nn.ELU
    elif activation_name in ['swish', 'silu']:
        activation = nn.SiLU
    elif activation_name == 'softplus':
        activation = nn.Softplus
    else:
        raise NotImplementedError("hidden activation '{}' is not implemented".format(activation_name))
    return activation

# RNN Encoder Speed Test

In [51]:
class RNNPolicyGenerator(nn.Module):
    def __init__(self, hparams_cfg: DictConfig):
        super().__init__()
        self.input_shared_dim = hparams_cfg.shared_networks.input_shared_dim
        self.output_shared_dim = hparams_cfg.shared_networks.output_shared_dim
        self.final_act = hparams_cfg.policy_generator.final_activation
        self.initializer = hparams_cfg.policy_generator.initializer
        self.final_act_func = get_activation(self.final_act)()

        # GRU input: (N,L,H_in) output: (N,L,H_out)
        self.rnn = nn.GRU(input_size=self.output_shared_dim, 
                          hidden_size=self.output_shared_dim,
                          num_layers=1,
                          bias=False,
                          batch_first=True,
                          bidirectional=True
                          )
        self.batch_norm = nn.BatchNorm1d(self.output_shared_dim)
        
                                        
        
    def forward(self, action_dim, shared_feature):
        # shared_feature: [batch_size, shared_dim]
        batch_size = shared_feature.size(0)
        shared_feature = shared_feature.unsqueeze(1) 
        # shared_feature: [batch_size, 1, shared_dim]
        features = [shared_feature for i in range(action_dim)]
        shared_feature = torch.cat(features, dim=1) 
         # shared_feature: [batch_size, action_dim, output_shared_dim]
        weights, _ = self.rnn(shared_feature) 
        # weights shape: [batch_size, action_dim, output_shared_dim]
        weights = torch.permute(weights, (0, 2, 1))
        # weights shape: [batch_size, output_shared_dim, action_dim]
        weights = weights.reshape(batch_size, 2, self.output_shared_dim, action_dim)
        weights = weights.mean(dim=1, keepdim=False)
        weights = self.batch_norm(weights)
        weights = self.final_act_func(weights)
        return weights
    


In [52]:
%%timeit
hparams_cfg = OmegaConf.create({
    "encoder_generator": {
        "weight_seed": 42,
        "final_activation": "identity",
    },
    "policy_generator": {
        "final_activation": "identity",
        "initializer": "uniform",
        },
    "shared_networks": {
        'input_shared_dim': 64,
        'output_shared_dim': 64,
        'hidden_dim': 64,
        'num_hidden_layers': 2,
        'activation': "relu"}
})
rnn_pg = RNNPolicyGenerator(hparams_cfg)

shared_feature = torch.ones([128, 64])
rnn_weights = rnn_pg(6, shared_feature)
print(shared_feature.shape)
print(rnn_weights.shape)


torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([1

In [40]:
class RNNPolicyGenerator(nn.Module):
    def __init__(self, hparams_cfg: DictConfig):
        super().__init__()
        self.input_shared_dim = hparams_cfg.shared_networks.input_shared_dim
        self.output_shared_dim = hparams_cfg.shared_networks.output_shared_dim
        self.final_act = hparams_cfg.policy_generator.final_activation
        self.initializer = hparams_cfg.policy_generator.initializer
        self.final_act_func = get_activation(self.final_act)()

        # GRU input: (N,L,H_in) output: (N,L,H_out)
        self.rnn = nn.GRU(input_size=self.output_shared_dim, 
                          hidden_size=self.output_shared_dim,
                          num_layers=1,
                          bias=False,
                          batch_first=True,
                          bidirectional=True
                          )
                                        
        
    def forward(self, action_dim, shared_feature):
        # shared_feature: [batch_size, shared_dim]
        batch_size = shared_feature.size(0)
        shared_feature = shared_feature.unsqueeze(1) 
        # shared_feature: [batch_size, 1, shared_dim]
        # shared_feature = [shared_feature for i in range(action_dim)]
        shared_feature = shared_feature.expand(batch_size, action_dim, self.output_shared_dim)
        # tensor = torch.ones(batch_size, action_dim, self.output_shared_dim)
        # shared_feature = shared_feature.expand_as(tensor)
        
        # shared_feature = torch.cat(features, dim=1) 
         # shared_feature: [batch_size, action_dim, output_shared_dim]
        weights, _ = self.rnn(shared_feature) 
        # weights shape: [batch_size, action_dim, output_shared_dim]
        weights = torch.permute(weights, (0, 2, 1))
        # weights shape: [batch_size, output_shared_dim, action_dim]
        weights = weights.reshape(batch_size, 2, self.output_shared_dim, action_dim)
        weights = weights.mean(dim=1, keepdim=False)
        weights = self.final_act_func(weights)
        return weights
    


In [42]:
%%timeit
hparams_cfg = OmegaConf.create({
    "encoder_generator": {
        "weight_seed": 42,
        "final_activation": "identity",
    },
    "policy_generator": {
        "final_activation": "identity",
        "initializer": "uniform",
        },
    "shared_networks": {
        'input_shared_dim': 64,
        'output_shared_dim': 64,
        'hidden_dim': 64,
        'num_hidden_layers': 2,
        'activation': "relu"}
})
rnn_pg = RNNPolicyGenerator(hparams_cfg)

shared_feature = torch.ones([128, 64])
rnn_weights = rnn_pg(6, shared_feature)
print(shared_feature.shape)
print(rnn_weights.shape)


torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([128, 64])
torch.Size([128, 64, 6])
torch.Size([1

In [45]:
import numpy as np
params = sum([np.prod(p.size()) for p in rnn_pg.parameters()])

In [46]:
params

49152

In [3]:

class TransformerModel(nn.Module):
    def __init__(
        self,
        feature_size,
        output_size,
        d_model,
        n_head,
        dim_feed_forward,
        nlayers,
        dropout=0.5,
        condition_decoder=False,
        transformer_norm=False,
    ):
        """This model is built upon https://pytorch.org/tutorials/beginner/transformer_tutorial.html"""
        super(TransformerModel, self).__init__()
        self.model_type = "Transformer"
        encoder_layers = TransformerEncoderLayer(d_model, n_head, dim_feed_forward, dropout)
        self.transformer_encoder = TransformerEncoder(
            encoder_layers,
            nlayers,
            norm=nn.LayerNorm(d_model) if transformer_norm else None,
        )
        self.encoder = nn.Linear(feature_size, d_model)
        self.d_model = d_model
        self.condition_decoder = condition_decoder
        self.decoder = nn.Linear(
            d_model + feature_size if condition_decoder else d_model, output_size
        )
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        encoded = self.encoder(src) * math.sqrt(self.d_model)
        output = self.transformer_encoder(encoded)
        if self.condition_decoder:
            output = torch.cat([output, src], axis=2)

        output = self.decoder(output)

        return output

In [4]:
tf = TransformerModel(
        feature_size=128,
        output_size=10,
        d_model=256,
        n_head=4,
        dim_feed_forward=64,
        nlayers=2,
        dropout=0.5,
        condition_decoder=False,
        transformer_norm=False,
    )

In [8]:
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, batch_first=True)
src = torch.rand((32, 10, 512))
tgt = torch.rand((32, 20, 512))
out = transformer_model(src, tgt)

In [9]:
out.shape

torch.Size([32, 20, 512])