In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
from collections import defaultdict
import statistics
from torch.distributions.categorical import Categorical
from IPython import display
import copy
from torch.utils.data import DataLoader
from tqdm import trange
import torch.nn.functional as F
from tqdm import tqdm
from scipy import spatial
import math
import cv2
from matplotlib import colors
from pathlib import Path
from tqdm import trange
from ipywidgets import interact
from collections import defaultdict
from src.env import ACTION_MAPPER
from src.rl_utils import get_project_folder

In [None]:
%run main.py --load-version 0 --n-rollouts 10 --n-samples 200_000 --epochs 300 --reg-coef 5e-2  \
    --num-envs 10 --n-clf 25 --clf-type "none" --mode "nmf" --algo "ppo" --env-id PongNoFrameskip-v4 --learning-rate 5e-3 --n-experts 4 \
        --n-concepts 2 --target-layer 3 --batch-size 512 --val-interval 10 --max-patience 5 --ccp-alpha 5e-5 --agent-version 0 \
            --save-start 0

In [None]:
if GLOBAL_INFO["args"].clf_type == "dt":
    model = sorted(GLOBAL_INFO["clf"], key=lambda x: (
        x[1], -x[0].clf.get_n_leaves()))[-1][0]
elif GLOBAL_INFO["args"].clf_type == "none":
    model = GLOBAL_INFO["clf"][0]
else:
    model = sorted(GLOBAL_INFO["clf"], key=lambda x: x[1])[-1]
    clf = model[0].clf
model = model[0]

## Create Concept-based Explanations

In [None]:
def evaluate_im(args, fabric, envs, model, num_eps: int):
    env_idx = 0
    model.eval()
    fabric.seed_everything(args.seed)
    result = defaultdict(list)

    next_obs = torch.tensor(envs.reset(seed=args.seed)[0],
                            device=fabric.device)

    for _ in trange(2_000):
        with torch.no_grad():
            action, logits, _, _, _ = model.agent.get_action_and_value(
                next_obs)

        result["obs"].append(next_obs[env_idx].numpy(
            force=True)[np.newaxis, ...])
        result["logits"].append(
            logits[env_idx].numpy(force=True)[np.newaxis, ...])

        if envs.envs[0].spec.id == "CarRacing-v2":
            img = cv2.resize(envs.envs[0].render(), dsize=(96, 96))[:84, 6:90][
                np.newaxis, ...]
        elif envs.envs[0].spec.id == "PongNoFrameskip-v4":
            img = cv2.resize(envs.envs[0].render()[34:-16], dsize=(84, 84))[
                np.newaxis, ...]
        elif envs.envs[0].spec.id == "MsPacmanNoFrameskip-v4":
            img = cv2.resize(envs.envs[0].render()[:-39], dsize=(84, 84))[
                np.newaxis, ...]
        elif envs.envs[0].spec.id == "BreakoutNoFrameskip-v4":
            img = cv2.resize(envs.envs[0].render()[32:, 8:-8], dsize=(84, 84))[
                np.newaxis, ...]
        else:
            raise NotImplementedError
        result["img"].append(img)

        next_obs, _, _, _, infos = envs.step(action.numpy(force=True))
        next_obs = torch.tensor(next_obs, device=fabric.device)

        result["action"].append(action[env_idx].numpy(force=True)[np.newaxis])

        if "final_info" not in infos:
            continue

        for info in infos["final_info"]:
            # Skip the envs that are not done
            if info is None or "episode" not in info:
                continue
            result["return"].append(info["episode"]["r"])

    result = {key: np.concatenate(value, axis=0)
              for key, value in result.items()}
    return result

In [None]:
im_results = evaluate_im(
    GLOBAL_INFO["args"], GLOBAL_INFO["fabric"], GLOBAL_INFO["envs"], model, GLOBAL_INFO["args"].num_envs)

In [None]:
folder = get_project_folder(
) / f"experiment-data/bb/states/{GLOBAL_INFO['args'].env_id}"
if not folder.exists():
    folder.mkdir(parents=True)


img = im_results["img"]

for i in range(len(img)):
    fig, ax = plt.subplots(layout="constrained")
    ax.imshow(img[i])
    ax.axis("off")

    fig.savefig(
        folder / f"fig{i}__action={im_results['action'][i]}.svg", bbox_inches='tight', pad_inches=0)
    plt.close()