# Params extractor

In [44]:
import torch
import torch.nn as nn

## Constants

In [45]:
GLOVE_DIM = 300

EMBED_DIM = GLOVE_DIM
DROPOUT = 0.5


NUM_HEADS = 15  # EMBED_DIM (300) should be divisible by NUM_HEADS
LSTM_LAYERS = 1
LSTM_H_DIM = EMBED_DIM


OUT_CHANNELS = 3
KERNEL_SIZE = 4
MAX_POOL_KERNEL = 2
MAX_POOL_STRIDE = 2
MAX_LEN = 300

## Utils

In [46]:
def get_trainable_params(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Models

In [47]:
class LSTMNet(nn.Module):
    def __init__(
        self,
        attention: bool,
        bidirectional: bool,
        output_dim: int = 1,
        hidden_dim: int = 128,
    ) -> None:
        super(LSTMNet, self).__init__()
        self._attention = attention
        self.bidirectional = bidirectional

        if self._attention:
            self.attention = nn.MultiheadAttention(EMBED_DIM, NUM_HEADS)

        self.lstm = nn.LSTM(
            EMBED_DIM,
            LSTM_H_DIM,
            num_layers=LSTM_LAYERS,
            bidirectional=self.bidirectional,
        )

        self.rnet = nn.Sequential(
            nn.Flatten(start_dim=1),
            nn.Linear(
                LSTM_H_DIM if not self.bidirectional else LSTM_H_DIM * 2, hidden_dim
            ),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid(),
        )

    def forward(self, x):
        if self._attention:
            x, _ = self.attention(x, x, x)
        _, (x, __) = self.lstm(x)
        x = torch.swapaxes(x, 0, 1)
        return self.rnet(x).squeeze()

In [48]:
class CNN(nn.Module):
    def __init__(self, cnn_dim: int, output_dim: int = 1, hidden_dim: int = 128) -> None:
        super(CNN, self).__init__()
        self.cnn_dim = cnn_dim

        self.conv_out_dim = (MAX_LEN - KERNEL_SIZE + 1) * OUT_CHANNELS
        self.cnn_out_dim = (
            int((self.conv_out_dim - MAX_POOL_KERNEL) / MAX_POOL_STRIDE) + 1
        )

        if self.cnn_dim == 1:
            self.cnn = nn.Sequential(
                nn.Conv1d(
                    in_channels=EMBED_DIM,
                    out_channels=OUT_CHANNELS,
                    kernel_size=KERNEL_SIZE,
                ),
                nn.ReLU(),
                nn.MaxPool1d(MAX_POOL_KERNEL, stride=MAX_POOL_STRIDE),
            )
        else:
            self.cnn = nn.Sequential(
                nn.Conv2d(
                    in_channels=EMBED_DIM,
                    out_channels=OUT_CHANNELS,
                    kernel_size=(2, KERNEL_SIZE),
                ),
                nn.ReLU(),
                nn.Flatten(start_dim=1, end_dim=2),
                nn.MaxPool1d(MAX_POOL_KERNEL, stride=MAX_POOL_STRIDE),
            )

        self.cnet = nn.Sequential(
            nn.Linear(self.cnn_out_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = torch.swapaxes(x, 1, 2) if self.cnn_dim == 1 else torch.swapaxes(x, 1, 3)

        x = self.cnn(x)
        x = torch.flatten(x, 1)
        return self.cnet(x).squeeze()

In [49]:
class RNetNN(nn.Module):
    def __init__(
        self, input_dim: int = 600, output_dim: int = 1, hidden_dim: int = 128
    ) -> None:
        super(RNetNN, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)


class SRModelNN(nn.Module):
    def __init__(
        self, input_dim: int = 300, hidden_size: int = 300, num_layers: int = LSTM_LAYERS
    ) -> None:
        super(SRModelNN, self).__init__()

        self.net = nn.LSTM(
            input_dim,
            hidden_size,
            num_layers=num_layers,
            bidirectional=False,
        )

    def forward(self, *x):
        return self.net(*x)

In [50]:
class A2CNet(nn.Module):
    def __init__(
        self, input_dim: int = 900, output_dim: int = 2, hidden_dim: int = 16
    ) -> None:
        super(A2CNet, self).__init__()

        self.body = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )

        self.policy = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

        self.value = nn.Sequential(
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x):
        body_out = self.body(x)
        return self.policy(body_out), self.value(body_out)


class PGN(nn.Module):
    def __init__(
        self, input_dim: int = 900, output_dim: int = 2, hidden_dim: int = 16
    ) -> None:
        super(PGN, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)


class DQN(nn.Module):
    def __init__(
        self, input_dim: int = 900, output_dim: int = 2, hidden_dim: int = 16
    ) -> None:
        super(DQN, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)

## Params

In [51]:
print(f"LSTM: {get_trainable_params(LSTMNet(attention=False, bidirectional=False))}")
print(f"biLSTM: {get_trainable_params(LSTMNet(attention=False, bidirectional=True))}")
print(
    f"LSTM+Attention: {get_trainable_params(LSTMNet(attention=True, bidirectional=False))}"
)
print(
    f"biLSTM+Attention: {get_trainable_params(LSTMNet(attention=True, bidirectional=True))}"
)

LSTM: 761057
biLSTM: 1521857
LSTM+Attention: 1122257
biLSTM+Attention: 1883057


In [52]:
print(f"CNN1D: {get_trainable_params(CNN(cnn_dim=1))}")
print(f"CNN2D: {get_trainable_params(CNN(cnn_dim=2))}")

CNN1D: 60820
CNN2D: 64420


In [53]:
print(
    f"REINFORCE: {get_trainable_params(RNetNN()) + get_trainable_params(SRModelNN()) + get_trainable_params(PGN())}, {get_trainable_params(PGN())}"
)
print(
    f"A2C: {get_trainable_params(RNetNN()) + get_trainable_params(SRModelNN()) + get_trainable_params(A2CNet())}, {get_trainable_params(A2CNet())}"
)
print(
    f"DQN: {get_trainable_params(RNetNN()) + get_trainable_params(SRModelNN()) + get_trainable_params(DQN())}, {get_trainable_params(DQN())}"
)

REINFORCE: 813907, 14450
A2C: 814196, 14739
DQN: 813907, 14450
