In [None]:
import os
import re
import yaml
from matplotlib import pyplot as plt
plt.switch_backend("agg")
import numpy as np
import torch
import cv2
from rl import RL
from atari import AtariEnv
from experiment import DQNExperiment


root_dir = r"./results/pong0"
params = yaml.safe_load(open(os.path.join(root_dir, "config.yaml"), 'r'))

# outputs
if not os.path.exists(os.path.join(root_dir, 'figs')):
    os.mkdir(os.path.join(root_dir, 'figs'))
    for item in ['game', 'mu', 'gqtl', 'mugq', 'obs']:
        f = os.path.join(root_dir, 'figs', item)
        if not os.path.exists(f):
            os.mkdir(f)

np.random.seed(params['random_seed'])
torch.manual_seed(params['random_seed'])
random_state = np.random.RandomState(params['random_seed'])
device = torch.device("cpu")

env = AtariEnv(game_name=params['game_name'], rendering=False, sticky_actions=False, frame_skip=params['frame_skip'], 
               terminal_on_life_loss=params['terminal_on_life_loss'], screen_size=params['screen_size'])

network_size = params['network_size']
ai = RL(state_shape=env.state_shape, nb_actions=env.nb_actions, action_dim=params['action_dim'],
        reward_dim=params['reward_dim'], history_len=params['history_len'], gamma=params['gamma'], 
        learning_rate=params['learning_rate'], epsilon=params['epsilon'], final_epsilon=params['final_epsilon'],
        test_epsilon=params['test_epsilon'], annealing_steps=params['annealing_steps'], minibatch_size=params['minibatch_size'],
        replay_max_size=params['replay_max_size'], update_freq=params['update_freq'], 
        learning_frequency=params['learning_frequency'], ddqn=params['ddqn'], network_size=network_size, 
        normalize=params['normalize'], event=params['event'], sided_Q=params['sided_Q'], rng=random_state, device="cpu")

network_weights_file = os.path.join(root_dir, 'ai/q_network_weights_201.pt')
ai.load_weights(weights_file_path=network_weights_file)

expt = DQNExperiment(env=env, ai=ai, episode_max_len=params['episode_max_len'], annealing=params['annealing'],
                        history_len=params['history_len'], max_start_nullops=params['max_start_nullops'],
                        replay_min_size=params['replay_min_size'], test_epsilon=params['test_epsilon'], 
                        folder_location=params['folder_location'], folder_name=params['folder_name'], score_window_size=100, 
                        make_folder=False, rng=random_state)


def tensor_linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
    steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
    for _ in range(start.ndim):
        steps = steps.unsqueeze(-1)

    out = start[None] + steps*(stop - start)[None]
    return out


def g_computer(x, x2, model, multiplier, compute_steps, filter_th):
    """
    outputs: 
    `result`: component-wise g-formula 
    `measure`: sum(results) == expt. change of grit/reachability computed from g-formula
    `ttl_grads`: component-wise integral of grads from x to x2
    """
    x = torch.Tensor(x)
    x2 = torch.Tensor(x2)
    states = tensor_linspace(x, x2, compute_steps)    # -> make it numpy
    result = torch.zeros_like(x)
    ttl_grads = torch.zeros_like(x)
    measure = 0
    for i in range(compute_steps-1):
        s = states[i].detach()
        s2 = states[i+1].detach()
        max_a = model.get_max_action(s)
        max_a2 = model.get_max_action(s2)
        grad = multiplier * model.get_grad(s)[max_a].squeeze()
        grad2 = multiplier * model.get_grad(s2)[max_a2].squeeze()
        grad[abs(grad) < filter_th] = 0  # denoising
        grad2[abs(grad2) < filter_th] = 0 
        mu = (s2 - s) / model.normalize  # TODO get rid of all the ai.normalize (should instead be done at env)
        grads = 0.5 * (grad + grad2)
        y = torch.mul(mu, grads.squeeze(dim=0))  # Note: dt is cancelled: denominator of mu and multiplication in the integral
        result += y.detach()
        measure += torch.sum(y.detach())
        ttl_grads += grads.detach()
    return result, measure, ttl_grads

In [None]:
GRAD_FILTER_TH = 0.1

d_grit = []
d_grit_decomp = []
grits = []
max_actions = []

print("Is grit: ", expt.ai.sided_Q)
multiplier = -1 if expt.ai.sided_Q == 'grit' else 1

