# ***Animations***
---

### **Wrapper**

In [2]:
import gymnasium as gym
from gymnasium import spaces
import torch
import numpy as np
import pygame

class WorldModelEnv(gym.Wrapper):
    def __init__(self, env_id='CarRacing-v3', vision_model=None, memory_model=None, device='cuda', render_mode='rgb_array'):
        super().__init__(gym.make(env_id, lap_complete_percent=1.0, render_mode=render_mode))
        
        self.device = device
        self.vision_model = vision_model
        self.memory_model = memory_model
        
        self.latent_dim = vision_model.latent_dim
        self.hidden_dim = memory_model.hidden_dim
        
        self.observation_space = spaces.Box(
            low=-np.inf, 
            high=np.inf,
            shape=(self.latent_dim + self.hidden_dim,),
            dtype=np.float32
        )
        
        self.action_space = spaces.Box(
            low=np.array([-1.0, 0.0, 0.0]),
            high=np.array([1.0, 1.0, 1.0]),
            dtype=np.float32
        )
        
        self.hidden_state = None
        self.z = None

        # Inicializamos pygame
        pygame.init()
        self.screen = None  # Se creará al llamar a render()
        self.clock = pygame.time.Clock()
        
    def reset(self, **kwargs):
        obs, info = super().reset(**kwargs)
        
        # Initial hidden state
        self.hidden_state = self.memory_model.rnn.init_hidden(1, self.device)
        
        # Get latent space
        self.z = self._encode_obs(obs)
        
        # Concatenates z and actual h
        full_state = self._get_full_state(self.z)
        
        return full_state, info
    
    def step(self, action):
        obs, reward, terminated, truncated, info = super().step(action)
        
        self.z = self._encode_obs(obs)
        
        # Update the hidden state with the MDNRNN
        with torch.no_grad():
            z_tensor = self.z.unsqueeze(1) 
            action_tensor = torch.tensor(action, dtype=torch.float32).unsqueeze(0).unsqueeze(1).to(self.device)
            _, self.hidden_state = self.memory_model.rnn(z_tensor, action_tensor, self.hidden_state)
        
        # Concatenates z and actual h
        full_state = self._get_full_state(self.z)
        
        return full_state, reward, terminated, truncated, info
    

        
    def _encode_obs(self, obs, size=(96, 96), cropp=12):
        """ Preprocess the image and encode it with VAE"""
        with torch.no_grad():
            obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device)
            obs_tensor = obs_tensor.permute(0, 3, 1, 2) / 255.0
            obs_tensor = obs_tensor[:, :, :-cropp, :]
            obs_tensor = torch.nn.functional.interpolate(obs_tensor, size=size, mode='bicubic')
            
            mu, logvar = self.vision_model.encode(obs_tensor)
            z = self.vision_model.reparameterize(mu, logvar)
            
        return z
    
    
    def _decode_obs(self, z, size=(600,600)):
        """Decodes the latent state and reconstructs the image"""
        with torch.no_grad():
            reconstructed_image = self.vision_model.decode(z)
            reconstructed_image = torch.nn.functional.interpolate(reconstructed_image, size=size, mode='bicubic').squeeze(0)
            reconstructed_image = reconstructed_image.permute(1, 2, 0).cpu().numpy()
            reconstructed_image = (reconstructed_image * 255).astype(np.uint8)
        
        return reconstructed_image
    
    def render_vae(self):
        return self._decode_obs(self.z)
    
    def _get_full_state(self, z):
        """Concatenates the latent state with the hidden state"""
        h = self.hidden_state[0].squeeze(0)
        full_state = torch.cat([z, h], dim=-1)
        return full_state.cpu().numpy()

In [3]:
import torch
from models.controller import Controller
from models.mdnrnn import MDNRNN
from models.vae import VAE

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

LATENT_DIM = 32
ACTION_DIM = 3
HIDDEN_DIM = 256
NUM_GAUSSIANS = 5

