In [2]:
import gymnasium as gym
import minigrid
from minigrid.wrappers import ImgObsWrapper, RGBImgObsWrapper
from src.modules.environment.minigrid_wrappers import FullyObsWrapper
import numpy as np

# minigrid.register_minigrid_envs()

In [8]:
env = gym.make("MiniGrid-Empty-5x5-v0")
observation, info = env.reset(seed=42)

env = FullyObsWrapper(env)
env = ImgObsWrapper(env)

observation, info = env.reset(seed=42)
print(observation)

observation, reward, terminated, truncated, info = env.step(0)
print(observation)


[[[ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [10  0  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 8  1  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]]]
[[[ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [10  0  3]
  [ 1  0  0]
  [ 1  0  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 8  1  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]]]


In [None]:
import torch
import torch.nn as nn
from gymnasium import spaces

from src.modules.training.datasets.utils import TokenIndex
from src.modules.training.models.cnn import CNN

model = CNN()
info = {
    'env_build': {
        'observation_space': spaces.Box(
            low=0,
            high=255,
            shape=(5, 5, 3),
            dtype="uint8",
        ),
        'action_space': spaces.Discrete(7),
    },
    'token_index': TokenIndex({
        'observation': [(0, 10), (1, 6), (2, 3), (3,5)],
        'action': [(0, 7)],
        'reward': [(0, 0)],
    }),
}
model.setup(info)

hidden_dims: list[int] = [32, 64, 128]
reversed_dims = list(reversed(hidden_dims))
decode = nn.Sequential(
    nn.ConvTranspose2d(
            reversed_dims[0],
            reversed_dims[1],
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1
        ),
    nn.ReLU(),
    nn.BatchNorm2d(reversed_dims[1]),
    nn.ConvTranspose2d(
            reversed_dims[1],
            reversed_dims[2],
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1
        ),
    nn.ReLU(),
    nn.BatchNorm2d(reversed_dims[2]),
)

final_decoder = nn.ConvTranspose2d(
    reversed_dims[2],
    24,
    kernel_size=3,
    stride=2,
    padding=1,
    output_padding=1
)

input = torch.randn(252, 24, 5, 5)
input_enc = model.encode(input)
print(input_enc.shape) # torch.Size([252, 128, 1, 1])

input_dec = decode(input_enc)

# ValueError: requested an output size of torch.Size([5, 5]), but valid sizes range from [7, 7] to [8, 8] (for an input of torch.Size([4, 4]))
# input_dec = final_decoder(input_dec, output_size=input.size())

print(input_dec.shape) # torch.Size([252, 32, 4, 4])


# Sample Code from Torch Documentation
input = torch.randn(1, 16, 12, 12)
print(input.size()) # torch.Size([1, 16, 12, 12])
downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
h = downsample(input)
print(h.size()) # torch.Size([1, 16, 6, 6])
output = upsample(h)
print(output.size()) # torch.Size([1, 16, 11, 11])
output = upsample(h, output_size=input.size())
print(output.size()) # torch.Size([1, 16, 12, 12])

torch.Size([252, 128, 1, 1])


ValueError: requested an output size of torch.Size([5, 5]), but valid sizes range from [7, 7] to [8, 8] (for an input of torch.Size([4, 4]))

In [29]:
import torch
import torch.nn as nn


# Setup Network: Encoder
hidden_dims: list[int] = [32, 64, 128]
encode = nn.Sequential(
    nn.Conv2d(24, hidden_dims[0], kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(hidden_dims[0]),
    nn.Conv2d(hidden_dims[0], hidden_dims[1], kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(hidden_dims[1]),
    nn.Conv2d(hidden_dims[1], hidden_dims[2], kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(hidden_dims[2]),
)

# Setup Network: Decoder
reversed_dims = list(reversed(hidden_dims))
relu = nn.ReLU()
deconv1 = nn.ConvTranspose2d(
    reversed_dims[0],
    reversed_dims[1],
    kernel_size=3,
    stride=2,
    padding=1,
    output_padding=1
)
bn1 = nn.BatchNorm2d(reversed_dims[1])
deconv2 = nn.ConvTranspose2d(
    reversed_dims[1],
    reversed_dims[2],
    kernel_size=3,
    stride=2,
    padding=1,
    output_padding=0
)
bn2 = nn.BatchNorm2d(reversed_dims[2])
deconv3 = nn.ConvTranspose2d(
    reversed_dims[2],
    24,
    kernel_size=3,
    stride=2,
    padding=1,
    output_padding=0
)

# Input
input = torch.randn(252, 24, 5, 5)

# Encode
input_enc = encode(input)
print(input_enc.shape) # torch.Size([252, 128, 1, 1])

# Decode
input_dec = relu(bn1(deconv1(input_enc)))
print(input_dec.shape) 
input_dec = relu(bn2(deconv2(input_dec)))
print(input_dec.shape)
input_dec = deconv3(input_dec)
print(input_dec.shape)


# # Sample Code from Torch Documentation
# input = torch.randn(1, 16, 12, 12)
# print(input.size()) # torch.Size([1, 16, 12, 12])
# downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
# upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
# h = downsample(input)
# print(h.size()) # torch.Size([1, 16, 6, 6])
# output = upsample(h)
# print(output.size()) # torch.Size([1, 16, 11, 11])
# output = upsample(h, output_size=input.size())
# print(output.size()) # torch.Size([1, 16, 12, 12])

torch.Size([252, 128, 1, 1])
torch.Size([252, 64, 2, 2])
torch.Size([252, 32, 3, 3])
torch.Size([252, 24, 5, 5])


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder
        self.conv1 = nn.Conv2d(24, 32, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        # Decoder
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1)
        self.dbn1 = nn.BatchNorm2d(64)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1)
        self.dbn2 = nn.BatchNorm2d(32)
        self.deconv3 = nn.ConvTranspose2d(32, 24, kernel_size=3, stride=2, padding=1)
        
    def forward(self, x):
        # Save input size for later
        input_size = x.size()
        
        # Encoder
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Decoder with dynamic output padding
        x = F.relu(self.dbn1(self.deconv1(x, output_size=(input_size[0], 64, 2, 2))))
        x = F.relu(self.dbn2(self.deconv2(x, output_size=(input_size[0], 32, 3, 3))))
        x = self.deconv3(x, output_size=input_size)
        
        return x

# Create model and test
model = Autoencoder()
input = torch.randn(252, 24, 5, 5)

# Test forward pass
with torch.no_grad():
    output = model(input)
    
print("Input shape:", input.shape)
print("Output shape:", output.shape)
print("Shapes match:", input.shape == output.shape)

# Optional: Test the encoding part separately
with torch.no_grad():
    # Manual encoding
    x = input
    x = F.relu(model.bn1(model.conv1(x)))
    x = F.relu(model.bn2(model.conv2(x)))
    encoded = F.relu(model.bn3(model.conv3(x)))
    print("Encoded shape:", encoded.shape)

Input shape: torch.Size([252, 24, 5, 5])
Output shape: torch.Size([252, 24, 5, 5])
Shapes match: True
Encoded shape: torch.Size([252, 128, 1, 1])


In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

class Autoencoder(nn.Module):
    def __init__(self, input_channels: int = 24, hidden_dims: List[int] = [32, 64, 128]):
        super().__init__()
        
        # Save configuration
        self.input_channels = input_channels
        self.hidden_dims = hidden_dims
        
        # Encoder layers
        self.encoder_layers = nn.ModuleList()
        in_channels = input_channels
        
        for hidden_dim in hidden_dims:
            self.encoder_layers.extend([
                nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU()
            ])
            in_channels = hidden_dim
            
        # Decoder layers
        self.decoder_layers = nn.ModuleList()
        reversed_dims = list(reversed(hidden_dims))
        
        for i in range(len(reversed_dims) - 1):
            self.decoder_layers.extend([
                nn.ConvTranspose2d(reversed_dims[i], reversed_dims[i + 1], 
                                 kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(reversed_dims[i + 1]),
                nn.ReLU()
            ])
            
        # Final decoder layer (no BatchNorm or ReLU)
        self.decoder_layers.append(
            nn.ConvTranspose2d(reversed_dims[-1], input_channels, 
                             kernel_size=3, stride=2, padding=1)
        )
        
    def forward(self, x):
        # Store input size
        input_size = x.size()
        
        # Store intermediate sizes for decoder
        intermediate_sizes = []
        
        # Encoding
        for i in range(0, len(self.encoder_layers), 3):
            # Apply conv
            x = self.encoder_layers[i](x)
            # Store size after conv
            intermediate_sizes.append(x.size())
            # Apply BatchNorm and ReLU
            x = self.encoder_layers[i + 2](self.encoder_layers[i + 1](x))
        
        print("Intermediate sizes:", intermediate_sizes)
        
        # Decoding
        for i in range(0, len(self.decoder_layers) - 1):
            # Get corresponding size from encoding phase
            print(f'intermediate_sizes[-i-1 = {-i-1}]: {intermediate_sizes[-i-1]}')
            output_size = intermediate_sizes[-i-1]
            # Apply transposed conv with specific output size
            x = self.decoder_layers[i](x, output_size=output_size)
            # Apply BatchNorm and ReLU
            x = self.decoder_layers[i + 2](self.decoder_layers[i + 1](x))
            
        # Final decoder layer
        x = self.decoder_layers[-1](x, output_size=input_size)
        
        return x


def test_autoencoder(input_size: int, hidden_dims: List[int]):
    print(f"\nTesting with input_size={input_size}, hidden_dims={hidden_dims}")
    
    model = Autoencoder(input_channels=24, hidden_dims=hidden_dims)
    input = torch.randn(252, 24, input_size, input_size)
    
    # Print model structure
    print("\nModel structure:")
    print("Encoder layers:", len(model.encoder_layers)//3)
    print("Decoder layers:", len(model.decoder_layers)//3 + 1)
    
    with torch.no_grad():
        # Track shapes through encoding
        x = input
        print(f"\nEncoding shapes:")
        print(f"Input: {x.shape}")
        
        for i in range(0, len(model.encoder_layers), 3):
            x = model.encoder_layers[i](x)
            x = model.encoder_layers[i + 2](model.encoder_layers[i + 1](x))
            print(f"After encoder {i//3 + 1}: {x.shape}")
        
        encoded = x
        
        # Decode
        output = model(input)
        
        print(f"\nFinal shapes:")
        print(f"Encoded: {encoded.shape}")
        print(f"Output: {output.shape}")
        print(f"Shapes match: {input.shape == output.shape}")

# Test different configurations
test_autoencoder(input_size=5, hidden_dims=[32, 64, 128])
test_autoencoder(input_size=9, hidden_dims=[16, 32, 64, 128])
test_autoencoder(input_size=16, hidden_dims=[32, 64])


Testing with input_size=5, hidden_dims=[32, 64, 128]

Model structure:
Encoder layers: 3
Decoder layers: 3

Encoding shapes:
Input: torch.Size([252, 24, 5, 5])
After encoder 1: torch.Size([252, 32, 3, 3])
After encoder 2: torch.Size([252, 64, 2, 2])
After encoder 3: torch.Size([252, 128, 1, 1])
Intermediate sizes: [torch.Size([252, 32, 3, 3]), torch.Size([252, 64, 2, 2]), torch.Size([252, 128, 1, 1])]
intermediate_sizes[-i-1 = -1]: torch.Size([252, 128, 1, 1])
intermediate_sizes[-i-1 = -2]: torch.Size([252, 64, 2, 2])


TypeError: _BatchNorm.forward() got an unexpected keyword argument 'output_size'