<a href="https://colab.research.google.com/github/lorenzosteccanella/HRL-MDP/blob/main/Example_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A simple example on how to learn a representation on the Nine Rooms environment

In [None]:
# only required to run this notebook on COLAB

!git clone https://github.com/lorenzosteccanella/HRL-MDP.git
!cd HRL-MDP && ls && pip install -r requirements.txt
%cd HRL-MDP

In [None]:
# To supress old gym deprecation warnings
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', message='Parameters to load are deprecated.*')

import numpy as np
from utils import collect_trajectories, representation_score, wandb_plot
from model import SoftClusterNetwork
from torch import optim
import torch
import matplotlib.pyplot as plt

In [None]:
config = {
    "seed": 0,
    "env": "MiniGrid-NineRoomsDet-v0",
    "load_data": False,
    "p_random_action": 0,
    "max_len_episode": 100,
    "pos_or_image": "image",
    "n_episodes_env": 1000,
    "n_abstract_states": 9,
    "width": 19,
    "height": 19,
    "lr": 1e-4,
    "epochs": 2000,
    "batch_size": 32,
    "wl1": 1,
    "wl2": 0.4,
    "wl3": 0.1
}

In [None]:
# First let's collect some random trajectory data
memory, trajectories_dataset, print_states, annotations = collect_trajectories(config)


In [None]:
# set seed
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])

# # # Training the network
network = SoftClusterNetwork(config["n_abstract_states"], config["width"], config["height"])
optimizer = optim.Adam(network.parameters(), lr=config["lr"])


In [None]:
losses = []
for i in range(config["epochs"]):
    network.train()
    idx, batch_x1, batch_x2, b_is_weights = memory.sample(config["batch_size"])
    x1 = torch.stack(batch_x1)
    x2 = torch.stack(batch_x2)
    z1 = network.pred(x1, 1)
    z2 = network.pred(x2, 1)
    compression_loss = ((-(z1 * z2.log())).sum(axis=1)).mean(axis=0)
    compression_loss = compression_loss / config["batch_size"]
    entropy_loss = (z1.mean(dim=0) * (z1.mean(dim=0).log())).sum()
    det_entropy_loss = (- (z1 * z1.log()).sum(dim=1)).mean()
    loss = config["wl1"] * compression_loss + config["wl2"] * entropy_loss + config["wl3"] * det_entropy_loss
    losses.append(loss.item())

    if i % 100 == 0:
        print(f"Epoch {i}, Loss: {loss.item(), compression_loss.item(), entropy_loss.item(), det_entropy_loss.item()}")
        error, squared_error, abs_error = representation_score(config, network.eval())
        print(f"Error: {error}, Squared Error: {squared_error}, Abs Error: {abs_error}")
        fig = wandb_plot(print_states, annotations, network.eval(), d=2)
        # Draw figure on canvas
        fig.canvas.draw()
        plt.show()

        # Create a new figure for losses
        plt.figure(figsize=(10, 6))
        # moving windows on the losses
        losses_to_plot = np.convolve(losses, np.ones(100)/100, mode='valid')
        # plot the losses graph with labels and title
        plt.plot(losses_to_plot, color='blue', linewidth=2)
        plt.xlabel('Training Steps (Moving Average Window=100)')
        plt.ylabel('Total Loss')
        plt.title('Training Loss Over Time')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.show()

    optimizer.zero_grad()
    loss.backward()
    # Replace gradient clamping with norm clipping
    torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.1)
    optimizer.step()