## Dependencies

In [None]:
import random

import torch

import context_changers
import ct_model
import dmc
import drqv2
import utils
import numpy as np
import rl_model

import imageio
from matplotlib import pyplot as plt

## Hyperparameters

In [None]:
task_name = 'reacher_hard'
expert_frame_stack = 3
frame_stack = 1
action_repeat = 2
seed = 432335
xml_path = 'domain_xmls/reacher.xml'
episode_len = 30
context_camera_ids = [0]
learner_camera_id = 0
im_w = 64
im_h = 64
n_video = 64
cam_id = random.choice(context_camera_ids)

In [None]:
utils.set_seed_everywhere(seed)

## Loading of the trained models

In [None]:
expert: drqv2.DrQV2Agent = drqv2.DrQV2Agent.load('experts/reacher_hard.pt')
expert.train(training=False)

context_translator: ct_model.CTNet = ct_model.CTNet.load('ct/reacher_hard.pt').to(utils.device())
context_translator.eval()

agent: rl_model.ACAgent = rl_model.ACAgent.load('ac/reacher_hard.pt').to(utils.device())

## Loading and wrapping of the environment

In [None]:
expert_env = dmc.make(task_name, expert_frame_stack, action_repeat, seed, xml_path, episode_len=episode_len)
context_changer = context_changers.ReacherHardContextChanger()

eval_env = dmc.make(task_name, frame_stack, action_repeat, seed + 1, xml_path, learner_camera_id, im_w, im_h, context_changers.ReacherHardContextChanger(), episode_len)

In [None]:
def make_expert_video():
    with torch.no_grad():
        videos = []
        for _ in range(n_video):
            context_changer.reset()

            cam_id = random.choice(context_camera_ids)
            episode = []
            time_step = expert_env.reset()

            with utils.change_context(expert_env, context_changer):
                episode.append(expert_env.physics.render(im_w, im_h, camera_id=cam_id))
            while not time_step.last():
                action = expert.act(time_step.observation, 1, eval_mode=True)
                time_step = expert_env.step(action)
                with utils.change_context(expert_env, context_changer):
                    episode.append(expert_env.physics.render(im_w, im_h, camera_id=cam_id))
            videos.append(episode)
        videos = np.array(videos, dtype=np.uint8)  # n_video x T x h x w x c
        videos = videos.transpose((0, 1, 4, 2, 3))  # n_video x T x c x h x w
    return videos

def predict_avg_states_frames(fobs):
    expert_videos = make_expert_video()
    with torch.no_grad():
        states = []
        frames = []

        fobs = torch.tensor(fobs, device=utils.device(), dtype=torch.float)
        expert_videos = torch.tensor(expert_videos, device=utils.device(), dtype=torch.float)
        for expert_video in expert_videos:
            state, frame = context_translator.translate(expert_video, fobs, keep_enc2=True)
            states.append(state)
            frames.append(frame)
        states = torch.stack(states)  # n x T x z
        frames = torch.stack(frames)  # n x T x c x h x w

        avg_states = states.mean(dim=0)  # T x z
        avg_frames = frames.mean(dim=0)  # T x c x h x w
        print(frames.flatten(start_dim=1).var(dim=0).sum())

    avg_states = avg_states.cpu().numpy()
    avg_frames = avg_frames.cpu().numpy()

    return expert_videos, avg_states, avg_frames


def change_step_observation(time_step, target_state, target_frame):
    with torch.no_grad():
        obs = torch.tensor(time_step.observation, device=utils.device(), dtype=torch.float)
        state = context_translator.encode(obs.unsqueeze(0))[0].cpu().numpy()
    state = np.concatenate([state, target_state])
    return time_step._replace(observation=state)

## Building of the agent video

In [None]:
agent_video = []

time_step = eval_env.reset()
frame = time_step.observation
expert_videos, avg_states, avg_frames = predict_avg_states_frames(frame)
target_state, target_frame = avg_states[1], avg_frames[1]
time_step = change_step_observation(time_step, target_state, target_frame)
episode_step = 0

agent_video.append(frame.transpose((1, 2, 0)))
while not time_step.last():
    with torch.no_grad(), utils.eval_mode(agent):
        state = torch.tensor(time_step.observation, device=utils.device(), dtype=torch.float)
        action = agent.act(state, 1, eval_mode=True)

    time_step = eval_env.step(action)
    episode_step += 1
    if episode_step + 1 < avg_states.shape[0]:
        target_state = avg_states[episode_step + 1]
        target_frame = avg_frames[episode_step + 1]

    agent_video.append(time_step.observation.transpose((1, 2, 0)))
    time_step = change_step_observation(time_step, target_state, target_frame)


In [None]:
agent_video = np.array(agent_video)
source_video = expert_videos[1].cpu().numpy().transpose((0, 2, 3, 1))
predicted_video = avg_frames.transpose((0, 2, 3, 1))

In [None]:
agent_video.shape

In [None]:
all_video = np.zeros( (source_video.shape[0], source_video.shape[1], source_video.shape[2] * 3, source_video.shape[3]))

all_video[:, :, 0:64, :] = source_video
all_video[:, :, 64:128:, :] = predicted_video
all_video[:, :, 128:, :] = agent_video

## Generation of the final video

The video path is `demo/demo_ifo.mp4`.

In [None]:
imageio.mimwrite('demo/demo_ifo.mp4', all_video, format='mp4', fps=24)