In [1]:
import os, sys
sys.path.append("..")
import torch

from mune import networks 

### Parameters

In [2]:
proprio_sensors_dim = 4
action_dim = 4
vision_dim = [3, 64, 64]
rnn_hidden_dim = 200

### Define networks

In [3]:
proprio_encoder = networks.ProprioEncoder(in_features=proprio_sensors_dim)
vision_encoder = networks.VisionEncoder()
fusion_layer = networks.FusionLayer(in_features=proprio_encoder.output_size+vision_encoder.output_size)

In [4]:
rssm = networks.Rssm(state_dim=fusion_layer.output_size,
                     action_dim=action_dim,
                     rnn_hidden_dim=rnn_hidden_dim)

In [5]:
proprio_decoder = networks.ProprioDecoder(state_dim=fusion_layer.output_size,
                                          rnn_hidden_dim=rnn_hidden_dim,
                                          out_features=proprio_sensors_dim)

In [6]:
vision_decoder = networks.VisionDecoder(state_dim=fusion_layer.output_size,
                                        rnn_hidden_dim=rnn_hidden_dim)

### Define inputs

In [7]:
x_proprio = torch.randn((1, proprio_sensors_dim))
x_vision = torch.randn([1] + vision_dim)
action = torch.randn((1, action_dim))
rnn_hidden = torch.randn((1, rnn_hidden_dim))

### Encode observation into features

In [8]:
h_proprio = proprio_encoder(x_proprio)
h_vision = vision_encoder(x_vision)
fused_state = fusion_layer(h_proprio, h_vision)

print("Fused state size : ", fused_state.size())

Fused state size :  torch.Size([1, 256])


### One-step imagination

#### Future prediction

In [9]:
pred_output, rnn_hidden = rssm.prior(state=fused_state,
                                     action=action,
                                     rnn_hidden=rnn_hidden)

print("Prediction output size : ", pred_output.size())
print("Rnn hidden size : ", rnn_hidden.size())

Prediction output size :  torch.Size([1, 200])
Rnn hidden size :  torch.Size([1, 200])


#### Reconstruction

In [10]:
recon_proprio = proprio_decoder(fused_state, pred_output)
recon_vision = vision_decoder(fused_state, pred_output)

print("Proprioception reconstruction size : ", recon_proprio.size())
print("Vision reconstruction size : ", recon_vision.size())

Proprioception reconstruction size :  torch.Size([1, 4])
Vision reconstruction size :  torch.Size([1, 3, 64, 64])
