In [1]:
import gymnasium as gym
import numpy as np
import torch
from einops import einsum, pack, rearrange, repeat
from torch import nn

In [2]:
# Load data from data.npz
data = np.load('data.npz')
torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32, device='cuda') for x in xs]
# torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32) for x in xs]
observations, actions = torchify(data['observations'], data['actions'])
# observations, actions = observations[:4096], actions[:4096]

In [3]:
class Perceptron(nn.Module):
    def __init__(self, input_dim, layer_sizes, output_dim):
        super().__init__()
        layer_sizes = [input_dim] + layer_sizes + [output_dim]
        self.layers = nn.ModuleList(
            [
                nn.Linear(layer_sizes[i], layer_sizes[i + 1])
                for i in range(len(layer_sizes) - 1)
            ]
        )

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        x = self.layers[-1](x)
        return x

In [4]:
class TransitionModel(nn.Module):
    def __init__(self, act_dim, state_dim, latent_dim, pe_wavelength_range=[1, 1024]):
        super().__init__()
        n_layers = 3
        self.sa_layers = nn.ModuleList(
            [
                nn.MultiheadAttention(latent_dim, 4, batch_first=True)
                for _ in range(n_layers)
            ]
        )
        self.up_scales = nn.ModuleList(
            [nn.Linear(latent_dim, latent_dim * 4) for _ in range(n_layers)]
        )
        self.down_scales = nn.ModuleList(
            [nn.Linear(latent_dim * 4, latent_dim) for _ in range(n_layers)]
        )
        self.up_scale = nn.Linear(state_dim + act_dim, latent_dim)
        self.down_scale = nn.Linear(latent_dim, state_dim)

        self.pe_wavelength_range = pe_wavelength_range

    def forward(self, initial_state, actions, start_indices=None, return_mask=False):
        # Concatenate actions to initial_state
        x = torch.cat([repeat(initial_state, "... e -> ... r e", r=actions.shape[-2]), actions], dim=-1)
        x = torch.relu(self.up_scale(x))

        if start_indices is None:
            start_indices = torch.zeros(
                initial_state.shape[0], dtype=torch.long, device=initial_state.device
            )

        embed_dim = x.shape[-1]
        # Do a log range of frequencies
        pe_freqs = 1 / torch.logspace(
            np.log(self.pe_wavelength_range[0]),
            np.log(self.pe_wavelength_range[1]),
            embed_dim // 2,
            base=2,
            device=x.device,
        )
        batch_size = x.shape[0]
        seq_len = actions.shape[-2]
        # Compute the positional encoding
        times = torch.arange(seq_len, device=x.device) - start_indices[..., None]
        pe_freq_mat = einsum(pe_freqs, times, "w, ... i -> i w")
        pe = torch.cat([torch.sin(pe_freq_mat), torch.cos(pe_freq_mat)], dim=-1)
        x = x + pe

        # Make a mask out to mask out the past
        mask = torch.zeros(batch_size, seq_len, device=x.device, dtype=torch.bool)
        mask[torch.arange(batch_size), start_indices] = True
        mask = ~mask.cumsum(dim=-1) > 0

        big_mask = repeat(mask, "i seq ... -> (i heads) seq seq_also ...", heads=4, seq_also=seq_len)

        for up_scale, sa_layer, down_scale in zip(
            self.up_scales, self.sa_layers, self.down_scales
        ):
            x = x + sa_layer(x, x, x, attn_mask=big_mask, need_weights=False)[0]
            z = torch.relu(up_scale(x))
            x = x + down_scale(z)

        x = self.down_scale(x)

        if return_mask:
            return x, mask
        else:
            return x

In [5]:
# Define networks

state_dim = observations.shape[-1]
action_dim = actions.shape[-1]

state_encoder = Perceptron(state_dim, [32, 32, 32], state_dim).cuda()
action_encoder = Perceptron(action_dim + state_dim, [32, 32, 32], action_dim).cuda()
transition_model = TransitionModel(2, 4, 64).cuda()
state_decoder = Perceptron(state_dim, [32, 32, 32], state_dim).cuda()
action_decoder = Perceptron(action_dim + state_dim, [32, 32, 32], action_dim).cuda()

