In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import tensor
from mindcraft.torch.module import FeedForward, Conv, ConvT

In [3]:
foo = "Sigmoid"
device = "cuda"  # "cpu"

In [4]:
def get_encoder(num_parameters, latent_size, kernels, strides, filters, foo, dropout=0.1, batch_norm=False, verbose=True):
    cnn_reduction = Conv.get_cnn_output_size(input_size=num_parameters, kernel_size=kernels, stride=strides)
    cnn_output_size = cnn_reduction[-1] * filters[-1]
    
    if verbose:
        print(f"CNN-reductions from input {num_parameters}: ", cnn_reduction)
        print("Final CNN-flatten-size to latent-size:", cnn_output_size, "->", latent_size)

    flatten_size = cnn_reduction[-1] * filters[-1]

    if dropout:
        dropout = [dropout] * len(kernels)

    if batch_norm:
        batch_norm = [True] * len(kernels)

    encoder = Conv(input_size=1, input_dim=1,
                   filters=filters,
                   strides=strides,
                   kernel_size=kernels,
                   dropout=dropout,
                   batch_norm=batch_norm,
                   activation=activation,
                   flatten=dict(cls=FeedForward, input_size=flatten_size, output_size=latent_size),
                  )

    if verbose:
        import torch
        x = torch.randn(1, 1, num_parameters)
        print("\n{0:<22s}".format("encoder input:"), x.shape)
        for cnn in encoder.cnn:
            x = cnn(x)
            print("- {0:<20s} output:".format(str(cnn.__class__.__name__)), x.shape)
    
        print()
        print(encoder.parameters_str)
        
    return encoder

In [10]:
def get_decoder(num_parameters, latent_size, kernels, strides, filters, foo, dropout=0.1, batch_norm=False, verbose=True,
                padding=None, output_padding=None):
    convt_unfold = ConvT.get_output_size(input_size=latent_size, kernel_size=kernels, stride=strides, 
                                         padding=padding, output_padding=output_padding)
    unfold_size = convt_unfold[-1] * filters[-1]

    if verbose:
        print(f"ConvT-unfold from latent {latent_size}: ", convt_unfold)

    if dropout:
        dropout = [dropout] * (len(kernels) - 1) + [None]

    if batch_norm:
        batch_norm = [True] * (len(kernels) - 1) + [None]

    decoder = ConvT(input_size=1, input_dim=1,
                    filters=filters,
                    kernel_size=kernels,
                    strides=strides,
                    activation=foo,
                    dropout=dropout,
                    batch_norm=batch_norm,
                    padding=padding or 0,
                    output_padding=output_padding or 0,
                    flatten=True,  # dict(cls=FeedForward, input_size=unfold_size, output_size=num_parameters),
                    )

    if verbose:
        x = torch.randn(1, 1, latent_size)
        print("\n{0:<22s}".format("decoder input:"), x.shape)
        for cnn in decoder.cnn_t:
            x = cnn(x)
            print("- {0:<20s} output:".format(str(cnn.__class__.__name__)), x.shape)  

        print()
        print(decoder.parameters_str)
    
    return decoder

## CartPole-v1

In [10]:
num_parameters = 178
latent_size = 9

In [11]:
encoder_strides = [1, 2, 2, 4]
encoder_filters = [8, 8, 8, 16]
encoder_kernels = [4, 4, 4, 6]

activation = ["Sigmoid", "Sigmoid", "Sigmoid", "Sigmoid"]

encoder = get_encoder(num_parameters, latent_size, 
                      kernels=encoder_kernels, strides=encoder_strides, filters=encoder_filters, foo=activation)

CNN-reductions from input 178:  (175, 86, 42, 10)
Final CNN-flatten-size to latent-size: 160 -> 9

encoder input:         torch.Size([1, 1, 178])
- Conv1d               output: torch.Size([1, 8, 175])
- BatchNorm1d          output: torch.Size([1, 8, 175])
- Sigmoid              output: torch.Size([1, 8, 175])
- Dropout              output: torch.Size([1, 8, 175])
- Conv1d               output: torch.Size([1, 8, 86])
- BatchNorm1d          output: torch.Size([1, 8, 86])
- Sigmoid              output: torch.Size([1, 8, 86])
- Dropout              output: torch.Size([1, 8, 86])
- Conv1d               output: torch.Size([1, 8, 42])
- BatchNorm1d          output: torch.Size([1, 8, 42])
- Sigmoid              output: torch.Size([1, 8, 42])
- Dropout              output: torch.Size([1, 8, 42])
- Conv1d               output: torch.Size([1, 16, 10])
- BatchNorm1d          output: torch.Size([1, 16, 10])
- Sigmoid              output: torch.Size([1, 16, 10])
- Dropout              output: torch.

