Skip to content

Commit

Permalink
added RSSM posterior
Browse files Browse the repository at this point in the history
  • Loading branch information
Kara Moraw committed Sep 30, 2021
1 parent 76f4337 commit 19d1432
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
27 changes: 26 additions & 1 deletion lab/models/dynamics_models/recurrent_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def _prior(self, prev_state, prev_action):

def _posterior(self, prev_state, prev_action, emb_observation):
"""s_t ~ q(s_t | h_t, e_t)"""
pass
prior = self._prior(prev_state, prev_action)
post_mean, post_stddev = self.posterior_model(prior["det_state"], emb_observation)
post_state = post_mean + post_stddev*torch.randn_like(post_mean) # sample from normal distribution
return {"det_state": prior["det_state"], "stoch_state": post_state, "mean": post_mean, "stddev": post_stddev}

def dec(self, state):
"""o_t ~ p(o_t | h_t, s_t)"""
Expand Down Expand Up @@ -55,3 +58,25 @@ def forward(self, state, action):
mean, stddev = output.chunk(2, dim=-1)
stddev = nn.functional.softplus(stddev) + self.min_stddev
return det_state, mean, stddev

class RSSMPosterior(nn.Module):
def __init__(self, min_stddev, state_size, embedded_size, hidden_size=None):
super().__init__()
self.input_size = state_size["det_state"] + embedded_size
if hidden_size is None:
hidden_size = 2*state_size["stoch_state"]
self.min_stddev = min_stddev
self.linear_relu_stack = nn.Sequential(
nn.Linear(self.input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 2*state_size["stoch_state"])
)

def forward(self, det_state, embedded):
input = torch.cat((det_state, embedded), dim=1)
output = self.linear_relu_stack(input)
mean, stddev = output.chunk(2, dim=-1)
stddev = nn.functional.softplus(stddev) + self.min_stddev
return mean, stddev
20 changes: 20 additions & 0 deletions tests/test_rssm_posterior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from lab.models.dynamics_models.recurrent_model import RSSMPosterior

min_stddev = 0.1
state_size = {"stoch_state":30, "det_state":25}
embedded_size = 15

posterior_model = RSSMPosterior(min_stddev, state_size, embedded_size)

stoch_state = torch.randn(state_size["stoch_state"]).unsqueeze(0)
det_state = torch.randn(state_size["det_state"]).unsqueeze(0)
prev_state = {"stoch_state": stoch_state, "det_state": det_state}
emb_observation = torch.randn(embedded_size).unsqueeze(0)

post_mean, post_stddev = posterior_model(prev_state["det_state"], emb_observation)
post_state = post_mean + post_stddev*torch.randn_like(post_mean) # sample from normal distribution
state = {"det_state": prev_state["det_state"], "stoch_state": post_state, "mean": post_mean, "stddev": post_stddev}

for k in state.keys():
print("Shape of {}: {}".format(k, state[k].shape))

0 comments on commit 19d1432

Please sign in to comment.