In [1]:
# %apt-get install -y python3-dev swig
# %pip install box2d-py

## Imports

In [2]:
# Envs
from collections import deque
import cv2
import imageio

## LunarLander
import gymnasium
import math

## CarRacing

# Network
import torch
import torch.nn as nn

# ReplayBuffer
import numpy as np

# MuZero
import os
from datetime import datetime
import matplotlib.pyplot as plt
import time

# Part 1

## Config

In [3]:
class Config:
    # Save graphs/anims
    GRAPH_DIR = "data/graphs/"
    ANIM_DIR = "data/anims/"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ### Game
    nb_episodes = 1000
    nb_test_chunk = 3
    nb_ep_bw_test = 20
    update_graph_interval = 5

    action_space = list(range(4))
    resolution = (96, 96)
    nb_channels = 1  # Number of channels in the observations, 1 for grayscale, 3 RGB
    stacked_observations = (
        4  # Number of previous observations and previous actions to add to the current observation
    )
    interval = 5  # Time-jumps between frames

    ### Self-Play
    max_moves = 400
    num_simulations = 30  # Number of future moves self-simulated
    discount = 0.999  # Chronological discount of the reward
    value_loss_weight = 0.25  # paper recommends 0.25

    # Root prior exploration noise
    root_dirichlet_alpha = 0.25
    root_exploration_fraction = 0.25

    # UCB formula
    pb_c_base = 19652
    pb_c_init = 1.25

    ### Network
    support_size = 20

    # Fully Connected Network
    encoding_size = 10
    fc_representation_layers = []
    fc_dynamics_layers = [64]
    fc_reward_layers = [64]
    fc_value_layers = [64]
    fc_policy_layers = [64]

    ### Training
    # Exponential learning rate schedule
    lr_init = 0.005
    lr_decay_rate = 1

    training_steps = 20000
    batch_size = 64

    ### Replay Buffer
    replay_buffer_size = 128  # Number of self-play games to keep in the replay buffer
    num_unroll_steps = 10  # Number of game moves to keep for every batch element
    td_steps = (
        15  # Number of steps in the future to take into account for calculating the target value
    )

    visit_softmax_temperature = 0.35

## Game

#### Lunar Lander

In [4]:
class LunarLander:
    def __init__(self, normalize=True, record=False):
        self.config = Config
        self.normalize = normalize
        self.record = record

        self.env = gymnasium.make("LunarLander-v3", render_mode="rgb_array")
        self.action_space = self.env.action_space

        self.frames = []
        self.observations = deque(maxlen=self.config.stacked_observations)

        self.step_count = 0
        self.total_reward = 0

    def step(self, action):
        observation, reward, terminated, truncated, _ = self.env.step(action)
        done = terminated or truncated

        frame = np.array(self.render())
        if self.record:
            self.frames.append(frame)

        observation = self._process_frame(frame)
        self.observations.append(observation)

        step_obs = torch.cat(list(self.observations), dim=0).unsqueeze(0)

        self.step_count += 1
        self.total_reward += reward
        if self.step_count >= self.config.max_moves:
            done = True

        return step_obs, reward, done

    def _process_frame(self, observation):
        target_shape = (
            self.config.resolution[0],
            self.config.resolution[1],
            self.config.nb_channels,
        )
        observation = cv2.resize(observation, target_shape[:2], interpolation=cv2.INTER_AREA)

        # observation = np.transpose(observation, (2, 0, 1))

        if self.config.nb_channels == 1:
            observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
            # observation = cv2.cvtColor(observation.transpose(1, 2, 0), cv2.COLOR_RGB2GRAY)
            observation = np.expand_dims(observation, axis=0)

        if self.normalize:
            observation = observation / 255.0

        return torch.from_numpy(observation).to(torch.float32)

    def render(self):
        return self.env.render()

    def reset(self):
        self.frames = []
        self.observations = deque(maxlen=self.config.stacked_observations)

        self.step_count = 0
        self.total_reward = 0

        self.env.reset()
        observation = np.array(self.env.render())

        processed_obs = self._process_frame(observation)
        for _ in range(self.config.stacked_observations):
            self.observations.append(processed_obs)

        step_obs = torch.cat(list(self.observations), dim=0).unsqueeze(0)

        return step_obs

    def close(self):
        self.env.close()

    def save_video(self, VID_DIR, episode, fps=60, test_ep=None, prnt=True):
        # Save the frames as a video
        if test_ep is None:
            video_filename = f"{VID_DIR}/ep_{episode}.mp4"
        else:
            video_filename = f"{VID_DIR}/ep_{episode}_test_{test_ep}.mp4"
        if prnt:
            print(f"Saving video to {video_filename}")
        imageio.mimsave(video_filename, self.frames, fps=fps)

#### Car Racing

In [5]:
class CarRacing:
    def __init__(self, normalize=True, record=False):
        self.config = Config
        self.normalize = normalize
        self.record = record

        self.env = gymnasium.make(
            "CarRacing-v2", render_mode="rgb_array", domain_randomize=False, continuous=False
        )
        self.action_space = self.env.action_space

        self.frames = []
        self.observations = deque(maxlen=self.config.stacked_observations)

        self.step_count = 0
        self.total_reward = 0

    def step(self, action):
        observation, reward, terminated, truncated, _ = self.env.step(action)
        if self.record:
            self.frames.append(np.array(self.render()))

        observation = self._process_frame(observation)
        self.observations.append(observation)

        step_obs = torch.cat(list(self.observations), dim=0).unsqueeze(0)

        done = terminated or truncated
        if self.step_count >= self.config.max_moves:
            done = True

        self.step_count += 1
        self.total_reward += reward

        return step_obs, reward, done

    def _process_frame(self, observation):
        observation = np.transpose(observation, (2, 0, 1))

        if self.config.nb_channels == 1:
            observation = cv2.cvtColor(observation.transpose(1, 2, 0), cv2.COLOR_RGB2GRAY)
            observation = np.expand_dims(observation, axis=0)

        if self.normalize:
            observation = observation / 255.0

        return torch.from_numpy(observation).to(torch.float32)

    def render(self):
        return self.env.render()

    def reset(self):
        self.frames = []
        self.observations = deque(maxlen=self.config.stacked_observations)

        self.step_count = 0
        self.total_reward = 0

        observation, _ = self.env.reset()
        if self.record:
            self.frames.append(np.array(self.render()))

        observation = self._process_frame(observation)
        for _ in range(self.config.stacked_observations):
            self.observations.append(observation)

        step_obs = torch.cat(list(self.observations), dim=0).unsqueeze(0)

        return step_obs

    def close(self):
        self.env.close()

    def save_video(self, VID_DIR, episode, fps=60, test_ep=None, prnt=True):
        # Save the frames as a video
        if test_ep is None:
            video_filename = f"{VID_DIR}/ep_{episode}.mp4"
        else:
            video_filename = f"{VID_DIR}/ep_{episode}_test_{test_ep}.mp4"
        if prnt:
            print(f"Saving video to {video_filename}")
        imageio.mimsave(video_filename, self.frames, fps=fps)

## Models

