# Imports and device

In [49]:
import gymnasium as gym
import torch
from torch import nn
import math
import numpy as np
import random
from tqdm.notebook import tqdm
import ipywidgets as widgets
from IPython.display import display
import os
import matplotlib.pyplot as plt
import copy
from torch import optim
from torch.nn import functional as F
import unittest
import sys
from IPython.display import clear_output
import time  # Add this import at the beginning of the file
from tqdm.notebook import trange
from collections import namedtuple, deque

# Set up device
# device = torch.device("mps")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")
print(device)
print(torch.__version__)
print(gym.__version__)

mps
2.0.1
0.29.0


# Networks

In [50]:
class RepresentationFunction(nn.Module):
    def __init__(self, input_size, representation_size):
        super(RepresentationFunction, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, representation_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [51]:

class InitialRepresentationFunction(nn.Module):
    def __init__(self, input_size, representation_size):
        super(InitialRepresentationFunction, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, representation_size)

    def forward(self, x):
        x = F.relu(self.fc1(x.to(self.fc1.weight.device)))
        x = F.relu(self.fc2(x))
        return x.view(1, -1)

In [52]:
class DynamicsFunction(nn.Module):
    def __init__(self, state_size, action_size, hidden_dim):
        super(DynamicsFunction, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_dim = hidden_dim

        # Define layers for the dynamics function
        self.fc_state = nn.Linear(state_size, hidden_dim)
        self.fc_action = nn.Linear(action_size, hidden_dim)
        self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc_output = nn.Linear(hidden_dim, state_size)

    def forward(self, state, action):
        state = torch.flatten(state, start_dim=1)
        action = torch.flatten(action, start_dim=1)
    
        # Concatenate state and action along the last dimension
        x = torch.cat([state, action], dim=1)


        x_state = self.fc_state(state)
        x_action = self.fc_action(action)

        x = torch.relu(self.fc_hidden(torch.cat([x_state, x_action], dim=1)))
        x = self.fc_output(x)

        return x

In [53]:
class PredictionFunction(nn.Module):
    def __init__(self, representation_size, action_space_size):
        super(PredictionFunction, self).__init__()
        self.fc = nn.Linear(representation_size, 128)  # Note the input size here
        self.out = nn.Linear(128, action_space_size)

    def forward(self, x):
        x = F.relu(self.fc(x))
        x = self.out(x)
        return x


# Config

In [54]:
class Config:
    def __init__(
        self,
        observation_space_size,
        input_size,
        action_space_size,
        representation_size=128,
        batch_size=64,
        learning_rate_prediction=0.001,
        update_interval=100,
        checkpoint_interval=1000,
        replay_buffer_capacity=10000,
    ):
        self.input_size = input_size
        self.action_space_size = action_space_size
        self.representation_size = representation_size
        self.batch_size = batch_size
        self.learning_rate_prediction = learning_rate_prediction
        self.update_interval = update_interval
        self.checkpoint_interval = checkpoint_interval
        self.replay_buffer_capacity = replay_buffer_capacity
        self.observation_space_size = observation_space_size
        self.environment_name = "Breakout-v4"
        self.render_mode = 'human'
        self.max_total_steps = 100000
        self.epsilon = 0.1
        self.state_size = 4
        self.action_size = 2
        self.hidden_dim = 128
        self.gamma = 0.99



# MCTS

In [55]:
class Node:
    def __init__(self, hidden_state, reward, terminal, action_space):
        self.hidden_state = hidden_state
        self.reward = reward
        self.terminal = terminal
        self.children = [None] * action_space
        self.total_value = [0] * action_space
        self.visit_count = [0] * action_space

In [56]:
class MCTS:
    def __init__(self, action_space_size, representation_function, dynamics_function, prediction_function):
        self.action_space_size = action_space_size
        self.num_simulations = config.num_simulations
        self.discount = config.mcts_discount
        self.root = None
        self.dynamics_function = dynamics_function
        self.prediction_function = prediction_function
        self.exploration_constant = config.exploration_constant

    def UCB_score(self, node, action):
        if node.visit_count[action] == 0:
            return float('inf')
        else:
            # Use the model to predict the value of the action
            state = node.hidden_state
            predicted_values = self.prediction_function(state)
            Q = predicted_values[0][action]
            U = self.exploration_constant * math.sqrt(math.log(sum(node.visit_count)) / node.visit_count[action])
            return Q + U

    def expand(self, node, action):
        next_state, reward = self.dynamics_function(node.hidden_state, torch.tensor([action], dtype=torch.float32).to(device))
        next_state = next_state.clone().detach().to(device)
        reward = reward.item()
        return Node(next_state, reward, False, self.action_space_size)

    def backpropagate(self, leaf_value, path):
        for node, action in reversed(path):
            node.visit_count[action] += 1
            node.total_value[action] += leaf_value
            leaf_value *= self.discount

    def run(self, initial_state):
        # Create root node with initial state
        initial_hidden_state = initial_state
        self.root = Node(initial_hidden_state, 0, False, self.action_space_size)

        for _ in range(self.num_simulations):
            node, path = self.root, []
            while node is not None:
                best_action = max(range(self.action_space_size), key=lambda a: self.UCB_score(node, a))
                path.append((node, best_action))
                if node.children[best_action] is None:
                    node.children[best_action] = self.expand(node, best_action)
                node = node.children[best_action]

            leaf = path[-1][0]
            self.backpropagate(leaf.reward, path)

        return self.root



# Replay Buffer(s)

In [57]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def add(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*batch)
        return state_batch, action_batch, reward_batch, next_state_batch, done_batch

    def __len__(self):
        return len(self.buffer)


# Agent

In [58]:
class Agent:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.env = gym.make("CartPole-v1")
        self.action_space = self.env.action_space.n
        self.action_space_size = self.env.action_space.n
        self.dynamics_input_size = self.config.representation_size + self.config.action_space_size
        self.representation_function = RepresentationFunction(
            input_size=self.config.input_size,
            representation_size=self.config.representation_size,
        ).to(self.device)
        self.dynamics_function = DynamicsFunction(
            state_size=self.config.representation_size,
            action_size=self.config.action_space_size,
            hidden_dim=64
        ).to(self.device)
        self.prediction_function = PredictionFunction(
            representation_size=self.config.representation_size,
            action_space_size=self.config.action_space_size
        ).to(self.device)
        self.replay_buffer = ReplayBuffer(config.replay_buffer_capacity)
        self.total_steps = 0
        self.training_steps_completed = 0
        self.optimizer_representation = optim.Adam(self.representation_function.parameters())
        self.optimizer_prediction = optim.Adam(self.prediction_function.parameters())
        self.epsilon_start = config.epsilon
        self.epsilon_final = 0.01
        self.epsilon_decay_duration = config.max_total_steps // 2
        self.best_loss = float('inf')

    
    def get_epsilon(self, step):
        # Linearly decay epsilon
        epsilon_start = 1.0
        epsilon_final = 0.1
        epsilon_decay_duration = self.config.max_total_steps // 2
        epsilon = epsilon_final + (epsilon_start - epsilon_final) * max(0, (epsilon_decay_duration - step) / epsilon_decay_duration)
        return epsilon

    def save_checkpoint(self):
        torch.save(self.representation_function.state_dict(), 'representation_function_best.pth')
        torch.save(self.dynamics_function.state_dict(), 'dynamics_function_best.pth')
        torch.save(self.prediction_function.state_dict(), 'prediction_function_best.pth')

    def concat_state_action(self, state_repr, action_tensor):
        action_one_hot = F.one_hot(action_tensor, num_classes=self.action_space).float()
        state_action_repr = torch.cat([state_repr, action_one_hot], dim=-1)
        return state_action_repr
    
    def get_action(self, state_repr, epsilon):
        if np.random.random() < epsilon:
            return np.random.choice(self.action_space_size)
    
        # Exploitation
        q_values = torch.zeros(self.action_space_size, device=self.device)
        for action in range(self.action_space_size):
            action_one_hot = F.one_hot(torch.tensor([action]), num_classes=self.action_space_size).to(self.device)
            # Concatenate state and action and feed to prediction function
            state_action = torch.cat([state_repr, action_one_hot], dim=1)
            state_action = state_action.unsqueeze(0)  # Add an extra dimension for the batch size
            q_value = self.prediction_function(state_action)
            q_values[action] = q_value.item()
    
        action = torch.argmax(q_values).item()  # Choose the action with the highest Q value
        assert action < self.action_space_size, f"Invalid action {action} selected"
        return action
    
    def select_action(self, state_repr, epsilon):
        if np.random.rand() < epsilon:  # Exploration
            return random.choice(range(self.action_space_size))

        # Exploitation
        q_values = self.prediction_function(state_repr)  # Using state_repr here
        action = torch.argmax(q_values).item()
        return action

    

    def train(self):
        # Sample a batch from the replay buffer
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.replay_buffer.sample(self.config.batch_size)
    
        # Convert batches to tensors
        state_batch = torch.tensor(state_batch, dtype=torch.float32).to(self.device)
        action_batch_tensor = torch.tensor(action_batch, dtype=torch.int64).to(self.device).unsqueeze(1)
        reward_batch = torch.tensor(reward_batch, dtype=torch.float32).unsqueeze(1).to(self.device)
        next_state_batch = torch.tensor(next_state_batch, dtype=torch.float32).to(self.device)
        done_batch = torch.tensor(done_batch, dtype=torch.float32).unsqueeze(1).to(self.device)
    
        # Compute state representations
        state_repr = self.representation_function(state_batch)
        next_state_repr = self.representation_function(next_state_batch)
    
        # One-hot encode the actions
        action_batch_one_hot = F.one_hot(action_batch_tensor.squeeze(), num_classes=self.config.action_space_size).float().to(self.device)
    
        # Forward pass through the dynamics function
        predicted_next_state_repr = self.dynamics_function(state_repr, action_batch_one_hot)
    
        # Compute the dynamics loss
        dynamics_loss = F.mse_loss(predicted_next_state_repr, next_state_repr)
    
        # Compute the prediction loss
        q_values = self.prediction_function(state_repr).gather(1, action_batch_tensor)
        next_q_values = self.prediction_function(next_state_repr).detach()
        max_next_q_values = next_q_values.max(1)[0].unsqueeze(1)
        target_q_values = reward_batch + self.config.gamma * max_next_q_values * (1 - done_batch)
        prediction_loss = F.mse_loss(q_values, target_q_values)
    
        # Zero the gradients for both the dynamics and prediction networks
        self.optimizer_representation.zero_grad()
        self.optimizer_prediction.zero_grad()
    
        # Backpropagate the total loss
        total_loss = dynamics_loss + prediction_loss
        total_loss.backward()

        if total_loss < self.best_loss:
            self.best_loss = total_loss
            self.save_checkpoint()
    
        # Update the parameters for both the dynamics and prediction networks
        self.optimizer_representation.step()
        self.optimizer_prediction.step()
    
        return total_loss.item()  # Return the combined loss

    def run(self):
        self.env = gym.make(self.config.environment_name, render_mode=self.config.render_mode)
        state, _ = self.env.reset()
        state_repr = self.representation_function(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device))
        done = False
        episode_reward = 0
    
        # Progress bar initialization
        progress_bar = tqdm(total=self.config.max_total_steps, desc="Training progress")
    
        while self.total_steps < self.config.max_total_steps:
            # Get epsilon for this step
            epsilon = self.get_epsilon(self.total_steps)
    
            # Select action using the epsilon-greedy policy
            action = self.select_action(state_repr, epsilon)
            
            next_state, reward, done, _, _ = self.env.step(action)
            next_state_repr = self.representation_function(torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(self.device))
    
            # Append the transition to the replay buffer
            self.replay_buffer.add(state, action, reward, next_state, done)  # Use state and next_state, not state_repr and next_state_repr
    
            state = next_state
            state_repr = next_state_repr
            episode_reward += reward
            self.total_steps += 1
    
            if done:
                state, _ = self.env.reset()
                state_repr = self.representation_function(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device))
    
            if len(self.replay_buffer) >= self.config.batch_size and self.total_steps % self.config.update_interval == 0:
                print("Updating parameters...")
                loss = self.train()
                self.training_steps_completed += 1
                if self.training_steps_completed % self.config.checkpoint_interval == 0:
                    self.save_checkpoint()
    
                if loss is not None:
                    print(f"Step: {self.total_steps}, Loss: {loss:.4f}")
    
            # Update the progress bar
            progress_bar.update(1)
    
        # Close the progress bar
        progress_bar.close()
    
        self.env.close()



    def populate_initial_buffer(self):
        state, _ = self.env.reset()
        state_repr = self.representation_function(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device))

        for _ in range(self.config.initial_buffer_size):
            action = self.env.action_space.sample()
            next_state, reward, done, _, _ = self.env.step(action)
            next_state_repr = self.representation_function(torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(self.device))
            # Append the transition to the replay buffer
            self.replay_buffer.add(state_repr.detach().cpu().numpy(), action, reward, next_state_repr.detach().cpu().numpy(), done)

            if done:
                state, _ = self.env.reset()
            else:
                state = next_state

    def update_parameters(self):
        # Sample a batch of transitions from the replay buffer
        batch = random.sample(self.replay_buffer, self.config.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
    
        # Convert the variables to tensors and move them to the appropriate device
        states = torch.stack(states).to(self.device)
        actions = torch.tensor(actions).unsqueeze(-1).to(self.device)  # Reshape actions to match states
        rewards = torch.tensor(rewards).to(self.device)
        next_states = torch.stack(next_states).to(self.device)
        dones = torch.tensor(dones).to(self.device)
    
        # Compute the loss for each of the three networks
        # 1. Representation loss
        predicted_state_repr = self.dynamics_function(states, actions)
        target_state_repr = self.representation_function(next_states)
        repr_loss = F.mse_loss(predicted_state_repr, target_state_repr.detach())
    
        # 2. Dynamics loss
        predicted_next_state_repr, predicted_reward = self.dynamics_function(states, actions)
        dynamics_loss = F.mse_loss(predicted_next_state_repr, target_state_repr.detach()) + F.mse_loss(predicted_reward, rewards.unsqueeze(-1).detach())
    
        # 3. Prediction loss
        predicted_action_values, predicted_reward = self.prediction_function(states)
        target_action_values = self.target_prediction_function(target_state_repr).detach()
        pred_loss = F.mse_loss(predicted_action_values, target_action_values) + F.mse_loss(predicted_reward, rewards.unsqueeze(-1).detach())
    
        # Combine the three losses and backpropagate
        loss = repr_loss + dynamics_loss + pred_loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
        # Update the target networks
        self._update_target_network(self.prediction_function, self.target_prediction_function, self.config.soft_update_tau)
    
        return loss.item()
 
    def play(self, num_episodes=1):
        # Load the saved models
        self.load_checkpoint()

        # Play the specified number of episodes
        for episode in range(num_episodes):
            state, _ = self.env.reset()
            state_repr = self.representation_function(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device))
            done = False
            episode_reward = 0

            while not done:
                action = self.select_action(state_repr, epsilon=0.0)
                next_state, reward, done, _, _ = self.env.step(action)
                next_state_repr = self.representation_function(torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(self.device))

                state = next_state
                state_repr = next_state_repr
                episode_reward += reward

                self.env.render()  # Uncomment this line if you want to visualize the gameplay

            print(f"Episode {episode + 1}: Total Reward = {episode_reward}")

        self.env.close()

    def save_checkpoint(self):
        torch.save(self.representation_function.state_dict(), 'representation_function.pth')
        torch.save(self.dynamics_function.state_dict(), 'dynamics_function.pth')
        torch.save(self.prediction_function.state_dict(), 'prediction_function.pth')

    def load_checkpoint(self):
        if os.path.isfile('representation_function.pth'):
            self.representation_function.load_state_dict(torch.load('representation_function.pth'))
        if os.path.isfile('dynamics_function.pth'):
            self.dynamics_function.load_state_dict(torch.load('dynamics_function.pth'))
        if os.path.isfile('prediction_function.pth'):
            self.prediction_function.load_state_dict(torch.load('prediction_function.pth'))


