## Dependencies

In [None]:
import random

import torch

import context_changers
import ct_model
import dmc
import drqv2
import utils
import numpy as np

import imageio
from matplotlib import pyplot as plt

## Hyperparameters

In [None]:
task_name = 'reacher_hard' # Name of the task
expert_frame_stack = 3  # Size of the frame stack with which the expert was trained
action_repeat = 2  # Number of action repeat
seed = 3
xml_path = 'domain_xmls/reacher.xml'  # XML of the task with some updates
episode_len = 100  # Length of the episode
context_camera_ids = [0]  # Number of camera
learner_camera_id = 0
im_w = 64
im_h = 64

cam_id = random.choice(context_camera_ids)

In [None]:
utils.set_seed_everywhere(seed)

## Loading of the trained models

In [None]:
expert: drqv2.DrQV2Agent = drqv2.DrQV2Agent.load('experts/reacher_hard.pt')
expert.train(training=False)

context_translator: ct_model.CTNet = ct_model.CTNet.load('ct/reacher_hard.pt').to(utils.device())
context_translator.eval()

## Loading and wrapping of the environment

In [None]:
expert_env = dmc.make(task_name, expert_frame_stack, action_repeat, seed, xml_path, episode_len=episode_len)
context_changer = context_changers.ReacherHardContextChanger()

## Expert video recording

In [None]:
source_video = []
with torch.no_grad():

    time_step = expert_env.reset()
    with utils.change_context(expert_env, context_changer):
        source_video.append(expert_env.physics.render(im_w, im_h, camera_id=cam_id))
    while not time_step.last():
        action = expert.act(time_step.observation, 1, eval_mode=True)
        time_step = expert_env.step(action)
        with utils.change_context(expert_env, context_changer):
            source_video.append(expert_env.physics.render(im_w, im_h, camera_id=cam_id))

source_video = np.array(source_video)

In [None]:
num_frames = 6
_, axes = plt.subplots(nrows=1, ncols=num_frames, figsize=(30, 5))
for i in range(num_frames):
    axes[i].imshow(source_video[i*4])

plt.show()

## Generation of the predicted video

In [None]:
time_step = expert_env.reset()
context_changer.reset()

with utils.change_context(expert_env, context_changer):
    fobs = expert_env.physics.render(im_w, im_h, camera_id=learner_camera_id).copy().transpose((2, 0, 1))
fobs = torch.tensor(fobs, device=utils.device(), dtype=torch.float)
expert_video = torch.tensor(source_video.transpose((0, 3, 1, 2)), device=utils.device(), dtype=torch.float)

state, frame = context_translator.translate(expert_video, fobs)
predicted_video = frame.int().detach().cpu().numpy().transpose((0, 2, 3, 1))

In [None]:
_, axes = plt.subplots(nrows=1, ncols=num_frames, figsize=(30, 5))
for i in range(num_frames):
    axes[i].imshow(predicted_video[i*4])

plt.show()

## Building of the target video

In [None]:
target_video = []
with torch.no_grad():
    with utils.change_context(expert_env, context_changer):
        target_video.append(expert_env.physics.render(im_w, im_h, camera_id=cam_id))
    while not time_step.last():
        action = expert.act(time_step.observation, 1, eval_mode=True)
        time_step = expert_env.step(action)
        with utils.change_context(expert_env, context_changer):
            target_video.append(expert_env.physics.render(im_w, im_h, camera_id=cam_id))

target_video = np.array(target_video)

In [None]:
_, axes = plt.subplots(nrows=1, ncols=num_frames, figsize=(30, 5))
for i in range(num_frames):
    axes[i].imshow(target_video[i*4])

plt.show()

In [None]:
source_video.shape

In [None]:
predicted_video.shape

In [None]:
target_video.shape

In [None]:
all_video = np.zeros( (source_video.shape[0], source_video.shape[1], source_video.shape[2] * 3, source_video.shape[3]))

all_video[:, :, 0:64, :] = source_video
all_video[:, :, 64:128:, :] = predicted_video
all_video[:, :, 128:, :] = target_video

## Generation of the final demo video

The video path is `demo/demo_ct.mp4`

In [None]:
imageio.mimwrite('demo/demo_ct.mp4', all_video, format='mp4', fps=24)