In [6]:
class MuZeroNetwork(nn.Module):
    def __init__(self):
        super(MuZeroNetwork, self).__init__()

        self.config = Config
        self.device = self.config.device
        self.to(self.device)

        self.fc_reward_layers = Config.fc_reward_layers
        self.fc_value_layers = Config.fc_value_layers
        self.fc_policy_layers = Config.fc_policy_layers
        self.fc_representation_layers = Config.fc_representation_layers
        self.fc_dynamics_layers = Config.fc_dynamics_layers
        self.full_support_size = 2 * self.config.support_size + 1
        self.support_values = (
            torch.arange(
                -self.config.support_size, self.config.support_size + 1, dtype=torch.float32
            )
            .unsqueeze(0)
            .to(self.device)
        )

        self.repr_conv = nn.Sequential(
            nn.Conv2d(self.config.nb_channels * self.config.stacked_observations, 32, 8, stride=4),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.Flatten(),
        )
        num_features = len(
            self.repr_conv(
                torch.zeros(
                    (
                        1,
                        self.config.nb_channels * self.config.stacked_observations,
                        self.config.resolution[0],
                        self.config.resolution[1],
                    )
                )
            )[0, :]
        )
        self.fc = nn.Sequential(
            nn.Linear(num_features, 128),
            nn.Linear(128, self.config.encoding_size),
        )

        self.representation_network = nn.Sequential(
            *self.repr_conv,
            *self.fc,
        )

        self.dynamics_encoded_state_network = network_builder(
            self.config.encoding_size + len(self.config.action_space),
            self.fc_dynamics_layers,
            self.config.encoding_size,
        )
        self.dynamics_reward_network = network_builder(
            self.config.encoding_size, self.config.fc_reward_layers, self.full_support_size
        )
        self.prediction_policy_network = network_builder(
            self.config.encoding_size, self.fc_policy_layers, len(self.config.action_space)
        )
        self.prediction_value_network = network_builder(
            self.config.encoding_size, self.fc_value_layers, self.full_support_size
        )

        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr=Config.lr_init,
            weight_decay=Config.lr_decay_rate,
        )

    def representation_function(self, observation):
        encoded_state = self.representation_network(observation)

        # Scale encoded state between [0, 1]
        min_encoded_state = encoded_state.min(1, keepdim=True)[0]
        max_encoded_state = encoded_state.max(1, keepdim=True)[0]
        scale_encoded_state = max_encoded_state - min_encoded_state
        scale_encoded_state[scale_encoded_state < 1e-5] += 1e-5
        encoded_state_normalized = (encoded_state - min_encoded_state) / scale_encoded_state
        return encoded_state_normalized

    def prediction_function(self, encoded_state):
        policy_logits = self.prediction_policy_network(encoded_state)
        value = self.prediction_value_network(encoded_state)
        return policy_logits, value

    def dynamics_function(self, encoded_state, action):
        # Stack encoded_state with a game specific one hot encoded action
        action_one_hot = torch.zeros(
            (action.shape[0], len(self.config.action_space)), device=self.device
        ).float()
        action_one_hot.scatter_(1, action.long(), 1.0)

        x = torch.cat((encoded_state, action_one_hot), dim=1)

        next_encoded_state = self.dynamics_encoded_state_network(x)

        reward = self.dynamics_reward_network(next_encoded_state)

        # Scale encoded state between [0, 1]
        min_next_encoded_state = next_encoded_state.min(1, keepdim=True)[0]
        max_next_encoded_state = next_encoded_state.max(1, keepdim=True)[0]
        scale_next_encoded_state = max_next_encoded_state - min_next_encoded_state
        scale_next_encoded_state[scale_next_encoded_state < 1e-5] += 1e-5
        next_encoded_state_normalized = (
            next_encoded_state - min_next_encoded_state
        ) / scale_next_encoded_state

        return next_encoded_state_normalized, reward

    def initial_inference(self, observation):
        observation = observation.to(self.device)
        encoded_state = self.representation_function(observation)
        policy_logits, value = self.prediction_function(encoded_state)
        # reward equal to 0 for consistency
        reward = torch.log(
            (
                torch.zeros(1, self.full_support_size)
                .scatter(1, torch.tensor([[self.full_support_size // 2]]).long(), 1.0)
                .repeat(len(observation), 1)
                .to(self.device)
            )
        )
        return value, reward, policy_logits, encoded_state

    def recurrent_inference(self, encoded_state, action):
        encoded_state = encoded_state.to(self.config.device)
        action = action.to(self.config.device)
        next_encoded_state, reward = self.dynamics_function(encoded_state, action)
        policy_logits, value = self.prediction_function(next_encoded_state)
        return value, reward, policy_logits, next_encoded_state


def network_builder(
    input_size,
    layer_sizes,
    output_size,
    output_activation=torch.nn.Identity,
    activation=torch.nn.ELU,
    input_layers=[],
):
    sizes = [input_size] + layer_sizes + [output_size]
    layers = []
    layers.extend(input_layers)
    for i in range(len(sizes) - 1):
        act = activation if i < len(sizes) - 2 else output_activation
        layers += [torch.nn.Linear(sizes[i], sizes[i + 1]), act()]
    return torch.nn.Sequential(*layers)


def support_to_scalar(logits, support_size):
    """
    Transform a categorical representation to a scalar
    """
    # Decode to a scalar
    probabilities = torch.softmax(logits, dim=1)
    support = (
        torch.tensor([x for x in range(-support_size, support_size + 1)], device=logits.device)
        .expand(probabilities.shape)
        .float()
    )

    x = torch.sum(support * probabilities, dim=1, keepdim=True)

    return torch.sign(x) * (
        ((torch.sqrt(1 + 4 * 0.001 * (torch.abs(x) + 1 + 0.001)) - 1) / (2 * 0.001)) ** 2 - 1
    )


def scalar_to_support(x, support_size):
    """
    Transform a scalar to a categorical representation with (2 * support_size + 1) categories
    """
    # Reduce the scale (defined in https://arxiv.org/abs/1805.11593)
    x = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + 0.001 * x

    # Encode on a vector
    x = torch.clamp(x, -support_size, support_size)
    floor = x.floor()
    prob = x - floor
    logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1, device=x.device)
    logits.scatter_(2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1))
    indexes = floor + support_size + 1
    prob = prob.masked_fill_(2 * support_size < indexes, 0.0)
    indexes = indexes.masked_fill_(2 * support_size < indexes, 0.0)
    logits.scatter_(2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1))
    return logits

## Replay buffer

In [7]:
class ReplayBuffer:
    def __init__(self):
        """
        Initialize the replay buffer with a given size.
        """
        self.config = Config
        self.device = self.config.device

        self.buffer = deque(maxlen=self.config.replay_buffer_size)

        self.num_played_steps = 0
        self.num_played_games = 0
        self.total_samples = 0

    def __len__(self):
        return len(self.buffer)

    def save_game(self, game_history):
        self.buffer.append(game_history)
        self.num_played_games += 1
        self.num_played_steps += len(game_history.root_values)
        self.total_samples = sum(len(game.root_values) for game in self.buffer)

    def compute_target_value(self, game_history, index):
        # The value target is the discounted root value of the search tree td_steps into the future, plus the discounted sum of all rewards until then.
        bootstrap_index = index + self.config.td_steps
        if bootstrap_index < len(game_history.root_values):
            root_values = game_history.root_values
            last_step_value = root_values[bootstrap_index]

            value = last_step_value * self.config.discount**self.config.td_steps
        else:
            value = 0

        for i, reward in enumerate(game_history.reward_history[index + 1 : bootstrap_index + 1]):
            value += reward * self.config.discount**i

        return value

    def get_batch(self):
        index_batch = []
        observation_batch = []
        action_batch = []
        value_batch = []
        reward_batch = []
        policy_batch = []
        gradient_scale_batch = []

        for game_id, game_history in self.sample_game(self.config.batch_size):
            game_pos = self.sample_position(game_history)

            values, rewards, policies, actions = self.make_target(game_history, game_pos)

            observation_batch.append(game_history.observation_history[game_pos])

            index_batch.append([game_id, game_pos])
            action_batch.append(actions)
            value_batch.append(values)
            reward_batch.append(rewards)
            policy_batch.append(policies)
            gradient_scale_batch.append(
                [min(self.config.num_unroll_steps, len(game_history.action_history) - game_pos)]
                * len(actions)
            )

        # Move all tensors in observation_batch to the same device
        observation_batch = [obs.to(self.device) for obs in observation_batch]

        # Now stack the tensors
        observation_batch = torch.stack([obs.squeeze(0) for obs in observation_batch], axis=0)

        return (
            index_batch,
            (
                observation_batch,
                action_batch,
                value_batch,
                reward_batch,
                policy_batch,
                gradient_scale_batch,
            ),
        )

    def _generate_single_batch(self, game_history, game_id, game_pos):
        values, rewards, policies, actions, observations = self.make_target(game_history, game_pos)

        single_index = [game_id, game_pos]
        single_observation = observations
        single_action = actions
        single_value = values
        single_reward = rewards
        single_policy = policies
        single_gradient_scale = [
            min(self.config.num_unroll_steps, len(game_history.action_history) - game_pos)
        ] * len(actions)

        return (
            single_index,
            single_observation,
            single_action,
            single_value,
            single_reward,
            single_policy,
            single_gradient_scale,
        )

    def sample_game(self, n_games=1):
        """
        Sample n_games from the buffer uniformly.
        See paper appendix Training.
        """
        game_indices = np.random.choice(len(self.buffer), size=n_games, replace=True)
        game_ids = [
            self.num_played_games - len(self.buffer) + game_index for game_index in game_indices
        ]

        return [(game_index, self.buffer[game_index]) for game_index in game_indices]

    def sample_position(self, game_history):
        """
        Sample position from game uniformly.
        See paper appendix Training.
        """
        return np.random.choice(len(game_history.root_values))

    def make_target(self, game_history, state_index):
        """
        Generate targets for every unroll steps.
        """
        target_values = []
        target_rewards = []
        target_policies = []
        actions = []

        for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1):
            value = self.compute_target_value(game_history, current_index)

            if current_index < len(game_history.root_values):
                target_values.append(value)
                target_rewards.append(game_history.reward_history[current_index])
                target_policies.append(game_history.child_visits[current_index])
                actions.append(game_history.action_history[current_index])
            elif current_index == len(game_history.root_values):
                target_values.append(0)
                target_rewards.append(game_history.reward_history[current_index])
                # Uniform policy
                target_policies.append(
                    [
                        1 / len(game_history.child_visits[0])
                        for _ in range(len(game_history.child_visits[0]))
                    ]
                )
                actions.append(game_history.action_history[current_index])
            else:
                # States past the end of games are treated as absorbing states
                target_values.append(0)
                target_rewards.append(0)
                # Uniform policy
                target_policies.append(
                    [
                        1 / len(game_history.child_visits[0])
                        for _ in range(len(game_history.child_visits[0]))
                    ]
                )
                actions.append(np.random.choice(self.config.action_space))

        return target_values, target_rewards, target_policies, actions

## GameHistory

