In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from virtual_rodent.environment import MAPPER
from virtual_rodent import VISION_DIM, PROPRI_DIM, ACTION_DIM
from virtual_rodent.network.vision_enc import ResNet18Enc
from virtual_rodent.network.propri_enc import MLPEnc
import virtual_rodent.network.Merel2019 as Merel2019
from virtual_rodent.utils import load_checkpoint

In [3]:
vision_enc = ResNet18Enc()
vision_emb_dim = vision_enc.get_emb_dim(VISION_DIM)

propri_emb_dim = 20 # propri_dim
propri_enc = MLPEnc(PROPRI_DIM[0], propri_emb_dim, hidden_dims=(50,))

critic_in_dim = vision_emb_dim + propri_emb_dim
critic = Merel2019.Critic(critic_in_dim)

actor_in_dim = critic_in_dim + PROPRI_DIM[0] + critic.hidden_dim
actor = Merel2019.Actor(actor_in_dim, ACTION_DIM, logit_scale=1)

model = Merel2019.MerelModel(vision_enc, propri_enc, VISION_DIM, PROPRI_DIM, 
                             actor, critic, ACTION_DIM) 

In [4]:
# state_dict = torch.load('./results/weights1000.pth', weights_only=True)
# model.load_state_dict(state_dict)

In [58]:
env_name = 'gaps'
# env_name = 'maze'
env, propri_attr = MAPPER[env_name](physics_dt=0.002, ctrl_dt=0.02)

In [77]:
import time
from virtual_rodent.simulation import get_vision, get_propri
def simulate(env, model, propri_attr, max_step, device, reset=True, time_step=None,
             ext_cam=(0,), ext_cam_size=(200, 200)):
    """Simulate until stop criteron is met
    """
    start_time = time.time()

    returns = dict(vision=[], propri=[], action=[], reward=[], log_prob=[], value=[], touch=[])
    returns.update(dict({f'cam{i}': [] for i in ext_cam}))
    
    if reset:
        time_step = env.reset()
        if hasattr(model, 'reset_rnn'):
            model.reset_rnn()
    else:
        if time_step is None:
            raise ValueError('`time_step` must be given if not reset.')

    action_spec = env.action_spec()

    for step in range(max_step):
        if time_step.last():
            break
        # Get state, reward and discount
        vision = torch.from_numpy(get_vision(time_step)).to(device)
        propri = torch.from_numpy(get_propri(time_step, propri_attr)).to(device)

        value, (action, log_prob, _) = model(vision=vision, propri=propri)

        time_step = env.step(np.clip(action.detach().cpu().squeeze().numpy(), 
                                     action_spec.minimum, action_spec.maximum))

        # Record state t, action t, reward t and done t+1; reward at start is 0
        returns['vision'].append(vision)
        returns['propri'].append(propri)
        returns['action'].append(action)
        returns['reward'].append(torch.tensor(time_step.reward))
        returns['log_prob'].append(log_prob)
        returns['value'].append(value)
        for i in ext_cam:
            cam = env.physics.render(camera_id=i, 
                    height=ext_cam_size[0], width=ext_cam_size[1])
            returns[f'cam{i}'].append(cam)

    end_time = time.time()
    returns['time'] = end_time - start_time
    returns['T'] = step 
    return returns

In [81]:
from virtual_rodent.visualization import video
# from virtual_rodent.simulation import simulate
ext_cam = (0,)
save_dir = './'
ext_cam_size = (200, 200)
with torch.no_grad():
    ret = simulate(env, model, propri_attr, max_step=100, device=torch.device('cpu'), ext_cam=ext_cam)
for i in ext_cam:
    anim = video(ret[f'cam{i}'])
    fname = f'demo_{env_name}_cam{i}.gif'
    anim.save(os.path.join(save_dir, fname), writer='pillow')