In [None]:
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.envs.transforms import CatTensors
from omni_drones.learning.dreamer import models, networks
from omni_drones.learning.modules.networks import MLP
from omni_drones.learning.modules.distributions import (
    MultiOneHotCategorical, IndependentNormalModule
)
from omni_drones.learning.common import make_encoder

from torchrl.data import TensorSpec
from typing import Sequence

class ObsEncoder(nn.Module):
    def __init__(
        self,
        encoder_cfg,
        observation_spec: TensorSpec,
        embed_dim: int,
        deter_dim: int,
        stat_shape: torch.Size,
        units: Sequence[int],
    ) -> None:
        super().__init__()
        self.encoder = make_encoder(encoder_cfg, observation_spec)
        self.obs_out_layers = MLP(
            [deter_dim + embed_dim] + units, 
            nn.LayerNorm
        )
        self.stat_shape = torch.Size(stat_shape)
        self.stat_layer = nn.Linear(units[-1], self.stat_shape.numel())

    def forward(self, obs, deter):
        embed = self.encoder(obs)
        x = torch.cat([deter, embed], dim=-1)
        x = self.obs_out_layers(x)
        stat = self.stat_layer(x).unflatten(-1, self.stat_shape)
        return stat


class ObsDecoder(nn.Module):
    def __init__(
        self,
        observation_spec: TensorSpec,
        latent_dim: int
    ) -> None:
        super().__init__()
        self.decoder = MLP(
            [latent_dim] + [observation_spec.shape[-1]],
            nn.LayerNorm
        )
    
    def forward(self, latent):
        x = self.decoder(latent)
        return x 


class SampleFromDist(nn.Module):
    def __init__(
        self, 
        dist_cls = None,
    ) -> None:
        super().__init__()
        self.dist_cls = dist_cls
    
    def forward(self, *args):
        if self.dist_cls is not None:
            dist = self.dist_cls(*args)
        elif len(args) == 0:
            dist = args[0]
        else:
            raise ValueError
        samples = dist.rsample()
        log_probs = dist.log_prob(samples)
        entropy = dist.entropy()
        return samples, log_probs, entropy


def make_dreamer(
    config,
    observation_spec: TensorSpec,
    action_spec: TensorSpec
):
    deter_dim: int
    stoch_dim: int
    discrete_dim: int
    action_dim = action_spec.shape[-1]

    obs_encoder = ObsEncoder(
        config.encoder,
        observation_spec,
        embed_dim=config.obs_embed_dim,
        deter_dim=deter_dim,
        stat_shape=(stoch_dim, discrete_dim),
        units=config.porj_units,
    )
    sequence_model = MLP(
        [stoch_dim * discrete_dim + action_dim] + config.sequence_model.units, 
        nn.LayerNorm
    )
    dynamics_predictor = MLP(
        [deter_dim] + config.dynamics_pred.units, 
        nn.LayerNorm
    )

    feat_size = stoch_dim * discrete_dim + deter_dim
    task_actor = nn.Sequential(
        MLP(
            [feat_size] + config.actor.units, 
            nn.LayerNorm
        ),
        IndependentNormalModule(config.actor.units, action_dim),
        SampleFromDist()
    )

    if config.value_head == "twohot_symlog":
        value = MLP(
            [feat_size] + config.value_pred.units + [255],
            normalization=nn.LayerNorm
        )
    else:
        value = MLP(
            [feat_size] + config.value_pred.units + [1],
            normalization=nn.LayerNorm
        )

    obs_decoder = ObsDecoder(

    )

    if config.reward_head == "twohot_symlog":
        reward_predictor = MLP(
            [feat_size] + config.reward_pred.units + [255],
            normalization=nn.LayerNorm
        )
    else:
        reward_predictor = MLP(
            [feat_size] + config.reward_pred.units + [1],
            normalization=nn.LayerNorm
        )
    
    discount_predictor = reward_predictor = MLP(
        [feat_size] + config.discont_pred.units + [1],
        normalization=nn.LayerNorm
    )

    policy_modules = TensorDictSequential(
        TensorDictModule(obs_encoder, ["obs", "deter"], ["post_logit"]),
        TensorDictModule(
            SampleFromDist(MultiOneHotCategorical), 
            ["post_logit"], ["post_stoch", "_", "_"]
        ),
        CatTensors(["deter", "post_stoch"], ["latent"], del_keys=False),
        TensorDictModule(task_actor, ["latent"], ["action", "log_prob", "entropy"]),
        TensorDictModule(value, ["latnet"], ["state_value"]),
        TensorDictModule(sequence_model, ["deter", "action"], [("next", "deter")]),
    )

    train_modules = TensorDictSequential(
        TensorDictModule(dynamics_predictor, ["deter"], ["prior_logit"]),
        TensorDictModule(
            SampleFromDist(MultiOneHotCategorical), 
            ["prior_logit"], ["prior_stoch", "_", "_"]
        ),
        TensorDictModule(obs_decoder, ["latnet"], ["obs_pred"]),
        TensorDictModule(reward_predictor, ["latnet"], ["reward_pred"]),
        TensorDictModule(discount_predictor, ["latnet"], ["discount_pred"]),
    )
    
    world_model = nn.ModuleDict({
        "obs_encoder": obs_encoder,
        "dynamics_predictor": dynamics_predictor,
        "sequence_model": sequence_model,
        "reward_predictor": reward_predictor,
        "discount_predictor": discount_predictor,
        "obs_decoder": obs_decoder
    })

    actor_ctiric = nn.ModuleDict({
        "actor": task_actor,
        "critic": value
    })

    return policy_modules, train_modules, world_model, actor_ctiric

: 

In [None]:
import torch.distributions as D

loc = torch.zeros(32, 4, 16)
scale = torch.ones_like(loc)
d = D.Independent(D.Normal(loc, scale), 1)