# Test the networks

In [59]:
class TestFunctions(unittest.TestCase):
    def setUp(self):
        observation_space_size = 4
        input_size = 4
        action_space_size = 2
        self.config = Config(observation_space_size, input_size, action_space_size)
        self.env = gym.make(self.config.environment_name, render_mode=self.config.render_mode)
        self.representation_size = 128
        action_space_size = self.env.action_space.n
        self.agent = Agent(self.config, self.representation_size, self.action_space_size)
        self.representation_function = RepresentationFunction(self.env.observation_space.shape[0], self.representation_size).to(device)
        self.dynamics_input_size = representation_size + self.env.action_space.n
        self.dynamics_function = DynamicsFunction(self.dynamics_input_size, representation_size, self.env.action_space.n).to(device)
        self.prediction_function = PredictionFunction(self.representation_size, action_space_size).to(device)

    def test_representation_function(self):
        representation_function = RepresentationFunction(self.env.observation_space.shape[0], self.representation_size).to(device)
        input_tensor = torch.randn(1, self.env.observation_space.shape[0]).to(device)
        output_tensor = representation_function(input_tensor)
        self.assertEqual(output_tensor.shape, (1, self.representation_size))

    def test_initial_representation_function(self):
        state, _ = self.env.reset()
        state_tensor = torch.tensor(state).float().unsqueeze(0).to(device)
        output_tensor = self.agent.representation_function(state_tensor)
        self.assertEqual(output_tensor.shape, (1, self.representation_size))

    def test_dynamics_function(self):
        state_repr = torch.randn(1, self.representation_size).to(device)
        action = torch.randn(1, self.env.action_space.n).to(device)
        output_tensor = self.agent.dynamics_function(state_repr, action)
        self.assertEqual(output_tensor.shape, (1, self.representation_size))

    def test_prediction_function(self):
        input_tensor = torch.randn(1, self.representation_size).to(device)
        output_tensor = self.agent.prediction_function(input_tensor)
        self.assertEqual(output_tensor.shape, (1, self.env.action_space.n))


