# Imports

In [1]:
import functools
import os
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm, trange

from diffusion_policy.common.normalize_util import get_image_range_normalizer
from diffusion_policy.common.pytorch_util import compute_conv_output_shape
from diffusion_policy.common.sampler import get_val_mask
from diffusion_policy.dataset.pusht_image_dataset import PushTImageDataset
from diffusion_policy.model.common.normalizer import (
    LinearNormalizer,
    SingleFieldLinearNormalizer,
)
from vae.pusht_vae import VanillaVAE

# Load dataset

In [2]:
path = "/nas/ucb/ebronstein/lsdp/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
dataset = PushTImageDataset(path)

In [3]:
episode = dataset.replay_buffer.get_episode(0)
for key, value in episode.items():
    print(key, value.shape)

In [4]:
# Visualize subset of the episode
episode_len = episode["img"].shape[0]
time_steps = np.linspace(0, episode_len - 1, 25).astype(int)
# Plot the first few images and actions in the episode
fig, axs = plt.subplots(5, 5, figsize=(20, 20))
for i, step in enumerate(time_steps):
    ax = axs[i // 5, i % 5]
    ax.imshow(episode["img"][step] / 255.0)
    ax.set_title(episode["action"][step])
    ax.axis("off")

In [5]:
episode["state"].min(axis=0), episode["state"].max(axis=0)

# Episode data sanity check

In [None]:
episode = dataset.replay_buffer.get_episode(0)

In [None]:
states = episode["state"]
actions = episode["action"]

norm_states = normalize_pn1(states, min_state, max_state)
norm_actions = normalize_pn1(actions, min_action, max_action)

In [None]:
offset = 2
plt.plot(states[offset:, :2], label="state")
plt.plot(actions[:-offset], label="action")
plt.legend()
plt.show()

In [None]:
for offset in range(1, 21):
    diff = states[offset:, :2] - actions[:-offset]
    norm_diff = norm_states[offset:, :2] - norm_actions[:-offset]

    mse = (diff**2).sum(axis=-1).mean()
    me = np.linalg.norm(diff, axis=-1).mean()

    norm_mse = (norm_diff**2).mean()
    norm_me = np.linalg.norm(norm_diff, axis=-1).mean()
    print(
        f"Offset {offset}: MSE {mse}, normalized MSE {norm_mse}, ME {me}, normalized ME {norm_me}"
    )

# Load VAE

In [6]:
# Load the VAE
img_data = torch.from_numpy(dataset.replay_buffer["img"]).permute(0, 3, 1, 2).float()
N, C, H, W = img_data.shape
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = VanillaVAE(in_channels=C, in_height=H, in_width=W, latent_dim=32).to(device)
save_dir = "models/pusht_vae"
vae.load_state_dict(torch.load(os.path.join(save_dir, "vae_32_20240403.pt")))

# Encode dataset with VAE

In [19]:
def encode_images(img, img_normalizer, vae, device):
    img = img_normalizer(img / 255.0)
    with torch.no_grad():
        mu, log_var = vae.encode(img.to(device))
    return mu.cpu().detach().numpy()

In [20]:
# Encode the full dataset by batches to avoid CUDA OOM.
img_normalizer = get_image_range_normalizer()
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mu_list = []
vae.eval()
for i in trange(0, N, batch_size):
    img_batch = img_data[i : i + batch_size]
    batch_mu = encode_images(img_batch, img_normalizer, vae, device)
    mu_list.append(batch_mu)

encoded_imgs = np.concatenate(mu_list, axis=0)

# Data normalizers

In [108]:
# Make the latent image normalizer
encoded_imgs_mean = encoded_imgs.mean(axis=0)
encoded_imgs_std = encoded_imgs.std(axis=0)

encoded_imgs_scale = (1.0 / encoded_imgs_std).astype(np.float32)
encoded_imgs_offset = -encoded_imgs_mean / encoded_imgs_std
encoded_imgs_stat = {
    "mean": encoded_imgs_mean,
    "std": encoded_imgs_std,
    "min": encoded_imgs.min(axis=0),
    "max": encoded_imgs.max(axis=0),
}
latent_img_normalizer = SingleFieldLinearNormalizer.create_manual(
    scale=encoded_imgs_scale,
    offset=encoded_imgs_offset,
    input_stats_dict=encoded_imgs_stat,
)

# Make the action normalizer.
max_action = dataset.replay_buffer["action"].max(axis=0)
min_action = np.zeros_like(max_action)

# Make the state normalizer.
max_state = dataset.replay_buffer["state"].max(axis=0)
min_state = np.zeros_like(max_state)


def normalize_pn1(x, min_val, max_val):
    # Normalize to [0, 1]
    nx = (x - min_val) / (max_val - min_val)
    # Normalize to [-1, 1]
    return nx * 2 - 1


def denormalize_pn1(nx, min_val, max_val):
    # Denormalize from [-1, 1]
    x = (nx + 1) / 2
    # Denormalize from [0, 1]
    return x * (max_val - min_val) + min_val

# Episode Dataset

In [144]:
import torch
from torch.utils.data import Dataset


class EpisodeDataset(Dataset):
    def __init__(
        self,
        dataset,
        n_obs_history=1,
        n_pred_horizon=1,
        episode_idxs=None,
        process_img_fn=None,
        device: str = "cpu",
    ):
        """
        Initialize the dataset with the main dataset object that contains
        the replay_buffer. Also, specify the lengths of observation history
        and prediction horizon.
        """
        self.dataset = dataset
        self.n_obs_history = n_obs_history
        self.n_pred_horizon = n_pred_horizon
        self.episode_idxs = list(episode_idxs)
        self.process_img_fn = process_img_fn
        self.device = device
        self.prepare_data()

    def prepare_data(self):
        """
        Preprocess the episodes to create a flat list of samples.
        Each sample is a tuple of dictionaries: (obs_history, pred_horizon).
        """
        self.samples = []

        if self.episode_idxs is None:
            self.episode_idxs = range(self.dataset.replay_buffer.n_episodes)

        for episode_idx in tqdm(self.episode_idxs, desc="Preparing data"):
            episode = self.dataset.replay_buffer.get_episode(episode_idx)
            img = episode["img"].transpose(0, 3, 1, 2)  # CHW format
            if self.process_img_fn is not None:
                img = self.process_img_fn(img)
            actions = torch.tensor(episode["action"], dtype=torch.float32).to(
                self.device
            )
            states = torch.tensor(episode["state"], dtype=torch.float32).to(self.device)

            # Iterate through the episode to create samples with observation history and prediction horizon
            for i in range(len(actions) - self.n_obs_history - self.n_pred_horizon + 1):
                obs_history_imgs = img[i : i + self.n_obs_history]
                obs_history_actions = actions[i : i + self.n_obs_history]
                obs_history_states = states[i : i + self.n_obs_history]

                pred_horizon_imgs = img[
                    i
                    + self.n_obs_history : i
                    + self.n_obs_history
                    + self.n_pred_horizon
                ]
                pred_horizon_actions = actions[
                    i
                    + self.n_obs_history : i
                    + self.n_obs_history
                    + self.n_pred_horizon
                ]
                pred_horizon_states = states[
                    i
                    + self.n_obs_history : i
                    + self.n_obs_history
                    + self.n_pred_horizon
                ]

                obs_history = {
                    "img": obs_history_imgs,
                    "action": obs_history_actions,
                    "state": obs_history_states,
                }
                pred_horizon = {
                    "img": pred_horizon_imgs,
                    "action": pred_horizon_actions,
                    "state": pred_horizon_states,
                }

                self.samples.append((obs_history, pred_horizon))

    def __len__(self):
        """
        Return the total number of samples across all episodes.
        """
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Return the idx-th sample from the dataset.
        """
        obs_history, pred_horizon = self.samples[idx]

        # Convert data to PyTorch tensors and ensure the data type is correct
        # for key, value in obs_history.items():
        #     obs_history[key] = torch.tensor(value, dtype=torch.float32)
        # for key, value in pred_horizon.items():
        #     pred_horizon[key] = torch.tensor(value, dtype=torch.float32)

        return obs_history, pred_horizon

# Make dataloaders

In [54]:
# Make train and val loaders
val_mask = get_val_mask(dataset.replay_buffer.n_episodes, 0.1)
val_idxs = np.where(val_mask)[0]
train_idxs = np.where(~val_mask)[0]

# Make the episode dataset and create a DataLoader.
batch_size = 256
n_obs_history = 5
n_pred_horizon = 0
process_img_fn = functools.partial(
    encode_images, img_normalizer=img_normalizer, vae=vae, device=device
)
train_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=train_idxs,
    process_img_fn=process_img_fn,
)
val_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=val_idxs,
    process_img_fn=process_img_fn,
)
train_loader = torch.utils.data.DataLoader(
    train_episode_dataset, batch_size=batch_size, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_episode_dataset, batch_size=batch_size, shuffle=False
)

In [55]:
obs_history, pred_horizon = next(iter(train_loader))
for k, v in obs_history.items():
    print(f"obs_history['{k}']", v.shape)
for k, v in pred_horizon.items():
    print(f"pred_horizon['{k}']", v.shape)

# Models

In [24]:
class InverseDynamicsMLP(nn.Module):
    def __init__(
        self, n_obs: int, obs_dim: int, action_dim: int, hidden_dims: list[int]
    ):
        super().__init__()

        layers = []
        in_dim = n_obs * obs_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU())
            in_dim = hidden_dim
        layers.append(nn.Linear(in_dim, action_dim))
        layers.append(nn.Tanh())
        self.model = nn.Sequential(*layers)

    def forward(self, obs_history: torch.Tensor) -> torch.Tensor:
        x = obs_history.flatten(start_dim=1)
        x = self.model(x)
        return x

In [25]:
class InverseDynamicsCNN(nn.Module):
    def __init__(
        self,
        in_channels: int,
        in_height: int,
        in_width: int,
        action_dim: int,
        n_obs_history: int,
        hidden_dims: list[int] = None,
    ):
        super(InverseDynamicsCNN, self).__init__()

        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        modules = []
        kernel_size = 3
        stride = 2
        padding = 1
        dilation = 1
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels,
                        out_channels=h_dim,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                        dilation=dilation,
                    ),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU(),
                )
            )
            in_channels = h_dim

        # Define the initial part of the CNN that processes individual images
        self.conv_branch = nn.Sequential(
            *modules,
            # Flatten the output for the dense layers
            nn.Flatten(),
        )

        # Compute the shape of the output of the convolutional branch before it
        # is flattened and passed through the dense layers.
        conv_out_shape = compute_conv_output_shape(
            H=in_height,
            W=in_width,
            padding=padding,
            stride=stride,
            kernel_size=kernel_size,
            dilation=dilation,
            num_layers=len(hidden_dims),
            last_hidden_dim=hidden_dims[-1],
        )
        conv_out_size = np.prod(conv_out_shape)

        # Define the part of the network that combines features and predicts the action
        self.action_predictor = nn.Sequential(
            nn.Linear(n_obs_history * conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Tanh(),
        )

    def forward(self, obs_history: torch.Tensor):
        """Forward pass.

        Args:
            obs_history (torch.Tensor): Observation history of shape (batch_size, n_obs, C, H, W).
        """
        # List to hold the features extracted from each image in the observation history
        features_list = []

        # Iterate over the second dimension (n_obs) of the obs_history tensor
        for i in range(obs_history.size(1)):
            # Extract the i-th image from the observation history
            img_i = obs_history[:, i]

            # Process the image through the convolutional branch
            img_i_features = self.conv_branch(img_i)

            # Append the features to the list
            features_list.append(img_i_features)

        # Concatenate the features from all images along the feature dimension (dim=1)
        combined_features = torch.cat(features_list, dim=1)

        # Predict the action from the combined features
        action_pred = self.action_predictor(combined_features)
        return action_pred

# Data normalizers

In [29]:
# Sanity check image normalization
img = dataset.replay_buffer.get_episode(0)["img"].transpose(0, 3, 1, 2)
encoded_img = encode_images(img, img_normalizer, vae, device)
print("img.shape:", img.shape)
print("encoded_img.shape:", encoded_img.shape)

print("encoded_img.mean():", encoded_img.mean())
print("encoded_img.std():", encoded_img.std())
normalized_encoded_img = latent_img_normalizer(encoded_img)
print("normalized_encoded_img.mean():", normalized_encoded_img.mean())
print("normalized_encoded_img.std():", normalized_encoded_img.std())

# Sanity check action normalization
action = dataset.replay_buffer.get_episode(0)["action"]
normalized_action = action_normalizer(action)
print("action.min():", action.min())
print("action.max():", action.max())
print("normalized_action.min():", normalized_action.min())
print("normalized_action.max():", normalized_action.max())

# Sanity check state normalization
state = dataset.replay_buffer.get_episode(0)["state"]
state_normalizer = functools.partial(normalize_state, min_state=min_state, max_state=max_state)
normalized_state = state_normalizer(state)
print("state.min():", state.min(axis=0))
print("state.max():", state.max(axis=0))
print("normalized_state.min():", normalized_state.min(axis=0))
print("normalized_state.max():", normalized_state.max(axis=0))

# Training and eval

In [145]:
def train_epochs(
    model,
    train_loader,
    val_loader,
    obs_normalizer,
    action_normalizer,
    obs_key: str = "img",
    opt_kwargs: Optional[dict] = None,
    num_epochs=10,
    log_freq: Optional[int] = None,
    save_freq=2,
    save_dir: Optional[str] = None,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.MSELoss()
    opt_kwargs = opt_kwargs or {}
    optimizer = torch.optim.Adam(model.parameters(), **opt_kwargs)

    train_losses = []
    test_losses = [
        eval(
            model,
            val_loader,
            obs_normalizer,
            action_normalizer,
            criterion,
            device,
            obs_key=obs_key,
        )
    ]
    with trange(num_epochs, desc="Epoch") as tepoch:
        for epoch in tepoch:
            model.train()
            with tqdm(train_loader, desc="Batch") as tbatch:
                # Prediction horizon is unused.
                for i, (obs_history, _) in enumerate(tbatch):
                    obs = obs_history[obs_key]
                    # The second-to-last action is the target action because it was
                    # applied to get the last image.
                    action = obs_history["action"][:, -2]

                    # Normalize image and action.
                    obs = obs_normalizer(obs)
                    action = action_normalizer(action)
                    # assert obs.min() >= -1 and obs.max() <= 1
                    # assert action.min() >= -1 and action.max() <= 1

                    obs = obs.to(device)
                    action = action.to(device)

                    optimizer.zero_grad()
                    action_pred = model(obs)
                    loss = criterion(action_pred, action)
                    loss.backward()
                    optimizer.step()

                    loss_cpu = loss.item()
                    train_losses.append(loss_cpu)

                    tbatch.set_postfix(loss=loss_cpu)
                    if log_freq is not None and (i % log_freq == 0):
                        print(f"Epoch {epoch}, Batch {i}, Train Loss: {loss_cpu}")

            # Eval
            test_loss = eval(
                model,
                val_loader,
                obs_normalizer,
                action_normalizer,
                criterion,
                device,
                obs_key=obs_key,
            )
            test_losses.append(test_loss)
            tepoch.set_postfix(test_loss=test_loss)

            # Save
            if save_dir is not None and (
                epoch % save_freq == 0 or epoch == num_epochs - 1
            ):
                epoch_str = "final" if epoch == num_epochs - 1 else str(epoch)
                torch.save(
                    model.state_dict(),
                    os.path.join(save_dir, f"inverse_dynamics_{epoch_str}.pt"),
                )

    return train_losses, test_losses


def eval(
    model,
    val_loader,
    obs_normalizer,
    action_normalizer,
    criterion,
    device,
    obs_key: str = "img",
):
    model.eval()
    test_losses = []
    with torch.no_grad():
        for obs_history, _ in val_loader:
            obs = obs_history[obs_key]
            # The second-to-last action is the target action because it was
            # applied to get the last image.
            action = obs_history["action"][:, -2]

            obs = obs_normalizer(obs).to(device)
            action = action_normalizer(action).to(device)

            action_pred = model(obs)
            loss = criterion(action_pred, action)
            # Multiply the loss by the number of samples in the batch.
            test_losses.append(loss.item() * obs.shape[0])

    # Compute the average loss across all batches.
    test_loss = np.sum(test_losses) / len(val_loader.dataset)
    return test_loss


def plot_losses(train_losses, test_losses):
    # Plot train and test losses.
    plt.figure(figsize=(12, 6))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(
        np.linspace(0, len(train_losses), len(test_losses)),
        test_losses,
        label="Test Loss",
    )
    # Remove outliers for better visualization
    # plt.ylim(0, 0.01)
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

# State history to action

In [161]:
# Make train and val loaders
# val_mask = get_val_mask(dataset.replay_buffer.n_episodes, 0.1)
# val_idxs = np.where(val_mask)[0]
# train_idxs = np.where(~val_mask)[0]

# Debug train/validation split
train_idxs = [0]
val_idxs = [1]

# Make the episode dataset and create a DataLoader.
batch_size = 256
n_obs_history = 10
n_pred_horizon = 0
process_img_fn = None
train_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=train_idxs,
    process_img_fn=process_img_fn,
    device="cuda",
)
val_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=val_idxs,
    process_img_fn=process_img_fn,
    device="cuda",
)
train_loader = torch.utils.data.DataLoader(
    train_episode_dataset, batch_size=batch_size, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_episode_dataset, batch_size=batch_size, shuffle=False
)

In [162]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_normalizer = functools.partial(
    normalize_pn1,
    min_val=torch.tensor(min_state, dtype=torch.float32).to(device),
    max_val=torch.tensor(max_state, dtype=torch.float32).to(device),
)
action_normalizer = functools.partial(
    normalize_pn1,
    min_val=torch.tensor(min_action, dtype=torch.float32).to(device),
    max_val=torch.tensor(max_action, dtype=torch.float32).to(device),
)
# obs_dim is the state dimension.
id_model = InverseDynamicsMLP(
    n_obs=n_obs_history, obs_dim=5, action_dim=2, hidden_dims=[256, 256, 256]
).to(device)
train_losses, test_losses = train_epochs(
    id_model,
    train_loader,
    val_loader,
    state_normalizer,
    action_normalizer,
    obs_key="state",
    opt_kwargs={"lr": 1e-3, "weight_decay": 0},
    num_epochs=100,
    log_freq=None,
    save_freq=2,
    save_dir=None,
)

In [163]:
plot_losses(train_losses, test_losses)

In [126]:
batch = next(iter(val_loader))

In [127]:
eval(
    id_model,
    val_loader,
    state_normalizer,
    action_normalizer,
    nn.MSELoss(),
    device,
    obs_key="state",
)

In [128]:
states = batch[0]["state"]
norm_states = state_normalizer(states)

actions = batch[0]["action"]
target_action = actions[:, -2]
norm_actions = action_normalizer(actions)
norm_target_action = norm_actions[:, -2]

norm_pred_action = id_model(norm_states)

In [130]:
actions.shape, target_action.shape, norm_target_action.shape, norm_pred_action.shape

In [131]:
F.mse_loss(norm_pred_action, norm_target_action, reduction="mean")

In [133]:
F.mse_loss(
    denormalize_pn1(
        norm_pred_action,
        torch.tensor(min_action).cuda(),
        torch.tensor(max_action).cuda(),
    ),
    target_action,
    reduction="mean",
)

# Train MLP

In [None]:
# Make train and val loaders
val_mask = get_val_mask(dataset.replay_buffer.n_episodes, 0.1)
val_idxs = np.where(val_mask)[0]
train_idxs = np.where(~val_mask)[0]

# Make the episode dataset and create a DataLoader.
batch_size = 256
n_obs_history = 5
n_pred_horizon = 0
process_img_fn = functools.partial(
    encode_images, img_normalizer=img_normalizer, vae=vae, device=device
)
train_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=train_idxs,
    process_img_fn=process_img_fn,
)
val_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=val_idxs,
    process_img_fn=process_img_fn,
)
train_loader = torch.utils.data.DataLoader(
    train_episode_dataset, batch_size=batch_size, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_episode_dataset, batch_size=batch_size, shuffle=False
)

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
id_model = InverseDynamicsMLP(
    n_obs=n_obs_history, obs_dim=32, action_dim=2, hidden_dims=[256, 256, 256]
).to(device)
train_losses, test_losses = train_epochs(
    id_model,
    train_loader,
    val_loader,
    latent_img_normalizer,
    action_normalizer,
    opt_kwargs={"lr": 1e-3, "weight_decay": 1e-5},
    num_epochs=10,
    log_freq=None,
    save_freq=2,
    save_dir=None,
)

In [22]:
plot_losses(train_losses, test_losses)

# Train CNN

In [34]:
# Make train and val loaders
val_mask = get_val_mask(dataset.replay_buffer.n_episodes, 0.1)
val_idxs = np.where(val_mask)[0]
train_idxs = np.where(~val_mask)[0]

# Make the episode dataset and create a DataLoader.
batch_size = 256
n_obs_history = 2
n_pred_horizon = 0
process_img_fn = lambda img: img / 255.0
train_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=train_idxs,
    process_img_fn=process_img_fn,
)
val_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=val_idxs,
    process_img_fn=process_img_fn,
)
train_loader = torch.utils.data.DataLoader(
    train_episode_dataset, batch_size=batch_size, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_episode_dataset, batch_size=batch_size, shuffle=False
)

In [45]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dims = None
N, H, W, C = dataset.replay_buffer["img"].shape
N, action_dim = dataset.replay_buffer["action"].shape
cnn_id_model = InverseDynamicsCNN(
    C, H, W, action_dim, n_obs_history, hidden_dims=hidden_dims
).to(device)
img_normalizer = get_image_range_normalizer()
train_losses, test_losses = train_epochs(
    cnn_id_model,
    train_loader,
    val_loader,
    img_normalizer,
    action_normalizer,
    opt_kwargs={"lr": 1e-4},
    num_epochs=10,
    log_freq=None,
    save_freq=2,
    save_dir=None,
)

In [46]:
plot_losses(train_losses, test_losses)

In [50]:
# Make train and val loaders
val_mask = get_val_mask(dataset.replay_buffer.n_episodes, 0.1)
val_idxs = np.where(val_mask)[0]
train_idxs = np.where(~val_mask)[0]

# Make the episode dataset and create a DataLoader.
batch_size = 256
n_obs_history = 4
n_pred_horizon = 0
process_img_fn = lambda img: img / 255.0
train_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=train_idxs,
    process_img_fn=process_img_fn,
)
val_episode_dataset = EpisodeDataset(
    dataset,
    n_obs_history=n_obs_history,
    n_pred_horizon=n_pred_horizon,
    episode_idxs=val_idxs,
    process_img_fn=process_img_fn,
)
train_loader = torch.utils.data.DataLoader(
    train_episode_dataset, batch_size=batch_size, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_episode_dataset, batch_size=batch_size, shuffle=False
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dims = None
N, H, W, C = dataset.replay_buffer["img"].shape
N, action_dim = dataset.replay_buffer["action"].shape
cnn_id_model = InverseDynamicsCNN(
    C, H, W, action_dim, n_obs_history, hidden_dims=hidden_dims
).to(device)
img_normalizer = get_image_range_normalizer()
train_losses, test_losses = train_epochs(
    cnn_id_model,
    train_loader,
    val_loader,
    img_normalizer,
    action_normalizer,
    opt_kwargs={"lr": 1e-4},
    num_epochs=10,
    log_freq=None,
    save_freq=2,
    save_dir=None,
)

In [51]:
plot_losses(train_losses, test_losses)

In [96]:
# Evaluate the model on the validation set.
test_loss = eval(model, val_loader, criterion, device)
print("Final Test Loss:", test_loss)

In [97]:
# Get the true and predicted action for a test batch.
model.eval()
with torch.no_grad():
    cur_img, next_img, action = next(iter(val_loader))
    cur_img = img_normalizer(cur_img)
    next_img = img_normalizer(next_img)
    action = action_normalizer(action)

    cur_img = cur_img.to(device)
    next_img = next_img.to(device)
    action = action.to(device)

    action_pred = model(cur_img, next_img)

# Unnormalize the action prediction
unnormalized_action = action_normalizer.unnormalize(action)
unnormalized_action_pred = action_normalizer.unnormalize(action_pred)
normalized_mse = criterion(action, action_pred)
unnormalized_mse = criterion(unnormalized_action, unnormalized_action_pred)

print("Normalized MSE:", normalized_mse.item())
print("Unnormalized MSE:", unnormalized_mse.item())

In [98]:
(unnormalized_action[:5], unnormalized_action_pred[:5])

In [43]:
# Load the model.
model = InverseDynamicsCNN(C, H, W, action_dim).to(device)
model.load_state_dict(torch.load("inverse_dynamics_cnn_final.pt"))