In [None]:
from RSSM import RSSM
import gymnasium as gym
import torch
from torch.optim.adam import Adam

env = gym.make("CarRacing-v2", continuous=False)
# env = gym.wrappers.ResizeObservation(env, (96, 96))

from agent_configs.base_config import ConfigBase


class RSSMConfig(ConfigBase):
    def __init__(self, config_dict):
        super().__init__(config_dict)

        self.observation_dimensions = self.parse_field("observation_dimensions")
        self.state_dim = self.parse_field("state_dim")
        self.hidden_dim = self.parse_field("hidden_dim")
        # self.action_dim = self.parse_field("action_dim")
        self.embedding_dim = self.parse_field("embedding_dim")
        self.activation = self.parse_field("activation")
        self.norm = self.parse_field("norm")

        self.optimizer = self.parse_field("optimizer")
        self.learning_rate = self.parse_field("learning_rate")
        self.adam_epsilon = self.parse_field("adam_epsilon")
        self.weight_decay = self.parse_field("weight_decay")
        self.clipnorm = self.parse_field("clipnorm")

        self.prediction_loss_coeff = self.parse_field("prediction_loss_coeff")
        self.dynamics_loss_coeff = self.parse_field("dynamics_loss_coeff")
        self.representation_loss_coeff = self.parse_field("representation_loss_coeff")

        self.replay_buffer_size = self.parse_field("replay_buffer_size")
        self.batch_size = self.parse_field("batch_size")
        self.batch_length = self.parse_field("batch_length")

        self.training_steps = self.parse_field("training_steps")
        self.is_image = self.parse_field("is_image")


rssm_config_dict = {
    "observation_dimensions": env.observation_space.shape[0],
    "state_dim": 512,
    "hidden_dim": 128,
    # "action_dim": env.action_space.n,
    "embedding_dim": 64,  # 256
    "activation": torch.nn.ReLU(),
    # "norm": torch.nn.LayerNorm,
    # "norm": torch.nn.RMSNorm,
    "norm": torch.nn.BatchNorm2d,
    "optimizer": Adam,
    "learning_rate": 1e-3,  # 1e-3
    "adam_epsilon": 1e-8,
    "weight_decay": 0.0,
    "clipnorm": 1000.0,  # 0
    "prediction_loss_coeff": 1.0,
    "dynamics_loss_coeff": 1.0,
    "representation_loss_coeff": 0.1,
    "replay_buffer_size": 100000,
    "batch_size": 16,
    "batch_length": 64,
    "training_steps": 10000,
    "is_image": True,
}

rssm_config = RSSMConfig(rssm_config_dict)

rssm = RSSM(env=env, config=rssm_config)


rssm.train()