# Imports and device

In [467]:
import gymnasium as gym
import torch
from torch import nn
import math
import numpy as np
import random
from tqdm.notebook import tqdm
import ipywidgets as widgets
from IPython.display import display
import os
import matplotlib.pyplot as plt
import copy
from torch import optim
from torch.nn import functional as F
import unittest
import sys
from IPython.display import clear_output
import time  # Add this import at the beginning of the file
from tqdm.notebook import trange

# Set up device
# device = torch.device("mps")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.__version__)
print(gym.__version__)

cuda
2.0.1+cu117
0.29.0


In [468]:
def scalar_to_onehot(self, action, action_space_size):
    action_onehot = torch.zeros(action_space_size).to(device)
    action_onehot[action] = 1
    return action_onehot.view(1, -1)


# Networks

In [469]:
class RepresentationFunction(nn.Module):
    def __init__(self, input_size, representation_size):
        super(RepresentationFunction, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, representation_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [470]:

class InitialRepresentationFunction(nn.Module):
    def __init__(self, input_size, representation_size):
        super(InitialRepresentationFunction, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, representation_size)

    def forward(self, x):
        x = F.relu(self.fc1(x.to(self.fc1.weight.device)))
        x = F.relu(self.fc2(x))
        return x.view(1, -1)

In [471]:
class DynamicsFunction(nn.Module):
    def __init__(self, input_size, representation_size, action_space_size):
        super(DynamicsFunction, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, representation_size)
        self.action_space_size = action_space_size

    def forward(self, state_repr, action):
        batch_size = state_repr.size(0)
        action_onehot = self.scalar_to_onehot(action, self.action_space_size)
        x = torch.cat((state_repr, action_onehot), dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def scalar_to_onehot(self, action, action_space_size):
        action_onehot = torch.zeros(action_space_size).to(device)
        action_onehot[action] = 1
        return action_onehot.view(1, -1)

In [472]:
class PredictionFunction(nn.Module):
    def __init__(self, representation_size, action_space_size):
        super(PredictionFunction, self).__init__()
        self.fc1 = nn.Linear(representation_size, 128)
        self.fc2 = nn.Linear(128, action_space_size)

    def forward(self, state_repr):
        x = F.relu(self.fc1(state_repr))
        x = self.fc2(x)
        return x


# Config

In [473]:
class Config:
    def __init__(self):
        self.environment_name = "LunarLander-v2"
        self.render_mode = "human"  # "human" or "rgb_array"
        self.seed = 42069
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.training_steps = 10000
        self.lr = 0.001
        self.representation_size = 64
        self.action_space_size = 4
        self.min_priority = 0.1
        self.checkpoint_interval = 1000
        self.num_unroll_steps = 5
        self.td_steps = 10
        self.discount = 0.997
        self.exploration_constant = 1.0
        self.learning_rate = 0.001
        self.representation_size = 64
        self.batch_size = 128
        self.num_epochs = 10
        self.replay_buffer_capacity = 1000
        self.alpha = 0.6
        self.beta = 0.4
        self.beta_increment = 0.001
        self.eps = 0.01
        self.num_simulations = 50
        self.temperature = 1.0
        self.dirichlet_alpha = 0.25
        self.noise_weight = 0.25
        self.gradient_clip = 40.0
        self.max_moves = 27000
        self.mcts_discount = self.discount
        self.episodes = 100000
        self.gamma = 0  # Discount factor for the Bellman equation
        self.lr_repr = 0.0001
        self.lr_dyn = 0.0001
        self.lr_pred = 0.0001
        self.update_interval = 10
        self.max_steps = 200
        self.max_episode_length = 200
        self.epsilon_start = 1.0
        self.epsilon_end = 0.1
        self.epsilon_decay = 200
        self.start_steps = 100
        self.reset_intervall = 750
        self.action_space_size = 4
        self.replay_initial = 1000

config = Config()


# MCTS

In [474]:
class Node:
    def __init__(self, hidden_state, reward, terminal, action_space):
        self.hidden_state = hidden_state
        self.reward = reward
        self.terminal = terminal
        self.children = [None] * action_space
        self.total_value = [0] * action_space
        self.visit_count = [0] * action_space

In [475]:
class MCTS:
    def __init__(self, action_space_size, representation_function, dynamics_function, prediction_function):
        self.action_space_size = action_space_size
        self.num_simulations = config.num_simulations
        self.discount = config.mcts_discount
        self.root = None
        self.dynamics_function = dynamics_function
        self.prediction_function = prediction_function
        self.exploration_constant = config.exploration_constant

    def UCB_score(self, node, action):
        if node.visit_count[action] == 0:
            return float('inf')
        else:
            # Use the model to predict the value of the action
            state = node.hidden_state
            predicted_values = self.prediction_function(state)
            Q = predicted_values[0][action]
            U = self.exploration_constant * math.sqrt(math.log(sum(node.visit_count)) / node.visit_count[action])
            return Q + U

    def expand(self, node, action):
        next_state, reward = self.dynamics_function(node.hidden_state, torch.tensor([action], dtype=torch.float32).to(device))
        next_state = next_state.clone().detach().to(device)
        reward = reward.item()
        return Node(next_state, reward, False, self.action_space_size)

    def backpropagate(self, leaf_value, path):
        for node, action in reversed(path):
            node.visit_count[action] += 1
            node.total_value[action] += leaf_value
            leaf_value *= self.discount

    def run(self, initial_state):
        # Create root node with initial state
        initial_hidden_state = initial_state
        self.root = Node(initial_hidden_state, 0, False, self.action_space_size)

        for _ in range(self.num_simulations):
            node, path = self.root, []
            while node is not None:
                best_action = max(range(self.action_space_size), key=lambda a: self.UCB_score(node, a))
                path.append((node, best_action))
                if node.children[best_action] is None:
                    node.children[best_action] = self.expand(node, best_action)
                node = node.children[best_action]

            leaf = path[-1][0]
            self.backpropagate(leaf.reward, path)

        return self.root



# Replay Buffer(s)

In [476]:

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = []
        self.capacity = capacity
        self.position = 0

    def push(self, experience):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = experience
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)


class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, eps=0.01):
        super().__init__(capacity)
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment
        self.eps = eps
        self.priorities = np.zeros((capacity,), dtype=np.float32)

    def push(self, experience):
        max_priority = self.priorities.max() if self.buffer else 1.0
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
        else:
            self.buffer[self.position] = experience
        self.priorities[self.position] = max_priority
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        if len(self.buffer) == self.capacity:
            priorities = self.priorities
        else:
            priorities = self.priorities[:self.position]
        probabilities = priorities ** self.alpha
        probabilities /= probabilities.sum()
        indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
        experiences = [self.buffer[idx] for idx in indices]
        weights = (len(self.buffer) * probabilities[indices]) ** (-self.beta)
        weights /= weights.max()
        self.beta = min(1.0, self.beta + self.beta_increment)
        return experiences, indices, np.array(weights, dtype=np.float32)

    def update_priorities(self, batch_indices, batch_priorities):
        for idx, priority in zip(batch_indices, batch_priorities):
            self.priorities[idx] = priority