In [8]:
class GameHistory:
    """
    Store only usefull information of a self-play game.
    """

    def __init__(self):
        self.observation_history = []
        self.action_history = []
        self.reward_history = []
        self.child_visits = []
        self.root_values = []

    def store_search_statistics(self, root, action_space):
        # Turn visit count from root into a policy
        if root is not None:
            sum_visits = sum(child.visit_count for child in root.children.values())
            self.child_visits.append(
                [
                    root.children[a].visit_count / sum_visits if a in root.children else 0
                    for a in action_space
                ]
            )

            self.root_values.append(root.value())
        else:
            self.root_values.append(None)

## MCTS

In [9]:
class Node:
    def __init__(self, prior):
        self.device = Config.device

        self.visit_count = 0
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def expand(self, actions, reward, policy_logits, hidden_state):
        """
        # FROM PAPER CODE GITHUB
        We expand a node using the value, reward and policy prediction obtained from the
        neural network.
        """
        self.reward = reward
        self.hidden_state = hidden_state

        # policy_values = torch.softmax(torch.tensor([policy_logits[0][a] for a in actions]), dim=0).tolist()
        policy_values = torch.softmax(
            torch.tensor([policy_logits[0][a] for a in actions]).to(self.device), dim=0
        ).tolist()
        policy = {a: policy_values[i] for i, a in enumerate(actions)}

        for action, p in policy.items():
            self.children[action] = Node(p)

    def add_exploration_noise(self, dirichlet_alpha, exploration_fraction):
        """
        At the start of each search, we add dirichlet noise to the prior of the root to
        encourage the search to explore new actions.
        """
        actions = list(self.children.keys())
        noise = np.random.dirichlet([dirichlet_alpha] * len(actions))
        frac = exploration_fraction
        for a, n in zip(actions, noise):
            self.children[a].prior = self.children[a].prior * (1 - frac) + n * frac


class MinMaxStats:
    """
    COPIED
    A class that holds the min-max values of the tree.
    """

    def __init__(self):
        self.maximum = -float("inf")
        self.minimum = float("inf")

    def update(self, value):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value):
        if self.maximum > self.minimum:
            # We normalize only when we have set the maximum and minimum values
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value


def select_action(node, best=False):
    """
    Select action according to the visit count distribution and the temperature.
    The temperature is changed dynamically with the visit_softmax_temperature function
    in the config.
    """

    visit_counts = np.array([child.visit_count for child in node.children.values()], dtype="int32")
    actions = [action for action in node.children.keys()]

    visit_count_distribution = visit_counts ** (1 / Config.visit_softmax_temperature)
    visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)

    if best:
        action = actions[np.argmax(visit_count_distribution)]
    else:
        action = np.random.choice(actions, p=visit_count_distribution)

    # if np.random.random() < 0.01:
    #     print(visit_count_distribution)
    #     print(action)

    return action


class MCTS:
    def __init__(self) -> None:
        self.config = Config

        self.discount = Config.discount

    def run(self, network: MuZeroNetwork, observation, add_exploration_noise=True):
        root = Node(0)

        root_predicted_value, reward, policy_logits, hidden_state = network.initial_inference(
            observation
        )

        root_predicted_value = support_to_scalar(
            root_predicted_value, self.config.support_size
        ).item()
        reward = support_to_scalar(reward, self.config.support_size).item()

        root.expand(
            self.config.action_space,
            reward,
            policy_logits,
            hidden_state,
        )

        if add_exploration_noise:
            root.add_exploration_noise(
                dirichlet_alpha=self.config.root_dirichlet_alpha,
                exploration_fraction=self.config.root_exploration_fraction,
            )

        min_max_stats = MinMaxStats()

        max_tree_depth = 0
        for _ in range(self.config.num_simulations):
            node = root
            search_path = [node]
            current_tree_depth = 0

            while node.expanded():
                current_tree_depth += 1
                action, node = self.select_child(node, min_max_stats)
                search_path.append(node)

            parent = search_path[-2]
            value, reward, policy_logits, hidden_state = network.recurrent_inference(
                parent.hidden_state, torch.tensor([[action]])
            )
            value = support_to_scalar(value, self.config.support_size).item()
            reward = support_to_scalar(reward, self.config.support_size).item()
            node.expand(
                self.config.action_space,
                reward,
                policy_logits,
                hidden_state,
            )

            self.backpropagate(search_path, value, min_max_stats)

            max_tree_depth = max(max_tree_depth, current_tree_depth)

        return root

    def select_child(self, node, min_max_stats):
        """
        COPIED FROM GITHUB
        Select the child with the highest UCB score.
        """
        max_ucb = max(
            self.ucb_score(node, child, min_max_stats) for action, child in node.children.items()
        )
        action = np.random.choice(
            [
                action
                for action, child in node.children.items()
                if self.ucb_score(node, child, min_max_stats) == max_ucb
            ]
        )
        return action, node.children[action]

    def ucb_score(self, parent, child, min_max_stats):
        """
        COPIED FROM GITHUB
        The score for a node is based on its value, plus an exploration bonus based on the prior.
        """
        pb_c = (
            math.log((parent.visit_count + self.config.pb_c_base + 1) / self.config.pb_c_base)
            + self.config.pb_c_init
        )
        pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

        prior_score = pb_c * child.prior

        if child.visit_count > 0:
            # Mean value Q
            value_score = min_max_stats.normalize(child.reward + self.discount * child.value())
        else:
            value_score = 0

        return prior_score + value_score

    def backpropagate(self, search_path, value, min_max_stats):
        """
        At the end of a simulation, we propagate the evaluation all the way up the tree
        to the root.
        """
        for node in reversed(search_path):
            node.value_sum += value
            node.visit_count += 1
            min_max_stats.update(node.reward + self.discount * node.value())

            value = node.reward + self.discount * value

## MuZero

