In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from virtual_rodent.environment import MAPPER
from virtual_rodent import VISION_DIM, PROPRI_DIM, ACTION_DIM
from virtual_rodent.network.vision_enc import ResNet18Enc
from virtual_rodent.network.propri_enc import MLPEnc
import virtual_rodent.network.Merel2019 as Merel2019
from virtual_rodent.utils import load_checkpoint

In [3]:
vision_enc = ResNet18Enc()
vision_emb_dim = vision_enc.get_emb_dim(VISION_DIM)

propri_emb_dim = 20 # propri_dim
propri_enc = MLPEnc(PROPRI_DIM[0], propri_emb_dim, hidden_dims=(50,))

critic_in_dim = vision_emb_dim + propri_emb_dim
critic = Merel2019.Critic(critic_in_dim)

actor_in_dim = critic_in_dim + PROPRI_DIM[0] + critic.hidden_dim
actor = Merel2019.Actor(actor_in_dim, ACTION_DIM, logit_scale=1)

model = Merel2019.MerelModel(vision_enc, propri_enc, VISION_DIM, PROPRI_DIM, 
                             actor, critic, ACTION_DIM) 

In [4]:
# state_dict = torch.load('./results/weights600.pth', weights_only=True)
# model.load_state_dict(state_dict)

In [5]:
model.actor.mode = 'test'

In [6]:
env_name = 'gaps'
# env_name = 'maze'
env, propri_attr = MAPPER[env_name](physics_dt=0.002, ctrl_dt=0.02)

In [7]:
from virtual_rodent.visualization import video
from virtual_rodent.simulation import simulate
ext_cam = (0,)
save_dir = './'
ext_cam_size = (200, 200)
with torch.no_grad():
    ret = simulate(env, model, propri_attr, max_step=100, device=torch.device('cpu'), 
                   ext_cam=ext_cam, train=True)
for i in ext_cam:
    anim = video(ret[f'cam{i}'])
    fname = f'demo_{env_name}_cam{i}.gif'
    anim.save(os.path.join(save_dir, fname), writer='pillow')

In [8]:
ret['action']

[tensor([[[ 0.6013,  0.6785,  0.1034,  0.0672, -0.0227, -0.3955, -0.3123,
            0.0903, -0.7544,  0.2798,  0.6499, -0.4419, -0.0799,  0.7046,
           -1.1489]]]),
 tensor([[[-0.4169,  0.0491,  0.2552,  0.3856,  0.5906,  0.8359,  0.5504,
           -0.0274, -0.4384, -0.0359, -0.6327, -0.5479,  0.3698,  0.1866,
           -0.2980]]]),
 tensor([[[ 0.7486,  0.1615,  0.8470,  0.3633,  0.7092, -0.3759, -0.9847,
            0.2784, -0.3791,  0.6106,  0.1590,  0.5559,  0.5292, -0.3293,
           -0.4418]]]),
 tensor([[[ 0.0138, -0.5064, -1.0270,  0.4625,  0.1696,  0.7072, -1.5648,
            0.9563, -0.1411,  0.4229, -0.6984, -0.4092, -0.2352,  0.4008,
           -0.7544]]]),
 tensor([[[ 0.3257,  0.0396, -0.2481, -0.6394, -0.3637, -0.2602,  0.1436,
           -0.8160, -0.2955,  0.8312,  0.2237,  0.0135, -0.2202,  1.2267,
            0.1286]]]),
 tensor([[[-0.0500,  0.1770, -1.1487,  0.2805,  0.3228,  0.1161,  0.2158,
            0.6621, -0.1464,  0.4529, -0.1170, -0.5696, -0.8591, -