This notebook expands the action embedding of a pretrained world model
to include the new actions introduced to enable the goal-conditioned
behavior. The new actions are:
1. `ACTION_GET_GOAL` (`env.num_actions`): The following state after this action denotes the desired goal state.
2. `ACTION_START_PLANNING` (`env.num_actions + 1`): The following states and actions after this action should be the goal-conditioned trajectory.

New checkpoint file is saved to `/path/to/goal-conditioned-iris/src/outputs/checkpoints/epoch_0/last.pt`.

In [None]:
# Load config file

from hydra import initialize, compose
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="../config/"):
    cfg = compose(config_name="trainer.yaml")
    print(OmegaConf.to_yaml(cfg))

In [None]:
!mkdir outputs
%cd outputs

In [None]:
from collections import defaultdict
from functools import partial
from pathlib import Path
import shutil
import sys
import time
from typing import Any, Dict, Optional, Tuple

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
import torch
import torch.nn as nn
from tqdm import tqdm

from agent import Agent
from collector import Collector
from envs import SingleProcessEnv, MultiProcessEnv
from models.actor_critic import ActorCritic
from models.world_model import WorldModel
from utils import configure_optimizer, EpisodeDirManager, set_seed
from planning_utils import expand_world_model_embedding


class Trainer:
    def __init__(self, cfg: DictConfig) -> None:
        if cfg.common.seed is not None:
            set_seed(cfg.common.seed)

        self.cfg = cfg
        self.start_epoch = 1
        self.device = torch.device(cfg.common.device)

        self.ckpt_dir = Path("checkpoints")
        self.media_dir = Path("media")
        self.episode_dir = self.media_dir / "episodes"
        self.reconstructions_dir = self.media_dir / "reconstructions"

        if not cfg.common.resume:
            config_dir = Path("config")
            config_path = config_dir / "trainer.yaml"
            config_dir.mkdir(exist_ok=True, parents=False)

            self.ckpt_dir.mkdir(exist_ok=True, parents=False)
            self.media_dir.mkdir(exist_ok=True, parents=False)
            self.episode_dir.mkdir(exist_ok=True, parents=False)
            self.reconstructions_dir.mkdir(exist_ok=True, parents=False)

        episode_manager_train = EpisodeDirManager(
            self.episode_dir / "train",
            max_num_episodes=cfg.collection.train.num_episodes_to_save,
        )
        episode_manager_test = EpisodeDirManager(
            self.episode_dir / "test",
            max_num_episodes=cfg.collection.test.num_episodes_to_save,
        )
        self.episode_manager_imagination = EpisodeDirManager(
            self.episode_dir / "imagination",
            max_num_episodes=cfg.evaluation.actor_critic.num_episodes_to_save,
        )

        def create_env(cfg_env, num_envs):
            env_fn = partial(instantiate, config=cfg_env)
            return (
                MultiProcessEnv(env_fn, num_envs, should_wait_num_envs_ratio=1.0)
                if num_envs > 1
                else SingleProcessEnv(env_fn)
            )

        if self.cfg.training.should:
            train_env = create_env(cfg.env.train, cfg.collection.train.num_envs)
            self.train_dataset = instantiate(cfg.datasets.train)
            self.train_collector = Collector(
                train_env, self.train_dataset, episode_manager_train
            )

        if self.cfg.evaluation.should:
            test_env = create_env(cfg.env.test, cfg.collection.test.num_envs)
            self.test_dataset = instantiate(cfg.datasets.test)
            self.test_collector = Collector(
                test_env, self.test_dataset, episode_manager_test
            )

        assert self.cfg.training.should or self.cfg.evaluation.should
        env = train_env if self.cfg.training.should else test_env

        tokenizer = instantiate(cfg.tokenizer)
        world_model = WorldModel(
            obs_vocab_size=tokenizer.vocab_size,
            act_vocab_size=env.num_actions,
            config=instantiate(cfg.world_model),
        )
        actor_critic = ActorCritic(**cfg.actor_critic, act_vocab_size=env.num_actions)
        self.agent = Agent(tokenizer, world_model, actor_critic).to(self.device)
        print(
            f"{sum(p.numel() for p in self.agent.tokenizer.parameters())} parameters in agent.tokenizer"
        )
        print(
            f"{sum(p.numel() for p in self.agent.world_model.parameters())} parameters in agent.world_model"
        )
        print(
            f"{sum(p.numel() for p in self.agent.actor_critic.parameters())} parameters in agent.actor_critic"
        )

        if cfg.initialization.path_to_checkpoint is not None:
            self.agent.load(**cfg.initialization, device=self.device)

        if (
            self.agent.world_model.embedder.embedding_tables[0].weight.shape[0]
            == self.train_collector.env.num_actions
        ):
            print("Reshaping embedding tables")
            expand_world_model_embedding(self.agent.world_model)

        self.optimizer_tokenizer = torch.optim.Adam(
            self.agent.tokenizer.parameters(), lr=cfg.training.learning_rate
        )
        self.optimizer_world_model = configure_optimizer(
            self.agent.world_model,
            cfg.training.learning_rate,
            cfg.training.world_model.weight_decay,
        )
        self.optimizer_actor_critic = torch.optim.Adam(
            self.agent.actor_critic.parameters(), lr=cfg.training.learning_rate
        )

        if cfg.common.resume:
            self.load_checkpoint()

    def _save_checkpoint(self, epoch: int, save_agent_only: bool) -> None:
        torch.save(self.agent.state_dict(), self.ckpt_dir / "last.pt")

        if not save_agent_only:
            torch.save(epoch, self.ckpt_dir / "epoch.pt")
            torch.save(
                {
                    "optimizer_tokenizer": self.optimizer_tokenizer.state_dict(),
                    "optimizer_world_model": self.optimizer_world_model.state_dict(),
                    "optimizer_actor_critic": self.optimizer_actor_critic.state_dict(),
                },
                self.ckpt_dir / "optimizer.pt",
            )

            ckpt_epoch_dir = self.ckpt_dir / f"epoch_{epoch}"
            ckpt_epoch_dir.mkdir(exist_ok=True, parents=False)
            torch.save(self.agent.state_dict(), ckpt_epoch_dir / "last.pt")
            torch.save(epoch, ckpt_epoch_dir / "epoch.pt")
            torch.save(
                {
                    "optimizer_tokenizer": self.optimizer_tokenizer.state_dict(),
                    "optimizer_world_model": self.optimizer_world_model.state_dict(),
                    "optimizer_actor_critic": self.optimizer_actor_critic.state_dict(),
                },
                ckpt_epoch_dir / "optimizer.pt",
            )

            ckpt_dataset_dir = self.ckpt_dir / "dataset"
            ckpt_dataset_dir.mkdir(exist_ok=True, parents=False)
            self.train_dataset.update_disk_checkpoint(ckpt_dataset_dir)
            if self.cfg.evaluation.should:
                torch.save(
                    self.test_dataset.num_seen_episodes,
                    self.ckpt_dir / "num_seen_episodes_test_dataset.pt",
                )

    def save_checkpoint(self, epoch: int, save_agent_only: bool) -> None:
        tmp_checkpoint_dir = Path("checkpoints_tmp")
        shutil.copytree(
            src=self.ckpt_dir,
            dst=tmp_checkpoint_dir,
            ignore=shutil.ignore_patterns("dataset"),
        )
        self._save_checkpoint(epoch, save_agent_only)
        shutil.rmtree(tmp_checkpoint_dir)

    def load_checkpoint(self) -> None:
        assert self.ckpt_dir.is_dir()
        self.start_epoch = torch.load(self.ckpt_dir / "epoch.pt") + 1
        self.agent.load(self.ckpt_dir / "last.pt", device=self.device)
        ckpt_opt = torch.load(self.ckpt_dir / "optimizer.pt", map_location=self.device)
        self.optimizer_tokenizer.load_state_dict(ckpt_opt["optimizer_tokenizer"])
        self.optimizer_world_model.load_state_dict(ckpt_opt["optimizer_world_model"])
        self.optimizer_actor_critic.load_state_dict(ckpt_opt["optimizer_actor_critic"])
        self.train_dataset.load_disk_checkpoint(self.ckpt_dir / "dataset")
        if self.cfg.evaluation.should:
            self.test_dataset.num_seen_episodes = torch.load(
                self.ckpt_dir / "num_seen_episodes_test_dataset.pt"
            )
        print(
            f"Successfully loaded model, optimizer and {len(self.train_dataset)} episodes from {self.ckpt_dir.absolute()}."
        )

    def _to_device(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        return {k: batch[k].to(self.device) for k in batch}

In [None]:
trainer = Trainer(cfg)

In [None]:
trainer.save_checkpoint(epoch=0, save_agent_only=False)