In [10]:
class MuZero:
    def __init__(self, seed=None):
        if seed is not None:
            np.random.seed(seed)
            torch.manual_seed(seed)

        self.config = Config
        self.device = self.config.device

        self.network = MuZeroNetwork()
        self.network.to(self.device)

        # self.env = CarRacing()
        self.env = LunarLander()
        self.buffer = ReplayBuffer()

        self.VID_DIR = f'{Config.ANIM_DIR}/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
        self.GRAPH_DIR = f'{Config.GRAPH_DIR}/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'

        if not os.path.exists(f"{self.VID_DIR}/test"):
            os.makedirs(f"{self.VID_DIR}/test")
        if not os.path.exists(f"{self.GRAPH_DIR}"):
            os.makedirs(f"{self.GRAPH_DIR}")

        self.training_step = 0
        self.training_history = {
            "loss": [],
            "value_loss": [],
            "reward_loss": [],
            "policy_loss": [],
            "reward": [],
        }

    def train(self):
        for episode in range(self.config.nb_episodes):
            # Test the MuZero network
            if episode % self.config.nb_ep_bw_test == 0:
                self.test_play(episode)

            # Run self-play simulation
            game_history = self.play(episode)
            self.training_history["reward"].append(self.env.total_reward)

            # Add trajectory data to replay buffer
            self.buffer.save_game(game_history)

            # Train the MuZero network
            if self.training_step < self.config.training_steps:
                batch_id, batch = self.buffer.get_batch()

                (
                    total_loss,
                    value_loss,
                    reward_loss,
                    policy_loss,
                ) = self.update_weights(batch)

                self.training_history["loss"].append(total_loss)
                self.training_history["value_loss"].append(value_loss)
                self.training_history["reward_loss"].append(reward_loss)
                self.training_history["policy_loss"].append(policy_loss)

                print(
                    f"\nLoss: {total_loss:.4f} | Value loss: {value_loss:.4f} | Reward loss: {reward_loss:.4f} | Policy loss: {policy_loss:.4f}"
                )

            # Update graphs every few episodes
            if episode % self.config.update_graph_interval == 0:
                self.update_graphs()

        print("End of training. Last tests...\n")

        self.test_play(episode)
        self.env.close()

    def play(self, episode: int):
        start = time.time()

        print(f"\n\nTraining episode {episode}...")

        # Train the MuZero network
        observation = self.env.reset()
        observation = observation.to(self.device)

        done = False

        game_history = GameHistory()
        game_history.action_history.append(0)
        game_history.observation_history.append(observation)
        game_history.reward_history.append(0)

        # Interaction loop
        while not done:
            # Plan and act using Monte-Carlo Tree Search (MCTS)
            root = MCTS().run(self.network, observation)
            action = select_action(root)

            # Interact with the environment
            observation, reward, done = self.env.step(action)

            game_history.store_search_statistics(root, self.config.action_space)
            game_history.action_history.append(action)
            game_history.observation_history.append(observation)
            game_history.reward_history.append(reward)

        print(
            f"Total reward: {self.env.total_reward:.2f} | Steps: {self.env.step_count} | Time: {time.time() - start:.2f}s"
        )

        return game_history

    def test_play(self, episode: int):
        print(f"\n\nTesting episode {episode}...")

        self.env.record = True

        # Test the MuZero network
        for test_ep in range(self.config.nb_test_chunk):
            # Reset the environment
            test_observation = self.env.reset()
            test_observation.to(self.device)
            test_done = False

            # Interaction loop for the test episode
            while not test_done:
                # Plan and act using Monte-Carlo Tree Search (MCTS)
                test_root = MCTS().run(self.network, test_observation, add_exploration_noise=False)
                test_action = select_action(test_root)  # , best=True)

                test_observation, test_reward, test_done = self.env.step(test_action)

            self.env.save_video(self.VID_DIR, episode, test_ep=test_ep)
            print(f"Total reward: {self.env.total_reward:.2f}")

        self.env.record = False

    def update_graphs(self):
        plt.figure(figsize=(12, 8))
        plt.plot(self.training_history["loss"], label="Total loss")
        plt.plot(self.training_history["value_loss"], label="Value loss")
        plt.plot(self.training_history["reward_loss"], label="Reward loss")
        plt.plot(self.training_history["policy_loss"], label="Policy loss")
        plt.legend()
        plt.savefig(f"{self.GRAPH_DIR}/loss.png")
        plt.close()

        smoothed_rewards = [self.training_history["reward"][0]]

        for i in range(1, len(self.training_history["reward"])):
            start = max(0, i - 6)
            mean_reward = sum(self.training_history["reward"][start : i + 1]) / (i - start + 1)
            smoothed_rewards.append(mean_reward)

        plt.figure(figsize=(12, 8))
        # plt.plot(self.training_history["reward"], label="Reward")
        plt.plot(smoothed_rewards, label="Smoothed reward")
        plt.legend()
        plt.savefig(f"{self.GRAPH_DIR}/reward.png")
        plt.close()

    def update_weights(self, batch):
        """
        Perform one training step.
        """
        (
            observation_batch,
            action_batch,
            target_value,
            target_reward,
            target_policy,
            gradient_scale_batch,
        ) = batch

        observation_batch = observation_batch.clone().detach()
        action_batch = torch.tensor(action_batch).to(self.device).long().unsqueeze(-1)
        target_value = torch.tensor(target_value).to(self.device).float()
        target_reward = torch.tensor(target_reward).to(self.device).float()
        target_policy = torch.tensor(target_policy).to(self.device).float()
        gradient_scale_batch = torch.tensor(gradient_scale_batch).to(self.device).float()

        target_value = scalar_to_support(target_value, self.config.support_size)
        target_reward = scalar_to_support(target_reward, self.config.support_size)

        ## Generate predictions
        value, reward, policy_logits, hidden_state = self.network.initial_inference(
            observation_batch
        )

        predictions = [(value, reward, policy_logits)]
        for i in range(1, action_batch.shape[1]):
            value, reward, policy_logits, hidden_state = self.network.recurrent_inference(
                hidden_state, action_batch[:, i]
            )

            # Scale the gradient at the start of the dynamics function (See paper appendix Training)

            # value = support_to_scalar(value, self.config.support_size).item()
            # reward = support_to_scalar(reward, self.config.support_size).item()

            hidden_state.register_hook(lambda grad: grad * 0.5)
            predictions.append((value, reward, policy_logits))

        ## Compute losses
        value_loss, reward_loss, policy_loss = (0, 0, 0)
        value, reward, policy_logits = predictions[0]
        # Ignore reward loss for the first batch step
        current_value_loss, _, current_policy_loss = self.loss_function(
            value.squeeze(-1),
            reward.squeeze(-1),
            policy_logits,
            target_value[:, 0],
            target_reward[:, 0],
            target_policy[:, 0],
        )
        value_loss += current_value_loss
        policy_loss += current_policy_loss

        for i in range(1, len(predictions)):
            value, reward, policy_logits = predictions[i]
            # print(f"Value: {value.squeeze(-1).shape}")
            # print(f"Reward: {reward.squeeze(-1).shape}")
            # print(f"Policy logits: {policy_logits.shape}")
            # print(f"Target value: {target_value[:, i].shape}")
            # print(f"Target reward: {target_reward[:, i].shape}")
            # print(f"Target policy: {target_policy[:, i].shape}")

            # print(support_to_scalar(value.squeeze(-1)[0], 41))
            # print(support_to_scalar(target_value[:, i][0], 41))

            # raise Exception

            (
                current_value_loss,
                current_reward_loss,
                current_policy_loss,
            ) = self.loss_function(
                value.squeeze(-1),
                reward.squeeze(-1),
                policy_logits,
                target_value[:, i],
                target_reward[:, i],
                target_policy[:, i],
            )

            # Scale gradient by the number of unroll steps (See paper appendix Training)
            current_value_loss.register_hook(lambda grad: grad / gradient_scale_batch[:, i])
            current_reward_loss.register_hook(lambda grad: grad / gradient_scale_batch[:, i])
            current_policy_loss.register_hook(lambda grad: grad / gradient_scale_batch[:, i])

            value_loss += current_value_loss
            reward_loss += current_reward_loss
            policy_loss += current_policy_loss

        # value_loss *= 0
        # reward_loss *= 0
        # policy_loss *= 1

        loss = value_loss * self.config.value_loss_weight + reward_loss + policy_loss

        # Mean over batch dimension (pseudocode do a sum)
        loss = loss.mean()

        # Optimize
        self.network.optimizer.zero_grad()
        loss.backward()
        self.network.optimizer.step()
        self.training_step += 1

        return (
            # For log purpose
            loss.item(),
            value_loss.mean().item(),
            reward_loss.mean().item(),
            policy_loss.mean().item(),
        )

    @staticmethod
    def loss_function(
        value,
        reward,
        policy_logits,
        target_value,
        target_reward,
        target_policy,
    ):
        value_loss = torch.sum(-target_value * torch.nn.LogSoftmax(dim=1)(value), dim=1)
        reward_loss = torch.sum(-target_reward * torch.nn.LogSoftmax(dim=1)(reward), dim=1)
        policy_loss = torch.sum(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits), dim=1)

        return value_loss, reward_loss, policy_loss

# Part 2

## Run

In [11]:
muzero = MuZero()

muzero.train()



Testing episode 0...
Saving video to data/anims//2024-10-18_15-47-50/ep_0_test_0.mp4




Total reward: -237.97




Saving video to data/anims//2024-10-18_15-47-50/ep_0_test_1.mp4
Total reward: -88.16




Saving video to data/anims//2024-10-18_15-47-50/ep_0_test_2.mp4
Total reward: -175.52


Training episode 0...
Total reward: -173.39 | Steps: 100 | Time: 1.63s

Loss: 62.6021 | Value loss: 41.9713 | Reward loss: 37.0372 | Policy loss: 15.0721


Training episode 1...
Total reward: -120.59 | Steps: 69 | Time: 1.17s

Loss: 61.7332 | Value loss: 41.2988 | Reward loss: 36.3982 | Policy loss: 15.0102


Training episode 2...
Total reward: -70.28 | Steps: 77 | Time: 1.29s

Loss: 61.1594 | Value loss: 41.0437 | Reward loss: 35.8716 | Policy loss: 15.0269


Training episode 3...
Total reward: -100.46 | Steps: 111 | Time: 1.78s

Loss: 60.8251 | Value loss: 40.7979 | Reward loss: 35.5829 | Policy loss: 15.0427


Training episode 4...
Total reward: -271.31 | Steps: 102 | Time: 1.80s

Loss: 60.0985 | Value loss: 40.6188 | Reward loss: 34.8952 | Policy loss: 15.0487


Training episode 5...
Total reward: -348.86 | Steps: 127 | Time: 2.07s

Loss: 59.8299 | Value loss: 40.6390 | Reward loss: 34.6206 | Po



Saving video to data/anims//2024-10-18_15-47-50/ep_20_test_0.mp4
Total reward: -118.85




Saving video to data/anims//2024-10-18_15-47-50/ep_20_test_1.mp4
Total reward: -117.07




Saving video to data/anims//2024-10-18_15-47-50/ep_20_test_2.mp4
Total reward: -128.16


Training episode 20...
Total reward: -109.31 | Steps: 72 | Time: 1.46s

Loss: 57.6775 | Value loss: 39.5537 | Reward loss: 32.5753 | Policy loss: 15.2138


Training episode 21...
Total reward: -129.51 | Steps: 98 | Time: 1.87s

Loss: 57.5589 | Value loss: 39.4046 | Reward loss: 32.4831 | Policy loss: 15.2246


Training episode 22...
Total reward: -126.87 | Steps: 94 | Time: 1.69s

Loss: 57.6260 | Value loss: 39.4672 | Reward loss: 32.5347 | Policy loss: 15.2245


Training episode 23...
Total reward: -239.86 | Steps: 111 | Time: 1.74s

Loss: 57.5945 | Value loss: 39.4773 | Reward loss: 32.5172 | Policy loss: 15.2080


Training episode 24...
Total reward: -88.71 | Steps: 123 | Time: 1.90s

Loss: 57.6023 | Value loss: 39.5801 | Reward loss: 32.4879 | Policy loss: 15.2193


Training episode 25...
Total reward: -335.42 | Steps: 113 | Time: 1.92s

Loss: 57.4886 | Value loss: 39.4039 | Reward loss: 32.413



Saving video to data/anims//2024-10-18_15-47-50/ep_40_test_0.mp4
Total reward: -88.62




Saving video to data/anims//2024-10-18_15-47-50/ep_40_test_1.mp4
Total reward: -98.83




