In [1]:
import sys, os
sys.path.append("..")
import torch

from mune.networks import AgentModel

### config

In [2]:
modalities_config = {
    "vision": {
        "type": "vision"
    },
    "proprio": {
        "type": "proprio",
        "in_features": 4
    }
}

In [3]:
config = {
    "action_output_dim": 4,
    "modalities_config": modalities_config,
    "determ_state_dim": 200,
    "stoch_state_dim": 30,
    "min_stddev": 0.1,
    "reward_hidden_dim": 100,
    "reward_n_layers": 2,
    "value_hidden_dim": 100,
    "value_n_layers": 2,
    "action_hidden_dim": 100,
    "action_n_layers": 2
}

In [4]:
agent_model = AgentModel(**config)

In [5]:
agent_model

AgentModel(
  (encoder_modalities): ModuleDict(
    (vision): VisionEncoder(
      (conv_encoder): Sequential(
        (conv1): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2))
        (relu1): ReLU()
        (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
        (relu2): ReLU()
        (conv3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
        (relu3): ReLU()
        (conv4): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
        (relu4): ReLU()
      )
    )
    (proprio): ProprioEncoder(
      (fc): Linear(in_features=4, out_features=32, bias=True)
    )
  )
  (decoder_modalities): ModuleDict(
    (vision): VisionDecoder(
      (fc_deter_stoch): Linear(in_features=230, out_features=1024, bias=True)
      (convt_decoder): Sequential(
        (convt1): ConvTranspose2d(1024, 128, kernel_size=(5, 5), stride=(2, 2))
        (relu1): ReLU()
        (convt2): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
        (relu2): ReLU()
        (convt3): Conv

In [6]:
observations = {
    "vision": torch.randn((1, 3, 64, 64)),
    "proprio": torch.randn((1, 4))
}

prev_action = torch.zeros(1, 4)
prev_action[:, 1] = 1

prev_determ_state = torch.randn((1, 200))
prev_stoch_state = torch.randn((1, 30))

rollout_values = []
for step in range(12):
    output = agent_model(observations=observations, 
                         prev_action=prev_action, 
                         prev_determ_state=prev_determ_state, 
                         prev_stoch_state=prev_stoch_state)
    rollout_values.append(output)
    prev_action = output['pred_action'].logits
    prev_determ_state = output['transitions']['determ_state']
    prev_stoch_state = output['transitions']['stoch_state_posterior']['stoch_state']

In [7]:
output.keys()

dict_keys(['pred_reward', 'pred_action', 'pred_value', 'pred_observations', 'transitions'])

In [8]:
output['pred_action'].logits

tensor([[-1.3598, -1.2564, -1.4358, -1.5111]], grad_fn=<SubBackward0>)

In [9]:
output['transitions']

{'determ_state': tensor([[-0.1610,  0.2567, -0.1118,  0.2478, -0.1000, -0.3470, -0.1523,  0.2782,
           0.2418,  0.0108,  0.0677, -0.1351, -0.0194,  0.0990,  0.1326, -0.0329,
           0.1082,  0.2206,  0.1680,  0.0755, -0.2854,  0.0392,  0.1008, -0.0412,
           0.4411,  0.0633, -0.4529,  0.3051, -0.3048,  0.2041,  0.0408,  0.0260,
          -0.2129, -0.0600,  0.0670, -0.1764,  0.1989,  0.1122, -0.0139, -0.0710,
           0.3072, -0.3006, -0.0211, -0.0022,  0.0881,  0.0523, -0.3669, -0.1972,
          -0.1646,  0.0779, -0.1723,  0.2695, -0.2377, -0.1299,  0.2289,  0.4681,
           0.2451,  0.2806,  0.2071, -0.2395, -0.0756,  0.0377, -0.0299,  0.0622,
           0.1092, -0.1334,  0.2260,  0.2305, -0.1351,  0.0246,  0.2574, -0.1058,
          -0.0774, -0.1229,  0.2091,  0.1196, -0.1252, -0.2136,  0.2892,  0.1243,
          -0.2165, -0.2635,  0.1131,  0.1697, -0.1212,  0.0235, -0.4012, -0.1236,
           0.1458,  0.4456,  0.0113, -0.1152, -0.1768,  0.0511, -0.0257, -0.3544,