nets = [
    state_encoder,
    action_encoder,
    transition_model,
    state_decoder,
    action_decoder,
]

In [6]:
class TransitionLoss(nn.Module):
    def __init__(self, loss_fn=nn.MSELoss()):
        super().__init__()
        self.loss_fn = loss_fn

    def forward(self, latent_fut_states_prime, latent_fut_states_gt, mask):
        masked_latent_fut_states_prime = torch.where(
            mask[..., None], latent_fut_states_gt, latent_fut_states_prime
        )
        loss = self.loss_fn(masked_latent_fut_states_prime, latent_fut_states_gt)
        return loss

In [7]:
class SmoothnessLoss(nn.Module):
    def __init__(self, norm_p=1, discount=0.99):
        super().__init__()
        self.discount = discount
        self.norm_p = norm_p

    def forward(
        self,
        latent_actions,
        latent_next_states,
        latent_actions_perturbed,
        latent_next_states_perturbed,
        mask,
    ):
        action_diffs = latent_actions_perturbed - latent_actions
        action_dists = torch.norm(action_diffs, p=self.norm_p, dim=-1)

        state_diffs = latent_next_states_perturbed - latent_next_states
        state_dists = torch.norm(state_diffs, p=self.norm_p, dim=-1)

        future_indices = torch.cumsum(~mask, dim=-1, dtype=torch.float32)
        future_discounts = self.discount**future_indices
        dist_limits = action_dists / future_discounts
        state_violations = torch.relu(state_dists - dist_limits)
        losses = state_violations**2
        masked_losses = torch.where(mask, 0, losses)

        return masked_losses.mean()

In [8]:
class CoverageLoss(nn.Module):
    def __init__(
        self,
        state_space_size,
        action_space_size,
        latent_samples=2048,
        space_ball_p=1,
        selection_tail_size=4,
        far_sample_count=64,
        pushing_sample_size=4,
    ):
        super().__init__()
        self.state_space_size = state_space_size
        self.action_space_size = action_space_size

        self.latent_samples = latent_samples
        self.space_ball_p = space_ball_p
        self.selection_tail_size = selection_tail_size
        self.far_sample_count = far_sample_count
        self.pushing_sample_size = pushing_sample_size

    def forward(self, latent_states, latent_actions):
        latent_states = rearrange(latent_states, "... e -> (...) e")
        latent_actions = rearrange(latent_actions, "... e -> (...) e")

        state_norms = torch.norm(latent_states, p=self.space_ball_p, dim=-1)
        action_norms = torch.norm(latent_actions, p=self.space_ball_p, dim=-1)

        state_violations = torch.relu(state_norms - self.state_space_size)
        action_violations = torch.relu(action_norms - self.action_space_size)

        state_size_violations = state_violations**2
        action_size_violations = action_violations**2

        state_size_loss = state_size_violations.mean()
        action_size_loss = action_size_violations.mean()

        # penalize for empty space within the state space
        # Sample random points in the latent space
        if self.space_ball_p != 1:
            raise NotImplementedError("Only L1 norm is supported :(")

        state_space_samples = (
            torch.rand(
                self.latent_samples,
                latent_states.shape[-1],
                device=latent_states.device,
            )
            * 2
            - 1
        ) * self.state_space_size
        action_space_samples = (
            torch.rand(
                self.latent_samples,
                latent_actions.shape[-1],
                device=latent_actions.device,
            )
            * 2
            - 1
        ) * self.action_space_size

        # Find the state_space that is the farthest from any of the latent_states
        state_space_dists = torch.cdist(state_space_samples, latent_states, p=1)
        state_space_dist_tail = (
            state_space_dists.sort(dim=-1)
            .values[:, : self.selection_tail_size]
            .mean(dim=-1)
        )
        farthest_idxs = state_space_dist_tail.argsort(descending=True)[
            : self.far_sample_count
        ]
        farthest_state_samples = state_space_samples[farthest_idxs]

        action_space_dists = torch.cdist(action_space_samples, latent_actions, p=1)
        action_space_dist_tail = (
            action_space_dists.sort(dim=-1)
            .values[:, : self.selection_tail_size]
            .mean(dim=-1)
        )
        farthest_idxs = action_space_dist_tail.argsort(descending=True)[
            : self.far_sample_count
        ]
        farthest_action_samples = action_space_samples[farthest_idxs]

        # Now make the states by the farthest latent states closer to the farthest samples
        # Maybe in the future make just the few closest ones closer
        empty_state_space_dists = torch.cdist(
            farthest_state_samples, latent_states, p=1
        )
        close_empty_state_space_dists = empty_state_space_dists.sort(dim=-1).values[
            :, : self.pushing_sample_size
        ]
        state_coverage_losses = close_empty_state_space_dists**2

        empty_action_space_dists = torch.cdist(
            farthest_action_samples, latent_actions, p=1
        )
        close_empty_action_space_dists = empty_action_space_dists.sort(dim=-1).values[
            :, : self.pushing_sample_size
        ]
        action_coverage_losses = close_empty_action_space_dists**2

        return (
            state_size_loss
            + action_size_loss
            + state_coverage_losses.mean()
            + action_coverage_losses.mean()
        )

