# Test LSTM Encoder

In [132]:
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import gym
import torch as th
import torch.nn as nn

class CnnLSTMEncoder(nn.Module):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """
    def __init__(self, observation_space: gym.spaces.Box, 
                  features_dim: int = 256):
        
        rnn_hidden_size = 100
        rnn_num_layers = 1
        super().__init__()
        
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.rnn = nn.LSTM(n_flatten, rnn_hidden_size, rnn_num_layers)
        self.linear = nn.Sequential(nn.Linear(rnn_hidden_size, features_dim), nn.ReLU())
        
    def forward(self, observations: th.Tensor) :
        print(observations.shape, observations.dtype)
        b_z, ts, c, h, w = observations.shape
        ii = 0
        y = self.cnn((observations[:,ii]))
        
        out, (hn, cn) = self.rnn(y.unsqueeze(1))
        out = self.linear(out) 
        
        return out 

    

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x
        

In [133]:
## Test encoder
from src.simulation.env_wrapper.parsing_env_wrapper import ParsingEnv
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf, open_dict
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage

import os
def test_encoder(encoder_cls, device):
    num_envs = 3
    with initialize(version_base=None, config_path="src/simulation/conf"):
        cfg = compose(config_name="config")

    env_config = cfg["Environment"]
    with open_dict(env_config):
            object =  "ship" if env_config["use_ship"] else "fork"
            env_config["mode"] = "rest" + "-"+ object +"-"+env_config["background"]
            env_config["random_pos"] = True
            env_config["rewarded"] = True
            env_config["run_id"] = cfg["run_id"] + "_" + "test"
            env_config["rec_path"] = os.path.join(env_config["rec_path"] , f"agent_0/")   
    env = ParsingEnv(**env_config)
    e_gen = lambda : env
    train_env = make_vec_env(env_id=e_gen, n_envs=1)
    train_env = VecTransposeImage(train_env)
        

    
    device = th.device(device)
    encoder = encoder_cls(observation_space=train_env.observation_space, features_dim=50).to(device)
    time_step = train_env.reset()
    time_step = time_step.reshape(time_step.shape[0],1,time_step.shape[1],time_step.shape[2],time_step.shape[3])
    encoder(th.from_numpy(time_step).type(th.FloatTensor).to(device))
    
    print("Encoder test passed!")

In [134]:
test_encoder(CnnLSTMEncoder,"cuda")

[INFO] Connected to Unity environment with package version 2.0.1 and communication version 1.5.0
[INFO] Connected new brain: ChickAgent?team=0
0.26.2
torch.Size([1, 1, 3, 64, 64]) torch.float32
Encoder test passed!
