In [None]:
import gymnasium as gym
import numpy as np
import torch
from einops import einsum, pack, rearrange, repeat
from torch import nn

from stuff import Perceptron, TransitionModel

In [None]:
# Load data from data.npz
data = np.load('data.npz')
torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32, device='cuda') for x in xs]
# torchify = lambda *xs: [torch.tensor(x, dtype=torch.float32) for x in xs]
observations, actions = torchify(data['observations'], data['actions'])
# observations, actions = observations[:4096], actions[:4096]

In [None]:
# Define networks

state_dim = observations.shape[-1]
action_dim = actions.shape[-1]

state_encoder = Perceptron(state_dim, [32, 64, 32], state_dim).cuda()
action_encoder = Perceptron(action_dim + state_dim, [32, 32, 32], action_dim).cuda()
transition_model = TransitionModel(2, 4, 64).cuda()
state_decoder = Perceptron(state_dim, [32, 64, 32], state_dim).cuda()
action_decoder = Perceptron(action_dim + state_dim, [32, 32, 32], action_dim).cuda()

nets = [
    state_encoder,
    action_encoder,
    transition_model,
    state_decoder,
    action_decoder,
]

In [None]:
class TransitionLoss(nn.Module):
    def __init__(self, loss_fn=nn.MSELoss()):
        super().__init__()
        self.loss_fn = loss_fn

    def forward(self, latent_fut_states_prime, latent_fut_states_gt, mask):
        masked_latent_fut_states_prime = torch.where(
            mask[..., None], latent_fut_states_gt, latent_fut_states_prime
        )
        loss = self.loss_fn(masked_latent_fut_states_prime, latent_fut_states_gt)
        return loss

In [None]:
class SmoothnessLoss(nn.Module):
    def __init__(self, norm_p=1, discount=0.99):
        super().__init__()
        self.discount = discount
        self.norm_p = norm_p

    def forward(
        self,
        latent_actions,
        latent_next_states,
        latent_actions_perturbed,
        latent_next_states_perturbed,
        mask,
    ):
        action_diffs = latent_actions_perturbed - latent_actions
        action_dists = torch.norm(action_diffs, p=self.norm_p, dim=-1)

        state_diffs = latent_next_states_perturbed - latent_next_states
        state_dists = torch.norm(state_diffs, p=self.norm_p, dim=-1)

        future_indices = torch.cumsum(~mask, dim=-1, dtype=torch.float32)
        future_discounts = self.discount**future_indices
        dist_limits = action_dists / future_discounts
        state_violations = torch.relu(state_dists - dist_limits)
        losses = state_violations**2
        masked_losses = torch.where(mask, 0, losses)

        return masked_losses.mean()

