In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim


# Latent Dynamics Model Components
class Encoder(nn.Module):
    """
    Encodes observations into a latent state.
    """

    def __init__(self, obs_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

    def forward(self, obs):
        return self.fc(obs)


class TransitionModel(nn.Module):
    """
    Predicts the next latent state given the current latent state and action.
    """

    def __init__(self, latent_dim, action_dim):
        super(TransitionModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

    def forward(self, z, action):
        return self.fc(torch.cat([z, action], dim = -1))


class RewardModel(nn.Module):
    """
    Predicts the reward in the latent space.
    """

    def __init__(self, latent_dim):
        super(RewardModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, z):
        return self.fc(z)


# Policy and Value Network
class Policy(nn.Module):
    """
    Outputs actions based on latent states.
    """

    def __init__(self, latent_dim, action_dim):
        super(Policy, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Tanh()  # Assuming actions are in [-1, 1]
        )

    def forward(self, z):
        return self.fc(z)


class ValueNetwork(nn.Module):
    """
    Outputs state values based on latent states.
    """

    def __init__(self, latent_dim):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, z):
        return self.fc(z)


# Model Predictive Control in Latent Space
class MPC:
    """
    Uses a learned model to plan optimal actions in latent space.
    """

    def __init__(self, encoder, transition_model, reward_model, horizon, num_samples, action_dim):
        self.encoder = encoder
        self.transition_model = transition_model
        self.reward_model = reward_model
        self.horizon = horizon
        self.num_samples = num_samples
        self.action_dim = action_dim

    def plan(self, obs):
        """
        Plan optimal actions by sampling and evaluating trajectories.
        """
        z = self.encoder(obs)  # Encode observation to latent state
        best_reward = -float('inf')
        best_action_seq = None

        # Sample action sequences
        action_sequences = torch.randn((self.num_samples, self.horizon, self.action_dim))

        for actions in action_sequences:
            total_reward = 0
            z_temp = z
            for a in actions:
                z_temp = self.transition_model(z_temp, a.unsqueeze(0))
                total_reward += self.reward_model(z_temp)
            if total_reward > best_reward:
                best_reward = total_reward
                best_action_seq = actions

        return best_action_seq[0]  # Return the first action in the sequence


# Training Loop
def train():
    env = gym.make("Pendulum-v1")
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    latent_dim = 16

    # Instantiate models
    encoder = Encoder(obs_dim, latent_dim)
    transition_model = TransitionModel(latent_dim, action_dim)
    reward_model = RewardModel(latent_dim)
    policy = Policy(latent_dim, action_dim)
    value_network = ValueNetwork(latent_dim)

    # Optimizers
    encoder_optim = optim.Adam(encoder.parameters(), lr = 1e-3)
    transition_optim = optim.Adam(transition_model.parameters(), lr = 1e-3)
    reward_optim = optim.Adam(reward_model.parameters(), lr = 1e-3)
    policy_optim = optim.Adam(policy.parameters(), lr = 1e-3)
    value_optim = optim.Adam(value_network.parameters(), lr = 1e-3)

    mpc = MPC(encoder, transition_model, reward_model, horizon = 10, num_samples = 100, action_dim = action_dim)

    for episode in range(1000):  # Number of episodes
        obs = env.reset()
        done = False
        total_reward = 0

        while not done:
            obs_tensor = torch.tensor(obs, dtype = torch.float32).unsqueeze(0)
            action = mpc.plan(obs_tensor)
            action = action.detach().numpy()
            next_obs, reward, done, _ = env.step(action)
            total_reward += reward

            # Encode states and train latent dynamics
            z = encoder(obs_tensor)
            z_next = encoder(torch.tensor(next_obs, dtype = torch.float32).unsqueeze(0))
            predicted_z_next = transition_model(z, torch.tensor(action, dtype = torch.float32).unsqueeze(0))
            reward_pred = reward_model(z)

            # Compute losses
            reconstruction_loss = nn.MSELoss()(predicted_z_next, z_next)
            reward_loss = nn.MSELoss()(reward_pred, torch.tensor([reward], dtype = torch.float32))
            value_loss = nn.MSELoss()(value_network(z), torch.tensor([reward], dtype = torch.float32))

            # Update models
            encoder_optim.zero_grad()
            transition_optim.zero_grad()
            reward_optim.zero_grad()
            policy_optim.zero_grad()
            value_optim.zero_grad()

            (reconstruction_loss + reward_loss + value_loss).backward()

            encoder_optim.step()
            transition_optim.step()
            reward_optim.step()
            value_optim.step()

            obs = next_obs

        print(f"Episode {episode}, Total Reward: {total_reward}")


train()