expt.env.reset()
expt._episode_reset()
game_over = False
max_steps = 70
for step in range(max_steps):
    if game_over:
        print("Game-over before max steps.")
        break

    plt.close("all")
    expt.last_episode_steps += 1
    
    if step <= 60:  # some no-ops
        if step % 10 == 0:
            print("step:", step)
        action = 0
        grit = multiplier * expt.ai.get_q(expt.last_state)[0]
        max_a = expt.ai.get_max_action(expt.last_state)
    else:
        grit = multiplier * expt.ai.get_q(expt.last_state)[0]
        max_a = expt.ai.get_max_action(expt.last_state)
        ###### Uncomment the following lines to manually control the game
        # print('Q >> ', grit)
        # print('Greedy action: ', max_a)
        # print('--')
        # action = input('action >> ')
        # # action = 0
        # # action = max_a
        # action = int(action)
        # if action >= expt.env.nb_actions:
        #     print('Unknown action.')
        #     continue
    
    grits.append(grit)
    max_actions.append(max_a)
    grit = grit[max_a]

    new_obs, reward, game_over, _ = expt.env.step(action)
    expt.env.save_img(os.path.join(root_dir, 'figs', 'game', 'pong'+str(expt.last_episode_steps)+".png"))

    prev_state = expt.last_state.copy()
    expt._update_state(new_obs)  # --> expt.last_state is now x'

    mu = (expt.last_state - prev_state) / expt.ai.normalize  # x_k - x_{k-1} --> for plotting
    
    g, measure, grad_total = g_computer(x=prev_state, x2=expt.last_state, model=expt.ai, multiplier=multiplier,
                                        compute_steps=10, filter_th=GRAD_FILTER_TH)
    
    torch.clamp_(g, min=0, max=1)  # only the causal ones (ignoring negative g) for plotting
    d_grit_decomp.append(measure)     # change of grit computed via Decomp Lemma
    grit2 = multiplier * np.max(expt.ai.get_q(expt.last_state)[0])
    d_grit.append(grit2 - grit)       # actual change of grit

    ## PLOTS:

    ## obs:
    fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5), dpi=300)
    ax.imshow(new_obs, cmap="gray")
    ax.xaxis.set_ticks_position('none')
    ax.yaxis.set_ticks_position('none')
    _ = ax.set_xticklabels([])
    _ = ax.set_yticklabels([])
    fig.tight_layout(pad=0, h_pad=0, w_pad=0)
    fig.savefig(os.path.join(root_dir, 'figs', 'obs', 'obs__'+str(expt.last_episode_steps)+".png"))

    ## mu:
    x = np.flip(mu[3, :, :], axis=0)
    fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5), dpi=300)
    plt.pcolor(x, vmin=-1, vmax=1, cmap="seismic")
    ax.xaxis.set_ticks_position('none')
    ax.yaxis.set_ticks_position('none')
    _ = ax.set_xticklabels([])
    _ = ax.set_yticklabels([])
    fig.tight_layout(pad=0, h_pad=0, w_pad=0)
    fig.savefig(os.path.join(root_dir, 'figs', 'mu', 'mu__'+str(expt.last_episode_steps)+".png"))

    ## total grad:
    x = np.flip(grad_total[3, :, :].numpy(), axis=0)
    fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5), dpi=300)
    plt.pcolor(x, vmin=-1, vmax=1, cmap="seismic")
    ax.xaxis.set_ticks_position('none')
    ax.yaxis.set_ticks_position('none')
    _ = ax.set_xticklabels([])
    _ = ax.set_yticklabels([])
    fig.tight_layout(pad=0, h_pad=0, w_pad=0)
    fig.savefig(os.path.join(root_dir, 'figs', 'gqtl', 'gqtl'+str(expt.last_episode_steps)+".png"))

    ## g:
    x = np.flip(g[3, :, :].numpy(), axis=0)
    fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5), dpi=300)
    plt.pcolor(x, vmin=-0.05, vmax=0.05, cmap="seismic")
    ax.xaxis.set_ticks_position('none')
    ax.yaxis.set_ticks_position('none')
    _ = ax.set_xticklabels([])
    _ = ax.set_yticklabels([])
    fig.tight_layout(pad=0, h_pad=0, w_pad=0)
    fig.savefig(os.path.join(root_dir, 'figs', 'mugq', 'mugq'+str(expt.last_episode_steps)+".png"))

    plt.close("all")
    if not game_over and expt.last_episode_steps >= expt.episode_max_len:
        print('Reaching maximum number of steps in the current episode.')
        game_over = True

In [None]:
atoi = lambda text: int(text) if text.isdigit() else text
natural_keys = lambda text: [ atoi(c) for c in re.split(r'(\d+)', text) ]

