This notebook evaluates the final checkpoint and creates plots for the planning and imagination on the In-Domain test set.

Please make sure that you have recorded your own trajectory, if you want evaluation for the custom dataset. The method for the recording is descripted in the `README.md` file.

In [None]:
from collections import defaultdict
from functools import partial
from pathlib import Path
import shutil
import sys
import time
from typing import Any, Dict, Optional, Tuple

import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import wandb
from einops import rearrange

from agent import Agent
from collector import Collector
from env_collector import EnvCollector
from envs import SingleProcessEnv, MultiProcessEnv
from episode import Episode
from make_reconstructions import make_reconstructions_from_batch
from models.actor_critic import ActorCritic
from models.world_model import WorldModel
from utils import configure_optimizer, EpisodeDirManager, set_seed
from planning_utils import expand_world_model_embedding
from models.slicer import Embedder, Head
from evaluate import evaluate

In [None]:
# Load config file

# Please set the path as your own path to the last checkpoint
PATH_TO_CHECKPOINT = (
    "/path/to/goal-conditioned-iris/src/outputs/checkpoints/epoch_250/last.pt"
)

with initialize(version_base=None, config_path="../config/"):
    cfg = compose(config_name="trainer.yaml")
    cfg.initialization.path_to_checkpoint = PATH_TO_CHECKPOINT
    print(OmegaConf.to_yaml(cfg))

In [None]:
%cd outputs
!mkdir eval

