In [1]:
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
from collections import deque
from typing import Dict, List, Optional

import habitat
import numpy as np
import torch
import tqdm

from torch.optim.lr_scheduler import LambdaLR
from habitat import Config, logger
from habitat.utils.visualizations.utils import observations_to_image
from habitat_baselines.common.base_trainer import BaseRLTrainer
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.env_utils import construct_envs
from habitat_baselines.common.environments import get_env_class
from habitat_baselines.common.rollout_storage import RolloutStorage
from habitat_baselines.common.tensorboard_utils import TensorboardWriter
from habitat_baselines.common.utils import (
    batch_obs,
    generate_video,
    linear_decay,
)
from habitat_baselines.config.default import get_config
from habitat_baselines.rl.ppo import PPO, PointNavBaselinePolicy

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
def evaluate(self, checkpoint_path):
    self.device = (
        torch.device("cuda", self.config.TORCH_GPU_ID)
        if torch.cuda.is_available()
        else torch.device("cpu")
    )

    ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

    if self.config.EVAL.USE_CKPT_CONFIG:
        config = self._setup_eval_config(ckpt_dict["config"])
    else:
        config = self.config.clone()

    ppo_cfg = config.RL.PPO
    config.defrost()
    config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
    config.NUM_PROCESSES = 1
    config.freeze()

    if len(self.config.VIDEO_OPTION) > 0:
        config.defrost()
        config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
        config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
        config.freeze()

    logger.info(f"env config: {config}")
    self.envs = construct_envs(config, get_env_class(config.ENV_NAME))
    self._setup_actor_critic_agent(ppo_cfg)

    self.agent.load_state_dict(ckpt_dict["state_dict"])
    self.actor_critic = self.agent.actor_critic

    # get name of performance metric, e.g. "spl"
    metric_name = config.TASK_CONFIG.TASK.MEASUREMENTS[0]
    metric_cfg = getattr(config.TASK_CONFIG.TASK, metric_name)
    measure_type = baseline_registry.get_measure(metric_cfg.TYPE)
    assert measure_type is not None, "invalid measurement type {}".format(
        metric_cfg.TYPE
    )
    self.metric_uuid = measure_type._get_uuid()

    observations = self.envs.reset()
    batch = batch_obs(observations, self.device)

    current_episode_reward = torch.zeros(
        self.envs.num_envs, 1, device=self.device
    )

    test_recurrent_hidden_states = torch.zeros(
        self.actor_critic.net.num_recurrent_layers,
        config.NUM_PROCESSES,
        ppo_cfg.hidden_size,
        device=self.device,
    )
    prev_actions = torch.zeros(
        config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long
    )
    not_done_masks = torch.zeros(
        config.NUM_PROCESSES, 1, device=self.device
    )
    stats_episodes = dict()  # dict of dicts that stores stats per episode

    rgb_frames = [
        [] for _ in range(config.NUM_PROCESSES)
    ]  # type: List[List[np.ndarray]]
    if len(config.VIDEO_OPTION) > 0:
        os.makedirs(config.VIDEO_DIR, exist_ok=True)

    pbar = tqdm.tqdm(total=config.TEST_EPISODE_COUNT)

    self.actor_critic.eval()

    while (
        len(stats_episodes) < config.TEST_EPISODE_COUNT
        and self.envs.num_envs > 0
    ):
        current_episodes = self.envs.current_episodes()

        with torch.no_grad():
            (
                _,
                actions,
                _,
                test_recurrent_hidden_states,
            ) = self.actor_critic.act(
                batch,
                test_recurrent_hidden_states,
                prev_actions,
                not_done_masks,
                deterministic=False,
            )

            prev_actions.copy_(actions)

        outputs = self.envs.step([a[0].item() for a in actions])

        observations, rewards, dones, infos = [
            list(x) for x in zip(*outputs)
        ]
        batch = batch_obs(observations, self.device)

        not_done_masks = torch.tensor(
            [[0.0] if done else [1.0] for done in dones],
            dtype=torch.float,
            device=self.device,
        )

        rewards = torch.tensor(
            rewards, dtype=torch.float, device=self.device
        ).unsqueeze(1)

        current_episode_reward += rewards
        next_episodes = self.envs.current_episodes()
        envs_to_pause = []
        n_envs = self.envs.num_envs

        for i in range(n_envs):
            if (
                next_episodes[i].scene_id,
                next_episodes[i].episode_id,
            ) in stats_episodes:
                envs_to_pause.append(i)

            # episode ended
            if not_done_masks[i].item() == 0:
                pbar.update()
                episode_stats = dict()
                episode_stats[self.metric_uuid] = infos[i][
                    self.metric_uuid
                ]
                episode_stats["success"] = int(
                    infos[i][self.metric_uuid] < 1.0
                )
                episode_stats["reward"] = current_episode_reward[i].item()
                current_episode_reward[i] = 0
                # use scene_id + episode_id as unique id for storing stats
                stats_episodes[
                    (
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )
                ] = episode_stats

                if len(config.VIDEO_OPTION) > 0:
                    generate_video(
                        video_option=config.VIDEO_OPTION,
                        video_dir=config.VIDEO_DIR,
                        images=rgb_frames[i],
                        episode_id=current_episodes[i].episode_id,
                        checkpoint_idx=checkpoint_index,
                        metric_name=self.metric_uuid,
                        metric_value=infos[i][self.metric_uuid],
                        tb_writer=writer,
                    )

                    rgb_frames[i] = []

            # episode continues
            elif len(config.VIDEO_OPTION) > 0:
                frame = observations_to_image(observations[i], infos[i])
                rgb_frames[i].append(frame)

        (
            self.envs,
            test_recurrent_hidden_states,
            not_done_masks,
            current_episode_reward,
            prev_actions,
            batch,
            rgb_frames,
        ) = self._pause_envs(
            envs_to_pause,
            self.envs,
            test_recurrent_hidden_states,
            not_done_masks,
            current_episode_reward,
            prev_actions,
            batch,
            rgb_frames,
        )

    aggregated_stats = dict()
    for stat_key in next(iter(stats_episodes.values())).keys():
        aggregated_stats[stat_key] = sum(
            [v[stat_key] for v in stats_episodes.values()]
        )
    num_episodes = len(stats_episodes)

    episode_reward_mean = aggregated_stats["reward"] / num_episodes
    episode_metric_mean = aggregated_stats[self.metric_uuid] / num_episodes
    episode_success_mean = aggregated_stats["success"] / num_episodes

    logger.info(f"Average episode reward: {episode_reward_mean:.6f}")
    logger.info(f"Average episode success: {episode_success_mean:.6f}")
    logger.info(
        f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}"
    )

    self.envs.close()