Saving video to data/anims//2024-10-18_15-47-50/ep_40_test_2.mp4
Total reward: -69.99


Training episode 40...
Total reward: -79.51 | Steps: 118 | Time: 1.91s

Loss: 58.7808 | Value loss: 39.9245 | Reward loss: 33.5639 | Policy loss: 15.2358


Training episode 41...
Total reward: -98.70 | Steps: 82 | Time: 1.23s

Loss: 58.7421 | Value loss: 39.8721 | Reward loss: 33.5333 | Policy loss: 15.2408


Training episode 42...
Total reward: -163.96 | Steps: 101 | Time: 1.53s

Loss: 58.9585 | Value loss: 40.0320 | Reward loss: 33.7130 | Policy loss: 15.2375


Training episode 43...
Total reward: -196.02 | Steps: 112 | Time: 1.72s

Loss: 58.9019 | Value loss: 39.8366 | Reward loss: 33.7081 | Policy loss: 15.2346


Training episode 44...
Total reward: -127.52 | Steps: 107 | Time: 1.61s

Loss: 58.9680 | Value loss: 39.9619 | Reward loss: 33.7362 | Policy loss: 15.2413


Training episode 45...
Total reward: -66.62 | Steps: 101 | Time: 1.53s

Loss: 59.1329 | Value loss: 40.0311 | Reward loss: 33.8829



Saving video to data/anims//2024-10-18_15-47-50/ep_60_test_0.mp4
Total reward: -104.32




Saving video to data/anims//2024-10-18_15-47-50/ep_60_test_1.mp4
Total reward: -119.89




Saving video to data/anims//2024-10-18_15-47-50/ep_60_test_2.mp4
Total reward: -109.21


Training episode 60...
Total reward: -155.75 | Steps: 125 | Time: 2.10s

Loss: 59.7457 | Value loss: 40.3251 | Reward loss: 34.4222 | Policy loss: 15.2422


Training episode 61...
Total reward: -81.95 | Steps: 83 | Time: 1.32s

Loss: 59.7001 | Value loss: 40.3372 | Reward loss: 34.3699 | Policy loss: 15.2459


Training episode 62...
Total reward: -173.16 | Steps: 101 | Time: 1.54s

Loss: 59.8099 | Value loss: 40.3944 | Reward loss: 34.4682 | Policy loss: 15.2431


Training episode 63...
Total reward: -85.35 | Steps: 67 | Time: 1.05s

Loss: 59.9399 | Value loss: 40.4183 | Reward loss: 34.5918 | Policy loss: 15.2436


Training episode 64...
Total reward: -202.29 | Steps: 115 | Time: 1.76s

Loss: 59.7089 | Value loss: 40.3275 | Reward loss: 34.3832 | Policy loss: 15.2438


Training episode 65...
Total reward: -83.78 | Steps: 71 | Time: 1.10s

Loss: 59.7955 | Value loss: 40.3534 | Reward loss: 34.4638 



Saving video to data/anims//2024-10-18_15-47-50/ep_80_test_0.mp4
Total reward: -118.99




Saving video to data/anims//2024-10-18_15-47-50/ep_80_test_1.mp4
Total reward: -257.45




Saving video to data/anims//2024-10-18_15-47-50/ep_80_test_2.mp4
Total reward: -121.00


Training episode 80...
Total reward: -99.02 | Steps: 81 | Time: 1.25s

Loss: 59.8704 | Value loss: 40.4750 | Reward loss: 34.5056 | Policy loss: 15.2461


Training episode 81...
Total reward: -126.63 | Steps: 60 | Time: 0.97s

Loss: 59.6951 | Value loss: 40.4500 | Reward loss: 34.3361 | Policy loss: 15.2464


Training episode 82...
Total reward: -358.91 | Steps: 101 | Time: 1.54s

Loss: 59.7570 | Value loss: 40.4317 | Reward loss: 34.4044 | Policy loss: 15.2447


Training episode 83...
Total reward: -142.85 | Steps: 84 | Time: 1.27s

Loss: 59.6721 | Value loss: 40.4722 | Reward loss: 34.3086 | Policy loss: 15.2454


Training episode 84...
Total reward: -115.01 | Steps: 79 | Time: 1.20s

Loss: 59.7683 | Value loss: 40.4598 | Reward loss: 34.4083 | Policy loss: 15.2450


Training episode 85...
Total reward: -119.67 | Steps: 65 | Time: 0.98s

Loss: 59.6795 | Value loss: 40.4884 | Reward loss: 34.3153 



Saving video to data/anims//2024-10-18_15-47-50/ep_100_test_0.mp4
Total reward: -187.86




Saving video to data/anims//2024-10-18_15-47-50/ep_100_test_1.mp4
Total reward: -140.97




Saving video to data/anims//2024-10-18_15-47-50/ep_100_test_2.mp4
Total reward: -119.28


Training episode 100...
Total reward: -57.86 | Steps: 67 | Time: 1.06s

Loss: 59.5965 | Value loss: 40.5066 | Reward loss: 34.2223 | Policy loss: 15.2476


Training episode 101...
Total reward: -211.61 | Steps: 84 | Time: 1.26s

Loss: 59.6736 | Value loss: 40.4879 | Reward loss: 34.3043 | Policy loss: 15.2473


Training episode 102...
Total reward: -84.38 | Steps: 63 | Time: 0.97s

Loss: 59.5948 | Value loss: 40.5212 | Reward loss: 34.2160 | Policy loss: 15.2485


Training episode 103...
Total reward: -140.29 | Steps: 64 | Time: 0.97s

Loss: 59.5817 | Value loss: 40.4889 | Reward loss: 34.2118 | Policy loss: 15.2477


Training episode 104...
Total reward: -194.52 | Steps: 74 | Time: 1.16s

Loss: 59.5234 | Value loss: 40.4665 | Reward loss: 34.1590 | Policy loss: 15.2478


Training episode 105...
Total reward: -137.84 | Steps: 72 | Time: 1.09s

Loss: 59.5344 | Value loss: 40.5083 | Reward loss: 34.



Saving video to data/anims//2024-10-18_15-47-50/ep_120_test_0.mp4
Total reward: -142.51




Saving video to data/anims//2024-10-18_15-47-50/ep_120_test_1.mp4
Total reward: -110.01




Saving video to data/anims//2024-10-18_15-47-50/ep_120_test_2.mp4
Total reward: -430.11


Training episode 120...
Total reward: -96.03 | Steps: 97 | Time: 1.52s

Loss: 59.3902 | Value loss: 40.4909 | Reward loss: 34.0199 | Policy loss: 15.2476


Training episode 121...
Total reward: -452.26 | Steps: 92 | Time: 1.42s

Loss: 59.4492 | Value loss: 40.4925 | Reward loss: 34.0786 | Policy loss: 15.2475


Training episode 122...
Total reward: -103.11 | Steps: 115 | Time: 1.75s

Loss: 59.4082 | Value loss: 40.5235 | Reward loss: 34.0298 | Policy loss: 15.2475


Training episode 123...
Total reward: -110.66 | Steps: 67 | Time: 1.01s

Loss: 59.5048 | Value loss: 40.5329 | Reward loss: 34.1227 | Policy loss: 15.2489


Training episode 124...
Total reward: -143.66 | Steps: 114 | Time: 1.72s

Loss: 59.4207 | Value loss: 40.4892 | Reward loss: 34.0500 | Policy loss: 15.2485


Training episode 125...
Total reward: -7.98 | Steps: 92 | Time: 1.38s

Loss: 59.4201 | Value loss: 40.5499 | Reward loss: 34



Saving video to data/anims//2024-10-18_15-47-50/ep_140_test_0.mp4
Total reward: -158.83




Saving video to data/anims//2024-10-18_15-47-50/ep_140_test_1.mp4
Total reward: -124.36




Saving video to data/anims//2024-10-18_15-47-50/ep_140_test_2.mp4
Total reward: 10.90


Training episode 140...
Total reward: -71.92 | Steps: 67 | Time: 1.03s

Loss: 59.3927 | Value loss: 40.4725 | Reward loss: 34.0258 | Policy loss: 15.2488


Training episode 141...
Total reward: -157.33 | Steps: 111 | Time: 1.69s

Loss: 59.3778 | Value loss: 40.5448 | Reward loss: 33.9928 | Policy loss: 15.2489


Training episode 142...
Total reward: -462.85 | Steps: 120 | Time: 1.87s

Loss: 59.3287 | Value loss: 40.5309 | Reward loss: 33.9469 | Policy loss: 15.2491


Training episode 143...
Total reward: -426.63 | Steps: 94 | Time: 1.47s

Loss: 59.3082 | Value loss: 40.4791 | Reward loss: 33.9393 | Policy loss: 15.2491


Training episode 144...
Total reward: -95.07 | Steps: 80 | Time: 1.22s

Loss: 59.3268 | Value loss: 40.4844 | Reward loss: 33.9567 | Policy loss: 15.2490


Training episode 145...
Total reward: -310.91 | Steps: 91 | Time: 1.38s

Loss: 59.4201 | Value loss: 40.5094 | Reward loss: 34.



Saving video to data/anims//2024-10-18_15-47-50/ep_160_test_0.mp4
Total reward: -191.59