# Agent

In [477]:
class Agent:
    def __init__(self, config, representation_size, action_space_size):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.env = gym.make(config.environment_name, render_mode=config.render_mode)
        self.representation_function = RepresentationFunction(representation_size=representation_size, input_size=self.env.observation_space.shape[0]).to(self.device)
        self.representation_size = representation_size
        self.dynamics_input_size = representation_size + action_space_size
        self.action_size = action_space_size
        self.dynamics_function = DynamicsFunction(self.dynamics_input_size, representation_size, self.action_size).to(device)
        self.prediction_function = PredictionFunction(self.representation_size, action_space_size).to(self.device)
        self.action_space_size = action_space_size

        self.optimizer_representation = optim.Adam(self.representation_function.parameters(), lr=config.learning_rate)
        self.optimizer_dynamics = optim.Adam(self.dynamics_function.parameters(), lr=config.learning_rate)
        self.optimizer_prediction = optim.Adam(self.prediction_function.parameters(), lr=config.learning_rate)
        self.total_steps = 0
        self.training_steps_completed = 0
        self.mcts = MCTS(action_space_size=config.action_space_size, representation_function=self.representation_function, dynamics_function=self.dynamics_function, prediction_function=self.prediction_function)
        self.replay_buffer = ReplayBuffer(config.replay_buffer_capacity)

    def get_action(self, state_repr, epsilon):
        if np.random.random() < epsilon:
            return np.random.randint(self.env.action_space.n)
        else:
            with torch.no_grad():
                # Create a tensor to store the Q values for all possible actions
                q_values = torch.empty(self.env.action_space.n).to(self.device)
                # Calculate Q value for each action
                for action in range(self.env.action_space.n):
                    action_one_hot = F.one_hot(torch.tensor([action]), self.env.action_space.n).float().to(self.device)
                    # Concatenate state and action and feed to prediction function
                    state_action = torch.cat([state_repr, action_one_hot], dim=1)
                    q_values[action] = self.prediction_function(state_action)
                # Choose the action with the highest Q value
                action = torch.argmax(q_values).item()  # Returns a Python integer
                return action

    def print_progress(self, current_step, total_steps):
        print(f"Training Agent: {current_step}/{total_steps} steps completed")

    def animate(self):
        state = self.env.reset()
        print(state)
        state_repr = self.representation_function(torch.tensor(state[0], dtype=torch.float32).unsqueeze(0).to(self.device))
        frames = []
        done = False  # Initialize done before using it in the loop
        while not done:
            action = self.get_action(state_repr, max(self.config.epsilon_end, self.config.epsilon_start - self.total_steps / self.config.epsilon_decay))
            next_state, _, done, _, _ = self.env.step(action)  # Extract the integer action value using .item()
            next_state_repr = self.representation_function(torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(self.device))
            # Render the environment as RGB array
            frame = self.env.render()
            frames.append([plt.imshow(frame, animated=True)])
            state_repr = next_state_repr
            self.total_steps += 1
        plt.show()
        return animation.ArtistAnimation(plt.gcf(), frames, interval=50, blit=True, repeat_delay=1000)

    def train(self):
        if len(self.replay_buffer) < self.config.replay_initial:
            return
    
        state, _ = self.env.reset()
        state_repr = self.representation_function(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device))
        done = False
    
        while not done:
            action = self.get_action(state_repr, self.config.epsilon_end)
            next_state, reward, done, _, _ = self.env.step(action)
            next_state_repr = self.representation_function(torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(self.device))
    
            self.replay_buffer.push((state_repr.detach().cpu().numpy(), action, reward, next_state_repr.detach().cpu().numpy(), done))
            state_repr = next_state_repr
    
            if done:
                state, _ = self.env.reset()
                state_repr = self.representation_function(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device))
    
            if len(self.replay_buffer) >= self.config.replay_initial:
                self.update_parameters()


    
    def run(self):
        while self.training_steps_completed < self.config.training_steps:
            state = self.env.reset()
            state_repr = self.representation_function(torch.tensor(state, dtype=torch.float32).to(self.device))  # Remove unsqueeze(0)
            done = False
            while not done:
                action = self.get_action(state_repr, max(self.config.epsilon_end, self.config.epsilon_start - self.total_steps / self.config.epsilon_decay))
                next_state, reward, done, _, _ = self.env.step(action)  # Convert action tensor to integer
                next_state_repr = self.representation_function(torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(self.device))
                self.replay_buffer.push((state_repr.detach().cpu().numpy(), action, reward, next_state_repr.detach().cpu().numpy(), done))
                state_repr = next_state_repr
                self.total_steps += 1
                if self.total_steps % self.config.update_interval == 0:
                    self.train()
            self.training_steps_completed += 1
            if self.training_steps_completed % self.config.checkpoint_interval == 0:
                self.save_checkpoint()

    def populate_initial_buffer(self):
        state, _ = self.env.reset()
        print(type(state), state)
        print("Initial state:", state)  
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
        print("State tensor shape:", state_tensor.shape)  # Add this line
        state_repr = self.representation_function(state_tensor)
        done = False
    
        while len(self.replay_buffer) < self.config.replay_initial:
            action = self.get_action(state_repr, max(self.config.epsilon_end, self.config.epsilon_start - self.total_steps / self.config.epsilon_decay))
            next_state, reward, done, _, _ = self.env.step(action)  # Convert action tensor to integer
            next_state_repr = self.representation_function(torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(self.device))
            self.replay_buffer.push((state_repr.detach().cpu().numpy(), action, reward, next_state_repr.detach().cpu().numpy(), done))
    
            if done:
                state, _ = self.env.reset()
                state_repr = self.representation_function(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device))
            else:
                state_repr = next_state_repr

    def save_checkpoint(self):
        torch.save(self.representation_function.state_dict(), 'representation_function.pth')
        torch.save(self.dynamics_function.state_dict(), 'dynamics_function.pth')
        torch.save(self.prediction_function.state_dict(), 'prediction_function.pth')

    def load_checkpoint(self):
        if os.path.isfile('representation_function.pth'):
            self.representation_function.load_state_dict(torch.load('representation_function.pth'))
        if os.path.isfile('dynamics_function.pth'):
            self.dynamics_function.load_state_dict(torch.load('dynamics_function.pth'))
        if os.path.isfile('prediction_function.pth'):
            self.prediction_function.load_state_dict(torch.load('prediction_function.pth'))


