In [106]:
import torch
from torch import nn
import numpy as np


In [107]:
class NotesDataEncoder(nn.Module):
    def __init__(self, width=[700, 800], output_dim=1024):
        super().__init__()
        self.output_dim = output_dim
        width.append(output_dim)
        self.width = width
        self.layers = len(width) - 1
        
        self.encoder = nn.Sequential(
            *[nn.Sequential(nn.Linear(width[l_i], width[l_i + 1]), nn.ReLU()) for l_i in range(self.layers)],
            nn.LayerNorm(width[-1]))

    def forward(self, x: torch.Tensor):
        return self.encoder(x)

In [108]:
model = NotesDataEncoder()

In [109]:
model

NotesDataEncoder(
  (encoder): Sequential(
    (0): Sequential(
      (0): Linear(in_features=700, out_features=800, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=800, out_features=1024, bias=True)
      (1): ReLU()
    )
    (2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
)

In [110]:
model(torch.randn(10,700))

tensor([[ 1.8448,  0.4736, -0.3810,  ..., -0.6798, -0.4843, -0.6798],
        [ 4.1571, -0.4864,  0.1862,  ...,  1.1307,  0.8351, -0.6937],
        [ 2.7739, -0.6796,  1.3654,  ..., -0.6796, -0.4673, -0.6796],
        ...,
        [ 1.8897,  2.2067,  2.4408,  ..., -0.6936,  1.0803, -0.6936],
        [ 1.1715,  0.8213,  0.9643,  ...,  0.5293, -0.6456,  1.2543],
        [ 2.0443, -0.6871, -0.6871,  ...,  1.7267,  0.8826, -0.6871]],
       grad_fn=<NativeLayerNormBackward0>)

In [111]:
model.encoder[0][0].weight

Parameter containing:
tensor([[-0.0015, -0.0163,  0.0274,  ..., -0.0345,  0.0230,  0.0330],
        [ 0.0157, -0.0021, -0.0176,  ...,  0.0221, -0.0014, -0.0250],
        [ 0.0006, -0.0147, -0.0312,  ..., -0.0358, -0.0062,  0.0112],
        ...,
        [ 0.0271,  0.0302,  0.0223,  ...,  0.0304, -0.0219,  0.0056],
        [-0.0350,  0.0167, -0.0313,  ...,  0.0232,  0.0192, -0.0195],
        [-0.0288,  0.0093,  0.0232,  ...,  0.0202, -0.0332,  0.0043]],
       requires_grad=True)

In [112]:
std = 1.
for i in range(len(model.encoder)-1):
    nn.init.normal_(model.encoder[i][0].weight, std=std)

In [113]:
len(model.encoder)-1

2

In [114]:
model.encoder[i][0]

Linear(in_features=800, out_features=1024, bias=True)