In [None]:
class CoverageLoss(nn.Module):
    def __init__(
        self,
        state_space_size,
        action_space_size,
        latent_samples=2048,
        space_ball_p=1,
        selection_tail_size=4,
        far_sample_count=64,
        pushing_sample_size=4,
    ):
        super().__init__()
        self.state_space_size = state_space_size
        self.action_space_size = action_space_size

        self.latent_samples = latent_samples
        self.space_ball_p = space_ball_p
        self.selection_tail_size = selection_tail_size
        self.far_sample_count = far_sample_count
        self.pushing_sample_size = pushing_sample_size

    def forward(self, latent_states, latent_actions):
        latent_states = rearrange(latent_states, "... e -> (...) e")
        latent_actions = rearrange(latent_actions, "... e -> (...) e")

        state_norms = torch.norm(latent_states, p=self.space_ball_p, dim=-1)
        action_norms = torch.norm(latent_actions, p=self.space_ball_p, dim=-1)

        state_violations = torch.relu(state_norms - self.state_space_size)
        action_violations = torch.relu(action_norms - self.action_space_size)

        state_size_violations = state_violations**2
        action_size_violations = action_violations**2

        state_size_loss = state_size_violations.mean()
        action_size_loss = action_size_violations.mean()

        # penalize for empty space within the state space
        # Sample random points in the latent space
        if self.space_ball_p != 1:
            raise NotImplementedError("Only L1 norm is supported :(")

        state_space_samples = (
            torch.rand(
                self.latent_samples,
                latent_states.shape[-1],
                device=latent_states.device,
            )
            * 2
            - 1
        ) * self.state_space_size
        action_space_samples = (
            torch.rand(
                self.latent_samples,
                latent_actions.shape[-1],
                device=latent_actions.device,
            )
            * 2
            - 1
        ) * self.action_space_size

        # Find the state_space that is the farthest from any of the latent_states
        state_space_dists = torch.cdist(state_space_samples, latent_states, p=1)
        state_space_dist_tail = (
            state_space_dists.sort(dim=-1)
            .values[:, : self.selection_tail_size]
            .mean(dim=-1)
        )
        farthest_idxs = state_space_dist_tail.argsort(descending=True)[
            : self.far_sample_count
        ]
        farthest_state_samples = state_space_samples[farthest_idxs]

        action_space_dists = torch.cdist(action_space_samples, latent_actions, p=1)
        action_space_dist_tail = (
            action_space_dists.sort(dim=-1)
            .values[:, : self.selection_tail_size]
            .mean(dim=-1)
        )
        farthest_idxs = action_space_dist_tail.argsort(descending=True)[
            : self.far_sample_count
        ]
        farthest_action_samples = action_space_samples[farthest_idxs]

        # Now make the states by the farthest latent states closer to the farthest samples
        # Maybe in the future make just the few closest ones closer
        empty_state_space_dists = torch.cdist(
            farthest_state_samples, latent_states, p=1
        )
        close_empty_state_space_dists = empty_state_space_dists.sort(dim=-1).values[
            :, : self.pushing_sample_size
        ]
        state_coverage_losses = close_empty_state_space_dists**2

        empty_action_space_dists = torch.cdist(
            farthest_action_samples, latent_actions, p=1
        )
        close_empty_action_space_dists = empty_action_space_dists.sort(dim=-1).values[
            :, : self.pushing_sample_size
        ]
        action_coverage_losses = close_empty_action_space_dists**2

        return (
            state_size_loss
            + action_size_loss
            + state_coverage_losses.mean()
            + action_coverage_losses.mean()
        )

In [None]:
action_mse = nn.MSELoss()
state_mse = nn.MSELoss()
transition_loss_func = TransitionLoss()

smoothness_loss_func = SmoothnessLoss()

coverage_loss_func = torch.compile(
    CoverageLoss(
        state_space_size=1.5,
        action_space_size=1.75,
        latent_samples=16_384,
        # latent_state_samples=1024,
        # latent_action_samples=1024,
    )
)

optimizer = torch.optim.AdamW(
    [param for net in nets for param in net.parameters()], lr=1e-2
)
# lr_scheduler = torch.optim.lr_scheduler.LinearLR(
#     optimizer, start_factor=1, end_factor=1e-2, total_iters=1024
# )
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.995,
    last_epoch=-1,
)

encoder_batch_size = 4096
transition_batch_size = 32

