In [38]:
import torch
from torch import nn
import torch.functional as F
from torch import optim

print(torch.__version__)

2.0.0+cu117


In [50]:
class Encoder(nn.Module):
    def __init__(self, hidden_dim: list = [500,200,100], activation: nn.Module = nn.Tanh(), use_batchnorm: bool = True):
        super().__init__()
        self.n_hidden = len(hidden_dim)
        self.hidden_dim: int = hidden_dim
        self.hidden: nn.ModuleList  = nn.ModuleList([nn.Linear(hidden_dim[i],hidden_dim[i+1]) for i in range(len(hidden_dim)-1)]) #store linear layers
        self.batchnorm: nn.ModuleList = nn.ModuleList([nn.BatchNorm1d(hidden_dim[i],hidden_dim[i+1]) for i in range(1,len(hidden_dim)-1)]) #store batchnorm layers
        self.activation: nn.Module = activation
        self.use_batchnorm: bool = use_batchnorm
        self.sequential = self._get_sequential()

    def _get_sequential(self): #compile to nn.Sequential
        res = nn.Sequential()
        for i, lin in enumerate(self.hidden):
            res.append(lin)
            res.append(self.activation)
            if self.use_batchnorm and i != self.n_hidden-2:
                res.append(self.batchnorm[i])
        return res

    def forward(self, x):
        out = x
        print(out.shape)
        for i, lin in enumerate(self.hidden):
            out = lin(out)
            out = self.activation(out)
            if self.use_batchnorm and i != self.n_hidden-2:
                out = self.batchnorm[i](out)
        return out

In [51]:
#Test
N = 20 #batch_size
input_dim = 500
factor_dim = 20
enc = Encoder(hidden_dim = [input_dim, 200, 100, factor_dim])
X = torch.randn(N, input_dim)
res = enc.forward(X)
print(enc)
for name, param in enc.named_parameters():
    print(name, param.size())

torch.Size([20, 500])
Encoder(
  (hidden): ModuleList(
    (0): Linear(in_features=500, out_features=200, bias=True)
    (1): Linear(in_features=200, out_features=100, bias=True)
    (2): Linear(in_features=100, out_features=20, bias=True)
  )
  (batchnorm): ModuleList(
    (0): BatchNorm1d(200, eps=100, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(100, eps=20, momentum=0.1, affine=True, track_running_stats=True)
  )
  (activation): Tanh()
  (sequential): Sequential(
    (0): Linear(in_features=500, out_features=200, bias=True)
    (1): BatchNorm1d(200, eps=100, momentum=0.1, affine=True, track_running_stats=True)
    (2): Linear(in_features=200, out_features=100, bias=True)
    (3): BatchNorm1d(100, eps=20, momentum=0.1, affine=True, track_running_stats=True)
    (4): Linear(in_features=100, out_features=20, bias=True)
  )
)
hidden.0.weight torch.Size([200, 500])
hidden.0.bias torch.Size([200])
hidden.1.weight torch.Size([100, 200])
hidden.1.bias torch.Siz