path = 'checkpoints'
# VAE model
vision = VAE(3, LATENT_DIM).to(device)
vision.load_state_dict(torch.load(f'{path}/vae.pth', weights_only=False))
vision.eval()

# LSTM-MDN model
memory = MDNRNN(latent_dim=LATENT_DIM, action_dim=3, hidden_dim=HIDDEN_DIM, num_gaussians=NUM_GAUSSIANS).to(device)
memory.load_state_dict(torch.load(f'{path}/memory.pth', weights_only=False))
memory.eval()


controller = Controller(LATENT_DIM+HIDDEN_DIM, ACTION_DIM).to(device)
controller.load_state_dict(torch.load(f'{path}/controller.pth', weights_only=False))

env =  WorldModelEnv(env_id='CarRacing-v3', device=device, vision_model=vision, memory_model=memory)


  gym.logger.warn(
  gym.logger.warn(


### **CMA-ES Parallel Animation**

In [None]:
import cv2
from tqdm import tqdm

env =  WorldModelEnv(env_id='CarRacing-v3', device=device, vision_model=vision, memory_model=memory)
render_dataset = []


for episode in tqdm(range(16)):
    episode_images = []
    obs, info = env.reset()
    for _ in range(1000):
        img = env.render()
        img = cv2.resize(img[:-50, :, :], (300,200))
        img = np.array(img, dtype=np.uint8)
        episode_images.append(img)
        
        action = controller.get_action(torch.tensor(obs, dtype=torch.float32).to(device))
        obs, reward, terminated, truncated, info = env.step(action.cpu().numpy())
    
    render_dataset.append(episode_images)
    
    

In [None]:
import cv2
import numpy as np

def create_combined_video(render_dataset, output_path='combined_video.mp4', fps=30):
    """
    Combina los 16 episodios en un único video que muestra los 16 episodios al mismo tiempo en una cuadrícula 4x4.

    Args:
    - render_dataset (list): Lista de episodios, cada uno es una lista de imágenes.
    - output_path (str): Ruta de salida para el archivo de video.
    - fps (int): Cuadros por segundo del video.
    """
    
    # Verifica que hay 16 episodios y obtén el tamaño del primer frame
    assert len(render_dataset) == 16, "render_dataset debe contener 16 episodios."
    frame_shape = render_dataset[0][0].shape  # Asume que todos los episodios tienen el mismo tamaño de frame

    # Configura la resolución del video final (cuadrícula 4x4)
    frame_height, frame_width, _ = frame_shape
    grid_height = frame_height * 4
    grid_width = frame_width * 4
    
    # Configura el escritor de video
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (grid_width, grid_height))

    # Itera a través de cada frame en el episodio
    num_frames = len(render_dataset[0])  # Número de frames en cada episodio
    for frame_idx in range(num_frames):
        # Crea una cuadrícula de frames (4x4)
        grid_frame = np.zeros((grid_height, grid_width, 3), dtype=np.uint8)
        
        for i in range(4):
            for j in range(4):
                # Calcula el índice del episodio correspondiente
                episode_idx = i * 4 + j
                # Obtiene el frame actual del episodio y convierte a RGB
                episode_frame = cv2.cvtColor(render_dataset[episode_idx][frame_idx], cv2.COLOR_BGR2RGB)
                # Coloca el frame en la cuadrícula
                y_start, y_end = i * frame_height, (i + 1) * frame_height
                x_start, x_end = j * frame_width, (j + 1) * frame_width
                grid_frame[y_start:y_end, x_start:x_end] = episode_frame

        # Escribe el frame de la cuadrícula en el video
        out.write(grid_frame)
    
    # Libera el escritor de video
    out.release()
    print(f"Video creado en: {output_path}")

# Uso:
create_combined_video(render_dataset, output_path='videos/cma_training.mp4', fps=30)


Video creado en: render_cma_train.mp4


# Frames