In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

In [None]:
class FeedForward(nn.Module):
    def __init__(self, embedding_dim):
        super(FeedForward, self).__init__()

        self.layers: nn.Sequential = nn.Sequential(
            nn.Linear(
                embedding_dim,
                embedding_dim
            ),
            nn.GELU(),
            nn.Linear(
                embedding_dim,
                embedding_dim
            )
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

In [None]:
class Conv1DBlock(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            kernel_size: int,
            stride: int,
            dropout_rate: float
            ):
        super(Conv1DBlock, self).__init__()

        self.conv = nn.Conv1d(in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=kernel_size, stride=stride, padding='same')
        self.batch_norm = nn.BatchNorm1d(num_features=embedding_dim)
        self.relu = nn.ReLU()
        self.dropout_conv = nn.Dropout(dropout_rate)
        self.ff = FeedForward(embedding_dim=embedding_dim)
        self.dropout_ff = nn.Dropout(dropout_rate)

    def forward(self, X):

        shortcut = X

        X = X.transpose(1,2)
        X = self.conv(X)
        X = self.batch_norm(X)
        X = self.relu(X)
        X = self.dropout_conv(X)
        X = X.transpose(1,2)
        X = X + shortcut

        X = self.ff(X)
        X = self.dropout_ff(X)
        X = X + shortcut
        
        return X


In [63]:
embeddings = 4
kernel_size = 3
stride = 1
dropout_rate = 0.1

In [None]:
batch_size = 1
steps = 5
X = torch.rand(batch_size, steps, embeddings)
X

In [None]:
class Output(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            vocabulary_size: int
            ):
        super(Output, self).__init__()

        self.output = nn.Sequential(OrderedDict([
            ('out_linear_1', nn.Linear(
                embedding_dim, embedding_dim * 2
            )),
            ('out_act_1', nn.GELU()),
            ('out_linear_2', nn.Linear(
                embedding_dim * 2, embedding_dim * 4
            )),
            ('out_act_2', nn.GELU()),
            ('out_linear_3', nn.Linear(
                embedding_dim * 4, embedding_dim * 4
            )),
            ('out_act_3', nn.GELU()),
            ('out_linear_4', nn.Linear(
                embedding_dim * 4, embedding_dim * 2
            )),
            ('out_act_4', nn.GELU()),
        ]))

        self.output: nn.Linear = nn.Linear(embedding_dim * 2, vocabulary_size),

In [70]:
n_layers = 4
layers = []
last = 0
for n in range(1, n_layers + 2):
    layer = nn.Linear(embeddings * (2**(n-1)), embeddings * ((2 ** n)))
    layers.append(layer)
    print(f"emb * {2**(n-1)}, emb * {2**n}")
    last = 2 ** n
layer = nn.Linear(embeddings * last, embeddings * last)
layers.append(layer)
print(f"emb * {last}, emb * {last}")
for n in range(n_layers+1, 0, -1):
    layer = nn.Linear(embeddings * (2**n), embeddings * ((2 ** (n-1))))
    print(f"emb * {2**n}, emb * {2**(n-1)}")
    layers.append(layer)

print(layers)

emb * 1, emb * 2
emb * 2, emb * 4
emb * 4, emb * 8
emb * 8, emb * 16
emb * 16, emb * 32
emb * 32, emb * 32
emb * 32, emb * 16
emb * 16, emb * 8
emb * 8, emb * 4
emb * 4, emb * 2
emb * 2, emb * 1
[Linear(in_features=4, out_features=8, bias=True), Linear(in_features=8, out_features=16, bias=True), Linear(in_features=16, out_features=32, bias=True), Linear(in_features=32, out_features=64, bias=True), Linear(in_features=64, out_features=128, bias=True), Linear(in_features=128, out_features=128, bias=True), Linear(in_features=128, out_features=64, bias=True), Linear(in_features=64, out_features=32, bias=True), Linear(in_features=32, out_features=16, bias=True), Linear(in_features=16, out_features=8, bias=True), Linear(in_features=8, out_features=4, bias=True)]


In [71]:
seq = nn.Sequential(*layers)
seq

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): Linear(in_features=8, out_features=16, bias=True)
  (2): Linear(in_features=16, out_features=32, bias=True)
  (3): Linear(in_features=32, out_features=64, bias=True)
  (4): Linear(in_features=64, out_features=128, bias=True)
  (5): Linear(in_features=128, out_features=128, bias=True)
  (6): Linear(in_features=128, out_features=64, bias=True)
  (7): Linear(in_features=64, out_features=32, bias=True)
  (8): Linear(in_features=32, out_features=16, bias=True)
  (9): Linear(in_features=16, out_features=8, bias=True)
  (10): Linear(in_features=8, out_features=4, bias=True)
)

In [62]:
n_layers = 3

last = 0
for n in range(1, n_layers + 2):
    print(2 ** (n-1), 2 ** n)
    last = 2 ** n

print(last, last)
for n in range(n_layers+1, 0, -1):
    print(2 ** n, 2 ** (n-1))


1 2
2 4
4 8
8 16
16 16
16 8
8 4
4 2
2 1