In [13]:
num_parameters = 178
latent_size = 16

In [14]:
encoder_strides = [2, 2, 2, 1]
encoder_kernels = [3, 3, 3, 4]
encoder_filters = [4, 8, 16, 1]
activation = ["Sigmoid", "Sigmoid", "Sigmoid", "Sigmoid"]

encoder = get_encoder(num_parameters, latent_size, 
                      kernels=encoder_kernels, strides=encoder_strides, filters=encoder_filters, foo=activation)

CNN-reductions from input 178:  (88, 43, 21, 18)
Final CNN-flatten-size to latent-size: 18 -> 16

encoder input:         torch.Size([1, 1, 178])
- Conv1d               output: torch.Size([1, 4, 88])
- BatchNorm1d          output: torch.Size([1, 4, 88])
- Sigmoid              output: torch.Size([1, 4, 88])
- Dropout              output: torch.Size([1, 4, 88])
- Conv1d               output: torch.Size([1, 8, 43])
- BatchNorm1d          output: torch.Size([1, 8, 43])
- Sigmoid              output: torch.Size([1, 8, 43])
- Dropout              output: torch.Size([1, 8, 43])
- Conv1d               output: torch.Size([1, 16, 21])
- BatchNorm1d          output: torch.Size([1, 16, 21])
- Sigmoid              output: torch.Size([1, 16, 21])
- Dropout              output: torch.Size([1, 16, 21])
- Conv1d               output: torch.Size([1, 1, 18])
- BatchNorm1d          output: torch.Size([1, 1, 18])
- Sigmoid              output: torch.Size([1, 1, 18])
- Dropout              output: torch.Size

In [19]:
decoder_kernels = [3, 4, 3, 4, 4]
decoder_strides = [1, 1, 2, 2, 2]
decoder_filters = [32, 16, 16, 8, 1]
decoder_activations = ["Sidmoig", "Sidmoig", "Sidmoig", "Sidmoig", None]

decoder = get_decoder(num_parameters, latent_size, kernels=decoder_kernels, strides=decoder_strides, filters=decoder_filters,
                      foo=decoder_activations)

ConvT-unfold from latent 16:  (18, 21, 43, 88, 178)

decoder input:         torch.Size([1, 1, 16])
- ConvTranspose1d      output: torch.Size([1, 32, 18])
- BatchNorm1d          output: torch.Size([1, 32, 18])
- Dropout              output: torch.Size([1, 32, 18])
- ConvTranspose1d      output: torch.Size([1, 16, 21])
- BatchNorm1d          output: torch.Size([1, 16, 21])
- Dropout              output: torch.Size([1, 16, 21])
- ConvTranspose1d      output: torch.Size([1, 16, 43])
- BatchNorm1d          output: torch.Size([1, 16, 43])
- Dropout              output: torch.Size([1, 16, 43])
- ConvTranspose1d      output: torch.Size([1, 8, 88])
- BatchNorm1d          output: torch.Size([1, 8, 88])
- Dropout              output: torch.Size([1, 8, 88])
- ConvTranspose1d      output: torch.Size([1, 1, 178])
- BatchNorm1d          output: torch.Size([1, 1, 178])
- Flatten              output: torch.Size([1, 178])

ConvT
cnn_t.0.weight	(1, 32, 3)
cnn_t.0.bias	(32,)
cnn_t.3.weight	(32, 16, 4)
cnn

In [132]:
num_parameters = 178
latent_size = 4

In [133]:
encoder_strides = [2, 2, 2, 1]
encoder_kernels = [3, 3, 3, 4]
encoder_filters = [4, 8, 16, 1]
activation = ["Sigmoid", "Sigmoid", "Sigmoid", "Sigmoid"]

encoder = get_encoder(num_parameters, latent_size, 
                      kernels=encoder_kernels, strides=encoder_strides, filters=encoder_filters, foo=activation)

CNN-reductions from input 178:  (88, 43, 21, 18)
Final CNN-flatten-size to latent-size: 18 -> 4