# Test the networks

In [478]:
class TestFunctions(unittest.TestCase):
    def setUp(self):
        self.config = Config()
        self.env = gym.make(self.config.environment_name, render_mode=self.config.render_mode)
        self.representation_size = 128
        action_space_size = self.env.action_space.n
        self.agent = Agent(self.config, self.representation_size)
        self.representation_function = RepresentationFunction(self.env.observation_space.shape[0], self.representation_size).to(device)
        self.dynamics_function = DynamicsFunction(self.representation_size + action_space_size, self.representation_size).to(device)
        self.prediction_function = PredictionFunction(self.representation_size, action_space_size).to(device)

    def test_representation_function(self):
        representation_function = RepresentationFunction(self.env.observation_space.shape[0], self.representation_size).to(device)
        input_tensor = torch.randn(1, self.env.observation_space.shape[0]).to(device)
        output_tensor = representation_function(input_tensor)
        self.assertEqual(output_tensor.shape, (1, self.representation_size))

    def test_initial_representation_function(self):
        state, _ = self.env.reset()
        state_tensor = torch.tensor(state).float().unsqueeze(0).to(device)
        output_tensor = self.agent.representation_function(state_tensor)
        self.assertEqual(output_tensor.shape, (1, self.representation_size))

    def test_dynamics_function(self):
        state_repr = torch.randn(1, self.representation_size).to(device)
        action = torch.randn(1, self.env.action_space.n).to(device)
        output_tensor = self.agent.dynamics_function(state_repr, action)
        self.assertEqual(output_tensor.shape, (1, self.representation_size))

    def test_prediction_function(self):
        input_tensor = torch.randn(1, self.representation_size).to(device)
        output_tensor = self.agent.prediction_function(input_tensor)
        self.assertEqual(output_tensor.shape, (1, self.env.action_space.n))


