In [1]:
import os, sys
sys.path.append("..")
import torch
from torch.nn import functional as F

from mune import networks 

### Parameters

In [2]:
proprio_sensors_dim = 4
action_dim = 4
vision_dim = [3, 64, 64]
rnn_hidden_dim = 200
stoch_dim = 30

### Define networks

In [3]:
proprio_encoder = networks.ProprioEncoder(in_features=proprio_sensors_dim)
vision_encoder = networks.VisionEncoder()
fusion_layer = networks.FusionLayer(in_features=proprio_encoder.output_size+vision_encoder.output_size)

In [4]:
rssm = networks.Rssm(emb_dim=fusion_layer.output_size,
                     action_dim=action_dim,
                     rnn_hidden_dim=rnn_hidden_dim,
                    stoch_dim=stoch_dim)

In [5]:
proprio_decoder = networks.ProprioDecoder(determ_state_dim = rnn_hidden_dim, 
                                          stoch_state_dim = stoch_dim,
                                          out_features=proprio_sensors_dim)

In [6]:
vision_decoder = networks.VisionDecoder(determ_state_dim = rnn_hidden_dim, 
                                        stoch_state_dim = stoch_dim)

### Define inputs

In [7]:
x_proprio = torch.randn((1, proprio_sensors_dim))
x_vision = torch.randn([1] + vision_dim)
action = torch.randn((1, action_dim))
rnn_hidden = torch.randn((1, rnn_hidden_dim))
post_state = torch.randn((1, stoch_dim))

### Encode observation into features

In [8]:
h_proprio = proprio_encoder(x_proprio)
h_vision = vision_encoder(x_vision)
fused_state = fusion_layer(h_proprio, h_vision)

print("Fused state size : ", fused_state.size())

Fused state size :  torch.Size([1, 256])


### One-step imagination

#### Future prediction

In [9]:
preds = rssm(embed_state=fused_state, 
                               prev_action=action, 
                               prev_post_state=post_state, 
                               prev_hidden_state=rnn_hidden)

In [11]:
preds

{'determ_state': tensor([[ 3.1697e-01, -2.1390e-01,  1.0233e+00,  1.2514e-01, -3.1609e-01,
          -6.9335e-01,  3.3119e-01, -1.8346e-01, -8.0287e-01, -1.0589e-01,
           5.0827e-01, -2.2856e-01,  5.7494e-03,  1.0625e-01, -1.0246e+00,
           2.2749e-02, -1.9430e-01,  5.4993e-02, -3.0332e-01, -5.4308e-02,
          -3.2555e-01, -4.2091e-01, -4.3671e-02,  7.1450e-01, -1.2277e+00,
           1.8485e-01, -4.9815e-01,  3.2795e-01, -3.6535e-01, -1.1720e+00,
          -9.5449e-01,  1.1918e-01, -1.4111e-01,  1.2857e-01, -8.4649e-02,
           6.2601e-01,  5.1609e-01,  2.1296e-02, -2.5004e-01, -2.8920e-01,
           1.9056e-01,  5.0068e-01,  1.3280e+00,  1.4539e-01, -4.7477e-01,
           4.5976e-01,  1.8784e-01, -1.0351e+00, -3.9315e-01,  6.8481e-01,
           1.8137e-02, -1.3800e+00,  8.1879e-01, -1.7342e-01, -5.0790e-01,
          -3.4193e-01,  5.8718e-02, -5.3629e-02,  1.0432e+00, -2.5643e-02,
          -1.2338e-01,  1.8291e+00,  1.8183e-01, -3.4192e-02, -1.0693e+00,
         

In [12]:
determ_state = preds["determ_state"]
stoch_state_prior = preds["stoch_state_prior"]
stoch_state_posterior = preds["stoch_state_posterior"]

#### Reconstruction

In [13]:
recon_proprio = proprio_decoder(determ_state, stoch_state_posterior['stoch_state'])
recon_vision = vision_decoder(determ_state, stoch_state_posterior['stoch_state'])

print("Proprioception reconstruction size : ", recon_proprio.size())
print("Vision reconstruction size : ", recon_vision.size())

Proprioception reconstruction size :  torch.Size([1, 4])
Vision reconstruction size :  torch.Size([1, 3, 64, 64])


### Losses

reconstruction vision : sum over image, then mean over batch / time channel

In [14]:
recon_vision_loss = F.mse_loss(recon_vision, x_vision, reduction='none').sum((1, 2, 3)).mean(0)

In [15]:
recon_vision_loss

tensor(12425.9229, grad_fn=<MeanBackward1>)

In [16]:
recon_proprio_loss = F.mse_loss(recon_proprio, x_proprio, reduction='none').sum(1).mean(0)

In [17]:
recon_proprio_loss

tensor(6.1588, grad_fn=<MeanBackward1>)