for i in range(1024):

    # Yoink a batch of data
    encoder_batch_raveled_inds = torch.randperm(np.prod(observations.shape[:-1]))[
        :encoder_batch_size
    ]
    encoder_batch_inds = torch.unravel_index(
        encoder_batch_raveled_inds, observations.shape[:-1]
    )

    transition_traj_batch_inds = torch.randperm(observations.shape[0], device="cuda")[
        :transition_batch_size
    ]
    transition_start_inds = torch.randint(
        0, int(observations.shape[-2] // 1.1), (transition_batch_size,), device="cuda"
    )

    state_batch = observations[encoder_batch_inds]
    action_batch = actions[encoder_batch_inds]

    starting_states = observations[transition_traj_batch_inds, transition_start_inds]
    state_traj_batch = observations[transition_traj_batch_inds]
    action_traj_batch = actions[transition_traj_batch_inds]

    # Now do a forward pass
    optimizer.zero_grad()

    latent_states = state_encoder(state_batch)
    latent_actions = action_encoder(torch.cat([action_batch, state_batch], dim=-1))

    reconstructed_states = state_decoder(latent_states)
    reconstructed_actions = action_decoder(
        torch.cat([latent_actions, latent_states], dim=-1)
    )

    state_reconstruction_loss = state_mse(reconstructed_states, state_batch)
    action_reconstruction_loss = action_mse(reconstructed_actions, action_batch)

    latent_start_states = state_encoder(starting_states).detach()
    latent_traj_actions = action_encoder(
        torch.cat([action_traj_batch, state_traj_batch], dim=-1)
    ).detach()
    latent_fut_states_prime, mask = transition_model(
        latent_start_states,
        latent_traj_actions,
        start_indices=transition_start_inds,
        return_mask=True,
    )
    latent_fut_states_gt = state_encoder(state_traj_batch)

    perturbations = torch.randn_like(latent_traj_actions)
    perturbations = perturbations / torch.norm(perturbations, p=1, dim=-1, keepdim=True)
    perturbations = perturbations * torch.rand(
        (*perturbations.shape[:-1], 1), device="cuda"
    )

    latent_traj_actions_perturbed = latent_traj_actions + perturbations
    # Normalize the perturbations if they are too large
    perturbed_action_norms = torch.norm(
        latent_traj_actions_perturbed, p=1, dim=-1, keepdim=True
    )
    latent_traj_actions_perturbed = torch.where(
        perturbed_action_norms > coverage_loss_func.action_space_size,
        latent_traj_actions_perturbed
        * coverage_loss_func.action_space_size
        / perturbed_action_norms,
        latent_traj_actions_perturbed,
    )
    latent_fut_states_prime_perturbed = transition_model(
        latent_start_states,
        latent_traj_actions_perturbed,
        start_indices=transition_start_inds,
    )

    transition_loss = transition_loss_func(
        latent_fut_states_prime, latent_fut_states_gt, mask
    )
    smoothness_loss = smoothness_loss_func(
        latent_traj_actions,
        latent_fut_states_prime,
        latent_traj_actions_perturbed,
        latent_fut_states_prime_perturbed,
        mask,
    )
    coverage_loss = coverage_loss_func(latent_states, latent_actions)

    loss = (
        state_reconstruction_loss
        + action_reconstruction_loss
        + transition_loss * 0.1
        + smoothness_loss
        + coverage_loss * 0.01
    )

    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    if i % 64 == 0:
        print(
            f"Iteration {i}, total loss: {loss.item()}, lr: {lr_scheduler.get_last_lr()}, transition loss: {transition_loss.item()}, smoothness loss: {smoothness_loss.item()}, coverage loss: {coverage_loss.item()}, state reconstruction loss: {state_reconstruction_loss.item()}, action reconstruction loss: {action_reconstruction_loss.item()}"
        )

# Delete everything related to optimization
optimizer.zero_grad()
del optimizer
del lr_scheduler

In [None]:
# Save Networks

torch.save(state_encoder, "trained_net_params/state_encoder.pt")
torch.save(action_encoder, "trained_net_params/action_encoder.pt")
torch.save(transition_model, "trained_net_params/transition_model.pt")
torch.save(state_decoder, "trained_net_params/state_decoder.pt")
torch.save(action_decoder, "trained_net_params/action_decoder.pt")

In [None]:
# Get stdev of encoded states and actions across each element
with torch.no_grad():
    encoded_states = rearrange(state_encoder(observations), "... e -> (...) e")
    state_std = torch.std(encoded_states, dim=0)

    encoded_actions = rearrange(
        action_encoder(torch.cat([actions, observations], dim=-1)), "... e -> (...) e"
    )
    
    recovered_states = state_decoder(encoded_states)
    recovered_actions = action_decoder(torch.cat([encoded_actions, encoded_states], dim=-1))
    
    action_std = torch.std(encoded_actions, dim=0)

In [None]:
state_std, action_std

(
    tensor([0.4034, 0.8038, 0.6741, 0.7653], device="cuda:0"),
    tensor([1.0077, 1.0054], device="cuda:0"),
)

In [None]:
recovered_states - rearrange(observations, "... e -> (...) e"), recovered_actions - rearrange(actions, "... e -> (...) e")

In [None]:
torch.stack([recovered_states, rearrange(observations, "... e -> (...) e")], dim=-1)