Saving video to data/anims//2024-10-18_15-47-50/ep_160_test_1.mp4
Total reward: -139.20




Saving video to data/anims//2024-10-18_15-47-50/ep_160_test_2.mp4
Total reward: -397.65


Training episode 160...
Total reward: -170.85 | Steps: 102 | Time: 1.63s

Loss: 59.3388 | Value loss: 40.5445 | Reward loss: 33.9537 | Policy loss: 15.2490


Training episode 161...
Total reward: -107.20 | Steps: 84 | Time: 1.29s

Loss: 59.0744 | Value loss: 40.4957 | Reward loss: 33.7013 | Policy loss: 15.2491


Training episode 162...
Total reward: -149.27 | Steps: 80 | Time: 1.22s

Loss: 59.3374 | Value loss: 40.5575 | Reward loss: 33.9490 | Policy loss: 15.2491


Training episode 163...
Total reward: -314.48 | Steps: 100 | Time: 1.52s

Loss: 59.3543 | Value loss: 40.5334 | Reward loss: 33.9717 | Policy loss: 15.2492


Training episode 164...
Total reward: -115.79 | Steps: 104 | Time: 1.59s

Loss: 59.2348 | Value loss: 40.5149 | Reward loss: 33.8567 | Policy loss: 15.2493


Training episode 165...
Total reward: -432.60 | Steps: 112 | Time: 1.75s

Loss: 59.2485 | Value loss: 40.5759 | Reward los



Saving video to data/anims//2024-10-18_15-47-50/ep_180_test_0.mp4
Total reward: -207.62




Saving video to data/anims//2024-10-18_15-47-50/ep_180_test_1.mp4
Total reward: -111.42




Saving video to data/anims//2024-10-18_15-47-50/ep_180_test_2.mp4
Total reward: -114.87


Training episode 180...
Total reward: -126.67 | Steps: 75 | Time: 1.15s

Loss: 59.1231 | Value loss: 40.5094 | Reward loss: 33.7467 | Policy loss: 15.2491


Training episode 181...
Total reward: -266.07 | Steps: 83 | Time: 1.26s

Loss: 59.1848 | Value loss: 40.5168 | Reward loss: 33.8064 | Policy loss: 15.2492


Training episode 182...
Total reward: -117.26 | Steps: 98 | Time: 1.50s

Loss: 59.2135 | Value loss: 40.4992 | Reward loss: 33.8395 | Policy loss: 15.2492


Training episode 183...
Total reward: -145.21 | Steps: 127 | Time: 1.93s

Loss: 59.1813 | Value loss: 40.4796 | Reward loss: 33.8123 | Policy loss: 15.2490


Training episode 184...
Total reward: -73.22 | Steps: 130 | Time: 1.98s

Loss: 59.0986 | Value loss: 40.5087 | Reward loss: 33.7223 | Policy loss: 15.2491


Training episode 185...
Total reward: -83.22 | Steps: 92 | Time: 1.45s

Loss: 59.2319 | Value loss: 40.5197 | Reward loss: 3



Saving video to data/anims//2024-10-18_15-47-50/ep_200_test_0.mp4
Total reward: -136.85




Saving video to data/anims//2024-10-18_15-47-50/ep_200_test_1.mp4
Total reward: -98.08




Saving video to data/anims//2024-10-18_15-47-50/ep_200_test_2.mp4
Total reward: -93.11


Training episode 200...
Total reward: -40.69 | Steps: 91 | Time: 1.44s

Loss: 59.0636 | Value loss: 40.4802 | Reward loss: 33.6943 | Policy loss: 15.2492


Training episode 201...
Total reward: -67.25 | Steps: 58 | Time: 0.90s

Loss: 59.0602 | Value loss: 40.4817 | Reward loss: 33.6905 | Policy loss: 15.2492


Training episode 202...
Total reward: -94.19 | Steps: 94 | Time: 1.48s

Loss: 59.2909 | Value loss: 40.4937 | Reward loss: 33.9183 | Policy loss: 15.2492


Training episode 203...
Total reward: -103.84 | Steps: 78 | Time: 1.23s

Loss: 59.1326 | Value loss: 40.5173 | Reward loss: 33.7540 | Policy loss: 15.2492


Training episode 204...
Total reward: -91.55 | Steps: 66 | Time: 1.02s

Loss: 58.9521 | Value loss: 40.4874 | Reward loss: 33.5811 | Policy loss: 15.2492


Training episode 205...
Total reward: -151.78 | Steps: 92 | Time: 1.46s

Loss: 59.0544 | Value loss: 40.5012 | Reward loss: 33.679



Saving video to data/anims//2024-10-18_15-47-50/ep_220_test_0.mp4
Total reward: -67.23




Saving video to data/anims//2024-10-18_15-47-50/ep_220_test_1.mp4
Total reward: -291.58




Saving video to data/anims//2024-10-18_15-47-50/ep_220_test_2.mp4
Total reward: -300.97


Training episode 220...
Total reward: -312.75 | Steps: 116 | Time: 1.81s

Loss: 59.1079 | Value loss: 40.5052 | Reward loss: 33.7324 | Policy loss: 15.2491


Training episode 221...
Total reward: -75.58 | Steps: 80 | Time: 1.22s

Loss: 58.9691 | Value loss: 40.4516 | Reward loss: 33.6070 | Policy loss: 15.2492


Training episode 222...
Total reward: -108.27 | Steps: 58 | Time: 0.89s

Loss: 58.9033 | Value loss: 40.4414 | Reward loss: 33.5438 | Policy loss: 15.2492


Training episode 223...
Total reward: -116.28 | Steps: 125 | Time: 1.91s

Loss: 59.2256 | Value loss: 40.5241 | Reward loss: 33.8454 | Policy loss: 15.2492


Training episode 224...
Total reward: -49.84 | Steps: 103 | Time: 1.57s

Loss: 59.0320 | Value loss: 40.5191 | Reward loss: 33.6530 | Policy loss: 15.2493


Training episode 225...
Total reward: -405.71 | Steps: 127 | Time: 1.94s

Loss: 59.0697 | Value loss: 40.5389 | Reward loss:



Saving video to data/anims//2024-10-18_15-47-50/ep_240_test_0.mp4
Total reward: -236.21




Saving video to data/anims//2024-10-18_15-47-50/ep_240_test_1.mp4
Total reward: -368.64




Saving video to data/anims//2024-10-18_15-47-50/ep_240_test_2.mp4
Total reward: -362.04


Training episode 240...
Total reward: -79.38 | Steps: 94 | Time: 1.44s

Loss: 59.0900 | Value loss: 40.4827 | Reward loss: 33.7201 | Policy loss: 15.2491


Training episode 241...
Total reward: -227.47 | Steps: 113 | Time: 1.80s

Loss: 59.1035 | Value loss: 40.5017 | Reward loss: 33.7288 | Policy loss: 15.2493


Training episode 242...
Total reward: -134.91 | Steps: 105 | Time: 1.60s

Loss: 58.8390 | Value loss: 40.5051 | Reward loss: 33.4635 | Policy loss: 15.2492


Training episode 243...
Total reward: -114.70 | Steps: 94 | Time: 1.51s

Loss: 59.1899 | Value loss: 40.5246 | Reward loss: 33.8095 | Policy loss: 15.2492


Training episode 244...
Total reward: -336.69 | Steps: 89 | Time: 1.35s

Loss: 58.9298 | Value loss: 40.5041 | Reward loss: 33.5545 | Policy loss: 15.2493


Training episode 245...
Total reward: -21.06 | Steps: 111 | Time: 1.68s

Loss: 59.0407 | Value loss: 40.5064 | Reward loss: 



Saving video to data/anims//2024-10-18_15-47-50/ep_260_test_0.mp4
Total reward: -174.64




Saving video to data/anims//2024-10-18_15-47-50/ep_260_test_1.mp4
Total reward: -22.34




Saving video to data/anims//2024-10-18_15-47-50/ep_260_test_2.mp4
Total reward: -31.38


Training episode 260...
Total reward: -102.31 | Steps: 62 | Time: 0.96s

Loss: 58.9472 | Value loss: 40.4837 | Reward loss: 33.5770 | Policy loss: 15.2492


Training episode 261...
Total reward: -87.23 | Steps: 89 | Time: 1.36s

Loss: 59.0086 | Value loss: 40.4974 | Reward loss: 33.6350 | Policy loss: 15.2492


Training episode 262...
Total reward: -183.82 | Steps: 76 | Time: 1.16s

Loss: 58.8716 | Value loss: 40.4717 | Reward loss: 33.5044 | Policy loss: 15.2492


Training episode 263...
Total reward: -117.91 | Steps: 58 | Time: 0.90s

Loss: 58.9338 | Value loss: 40.4996 | Reward loss: 33.5596 | Policy loss: 15.2492


Training episode 264...
Total reward: -132.48 | Steps: 100 | Time: 1.58s

Loss: 59.0807 | Value loss: 40.5230 | Reward loss: 33.7007 | Policy loss: 15.2492


Training episode 265...
Total reward: -124.64 | Steps: 111 | Time: 1.70s

