In [None]:
import torch
import torch.nn as nn
from gymnasium import spaces

from src.modules.training.datasets.utils import TokenIndex
from src.modules.training.models.cnn import CNN

model = CNN()
info = {
    'env_build': {
        'observation_space': spaces.Box(
            low=0,
            high=255,
            shape=(5, 5, 3),
            dtype="uint8",
        ),
        'action_space': spaces.Discrete(7),
    },
    'token_index': TokenIndex({
        'observation': [(0, 10), (1, 6), (2, 3), (3,5)],
        'action': [(0, 7)],
        'reward': [(0, 0)],
    }),
}
model.setup(info)

hidden_dims: list[int] = [32, 64, 128]
reversed_dims = list(reversed(hidden_dims))
decode = nn.Sequential(
    nn.ConvTranspose2d(
            reversed_dims[0],
            reversed_dims[1],
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1
        ),
    nn.ReLU(),
    nn.BatchNorm2d(reversed_dims[1]),
    nn.ConvTranspose2d(
            reversed_dims[1],
            reversed_dims[2],
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1
        ),
    nn.ReLU(),
    nn.BatchNorm2d(reversed_dims[2]),
)

final_decoder = nn.ConvTranspose2d(
    reversed_dims[2],
    24,
    kernel_size=3,
    stride=2,
    padding=1,
    output_padding=1
)

input = torch.randn(252, 24, 5, 5)
input_enc = model.encode(input)
print(input_enc.shape) # torch.Size([252, 128, 1, 1])

input_dec = decode(input_enc)

# ValueError: requested an output size of torch.Size([5, 5]), but valid sizes range from [7, 7] to [8, 8] (for an input of torch.Size([4, 4]))
# input_dec = final_decoder(input_dec, output_size=input.size())

print(input_dec.shape) # torch.Size([252, 32, 4, 4])


# Sample Code from Torch Documentation
input = torch.randn(1, 16, 12, 12)
print(input.size()) # torch.Size([1, 16, 12, 12])
downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
h = downsample(input)
print(h.size()) # torch.Size([1, 16, 6, 6])
output = upsample(h)
print(output.size()) # torch.Size([1, 16, 11, 11])
output = upsample(h, output_size=input.size())
print(output.size()) # torch.Size([1, 16, 12, 12])

In [None]:
import torch
import torch.nn as nn

input_channels: int = 24
hidden_channels: list[int] = [32, 64]

relu = nn.ReLU()
channels = [input_channels] + hidden_channels
reversed_channels = list(reversed(channels))

# Setup Network: Encoder
## Encoder 0
encode00 = nn.Conv2d(channels[0], channels[1], 3, padding=1)
enc_bn00 = nn.BatchNorm2d(channels[1])
encode01 = nn.Conv2d(channels[1], channels[1], 3, padding=1)
enc_bn01 = nn.BatchNorm2d(channels[1])

## Encoder 1
mp01 = nn.MaxPool2d(2)
encode10 = nn.Conv2d(channels[1], channels[2], 3, padding=1)
enc_bn10 = nn.BatchNorm2d(channels[2])
encode11 = nn.Conv2d(channels[2], channels[2], 3, padding=1)
enc_bn11 = nn.BatchNorm2d(channels[2])

# Setup Network: Decoder
## Decoder 1
deconv_up_01 = nn.ConvTranspose2d(reversed_channels[0], reversed_channels[0]//2, kernel_size=2, stride=2)
deconv_bn1 = nn.BatchNorm2d(reversed_channels[1])
deconv2 = nn.ConvTranspose2d(reversed_channels[1], reversed_channels[2], 3, padding=1)
bn2 = nn.BatchNorm2d(reversed_channels[2])

# Input
input = torch.randn(252, 24, 5, 5)
print(f'input  : {input.shape}')

# Encode
input_enc = relu(bn01(encode1(input)))
print(f'encode1: {input_enc.shape}')
input_enc = relu(bn02(encode2(input_enc)))
print(f'encode2: {input_enc.shape}')
input_enc = mp1(input_enc)
print(f'maxpool: {input_enc.shape}')

# Decode
input_dec = relu(bn1(deconv1(input_enc)))
print(f'decode1: {input_dec.shape}')
input_dec = relu(bn2(deconv2(input_dec)))
print(f'decode2: {input_dec.shape}')

# # Sample Code from Torch Documentation
# input = torch.randn(1, 16, 12, 12)
# print(input.size()) # torch.Size([1, 16, 12, 12])
# downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
# upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
# h = downsample(input)
# print(h.size()) # torch.Size([1, 16, 6, 6])
# output = upsample(h)
# print(output.size()) # torch.Size([1, 16, 11, 11])
# output = upsample(h, output_size=input.size())
# print(output.size()) # torch.Size([1, 16, 12, 12])