if __name__ == "__main__":
    # Load and run the test case
    suite = unittest.TestLoader().loadTestsFromTestCase(TestFunctions)
    result = unittest.TextTestRunner().run(suite)

EEEE
ERROR: test_dynamics_function (__main__.TestFunctions)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/l5/r3g8kwp54k9f9vc409rvbkrh0000gn/T/ipykernel_9196/1157417371.py", line 7, in setUp
    self.env = gym.make(self.config.environment_name, render_mode=self.config.render_mode)
  File "/Users/bigmaggi/miniforge3/lib/python3.10/site-packages/gymnasium/envs/registration.py", line 741, in make
    env_spec = _find_spec(id)
  File "/Users/bigmaggi/miniforge3/lib/python3.10/site-packages/gymnasium/envs/registration.py", line 527, in _find_spec
    _check_version_exists(ns, name, version)
  File "/Users/bigmaggi/miniforge3/lib/python3.10/site-packages/gymnasium/envs/registration.py", line 393, in _check_version_exists
    _check_name_exists(ns, name)
  File "/Users/bigmaggi/miniforge3/lib/python3.10/site-packages/gymnasium/envs/registration.py", line 370, in _check_name_exists
    raise error.NameNotFound(
gym

# Run it all

In [60]:
if __name__ == "__main__":
      # Create the environment
    env = gym.make("Breakout-v4", render_mode="human")

    # Define configuration parameters
    config = Config(
        observation_space_size=env.observation_space.shape[0],
        action_space_size=env.action_space.n,
        representation_size=128,
        batch_size=64,
        learning_rate_prediction=0.001,
        update_interval=100,
        checkpoint_interval=1000,
        replay_buffer_capacity=10000,
        input_size=env.observation_space.shape[0],
    )

    # Create an instance of the Agent class
    agent = Agent(config=config )

    # Run the agent
    agent.run()

    # let the agent play the game
    agent.play(num_episodes=10)


NameNotFound: Environment `Breakout` doesn't exist.