Loss: 58.9848 | Value loss: 40.4724 | Reward loss: 3



Saving video to data/anims//2024-10-18_15-47-50/ep_280_test_0.mp4
Total reward: -537.52




Saving video to data/anims//2024-10-18_15-47-50/ep_280_test_1.mp4
Total reward: -416.99




Saving video to data/anims//2024-10-18_15-47-50/ep_280_test_2.mp4
Total reward: -255.74


Training episode 280...
Total reward: -213.79 | Steps: 106 | Time: 1.66s

Loss: 58.8886 | Value loss: 40.4835 | Reward loss: 33.5185 | Policy loss: 15.2492


Training episode 281...
Total reward: -113.62 | Steps: 64 | Time: 1.16s

Loss: 59.0747 | Value loss: 40.5277 | Reward loss: 33.6936 | Policy loss: 15.2492


Training episode 282...
Total reward: -138.23 | Steps: 119 | Time: 1.81s

Loss: 58.9561 | Value loss: 40.4993 | Reward loss: 33.5820 | Policy loss: 15.2492


Training episode 283...
Total reward: -196.57 | Steps: 82 | Time: 1.29s

Loss: 59.1226 | Value loss: 40.5355 | Reward loss: 33.7395 | Policy loss: 15.2492


Training episode 284...
Total reward: -112.45 | Steps: 62 | Time: 0.96s

Loss: 58.9808 | Value loss: 40.5306 | Reward loss: 33.5989 | Policy loss: 15.2492


Training episode 285...
Total reward: -153.70 | Steps: 87 | Time: 1.34s

Loss: 59.0494 | Value loss: 40.4822 | Reward loss:



Saving video to data/anims//2024-10-18_15-47-50/ep_300_test_0.mp4
Total reward: -60.33




Saving video to data/anims//2024-10-18_15-47-50/ep_300_test_1.mp4
Total reward: -175.52




Saving video to data/anims//2024-10-18_15-47-50/ep_300_test_2.mp4
Total reward: -64.49


Training episode 300...
Total reward: -252.97 | Steps: 116 | Time: 1.78s

Loss: 58.8672 | Value loss: 40.4710 | Reward loss: 33.5002 | Policy loss: 15.2493


Training episode 301...
Total reward: -224.67 | Steps: 81 | Time: 1.23s

Loss: 58.8288 | Value loss: 40.4318 | Reward loss: 33.4717 | Policy loss: 15.2492


Training episode 302...
Total reward: -88.59 | Steps: 78 | Time: 1.23s

Loss: 59.0735 | Value loss: 40.5176 | Reward loss: 33.6950 | Policy loss: 15.2492


Training episode 303...
Total reward: -232.59 | Steps: 113 | Time: 1.72s

Loss: 58.9194 | Value loss: 40.4727 | Reward loss: 33.5520 | Policy loss: 15.2492


Training episode 304...
Total reward: -320.68 | Steps: 95 | Time: 1.45s

Loss: 58.9913 | Value loss: 40.4756 | Reward loss: 33.6231 | Policy loss: 15.2492


Training episode 305...
Total reward: 109.77 | Steps: 400 | Time: 6.06s

Loss: 58.9183 | Value loss: 40.4472 | Reward loss: 3



Saving video to data/anims//2024-10-18_15-47-50/ep_320_test_0.mp4
Total reward: -363.01




Saving video to data/anims//2024-10-18_15-47-50/ep_320_test_1.mp4
Total reward: -181.82




Saving video to data/anims//2024-10-18_15-47-50/ep_320_test_2.mp4
Total reward: -262.62


Training episode 320...
Total reward: -67.65 | Steps: 81 | Time: 1.31s

Loss: 58.9488 | Value loss: 40.5099 | Reward loss: 33.5721 | Policy loss: 15.2493


Training episode 321...
Total reward: -92.60 | Steps: 61 | Time: 0.97s

Loss: 58.9384 | Value loss: 40.4949 | Reward loss: 33.5654 | Policy loss: 15.2493


Training episode 322...
Total reward: -98.08 | Steps: 85 | Time: 1.34s

Loss: 59.0166 | Value loss: 40.5108 | Reward loss: 33.6396 | Policy loss: 15.2493


Training episode 323...
Total reward: -83.93 | Steps: 116 | Time: 1.83s

Loss: 59.0276 | Value loss: 40.5362 | Reward loss: 33.6444 | Policy loss: 15.2492


Training episode 324...
Total reward: -310.78 | Steps: 95 | Time: 1.55s

Loss: 58.9314 | Value loss: 40.4764 | Reward loss: 33.5631 | Policy loss: 15.2492


Training episode 325...
Total reward: -85.45 | Steps: 127 | Time: 2.80s

Loss: 59.0438 | Value loss: 40.5426 | Reward loss: 33.6



Saving video to data/anims//2024-10-18_15-47-50/ep_340_test_0.mp4
Total reward: -337.18




Saving video to data/anims//2024-10-18_15-47-50/ep_340_test_1.mp4
Total reward: -47.79




Saving video to data/anims//2024-10-18_15-47-50/ep_340_test_2.mp4
Total reward: -385.72


Training episode 340...
Total reward: -86.62 | Steps: 72 | Time: 1.14s

Loss: 58.9369 | Value loss: 40.4386 | Reward loss: 33.5780 | Policy loss: 15.2492


Training episode 341...
Total reward: -125.11 | Steps: 88 | Time: 1.36s

Loss: 58.9368 | Value loss: 40.4692 | Reward loss: 33.5702 | Policy loss: 15.2493


Training episode 342...
Total reward: -327.13 | Steps: 116 | Time: 1.84s

Loss: 58.8965 | Value loss: 40.4679 | Reward loss: 33.5303 | Policy loss: 15.2492


Training episode 343...
Total reward: -156.92 | Steps: 104 | Time: 1.67s

Loss: 58.8862 | Value loss: 40.4679 | Reward loss: 33.5198 | Policy loss: 15.2494


Training episode 344...
Total reward: -132.48 | Steps: 92 | Time: 1.90s

Loss: 59.0802 | Value loss: 40.4577 | Reward loss: 33.7164 | Policy loss: 15.2494


Training episode 345...
Total reward: -68.32 | Steps: 71 | Time: 1.27s

Loss: 58.9342 | Value loss: 40.4422 | Reward loss: 3



Saving video to data/anims//2024-10-18_15-47-50/ep_360_test_0.mp4
Total reward: -69.35




Saving video to data/anims//2024-10-18_15-47-50/ep_360_test_1.mp4
Total reward: -461.84




Saving video to data/anims//2024-10-18_15-47-50/ep_360_test_2.mp4
Total reward: -45.22


Training episode 360...
Total reward: -111.00 | Steps: 94 | Time: 1.44s

Loss: 59.0295 | Value loss: 40.4935 | Reward loss: 33.6569 | Policy loss: 15.2492


Training episode 361...
Total reward: -363.90 | Steps: 115 | Time: 1.74s

Loss: 59.0768 | Value loss: 40.4966 | Reward loss: 33.7034 | Policy loss: 15.2493


Training episode 362...
Total reward: -294.08 | Steps: 96 | Time: 1.52s

Loss: 58.8976 | Value loss: 40.4856 | Reward loss: 33.5268 | Policy loss: 15.2493


Training episode 363...
Total reward: -126.92 | Steps: 73 | Time: 1.12s

Loss: 58.9950 | Value loss: 40.4894 | Reward loss: 33.6234 | Policy loss: 15.2492


Training episode 364...
Total reward: -160.48 | Steps: 71 | Time: 1.10s

Loss: 58.8237 | Value loss: 40.4581 | Reward loss: 33.4599 | Policy loss: 15.2492


Training episode 365...
Total reward: -99.63 | Steps: 68 | Time: 1.04s

Loss: 58.9896 | Value loss: 40.4681 | Reward loss: 33



Saving video to data/anims//2024-10-18_15-47-50/ep_380_test_0.mp4
Total reward: -464.62




Saving video to data/anims//2024-10-18_15-47-50/ep_380_test_1.mp4
Total reward: -307.77




Saving video to data/anims//2024-10-18_15-47-50/ep_380_test_2.mp4
Total reward: -489.06


Training episode 380...
Total reward: -130.17 | Steps: 99 | Time: 1.52s

Loss: 58.8854 | Value loss: 40.4969 | Reward loss: 33.5119 | Policy loss: 15.2493


Training episode 381...
Total reward: -117.84 | Steps: 70 | Time: 1.07s

Loss: 58.7895 | Value loss: 40.4347 | Reward loss: 33.4317 | Policy loss: 15.2491


Training episode 382...
Total reward: 2.95 | Steps: 100 | Time: 1.61s

Loss: 58.9760 | Value loss: 40.4669 | Reward loss: 33.6100 | Policy loss: 15.2493


Training episode 383...
Total reward: -130.52 | Steps: 108 | Time: 1.68s

Loss: 58.9336 | Value loss: 40.4728 | Reward loss: 33.5661 | Policy loss: 15.2492


Training episode 384...
Total reward: -77.93 | Steps: 84 | Time: 1.33s

Loss: 58.9328 | Value loss: 40.4343 | Reward loss: 33.5751 | Policy loss: 15.2491