In [9]:
action_mse = nn.MSELoss()
state_mse = nn.MSELoss()
transition_loss_func = TransitionLoss()

smoothness_loss_func = SmoothnessLoss()

coverage_loss_func = CoverageLoss(
    state_space_size=1.5,
    action_space_size=1.75,
    # latent_state_samples=1024,
    # latent_action_samples=1024,
)

optimizer = torch.optim.AdamW(
    [param for net in nets for param in net.parameters()], lr=1e-2
)
# lr_scheduler = torch.optim.lr_scheduler.LinearLR(
#     optimizer, start_factor=1, end_factor=1e-2, total_iters=1024
# )
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer, gamma=0.995, last_epoch=-1,
)

encoder_batch_size = 1024
transition_batch_size = 32

for i in range(1024):

    # Yoink a batch of data
    encoder_batch_raveled_inds = torch.randperm(np.prod(observations.shape[:-1]))[
        :encoder_batch_size
    ]
    encoder_batch_inds = torch.unravel_index(
        encoder_batch_raveled_inds, observations.shape[:-1]
    )

    transition_traj_batch_inds = torch.randperm(observations.shape[0], device="cuda")[
        :transition_batch_size
    ]
    transition_start_inds = torch.randint(
        0, int(observations.shape[-2] // 1.1), (transition_batch_size,), device="cuda"
    )

    state_batch = observations[[torch.tensor(inds) for inds in encoder_batch_inds]]
    action_batch = actions[encoder_batch_inds]

    starting_states = observations[transition_traj_batch_inds, transition_start_inds]
    state_traj_batch = observations[transition_traj_batch_inds]
    action_traj_batch = actions[transition_traj_batch_inds]

    # Now do a forward pass
    optimizer.zero_grad()

    latent_states = state_encoder(state_batch)
    latent_actions = action_encoder(torch.cat([action_batch, state_batch], dim=-1))

    reconstructed_states = state_decoder(latent_states)
    reconstructed_actions = action_decoder(
        torch.cat([latent_actions, latent_states], dim=-1)
    )

    state_reconstruction_loss = state_mse(reconstructed_states, state_batch)
    action_reconstruction_loss = action_mse(reconstructed_actions, action_batch)

    latent_start_states = state_encoder(starting_states)
    latent_traj_actions = action_encoder(
        torch.cat([action_traj_batch, state_traj_batch], dim=-1)
    )
    latent_fut_states_prime, mask = transition_model(
        latent_start_states,
        latent_traj_actions,
        start_indices=transition_start_inds,
        return_mask=True,
    )
    latent_fut_states_gt = state_encoder(state_traj_batch)

    perturbations = torch.randn_like(latent_traj_actions)
    perturbations = perturbations / torch.norm(perturbations, p=1, dim=-1, keepdim=True)
    perturbations = perturbations * torch.rand(
        (*perturbations.shape[:-1], 1), device="cuda"
    )

    latent_traj_actions_perturbed = latent_traj_actions + perturbations
    # Normalize the perturbations if they are too large
    perturbed_action_norms = torch.norm(
        latent_traj_actions_perturbed, p=1, dim=-1, keepdim=True
    )
    latent_traj_actions_perturbed = torch.where(
        perturbed_action_norms > coverage_loss_func.action_space_size,
        latent_traj_actions_perturbed
        * coverage_loss_func.action_space_size
        / perturbed_action_norms,
        latent_traj_actions_perturbed,
    )
    latent_fut_states_prime_perturbed = transition_model(
        latent_start_states,
        latent_traj_actions_perturbed,
        start_indices=transition_start_inds,
    )

    transition_loss = transition_loss_func(
        latent_fut_states_prime, latent_fut_states_gt, mask
    )
    smoothness_loss = smoothness_loss_func(
        latent_actions,
        latent_fut_states_prime,
        latent_traj_actions_perturbed,
        latent_fut_states_prime_perturbed,
        mask,
    )
    coverage_loss = coverage_loss_func(latent_states, latent_actions) * 0.01

    loss = (
        state_reconstruction_loss
        + action_reconstruction_loss
        + transition_loss
        + smoothness_loss
        + coverage_loss
    )

    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    if i % 64 == 0:
        print(
            f"Iteration {i}, total loss: {loss.item()}, lr: {lr_scheduler.get_last_lr()}, transition loss: {transition_loss.item()}, smoothness loss: {smoothness_loss.item()}, coverage loss: {coverage_loss.item()}, state reconstruction loss: {state_reconstruction_loss.item()}, action reconstruction loss: {action_reconstruction_loss.item()}"
        )

  state_batch = observations[[torch.tensor(inds) for inds in encoder_batch_inds]]


Iteration 0, total loss: 5.028264999389648, lr: [0.00995], transition loss: 0.3088099956512451, smoothness loss: 1.4508627543818875e-07, coverage loss: 0.32831481099128723, state reconstruction loss: 4.0420355796813965, action reconstruction loss: 0.34910470247268677
Iteration 64, total loss: 2.071960926055908, lr: [0.007219385759785157], transition loss: 0.022135138511657715, smoothness loss: 2.125751450421376e-07, coverage loss: 0.18434357643127441, state reconstruction loss: 1.7123594284057617, action reconstruction loss: 0.1531226933002472
Iteration 128, total loss: 0.6194939017295837, lr: [0.005238143793828013], transition loss: 0.03545884042978287, smoothness loss: 6.955904154892778e-07, coverage loss: 0.12750311195850372, state reconstruction loss: 0.4347015619277954, action reconstruction loss: 0.021829720586538315
Iteration 192, total loss: 0.4290889799594879, lr: [0.003800621177172762], transition loss: 0.041192058473825455, smoothness loss: 1.1217654218853568e-06, coverage l

In [None]:
# Get stdev of encoded states and actions across each element
with torch.no_grad():
    encoded_states = rearrange(state_encoder(observations), "... e -> (...) e")
    state_std = torch.std(encoded_states, dim=0)

    encoded_actions = rearrange(
        action_encoder(torch.cat([actions, observations], dim=-1)), "... e -> (...) e"
    )
    
    recovered_states = state_decoder(encoded_states)
    recovered_actions = action_decoder(torch.cat([encoded_actions, encoded_states], dim=-1))
    
    action_std = torch.std(encoded_actions, dim=0)

In [None]:
state_std, action_std

(tensor([0.1045, 0.1314, 0.1917, 0.2648], device='cuda:0'),
 tensor([0.8885, 0.8942], device='cuda:0'))

(
    tensor([0.4034, 0.8038, 0.6741, 0.7653], device="cuda:0"),
    tensor([1.0077, 1.0054], device="cuda:0"),
)

In [None]:
recovered_states - rearrange(observations, "... e -> (...) e"), recovered_actions - rearrange(actions, "... e -> (...) e")

(tensor([[-0.0330,  0.0288,  0.0944,  0.1882],
         [-0.0304,  0.0295,  0.1164, -0.0094],
         [-0.0265,  0.0345,  0.0065, -0.1179],
         ...,
         [ 0.0164, -0.0168, -0.7266, -0.0929],
         [ 0.0190,  0.0030, -0.7976, -0.3223],
         [ 0.0162, -0.0236, -0.5972,  0.0540]], device='cuda:0'),
 tensor([[ 0.0602, -0.0286],
         [-0.0025, -0.0470],
         [ 0.0062,  0.0016],
         ...,
         [ 0.1599, -0.1029],
         [-0.0293,  0.0160],
         [-0.0170,  0.0015]], device='cuda:0'))

In [None]:
torch.stack([recovered_states, rearrange(observations, "... e -> (...) e")], dim=-1)

tensor([[[ 1.3718,  1.4048],
         [ 2.8031,  2.7743],
         [ 0.0139, -0.0805],
         [ 0.0267, -0.1615]],

        [[ 1.3733,  1.4038],
         [ 2.8042,  2.7747],
         [ 0.0139, -0.1025],
         [ 0.0267,  0.0362]],

        [[ 1.3773,  1.4038],
         [ 2.8106,  2.7761],
         [ 0.0140,  0.0076],
         [ 0.0269,  0.1447]],

        ...,

        [[ 1.9412,  1.9248],
         [-0.6199, -0.6031],
         [-0.0166,  0.7100],
         [-0.0083,  0.0847]],

        [[ 1.9516,  1.9326],
         [-0.5970, -0.5999],
         [-0.0159,  0.7817],
         [-0.0088,  0.3135]],

        [[ 1.9546,  1.9384],
         [-0.6242, -0.6006],
         [-0.0165,  0.5807],
         [-0.0084, -0.0624]]], device='cuda:0')

In [None]:
class ActorPolicy(nn.Module):
    def __init__(self, action_dim, horizon=128, iters=64):
        super().__init__()
        self.action_dim = action_dim
        self.horizon = horizon
        self.iters = iters

    def forward(self, state, target_state, prev_latent_action_plan=None):
        if prev_latent_action_plan is None:
            prev_latent_action_plan = torch.randn(
                state.shape[0], self.horizon, self.action_dim, device=state.device
            )
            prev_latent_action_plan = prev_latent_action_plan / torch.norm(
                prev_latent_action_plan, p=1, dim=-1, keepdim=True
            )
            prev_latent_action_plan = prev_latent_action_plan * torch.rand(
                (*prev_latent_action_plan.shape[:-1], 1), device=state.device
            )
            prev_latent_action_plan = (
                prev_latent_action_plan * coverage_loss_func.action_space_size
            )

        latent_action_plan = prev_latent_action_plan.clone()

        latent_state = state_encoder(state)
        latent_target_state = state_encoder(target_state)

        optim = torch.optim.Adam([latent_action_plan], lr=1e-2)

        for i in range(self.iters):
            optim.zero_grad()
            latent_fut_states = transition_model(latent_state, latent_action_plan)
            loss = state_mse(latent_fut_states, latent_target_state)
            loss.backward()
            optim.step()

        next_action = action_decoder(
            torch.cat([latent_action_plan, latent_state], dim=-1)
        )
        return next_action, latent_action_plan

In [None]:
# Now let's optimize a trajectory

state_encoder.eval()
action_encoder.eval()
transition_model.eval()
state_decoder.eval()
action_decoder.eval()

initial_state = observations[0, 0]

actions = torch.randn(32, 64, 2, device="cuda")
actions = actions / torch.norm(actions, p=1, dim=-1, keepdim=True)
actions = (
    actions
    * torch.rand((32, 64, 1), device="cuda")
    * coverage_loss_func.action_space_size
)

target_state = observations[0, -1]

actor = ActorPolicy(2, 128).cuda()

env = gym.make("PointMaze_Large-v3")
latent_action_plan = None

for i in range(1024):
    next_action, latent_action_plan = actor(initial_state[None], target_state[None], latent_action_plan)
    print(f"Iteration {i}, loss: {loss.item()}")

  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.