In [None]:
class Trainer:
    def __init__(self, cfg: DictConfig) -> None:
        wandb.init(
            config=OmegaConf.to_container(cfg, resolve=True),
            reinit=True,
            resume=True,
            **cfg.wandb,
        )

        if cfg.common.seed is not None:
            set_seed(cfg.common.seed)

        self.cfg = cfg
        self.start_epoch = 1
        self.device = torch.device(cfg.common.device)

        self.ckpt_dir = Path("checkpoints")
        self.media_dir = Path("media")
        self.episode_dir = self.media_dir / "episodes"
        self.reconstructions_dir = self.media_dir / "reconstructions"

        if not cfg.common.resume:
            config_dir = Path("config")
            config_path = config_dir / "trainer.yaml"
            config_dir.mkdir(exist_ok=True, parents=False)
            wandb.save(str(config_path))

            self.ckpt_dir.mkdir(exist_ok=True, parents=False)
            self.media_dir.mkdir(exist_ok=True, parents=False)
            self.episode_dir.mkdir(exist_ok=True, parents=False)
            self.reconstructions_dir.mkdir(exist_ok=True, parents=False)

        episode_manager_train = EpisodeDirManager(
            self.episode_dir / "train",
            max_num_episodes=cfg.collection.train.num_episodes_to_save,
        )
        episode_manager_test = EpisodeDirManager(
            self.episode_dir / "test",
            max_num_episodes=cfg.collection.test.num_episodes_to_save,
        )
        self.episode_manager_imagination = EpisodeDirManager(
            self.episode_dir / "imagination",
            max_num_episodes=cfg.evaluation.actor_critic.num_episodes_to_save,
        )

        def create_env(cfg_env, num_envs):
            env_fn = partial(instantiate, config=cfg_env)
            return (
                MultiProcessEnv(env_fn, num_envs, should_wait_num_envs_ratio=1.0)
                if num_envs > 1
                else SingleProcessEnv(env_fn)
            )

        if self.cfg.training.should:
            train_env = create_env(cfg.env.train, cfg.collection.train.num_envs)
            self.train_dataset = instantiate(cfg.datasets.train)
            self.train_collector = Collector(
                train_env, self.train_dataset, episode_manager_train
            )

        if self.cfg.evaluation.should:
            test_env = create_env(cfg.env.test, cfg.collection.test.num_envs)
            self.test_dataset = instantiate(cfg.datasets.test)
            self.test_collector = EnvCollector(
                test_env, self.test_dataset, episode_manager_test
            )

        assert self.cfg.training.should or self.cfg.evaluation.should
        env = train_env if self.cfg.training.should else test_env

        tokenizer = instantiate(cfg.tokenizer)
        world_model = WorldModel(
            obs_vocab_size=tokenizer.vocab_size,
            act_vocab_size=env.num_actions,
            config=instantiate(cfg.world_model),
        )
        actor_critic = ActorCritic(**cfg.actor_critic, act_vocab_size=env.num_actions)
        self.agent = Agent(tokenizer, world_model, actor_critic).to(self.device)
        print(
            f"{sum(p.numel() for p in self.agent.tokenizer.parameters())} parameters in agent.tokenizer"
        )
        print(
            f"{sum(p.numel() for p in self.agent.world_model.parameters())} parameters in agent.world_model"
        )
        print(
            f"{sum(p.numel() for p in self.agent.actor_critic.parameters())} parameters in agent.actor_critic"
        )

        if (
            self.agent.world_model.embedder.embedding_tables[0].weight.shape[0]
            == self.train_collector.env.num_actions
        ):
            print("Reshaping embedding tables")
            expand_world_model_embedding(self.agent.world_model)

        if cfg.initialization.path_to_checkpoint is not None:
            self.agent.load(**cfg.initialization, device=self.device)

        self.optimizer_tokenizer = torch.optim.Adam(
            self.agent.tokenizer.parameters(), lr=cfg.training.learning_rate
        )
        self.optimizer_world_model = configure_optimizer(
            self.agent.world_model,
            cfg.training.learning_rate,
            cfg.training.world_model.weight_decay,
        )
        self.optimizer_actor_critic = torch.optim.Adam(
            self.agent.actor_critic.parameters(), lr=cfg.training.learning_rate
        )

        if cfg.common.resume:
            self.load_checkpoint()

    def run(self) -> None:
        self.train_collector.collect(
            self.agent, epoch=1, **self.cfg.collection.train.config
        )
        self.test_collector.collect(
            self.agent, epoch=1, **self.cfg.collection.test.config
        )
        for epoch in range(self.start_epoch, 1 + self.cfg.common.epochs):
            print(f"\nEpoch {epoch} / {self.cfg.common.epochs}\n")
            start_time = time.time()
            to_log = []

            if self.cfg.training.should:
                to_log += self.train_world_model(epoch)

            if self.cfg.evaluation.should and (epoch % self.cfg.evaluation.every == 0):
                to_log += self.eval_agent(epoch)

            if self.cfg.training.should and (epoch % 10 == 0):
                self.save_checkpoint(
                    epoch, save_agent_only=not self.cfg.common.do_checkpoint
                )

            to_log.append({"duration": (time.time() - start_time) / 3600})
            for metrics in to_log:
                wandb.log({"epoch": epoch, **metrics})

        self.finish()

    def train_world_model(self, epoch: int) -> None:
        self.agent.train()
        self.agent.zero_grad()

        metrics_tokenizer, metrics_world_model, metrics_actor_critic = {}, {}, {}

        cfg_tokenizer = self.cfg.training.tokenizer
        cfg_world_model = self.cfg.training.world_model
        cfg_actor_critic = self.cfg.training.actor_critic

        w = self.cfg.training.sampling_weights

        if epoch > cfg_world_model.start_after_epochs:
            metrics_world_model = self.train_component(
                self.agent.world_model,
                self.optimizer_world_model,
                sequence_length=self.cfg.common.sequence_length,
                sample_from_start=True,
                sampling_weights=w,
                tokenizer=self.agent.tokenizer,
                **cfg_world_model,
            )
        self.agent.world_model.eval()

        return [
            {
                "epoch": epoch,
                **metrics_tokenizer,
                **metrics_world_model,
                **metrics_actor_critic,
            }
        ]

    def get_planning_batch(
        self, batch: Dict[str, torch.Tensor], planning_steps
    ) -> Dict[str, torch.Tensor]:
        ACTION_GET_GOAL = self.train_collector.env.num_actions
        ACTION_START_PLANNING = self.train_collector.env.num_actions + 1

        new_batch = {k: [] for k in batch}

        batch_size, sequence_length = batch["mask_padding"].shape
        for batch_idx in range(batch_size):
            idx = []

            for k in new_batch:
                new_batch[k].append([])

            # put first 5 steps for the context
            for i in range(5):
                for k in new_batch:
                    new_batch[k][batch_idx].append(batch[k][batch_idx, None, i])

            # Check next n steps if there's no termination
            i = 5
            while len(new_batch["observations"][batch_idx]) < sequence_length - (
                planning_steps + 2
            ):
                if torch.all(
                    batch["mask_padding"][batch_idx, i : i + planning_steps]
                ) and not torch.any(batch["ends"][batch_idx, i : i + planning_steps]):
                    # If there is no termination, add the next n steps to the batch
                    for k in batch:
                        new_batch[k][batch_idx].append(
                            batch[k][batch_idx, None, i].clone()
                        )
                        new_batch[k][batch_idx].append(
                            batch[k][batch_idx, None, i + planning_steps - 1].clone()
                        )
                        for j in range(i, i + planning_steps - 1):
                            new_batch[k][batch_idx].append(
                                batch[k][batch_idx, None, j].clone()
                            )

                    new_batch["actions"][batch_idx][-planning_steps - 1][
                        0
                    ] = ACTION_GET_GOAL
                    new_batch["actions"][batch_idx][-planning_steps][
                        0
                    ] = ACTION_START_PLANNING
                    i += planning_steps - 1

                else:
                    break

            # add remainders
            while len(new_batch["observations"][batch_idx]) < sequence_length:
                for k in batch:
                    new_batch[k][batch_idx].append(batch[k][batch_idx, None, i].clone())
                i += 1

            for k in new_batch:
                new_batch[k][batch_idx] = torch.cat(new_batch[k][batch_idx], dim=0)

        # convert to tensors
        for k in new_batch:
            new_batch[k] = torch.stack(new_batch[k])

        return new_batch

    def train_component(
        self,
        component: nn.Module,
        optimizer: torch.optim.Optimizer,
        steps_per_epoch: int,
        batch_num_samples: int,
        grad_acc_steps: int,
        max_grad_norm: Optional[float],
        sequence_length: int,
        sampling_weights: Optional[Tuple[float]],
        sample_from_start: bool,
        **kwargs_loss: Any,
    ) -> Dict[str, float]:
        loss_total_epoch = 0.0
        intermediate_losses = defaultdict(float)

        for _ in tqdm(
            range(steps_per_epoch), desc=f"Training {str(component)}", file=sys.stdout
        ):
            optimizer.zero_grad()
            for _ in range(grad_acc_steps):
                batch = self.train_dataset.sample_batch(
                    batch_num_samples,
                    sequence_length,
                    sampling_weights,
                    sample_from_start,
                )
                batch = self.get_planning_batch(batch, self.cfg.common.planning_steps)

                batch = self._to_device(batch)

                losses = component.compute_loss(batch, **kwargs_loss) / grad_acc_steps
                loss_total_step = losses.loss_total
                loss_total_step.backward()
                loss_total_epoch += loss_total_step.item() / steps_per_epoch

                for loss_name, loss_value in losses.intermediate_losses.items():
                    intermediate_losses[f"{str(component)}/train/{loss_name}"] += (
                        loss_value / steps_per_epoch
                    )

            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(component.parameters(), max_grad_norm)

            optimizer.step()

        metrics = {
            f"{str(component)}/train/total_loss": loss_total_epoch,
            **intermediate_losses,
        }
        return metrics

    @torch.no_grad()
    def eval_agent(self, epoch: int) -> None:
        self.agent.eval()

        metrics_tokenizer, metrics_world_model = {}, {}

        cfg_tokenizer = self.cfg.evaluation.tokenizer
        cfg_world_model = self.cfg.evaluation.world_model
        cfg_actor_critic = self.cfg.evaluation.actor_critic

        if epoch > cfg_world_model.start_after_epochs:
            metrics_world_model = evaluate(
                self.test_collector.env,
                self.test_dataset,
                self.agent,
                self.cfg.common.planning_steps,
                save_plots=True,
            )
            metrics_random = evaluate(
                self.test_collector.env,
                self.test_dataset,
                self.agent,
                self.cfg.common.planning_steps,
                save_plots=False,
                random=True,
            )
            metrics_world_model["random/eval/goal_distance"] = metrics_random[
                "world_model/eval/goal_distance"
            ]

        return [metrics_world_model]

    @torch.no_grad()
    def eval_component(
        self,
        component: nn.Module,
        batch_num_samples: int,
        sequence_length: int,
        **kwargs_loss: Any,
    ) -> Dict[str, float]:
        loss_total_epoch = 0.0
        intermediate_losses = defaultdict(float)

        steps = 0
        pbar = tqdm(desc=f"Evaluating {str(component)}", file=sys.stdout)
        for batch in self.test_dataset.traverse(batch_num_samples, sequence_length):
            batch = self.get_planning_batch(batch, self.cfg.common.planning_steps)
            batch = self._to_device(batch)

            losses = component.compute_loss(batch, **kwargs_loss)
            loss_total_epoch += losses.loss_total.item()

            for loss_name, loss_value in losses.intermediate_losses.items():
                intermediate_losses[f"{str(component)}/eval/{loss_name}"] += loss_value

            steps += 1
            pbar.update(1)

        intermediate_losses = {k: v / steps for k, v in intermediate_losses.items()}
        metrics = {
            f"{str(component)}/eval/total_loss": loss_total_epoch / steps,
            **intermediate_losses,
        }
        return metrics

    def _save_checkpoint(self, epoch: int, save_agent_only: bool) -> None:
        torch.save(self.agent.state_dict(), self.ckpt_dir / "last.pt")

        if not save_agent_only:
            torch.save(epoch, self.ckpt_dir / "epoch.pt")
            torch.save(
                {
                    "optimizer_tokenizer": self.optimizer_tokenizer.state_dict(),
                    "optimizer_world_model": self.optimizer_world_model.state_dict(),
                    "optimizer_actor_critic": self.optimizer_actor_critic.state_dict(),
                },
                self.ckpt_dir / "optimizer.pt",
            )

            ckpt_epoch_dir = self.ckpt_dir / f"epoch_{epoch}"
            ckpt_epoch_dir.mkdir(exist_ok=True, parents=False)
            torch.save(self.agent.state_dict(), ckpt_epoch_dir / "last.pt")
            torch.save(epoch, ckpt_epoch_dir / "epoch.pt")
            torch.save(
                {
                    "optimizer_tokenizer": self.optimizer_tokenizer.state_dict(),
                    "optimizer_world_model": self.optimizer_world_model.state_dict(),
                    "optimizer_actor_critic": self.optimizer_actor_critic.state_dict(),
                },
                ckpt_epoch_dir / "optimizer.pt",
            )

            ckpt_dataset_dir = self.ckpt_dir / "dataset"
            ckpt_dataset_dir.mkdir(exist_ok=True, parents=False)
            self.train_dataset.update_disk_checkpoint(ckpt_dataset_dir)
            if self.cfg.evaluation.should:
                torch.save(
                    self.test_dataset.num_seen_episodes,
                    self.ckpt_dir / "num_seen_episodes_test_dataset.pt",
                )

    def save_checkpoint(self, epoch: int, save_agent_only: bool) -> None:
        tmp_checkpoint_dir = Path("checkpoints_tmp")
        shutil.copytree(
            src=self.ckpt_dir,
            dst=tmp_checkpoint_dir,
            ignore=shutil.ignore_patterns("dataset"),
        )
        self._save_checkpoint(epoch, save_agent_only)
        shutil.rmtree(tmp_checkpoint_dir)

    def load_checkpoint(self) -> None:
        assert self.ckpt_dir.is_dir()
        self.start_epoch = torch.load(self.ckpt_dir / "epoch.pt") + 1
        self.agent.load(self.ckpt_dir / "last.pt", device=self.device)
        ckpt_opt = torch.load(self.ckpt_dir / "optimizer.pt", map_location=self.device)
        self.optimizer_tokenizer.load_state_dict(ckpt_opt["optimizer_tokenizer"])
        self.optimizer_world_model.load_state_dict(ckpt_opt["optimizer_world_model"])
        self.optimizer_actor_critic.load_state_dict(ckpt_opt["optimizer_actor_critic"])
        self.train_dataset.load_disk_checkpoint(self.ckpt_dir / "dataset")
        if self.cfg.evaluation.should:
            self.test_dataset.num_seen_episodes = torch.load(
                self.ckpt_dir / "num_seen_episodes_test_dataset.pt"
            )
        print(
            f"Successfully loaded model, optimizer and {len(self.train_dataset)} episodes from {self.ckpt_dir.absolute()}."
        )

    def _to_device(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        return {k: batch[k].to(self.device) for k in batch}

    def finish(self) -> None:
        wandb.finish()

In [None]:
trainer = Trainer(cfg)

#### In-domain Evaluation

In [None]:
cfg.collection.test.config.num_episodes = 2

In [None]:
to_log = trainer.test_collector.collect(
    trainer.agent, epoch=1, **trainer.cfg.collection.test.config
)
print(to_log)

In [None]:
to_log = trainer.eval_agent(epoch=2)
print(to_log)

Check out the `src/outputs/eval` folder to see the planning / imagination plots.

#### OOD Evaluation (Evaluation on Custom Dataset)

In [None]:
trainer.test_dataset.clear()
trainer.test_dataset.load_custom_trajectories(Path("../../custom_trajectories/"))

In [None]:
to_log = trainer.eval_agent(epoch=3)
print(to_log)