encoder input:         torch.Size([1, 1, 178])
- Conv1d               output: torch.Size([1, 4, 88])
- BatchNorm1d          output: torch.Size([1, 4, 88])
- Sigmoid              output: torch.Size([1, 4, 88])
- Dropout              output: torch.Size([1, 4, 88])
- Conv1d               output: torch.Size([1, 8, 43])
- BatchNorm1d          output: torch.Size([1, 8, 43])
- Sigmoid              output: torch.Size([1, 8, 43])
- Dropout              output: torch.Size([1, 8, 43])
- Conv1d               output: torch.Size([1, 16, 21])
- BatchNorm1d          output: torch.Size([1, 16, 21])
- Sigmoid              output: torch.Size([1, 16, 21])
- Dropout              output: torch.Size([1, 16, 21])
- Conv1d               output: torch.Size([1, 1, 18])
- BatchNorm1d          output: torch.Size([1, 1, 18])
- Sigmoid              output: torch.Size([1, 1, 18])
- Dropout              output: torch.Size(

In [144]:
decoder_kernels = [6, 4, 3, 4, 4]
decoder_strides = [4, 1, 2, 2, 2]
decoder_filters = [32, 16, 16, 8, 1]
decoder_activations = ["Sidmoig", "Sidmoig", "Sidmoig", "Sidmoig", None]

decoder = get_decoder(num_parameters, latent_size, kernels=decoder_kernels, strides=decoder_strides, filters=decoder_filters,
                      foo=decoder_activations)

ConvT-unfold from latent 4:  (18, 21, 43, 88, 178)
Final ConvT-flatten-size to latent-size: 21 -> 4

decoder input:         torch.Size([1, 1, 4])
- ConvTranspose1d      output: torch.Size([1, 32, 18])
- BatchNorm1d          output: torch.Size([1, 32, 18])
- Dropout              output: torch.Size([1, 32, 18])
- ConvTranspose1d      output: torch.Size([1, 16, 21])
- BatchNorm1d          output: torch.Size([1, 16, 21])
- Dropout              output: torch.Size([1, 16, 21])
- ConvTranspose1d      output: torch.Size([1, 16, 43])
- BatchNorm1d          output: torch.Size([1, 16, 43])
- Dropout              output: torch.Size([1, 16, 43])
- ConvTranspose1d      output: torch.Size([1, 8, 88])
- BatchNorm1d          output: torch.Size([1, 8, 88])
- Dropout              output: torch.Size([1, 8, 88])
- ConvTranspose1d      output: torch.Size([1, 1, 178])
- BatchNorm1d          output: torch.Size([1, 1, 178])
- Flatten              output: torch.Size([1, 178])

ConvT
cnn_t.0.weight	(1, 32, 6)
cn

In [148]:
num_parameters = 178
latent_size = 2

In [149]:
encoder_strides = [2, 2, 2, 1]
encoder_kernels = [3, 3, 3, 4]
encoder_filters = [4, 8, 16, 1]
activation = ["Sigmoid", "Sigmoid", "Sigmoid", "Sigmoid"]

encoder = get_encoder(num_parameters, latent_size, 
                      kernels=encoder_kernels, strides=encoder_strides, filters=encoder_filters, foo=activation)

CNN-reductions from input 178:  (88, 43, 21, 18)
Final CNN-flatten-size to latent-size: 18 -> 2

encoder input:         torch.Size([1, 1, 178])
- Conv1d               output: torch.Size([1, 4, 88])
- BatchNorm1d          output: torch.Size([1, 4, 88])
- Sigmoid              output: torch.Size([1, 4, 88])
- Dropout              output: torch.Size([1, 4, 88])
- Conv1d               output: torch.Size([1, 8, 43])
- BatchNorm1d          output: torch.Size([1, 8, 43])
- Sigmoid              output: torch.Size([1, 8, 43])
- Dropout              output: torch.Size([1, 8, 43])
- Conv1d               output: torch.Size([1, 16, 21])
- BatchNorm1d          output: torch.Size([1, 16, 21])
- Sigmoid              output: torch.Size([1, 16, 21])
- Dropout              output: torch.Size([1, 16, 21])
- Conv1d               output: torch.Size([1, 1, 18])
- BatchNorm1d          output: torch.Size([1, 1, 18])
- Sigmoid              output: torch.Size([1, 1, 18])
- Dropout              output: torch.Size(

In [144]:
decoder_kernels = [6, 4, 3, 4, 4]
decoder_strides = [4, 1, 2, 2, 2]
decoder_filters = [32, 16, 16, 8, 1]
decoder_activations = ["Sidmoig", "Sidmoig", "Sidmoig", "Sidmoig", None]

decoder = get_decoder(num_parameters, latent_size, kernels=decoder_kernels, strides=decoder_strides, filters=decoder_filters,
                      foo=decoder_activations)

ConvT-unfold from latent 4:  (18, 21, 43, 88, 178)
Final ConvT-flatten-size to latent-size: 21 -> 4

decoder input:         torch.Size([1, 1, 4])
- ConvTranspose1d      output: torch.Size([1, 32, 18])
- BatchNorm1d          output: torch.Size([1, 32, 18])
- Dropout              output: torch.Size([1, 32, 18])
- ConvTranspose1d      output: torch.Size([1, 16, 21])
- BatchNorm1d          output: torch.Size([1, 16, 21])
- Dropout              output: torch.Size([1, 16, 21])
- ConvTranspose1d      output: torch.Size([1, 16, 43])
- BatchNorm1d          output: torch.Size([1, 16, 43])
- Dropout              output: torch.Size([1, 16, 43])
- ConvTranspose1d      output: torch.Size([1, 8, 88])
- BatchNorm1d          output: torch.Size([1, 8, 88])
- Dropout              output: torch.Size([1, 8, 88])
- ConvTranspose1d      output: torch.Size([1, 1, 178])
- BatchNorm1d          output: torch.Size([1, 1, 178])
- Flatten              output: torch.Size([1, 178])

ConvT
cnn_t.0.weight	(1, 32, 6)
cn

In [156]:
num_parameters = 792
latent_size = 24

In [162]:
encoder_strides = [2, 2, 2, 2, 1]
encoder_kernels = [3, 3, 3, 3, 4]
encoder_filters = [8, 16, 32, 64, 1]
activation = ["Sigmoid", "Sigmoid", "Sigmoid", "Sigmoid", "Sigmoid"]

encoder = get_encoder(num_parameters, latent_size, batch_norm=None, dropout=None,
                      kernels=encoder_kernels, strides=encoder_strides, filters=encoder_filters, foo=activation)

CNN-reductions from input 792:  (395, 197, 98, 48, 45)
Final CNN-flatten-size to latent-size: 45 -> 24

encoder input:         torch.Size([1, 1, 792])
- Conv1d               output: torch.Size([1, 8, 395])
- Sigmoid              output: torch.Size([1, 8, 395])
- Conv1d               output: torch.Size([1, 16, 197])
- Sigmoid              output: torch.Size([1, 16, 197])
- Conv1d               output: torch.Size([1, 32, 98])
- Sigmoid              output: torch.Size([1, 32, 98])
- Conv1d               output: torch.Size([1, 64, 48])
- Sigmoid              output: torch.Size([1, 64, 48])
- Conv1d               output: torch.Size([1, 1, 45])
- Sigmoid              output: torch.Size([1, 1, 45])
- Flatten              output: torch.Size([1, 45])
- FeedForward          output: torch.Size([1, 24])

Conv
cnn.0.weight	(8, 1, 3)
cnn.0.bias	(8,)
cnn.2.weight	(16, 8, 3)
cnn.2.bias	(16,)
cnn.4.weight	(32, 16, 3)
cnn.4.bias	(32,)
cnn.6.weight	(64, 32, 3)
cnn.6.bias	(64,)
cnn.8.weight	(1, 64, 4)
cnn

In [181]:
decoder_kernels = [2, 4, 3, 3, 4]
decoder_strides = [2, 2, 2, 2, 2]
decoder_filters = [64, 32, 32, 32, 1]
decoder_activations = ["Sidmoig", "Sidmoig", "Sidmoig", "Sidmoig", None]

decoder = get_decoder(num_parameters, latent_size, kernels=decoder_kernels, strides=decoder_strides, filters=decoder_filters,
                      foo=decoder_activations)

ConvT-unfold from latent 24:  (48, 98, 197, 395, 792)
Final ConvT-flatten-size to latent-size: 21 -> 24

decoder input:         torch.Size([1, 1, 24])
- ConvTranspose1d      output: torch.Size([1, 64, 48])
- BatchNorm1d          output: torch.Size([1, 64, 48])
- Dropout              output: torch.Size([1, 64, 48])
- ConvTranspose1d      output: torch.Size([1, 32, 98])
- BatchNorm1d          output: torch.Size([1, 32, 98])
- Dropout              output: torch.Size([1, 32, 98])
- ConvTranspose1d      output: torch.Size([1, 32, 197])
- BatchNorm1d          output: torch.Size([1, 32, 197])
- Dropout              output: torch.Size([1, 32, 197])
- ConvTranspose1d      output: torch.Size([1, 32, 395])
- BatchNorm1d          output: torch.Size([1, 32, 395])
- Dropout              output: torch.Size([1, 32, 395])
- ConvTranspose1d      output: torch.Size([1, 1, 792])
- BatchNorm1d          output: torch.Size([1, 1, 792])
- Flatten              output: torch.Size([1, 792])

ConvT
cnn_t.0.weight

## BARS

### FF-Flat-FF

In [37]:
num_parameters = 103
latent_size = 3

In [38]:
encoder_strides = [2, 2, 2, 2, 1]
encoder_kernels = [3, 3, 3, 3, 3]
encoder_filters = [8, 16, 32, 64, 1]
activation = ["Sigmoid", "Sigmoid", "Sigmoid", "Sigmoid", "Sigmoid"]

encoder = get_encoder(num_parameters, latent_size, batch_norm=None, dropout=None,
                      kernels=encoder_kernels, strides=encoder_strides, filters=encoder_filters, foo=activation)

CNN-reductions from input 103:  (51, 25, 12, 5, 3)
Final CNN-flatten-size to latent-size: 3 -> 3

encoder input:         torch.Size([1, 1, 103])
- Conv1d               output: torch.Size([1, 8, 51])
- Sigmoid              output: torch.Size([1, 8, 51])
- Conv1d               output: torch.Size([1, 16, 25])
- Sigmoid              output: torch.Size([1, 16, 25])
- Conv1d               output: torch.Size([1, 32, 12])
- Sigmoid              output: torch.Size([1, 32, 12])
- Conv1d               output: torch.Size([1, 64, 5])
- Sigmoid              output: torch.Size([1, 64, 5])
- Conv1d               output: torch.Size([1, 1, 3])
- Sigmoid              output: torch.Size([1, 1, 3])
- Flatten              output: torch.Size([1, 3])
- FeedForward          output: torch.Size([1, 3])

Conv
cnn.0.weight	(8, 1, 3)
cnn.0.bias	(8,)
cnn.2.weight	(16, 8, 3)
cnn.2.bias	(16,)
cnn.4.weight	(32, 16, 3)
cnn.4.bias	(32,)
cnn.6.weight	(64, 32, 3)
cnn.6.bias	(64,)
cnn.8.weight	(1, 64, 3)
cnn.8.bias	(1,)
cnn

In [42]:
decoder_kernels = [3, 3, 3, 3, 3]
decoder_strides = [1, 2, 2, 2, 2]
output_padding  = [0, 1, 0, 0, 0]
decoder_filters = [64, 32, 32, 32, 1]
decoder_activations = ["Sidmoig", "Sidmoig", "Sidmoig", "Sidmoig", None]

decoder = get_decoder(num_parameters, latent_size, kernels=decoder_kernels, strides=decoder_strides, filters=decoder_filters,
                      foo=decoder_activations, output_padding=output_padding)

ConvT-unfold from latent 3:  (5, 12, 25, 51, 103)

decoder input:         torch.Size([1, 1, 3])
- ConvTranspose1d      output: torch.Size([1, 64, 5])
- BatchNorm1d          output: torch.Size([1, 64, 5])
- Dropout              output: torch.Size([1, 64, 5])
- ConvTranspose1d      output: torch.Size([1, 32, 12])
- BatchNorm1d          output: torch.Size([1, 32, 12])
- Dropout              output: torch.Size([1, 32, 12])
- ConvTranspose1d      output: torch.Size([1, 32, 25])
- BatchNorm1d          output: torch.Size([1, 32, 25])
- Dropout              output: torch.Size([1, 32, 25])
- ConvTranspose1d      output: torch.Size([1, 32, 51])
- BatchNorm1d          output: torch.Size([1, 32, 51])
- Dropout              output: torch.Size([1, 32, 51])
- ConvTranspose1d      output: torch.Size([1, 1, 103])
- BatchNorm1d          output: torch.Size([1, 1, 103])
- Flatten              output: torch.Size([1, 103])

ConvT
cnn_t.0.weight	(1, 64, 3)
cnn_t.0.bias	(64,)
cnn_t.3.weight	(64, 32, 3)
cnn_t.