def make(image_folder):
    subfolders = ['game', 'mu', 'gqtl', 'mugq', 'obs']
    images = dict()
    for sf in subfolders:
        f = os.path.join(image_folder, sf)
        images[sf] = [img for img in os.listdir(f) if img.endswith(".jpg") or img.endswith("png")]
        images[sf].sort(key=natural_keys)
        images[sf] = [cv2.imread(os.path.join(f, item)) for item in images[sf]]
    return images


fig_dir = os.path.join(root_dir, 'figs')
frames = make(fig_dir)
min_grit = [np.min(k) for k in grits]
num_frames = 11  # 8
start = 46  # 41, 49

fig, axs = plt.subplots(5, num_frames, figsize=(9, 4), dpi=300)  # (6.5, 4)

for i in range(num_frames):
    # axs[0, i].imshow(np.flip(frames['game'][start + i], axis=2))
    axs[0, i].imshow(frames['obs'][start + i])
    axs[1, i].imshow(np.flip(frames['mu'][start + i], axis=2))
    axs[2, i].imshow(np.flip(frames['gqtl'][start + i], axis=2))
    axs[3, i].imshow(np.flip(frames['mugq'][start + i], axis=2))
    axs[4, i].axhline(y=0, color="lightgrey", ls="-", lw=0.5)
    axs[4, i].axhline(y=0.25, color="lightgrey", ls="-", lw=0.5)
    axs[4, i].axhline(y=0.5, color="lightgrey", ls="-", lw=0.5)
    axs[4, i].axhline(y=0.75, color="lightgrey", ls="-", lw=0.5)
    axs[4, i].axhline(y=1, color="lightgrey", ls="-", lw=0.5)
    axs[4, i].plot(min_grit[start + i], "ko", ms=3)
    axs[4, i].bar(0, min_grit[start + i], width=0.2, color="orange")
    for ax in axs[:, i]:
        ax.xaxis.set_ticks_position('none')
        ax.yaxis.set_ticks_position('none')
        _ = ax.set_xticklabels([])
        _ = ax.set_yticklabels([])
        ax.margins(0)
    # axs[4, i].set_frame_on(False)
    axs[4, i].set_xlim(-0.5,0.5)
    axs[4, i].set_ylim(-0.02,1.02)
# fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
fig.tight_layout(pad=0.08, h_pad=0.2, w_pad=0.2)
fig.savefig(os.path.join(fig_dir, "full_" + str(start) + "_" + str(start+num_frames-1) + ".pdf"))
plt.close("all")

In [None]:
min_grit = [np.min(k) for k in grits]

num_frames = len(frames["game"])
start = 0

if not os.path.exists(os.path.join(fig_dir, 'video')):
    os.mkdir(os.path.join(fig_dir, 'video'))

for i in range(num_frames):
    fig, axs = plt.subplots(2, 3, figsize=(6, 4), dpi=300)
    axs[0, 0].imshow(np.flip(frames['game'][start + i], axis=2))
    axs[0, 1].imshow(frames['obs'][start + i])
    axs[0, 2].imshow(np.flip(frames['mu'][start + i], axis=2))
    axs[1, 0].imshow(np.flip(frames['gqtl'][start + i], axis=2))
    axs[1, 1].imshow(np.flip(frames['mugq'][start + i], axis=2))
    axs[1, 2].axhline(y=0, color="lightgrey", ls="-", lw=0.5)
    axs[1, 2].axhline(y=0.25, color="lightgrey", ls="-", lw=0.5)
    axs[1, 2].axhline(y=0.5, color="lightgrey", ls="-", lw=0.5)
    axs[1, 2].axhline(y=0.75, color="lightgrey", ls="-", lw=0.5)
    axs[1, 2].axhline(y=1, color="lightgrey", ls="-", lw=0.5)
    axs[1, 2].plot(min_grit[start + i], "ko", ms=3)
    axs[1, 2].bar(0, min_grit[start + i], width=0.2, color="orange")
    for ax in axs.ravel():
        ax.xaxis.set_ticks_position('none')
        ax.yaxis.set_ticks_position('none')
        _ = ax.set_xticklabels([])
        _ = ax.set_yticklabels([])
        ax.margins(0)
    axs[0, 0].set_frame_on(False)
    axs[1, 2].set_xlim(-0.5,0.5)
    axs[1, 2].set_ylim(-0.02,1.02)
    # fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
    fig.tight_layout(pad=0.3, h_pad=0.5, w_pad=0.5)
    fig.savefig(os.path.join(fig_dir, 'video', "v_" + str(i) + ".png"))
    plt.close("all")