In [20]:
import torch 
import torch.nn as nn

torch.manual_seed(0)

<torch._C.Generator at 0x1a585802cd0>

In [23]:
charset_length = 34
max_length = 72

conv1 = nn.Conv1d(in_channels=charset_length, out_channels=9, kernel_size=9)
bn1 = nn.BatchNorm1d(9)

conv2 = nn.Conv1d(in_channels=9, out_channels=9, kernel_size=9)
bn2 = nn.BatchNorm1d(9)

conv3 = nn.Conv1d(in_channels=9, out_channels=10, kernel_size=11)
bn3 = nn.BatchNorm1d(10)

# Flatten
flatten = nn.Flatten()

# Dense Layers
dense1 = nn.Linear(10 * (max_length - 26), 436)
bn_dense1 = nn.BatchNorm1d(436, track_running_stats=False)
dropout1 = nn.Dropout(0.083)

# Decoder Layers
decode_dense1 = nn.Linear(436, 436)
decode_bn1 = nn.BatchNorm1d(436, track_running_stats=False)
decode_dropout1 = nn.Dropout(0.1)

# Repeat Vector (similar to TensorFlow's RepeatVector)
max_length = max_length

# Recurrent Layer
gru_hidden_size = 488
gru = nn.GRU(input_size=436, hidden_size=gru_hidden_size, num_layers=3, batch_first=True)

# Final layer to reconstruct one-hot encoded sequence
reconstruct = nn.Linear(gru_hidden_size, charset_length)

In [24]:
conv_layers = [
    conv1,
    bn1,
    nn.Tanh(),
    conv2,
    bn2,
    nn.Tanh(),
    conv3,
    bn3,
    nn.Tanh(),
]

dense_layers = [
    dense1,
    bn_dense1,
    dropout1,
    nn.Tanh(),
]

conv_layers = nn.Sequential(*conv_layers)
dense_layers = nn.Sequential(*dense_layers)

decoder_dense_layers = [
    decode_dense1,
    decode_bn1,
    decode_dropout1,
    nn.Tanh()
]

decoder_dense_layers = nn.Sequential(*decoder_dense_layers)
decor_gru = gru
reconstruct_layer = reconstruct

def forward_sequential(x):
    conv_layers.eval()
    dense_layers.eval()
    decoder_dense_layers.eval()
    gru.eval()
    reconstruct_layer.eval()

    # Encode
    x = conv_layers(x)
    x = flatten(x)
    x = dense_layers(x)

    # Decode
    x = decoder_dense_layers(x)
    x = x.unsqueeze(1).repeat(1, max_length, 1)
    x, _ = gru(x)
    x = reconstruct_layer(x)

    return x



In [25]:
def encode(x):
    x = conv1(x)
    x = bn1(x)
    x = torch.tanh(x)

    x = conv2(x)
    x = bn2(x)
    x = torch.tanh(x)

    x = conv3(x)
    x = bn3(x)
    x = torch.tanh(x)

    x = flatten(x)

    x = dense1(x)
    x = bn_dense1(x)
    x = torch.tanh(x)
    x = dropout1(x)

    return x

def decode(x):
    # Decoder Dense Layers
    x = decode_dense1(x)
    x = decode_bn1(x)
    x = torch.tanh(x)
    x = decode_dropout1(x)

    # Repeat Vector (similar to TensorFlow's RepeatVector)
    # Unsqueeze to add sequence dimension and repeat
    x = x.unsqueeze(1).repeat(1, max_length, 1)

    # GRU Layers
    x, _ = gru(x)

    # Reconstruct one-hot encoded sequence
    x = reconstruct(x)

    return x

def forward_manaul(x):
    # set to eval mode every layer
    conv1.eval()
    bn1.eval()
    conv2.eval()
    bn2.eval()
    conv3.eval()
    bn3.eval()
    
    dense1.eval()
    bn_dense1.eval()
    dropout1.eval()
    
    decode_dense1.eval()
    decode_bn1.eval()
    decode_dropout1.eval()
    
    gru.eval()
    reconstruct.eval()
    
    # Encode
    latent = encode(x)

    # Decode
    reconstructed = decode(latent)

    return reconstructed


In [26]:
x = torch.randn(2, 34, 72)
out_seq = forward_sequential(x)
out_seq.shape
out_man = forward_manaul(x)
out_man.shape
#comapre if equal
torch.allclose(out_seq, out_man)

True

In [27]:
dense_layers

Sequential(
  (0): Linear(in_features=460, out_features=436, bias=True)
  (1): BatchNorm1d(436, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (2): Dropout(p=0.083, inplace=False)
  (3): Tanh()
)

In [28]:
x = torch.randn(2, 34, 72)

y1 = conv_layers(x)
y2 = conv1(x)
y2 = bn1(y2)
y2 = torch.tanh(y2)
y2 = conv2(y2)
y2 = bn2(y2)
y2 = torch.tanh(y2)
y2 = conv3(y2)
y2 = bn3(y2)
y2 = torch.tanh(y2)

torch.allclose(y1, y2)

True

In [29]:
y_dense_1 = flatten(y1)
y_dense_2 = flatten(y2)
torch.allclose(y_dense_1, y_dense_2)

True

In [30]:
dense1
bn_dense1

BatchNorm1d(436, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)

In [31]:
y_dense_1 = dense_layers(y_dense_1)

y_dense_2 = dense1(y_dense_2)
y_dense_2 = bn_dense1(y_dense_2)
y_dense_2 = torch.tanh(y_dense_2)

torch.allclose(y_dense_1, y_dense_2)

True

In [32]:
y_decoder_dense_1 = decoder_dense_layers(y_dense_1)
y_decoder_dense_2 = decode_dense1(y_dense_2)
y_decoder_dense_2 = decode_bn1(y_decoder_dense_2)
y_decoder_dense_2 = torch.tanh(y_decoder_dense_2)
y_decoder_dense_2 = decode_dropout1(y_decoder_dense_2)

torch.allclose(y_decoder_dense_1, y_decoder_dense_2)

True

In [33]:
y_decoder_dense_1 = y_decoder_dense_1.unsqueeze(1).repeat(1, max_length, 1)
y_decoder_dense_2 = y_decoder_dense_2.unsqueeze(1).repeat(1, max_length, 1)

torch.allclose(y_decoder_dense_1, y_decoder_dense_2)

True

In [34]:
y_gru_1, _ = decor_gru(y_decoder_dense_1)
y_gru_2, _ = gru(y_decoder_dense_2)

torch.allclose(y_gru_1, y_gru_2)

True

In [35]:
y_reconstruct_1 = reconstruct_layer(y_gru_1)
y_reconstruct_2 = reconstruct(y_gru_2)

torch.allclose(y_reconstruct_1, y_reconstruct_2)

True

In [13]:
from models import MOVVAELightning
import yaml

with open('../configs/movae_config.yml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    
    
int_to_char = {i: chr(i + 33) for i in range(34)}

model = MOVVAELightning(config['model']['args']['params'], charset_size=34, seq_len=72, loss='bce', lr=0.001, int_to_char=int_to_char)
model = model.model

x = torch.randn(2, 34, 72)
out = model(x)

torch.Size([2, 192])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x192 and 436x436)

In [12]:
out.shape

torch.Size([2, 72, 34])