In [20]:
import os, sys
sys.path.append("..")
import torch
from torch.nn import functional as F

from mune import networks 

### Parameters

In [2]:
proprio_sensors_dim = 4
action_dim = 4
vision_dim = [3, 64, 64]
rnn_hidden_dim = 200
stoch_dim = 30

### Define networks

In [3]:
proprio_encoder = networks.ProprioEncoder(in_features=proprio_sensors_dim)
vision_encoder = networks.VisionEncoder()
fusion_layer = networks.FusionLayer(in_features=proprio_encoder.output_size+vision_encoder.output_size)

In [4]:
rssm = networks.Rssm(emb_dim=fusion_layer.output_size,
                     action_dim=action_dim,
                     rnn_hidden_dim=rnn_hidden_dim,
                    stoch_dim=stoch_dim)

In [7]:
proprio_decoder = networks.ProprioDecoder(determ_state_dim = rnn_hidden_dim, 
                                          stoch_state_dim = stoch_dim,
                                          out_features=proprio_sensors_dim)

In [9]:
vision_decoder = networks.VisionDecoder(determ_state_dim = rnn_hidden_dim, 
                                        stoch_state_dim = stoch_dim)

### Define inputs

In [10]:
x_proprio = torch.randn((1, proprio_sensors_dim))
x_vision = torch.randn([1] + vision_dim)
action = torch.randn((1, action_dim))
rnn_hidden = torch.randn((1, rnn_hidden_dim))
post_state = torch.randn((1, stoch_dim))

### Encode observation into features

In [11]:
h_proprio = proprio_encoder(x_proprio)
h_vision = vision_encoder(x_vision)
fused_state = fusion_layer(h_proprio, h_vision)

print("Fused state size : ", fused_state.size())

Fused state size :  torch.Size([1, 256])


### One-step imagination

#### Future prediction

In [12]:
preds = rssm(embed_state=fused_state, 
                               prev_action=action, 
                               prev_post_state=post_state, 
                               prev_hidden_state=rnn_hidden)

In [15]:
determ_state = preds["determ_state"]
stoch_state = preds["stoch_state"]

#### Reconstruction

In [16]:
recon_proprio = proprio_decoder(determ_state, stoch_state)
recon_vision = vision_decoder(determ_state, stoch_state)

print("Proprioception reconstruction size : ", recon_proprio.size())
print("Vision reconstruction size : ", recon_vision.size())

Proprioception reconstruction size :  torch.Size([1, 4])
Vision reconstruction size :  torch.Size([1, 3, 64, 64])


### Losses

reconstruction vision : sum over image, then mean over batch / time 

In [34]:
recon_vision_loss = F.mse_loss(recon_vision, x_vision, reduction='none').sum((1, 2, 3)).mean(0)

In [36]:
recon_vision_loss

tensor(12171.2266, grad_fn=<MeanBackward1>)