In [1]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v3_small
import matplotlib.pyplot as plt

from tensordict import tensordict, MemmapTensor
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.envs.transforms import CatTensors
from torchrl.data import TensorSpec, CompositeSpec

# from omni_drones.learning.utils.distributions import IndependentNormal
# from omni_drones.learning.utils.network import MLP
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.replay_buffers import TensorDictReplayBuffer

buffer = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(20, scratch_dir="tmp")
)

  from .autonotebook import tqdm as notebook_tqdm
  "Constructing replay buffer without specifying behaviour is no longer "


In [2]:
data = torch.load("trajectories/0.pth")

In [15]:
class Actor(nn.Module):

    def forward(self, loc, scale):
        dist = IndependentNormal(loc, scale)
        action = dist.sample()
        logp = dist.log_prob(action)
        return action, logp

class ActorEval(nn.Module):

    def forward(self, action, loc, scale):
        dist = IndependentNormal(loc, scale)
        logp = dist.log_prob(action)
        entropy = dist.entropy()
        return logp, entropy

def make_model(
    input_spec: CompositeSpec,
    action_spec: TensorSpec,
    device: torch.device
):
    visual_encoder = nn.Sequential(
        nn.Conv2d(1, 3, 1),
        mobilenet_v3_small(num_classes=128)
    )
    state_input_shape = input_spec["state"].shape
    state_encoder = MLP([state_input_shape[-1], 128], normalization=nn.LayerNorm)
    encoder = TensorDictSequential(
        TensorDictModule(visual_encoder, [("drone.obs", "distance_to_camera")], ["visual_feature"]),
        TensorDictModule(state_encoder, [("drone.obs", "state")], ["state_feature"]),
        TensorDictModule(CatTensors(["state_feature", "visual_feature"])),
        TensorDictModule(MLP([128 + 128, 128, action_spec.shape[-1] * 2])),
        TensorDictModule(NormalParamExtractor(), ["observation_vector"], ["loc", "scale"]),
    ).to(device)
    actor = TensorDictSequential(
        encoder,
        TensorDictModule(Actor(), ["loc", "scale"], ["action", "logp"])
    )
    actor_eval = TensorDictSequential(
        encoder,
        TensorDictModule(ActorEval(), ["drone.action", "loc", "scale"], ["logp", "entropy"])
    )
    return actor, actor_eval 

actor, actor_eval = make_model()