if __name__ == "__main__":
    # Load and run the test case
    suite = unittest.TestLoader().loadTestsFromTestCase(TestFunctions)
    result = unittest.TextTestRunner().run(suite)

EEEE
ERROR: test_dynamics_function (__main__.TestFunctions.test_dynamics_function)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/ipykernel_3270/3011834528.py", line 7, in setUp
    self.agent = Agent(self.config, self.representation_size)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Agent.__init__() missing 1 required positional argument: 'action_space_size'

ERROR: test_initial_representation_function (__main__.TestFunctions.test_initial_representation_function)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/ipykernel_3270/3011834528.py", line 7, in setUp
    self.agent = Agent(self.config, self.representation_size)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Agent.__init__() missing 1 required positional argument: 'action_space_size'

ERROR: test_prediction_function (__main__.TestFunct

# Run it all

In [479]:
if __name__ == "__main__":
    config = Config()
    env = gym.make(config.environment_name, render_mode=config.render_mode)
    print("Environment observation space:", env.observation_space.shape)
    input_size = env.observation_space.shape[0]
    representation_size = 128
    action_size = env.action_space.n
    representation_function = RepresentationFunction(input_size, representation_size).to(device)
    dynamics_input_size = representation_size + action_size
    dynamics_function = DynamicsFunction(dynamics_input_size, representation_size, action_size).to(device)
    prediction_input_size = representation_size + action_size
    prediction_output_size = 65
    prediction_function = PredictionFunction(prediction_input_size, prediction_output_size).to(device)
    
    agent = Agent(config, representation_size, action_size)  # Add 'action_size' argument here

    if os.path.isfile('representation_function.pth'):
        representation_function.load_state_dict(torch.load('representation_function.pth'))
    if os.path.isfile('dynamics_function.pth'):
        dynamics_function.load_state_dict(torch.load('dynamics_function.pth'))
    if os.path.isfile('prediction_function.pth'):
        prediction_function.load_state_dict(torch.load('prediction_function.pth'))

    agent.representation_function = representation_function
    agent.dynamics_function = dynamics_function
    agent.prediction_function = prediction_function

    print("Agent representation function:", agent.representation_function)

    
    agent.populate_initial_buffer()
    agent.train()

    # Visualization of agent's interactions with the environment

Environment observation space: (8,)
Agent representation function: RepresentationFunction(
  (fc1): Linear(in_features=8, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
)
<class 'numpy.ndarray'> [ 0.00586824  1.4191151   0.59436554  0.36420003 -0.00679295 -0.13463262
  0.          0.        ]
Initial state: [ 0.00586824  1.4191151   0.59436554  0.36420003 -0.00679295 -0.13463262
  0.          0.        ]
State tensor shape: torch.Size([1, 8])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x128 and 132x128)