Training episode 385...
Total reward: -315.48 | Steps: 97 | Time: 1.80s

Loss: 58.9499 | Value loss: 40.5080 | Reward loss: 33.



Saving video to data/anims//2024-10-18_15-47-50/ep_400_test_0.mp4
Total reward: -115.89




Saving video to data/anims//2024-10-18_15-47-50/ep_400_test_1.mp4
Total reward: -363.32




Saving video to data/anims//2024-10-18_15-47-50/ep_400_test_2.mp4
Total reward: -280.24


Training episode 400...
Total reward: -131.21 | Steps: 91 | Time: 1.47s

Loss: 59.0354 | Value loss: 40.5429 | Reward loss: 33.6504 | Policy loss: 15.2493


Training episode 401...
Total reward: -243.84 | Steps: 100 | Time: 1.53s

Loss: 58.9456 | Value loss: 40.4861 | Reward loss: 33.5748 | Policy loss: 15.2493


Training episode 402...
Total reward: -288.64 | Steps: 83 | Time: 1.28s

Loss: 58.9502 | Value loss: 40.5228 | Reward loss: 33.5703 | Policy loss: 15.2492


Training episode 403...
Total reward: -13.59 | Steps: 131 | Time: 2.04s

Loss: 58.9700 | Value loss: 40.5231 | Reward loss: 33.5900 | Policy loss: 15.2493


Training episode 404...
Total reward: -97.65 | Steps: 77 | Time: 1.20s

Loss: 59.1201 | Value loss: 40.5553 | Reward loss: 33.7320 | Policy loss: 15.2493


Training episode 405...
Total reward: -109.21 | Steps: 91 | Time: 1.38s

Loss: 59.1106 | Value loss: 40.5279 | Reward loss: 3



Saving video to data/anims//2024-10-18_15-47-50/ep_420_test_0.mp4
Total reward: -173.61




Saving video to data/anims//2024-10-18_15-47-50/ep_420_test_1.mp4
Total reward: -365.93




Saving video to data/anims//2024-10-18_15-47-50/ep_420_test_2.mp4
Total reward: -94.28


Training episode 420...
Total reward: -12.83 | Steps: 144 | Time: 2.20s

Loss: 58.8047 | Value loss: 40.5193 | Reward loss: 33.4256 | Policy loss: 15.2493


Training episode 421...
Total reward: -277.64 | Steps: 90 | Time: 1.40s

Loss: 59.0278 | Value loss: 40.5466 | Reward loss: 33.6419 | Policy loss: 15.2492


Training episode 422...
Total reward: -214.69 | Steps: 119 | Time: 1.86s

Loss: 58.9264 | Value loss: 40.5069 | Reward loss: 33.5505 | Policy loss: 15.2492


Training episode 423...
Total reward: -95.96 | Steps: 93 | Time: 1.42s

Loss: 58.8788 | Value loss: 40.5280 | Reward loss: 33.4975 | Policy loss: 15.2493


Training episode 424...
Total reward: -124.03 | Steps: 84 | Time: 1.28s

Loss: 58.8262 | Value loss: 40.4505 | Reward loss: 33.4643 | Policy loss: 15.2493


Training episode 425...
Total reward: -306.39 | Steps: 113 | Time: 1.74s

Loss: 58.9464 | Value loss: 40.4989 | Reward loss: 3



Saving video to data/anims//2024-10-18_15-47-50/ep_440_test_0.mp4
Total reward: -101.79




Saving video to data/anims//2024-10-18_15-47-50/ep_440_test_1.mp4
Total reward: -178.77




Saving video to data/anims//2024-10-18_15-47-50/ep_440_test_2.mp4
Total reward: -96.99


Training episode 440...
Total reward: -185.00 | Steps: 67 | Time: 1.06s

Loss: 58.9239 | Value loss: 40.5195 | Reward loss: 33.5448 | Policy loss: 15.2492


Training episode 441...
Total reward: -98.60 | Steps: 93 | Time: 1.45s

Loss: 58.9568 | Value loss: 40.5003 | Reward loss: 33.5824 | Policy loss: 15.2493


Training episode 442...
Total reward: -264.05 | Steps: 130 | Time: 2.05s

Loss: 58.9481 | Value loss: 40.5214 | Reward loss: 33.5684 | Policy loss: 15.2493


Training episode 443...
Total reward: -195.28 | Steps: 109 | Time: 1.71s

Loss: 58.8811 | Value loss: 40.5205 | Reward loss: 33.5017 | Policy loss: 15.2493


Training episode 444...
Total reward: -92.73 | Steps: 83 | Time: 1.27s

Loss: 58.9625 | Value loss: 40.5078 | Reward loss: 33.5863 | Policy loss: 15.2492


Training episode 445...
Total reward: -151.02 | Steps: 93 | Time: 1.41s

Loss: 58.9750 | Value loss: 40.4941 | Reward loss: 33



Saving video to data/anims//2024-10-18_15-47-50/ep_460_test_0.mp4
Total reward: -8.10




Saving video to data/anims//2024-10-18_15-47-50/ep_460_test_1.mp4
Total reward: -236.82




Saving video to data/anims//2024-10-18_15-47-50/ep_460_test_2.mp4
Total reward: -348.48


Training episode 460...
Total reward: -118.92 | Steps: 97 | Time: 1.49s

Loss: 59.1393 | Value loss: 40.5327 | Reward loss: 33.7569 | Policy loss: 15.2492


Training episode 461...
Total reward: -226.63 | Steps: 126 | Time: 2.00s

Loss: 58.9170 | Value loss: 40.4833 | Reward loss: 33.5469 | Policy loss: 15.2493


Training episode 462...
Total reward: -121.08 | Steps: 68 | Time: 1.11s

Loss: 58.9555 | Value loss: 40.5331 | Reward loss: 33.5729 | Policy loss: 15.2493


Training episode 463...
Total reward: -127.89 | Steps: 73 | Time: 1.17s

Loss: 58.9822 | Value loss: 40.4822 | Reward loss: 33.6124 | Policy loss: 15.2493


Training episode 464...
Total reward: -440.72 | Steps: 88 | Time: 1.43s

Loss: 59.0256 | Value loss: 40.5075 | Reward loss: 33.6496 | Policy loss: 15.2492


Training episode 465...
Total reward: -116.40 | Steps: 75 | Time: 1.23s

Loss: 58.9427 | Value loss: 40.4613 | Reward loss: 



Saving video to data/anims//2024-10-18_15-47-50/ep_480_test_0.mp4
Total reward: -557.33




Saving video to data/anims//2024-10-18_15-47-50/ep_480_test_1.mp4
Total reward: -295.30




Saving video to data/anims//2024-10-18_15-47-50/ep_480_test_2.mp4
Total reward: -220.19


Training episode 480...
Total reward: -90.50 | Steps: 72 | Time: 1.17s

Loss: 59.0031 | Value loss: 40.5166 | Reward loss: 33.6247 | Policy loss: 15.2493


Training episode 481...
Total reward: -106.03 | Steps: 78 | Time: 1.27s

Loss: 59.0291 | Value loss: 40.5443 | Reward loss: 33.6439 | Policy loss: 15.2492


Training episode 482...
Total reward: -265.96 | Steps: 98 | Time: 1.54s

Loss: 58.9614 | Value loss: 40.5157 | Reward loss: 33.5831 | Policy loss: 15.2493


Training episode 483...
Total reward: -88.85 | Steps: 63 | Time: 0.98s

Loss: 59.1285 | Value loss: 40.5699 | Reward loss: 33.7367 | Policy loss: 15.2493


Training episode 484...
Total reward: -443.49 | Steps: 107 | Time: 1.75s

Loss: 58.9769 | Value loss: 40.4751 | Reward loss: 33.6088 | Policy loss: 15.2493


Training episode 485...
Total reward: -90.28 | Steps: 63 | Time: 0.99s

Loss: 59.0192 | Value loss: 40.5275 | Reward loss: 33.



Saving video to data/anims//2024-10-18_15-47-50/ep_500_test_0.mp4
Total reward: -153.76




Saving video to data/anims//2024-10-18_15-47-50/ep_500_test_1.mp4
Total reward: -143.65




Saving video to data/anims//2024-10-18_15-47-50/ep_500_test_2.mp4
Total reward: -302.95


Training episode 500...
Total reward: -147.98 | Steps: 115 | Time: 1.90s

Loss: 58.9543 | Value loss: 40.5315 | Reward loss: 33.5721 | Policy loss: 15.2493


Training episode 501...
Total reward: -102.02 | Steps: 76 | Time: 1.27s

Loss: 58.9086 | Value loss: 40.5433 | Reward loss: 33.5235 | Policy loss: 15.2492


Training episode 502...
Total reward: -113.02 | Steps: 85 | Time: 1.30s

Loss: 58.9777 | Value loss: 40.5433 | Reward loss: 33.5927 | Policy loss: 15.2492


Training episode 503...
Total reward: -223.39 | Steps: 72 | Time: 1.16s

Loss: 58.9792 | Value loss: 40.5307 | Reward loss: 33.5972 | Policy loss: 15.2493


Training episode 504...


KeyboardInterrupt: 