In [19]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm, trange

from diffusion_policy.common.normalize_util import get_image_range_normalizer
from diffusion_policy.common.pytorch_util import compute_conv_output_shape
from diffusion_policy.common.sampler import get_val_mask
from diffusion_policy.dataset.pusht_image_dataset import PushTImageDataset
from diffusion_policy.model.common.normalizer import LinearNormalizer

In [2]:
path = "/nas/ucb/ebronstein/lsdp/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
dataset = PushTImageDataset(path)

In [3]:
episode = dataset.replay_buffer.get_episode(0)
episode.keys()
print("episode['img']:", episode['img'].shape)
print("episode['action']:", episode['action'].shape)

In [4]:
episode_len = episode["img"].shape[0]
time_steps = np.linspace(0, episode_len - 1, 25).astype(int)
# Plot the first few images and actions in the episode
fig, axs = plt.subplots(5, 5, figsize=(20, 20))
for i, step in enumerate(time_steps):
    ax = axs[i // 5, i % 5]
    ax.imshow(episode["img"][step] / 255.0)
    ax.set_title(episode["action"][step])
    ax.axis("off")

In [30]:
class EpisodeDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, episode_idxs=None):
        """
        Initialize the dataset with the main dataset object that contains
        the replay_buffer.
        """
        self.dataset = dataset
        self.episode_idxs = episode_idxs
        self.prepare_data()

    def prepare_data(self):
        """
        Preprocess the episodes to create a flat list of samples.
        Each sample is a tuple: (current_image, next_image, action).
        """
        self.samples = []

        if self.episode_idxs is None:
            self.episode_idxs = range(self.dataset.replay_buffer.n_episodes)

        for episode_idx in self.episode_idxs:
            episode = self.dataset.replay_buffer.get_episode(episode_idx)
            img = episode["img"] / 255.0  # Normalize the images to [0, 1]
            actions = episode["action"]

            # Ensure there is a next image for each current image
            assert len(img) == len(actions)

            for i in range(len(actions) - 1):
                self.samples.append((img[i], img[i + 1], actions[i]))

    def __len__(self):
        """
        Return the total number of samples across all episodes.
        """
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Return the idx-th sample from the dataset.
        """
        current_img, next_img, action = self.samples[idx]

        # Convert data to PyTorch tensors and ensure the data type is correct
        current_img_tensor = torch.tensor(current_img, dtype=torch.float32).permute(
            2, 0, 1
        )  # Convert HWC to CHW
        next_img_tensor = torch.tensor(next_img, dtype=torch.float32).permute(
            2, 0, 1
        )  # Convert HWC to CHW
        action_tensor = torch.tensor(action, dtype=torch.float32)

        return current_img_tensor, next_img_tensor, action_tensor

In [31]:
val_mask = get_val_mask(dataset.replay_buffer.n_episodes, 0.1)
val_idxs = np.where(val_mask)[0]
train_idxs = np.where(~val_mask)[0]

# Make the episode dataset and create a DataLoader.
batch_size = 256
train_episode_dataset = EpisodeDataset(dataset, episode_idxs=train_idxs)
val_episode_dataset = EpisodeDataset(dataset, episode_idxs=val_idxs)
train_loader = torch.utils.data.DataLoader(
    train_episode_dataset, batch_size=batch_size, shuffle=True, num_workers=1
)
val_loader = torch.utils.data.DataLoader(
    val_episode_dataset, batch_size=batch_size, shuffle=False, num_workers=1
)

cur_img, next_img, action = next(iter(train_loader))
print("cur_img.shape:", cur_img.shape)
print("next_img.shape:", next_img.shape)
print("action.shape:", action.shape)

In [7]:
# Make the action and image normalizers.
action_normalizer = LinearNormalizer()
action_normalizer.fit(
    data=dataset.replay_buffer["action"], last_n_dims=1, mode="limits"
)
img_normalizer = get_image_range_normalizer()

In [8]:
# Normalize image and action to [-1, 1].
cur_img_normalized = img_normalizer(cur_img)
next_img_normalized = img_normalizer(next_img)
action_normalized = action_normalizer(action)

assert cur_img_normalized.min() >= -1 and cur_img_normalized.max() <= 1
assert next_img_normalized.min() >= -1 and next_img_normalized.max() <= 1
assert action_normalized.min() >= -1 and action_normalized.max() <= 1

In [90]:
class InverseDynamicsCNN(nn.Module):
    def __init__(
        self,
        in_channels: int,
        in_height: int,
        in_width: int,
        action_dim: int,
        hidden_dims: list[int] = None,
    ):
        super(InverseDynamicsCNN, self).__init__()

        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        modules = []
        kernel_size = 3
        stride = 2
        padding = 1
        dilation = 1
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels,
                        out_channels=h_dim,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                        dilation=dilation,
                    ),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU(),
                )
            )
            in_channels = h_dim

        # Define the initial part of the CNN that processes individual images
        self.conv_branch = nn.Sequential(
            *modules,
            # Flatten the output for the dense layers
            nn.Flatten(),
        )

        # Compute the shape of the output of the convolutional branch before it
        # is flattened and passed through the dense layers.
        conv_out_shape = compute_conv_output_shape(
            H=in_height,
            W=in_width,
            padding=padding,
            stride=stride,
            kernel_size=kernel_size,
            dilation=dilation,
            num_layers=len(hidden_dims),
            last_hidden_dim=hidden_dims[-1],
        )
        conv_out_size = np.prod(conv_out_shape)

        # Define the part of the network that combines features and predicts the action
        self.action_predictor = nn.Sequential(
            nn.Linear(2 * conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Tanh(),
        )

    def forward(self, image1, image2):
        # Process each image through the same convolutional branch
        img1_features = self.conv_branch(image1)
        img2_features = self.conv_branch(image2)

        # Concatenate the features from both images
        combined_features = torch.cat((img1_features, img2_features), dim=1)

        # Predict the action from the combined features
        action_pred = self.action_predictor(combined_features)
        return action_pred

In [91]:
def eval(model, val_loader, criterion, device):
    model.eval()
    test_losses = []
    with torch.no_grad():
        for cur_img, next_img, action in val_loader:
            cur_img = img_normalizer(cur_img)
            next_img = img_normalizer(next_img)
            action = action_normalizer(action)

            cur_img = cur_img.to(device)
            next_img = next_img.to(device)
            action = action.to(device)

            action_pred = model(cur_img, next_img)
            loss = criterion(action_pred, action)
            # Multiply the loss by the number of samples in the batch.
            test_losses.append(loss.item() * len(cur_img))

    # Compute the average loss across all batches.
    test_loss = np.sum(test_losses) / len(val_loader.dataset)
    return test_loss

In [92]:
N, H, W, C = dataset.replay_buffer["img"].shape
N, action_dim = dataset.replay_buffer["action"].shape

val_mask = get_val_mask(dataset.replay_buffer.n_episodes, 0.1)
val_idxs = np.where(val_mask)[0]
train_idxs = np.where(~val_mask)[0]

# Make the episode dataset and create a DataLoader.
batch_size = 256
train_episode_dataset = EpisodeDataset(dataset, episode_idxs=train_idxs)
val_episode_dataset = EpisodeDataset(dataset, episode_idxs=val_idxs)
train_loader = torch.utils.data.DataLoader(
    train_episode_dataset, batch_size=batch_size, shuffle=True, num_workers=1
)
val_loader = torch.utils.data.DataLoader(
    val_episode_dataset, batch_size=batch_size, shuffle=False, num_workers=1
)

In [93]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dims = None
model = InverseDynamicsCNN(C, H, W, action_dim, hidden_dims=hidden_dims).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 10
log_freq = 10
save_freq = 2
train_losses = []
test_losses = []
for epoch in trange(num_epochs, desc="Epoch"):
    model.train()
    for i, (cur_img, next_img, action) in enumerate(
        tqdm(train_loader, desc="Batch", leave=False)
    ):
        # Normalize image and action to [-1, 1].
        cur_img = img_normalizer(cur_img)
        next_img = img_normalizer(next_img)
        action = action_normalizer(action)

        cur_img = cur_img.to(device)
        next_img = next_img.to(device)
        action = action.to(device)

        optimizer.zero_grad()
        action_pred = model(cur_img, next_img)
        loss = criterion(action_pred, action)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())

        if i % log_freq == 0:
            print(f"Epoch {epoch}, Batch {i}, Train Loss: {loss.item()}")

    # Eval
    test_loss = eval(model, val_loader, criterion, device)
    test_losses.append(test_loss)
    print(f"Epoch {epoch}, Test Loss: {test_loss}")

    # Save
    if epoch % save_freq == 0 or epoch == num_epochs - 1:
        epoch_str = "final" if epoch == num_epochs - 1 else str(epoch)
        torch.save(model.state_dict(), f"inverse_dynamics_cnn_{epoch_str}.pt")

In [95]:
# Plot train and test losses.
plt.figure(figsize=(12, 6))
plt.plot(train_losses, label="Train Loss")
plt.plot(np.linspace(0, len(train_losses), len(test_losses)), test_losses, label="Test Loss")
# Remove outliers for better visualization
plt.ylim(0, 0.01)
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [96]:
# Evaluate the model on the validation set.
test_loss = eval(model, val_loader, criterion, device)
print("Final Test Loss:", test_loss)

In [97]:
# Get the true and predicted action for a test batch.
model.eval()
with torch.no_grad():
    cur_img, next_img, action = next(iter(val_loader))
    cur_img = img_normalizer(cur_img)
    next_img = img_normalizer(next_img)
    action = action_normalizer(action)

    cur_img = cur_img.to(device)
    next_img = next_img.to(device)
    action = action.to(device)

    action_pred = model(cur_img, next_img)

# Unnormalize the action prediction
unnormalized_action = action_normalizer.unnormalize(action)
unnormalized_action_pred = action_normalizer.unnormalize(action_pred)
normalized_mse = criterion(action, action_pred)
unnormalized_mse = criterion(unnormalized_action, unnormalized_action_pred)

print("Normalized MSE:", normalized_mse.item())
print("Unnormalized MSE:", unnormalized_mse.item())

In [98]:
(unnormalized_action[:5], unnormalized_action_pred[:5])

In [43]:
# Load the model.
model = InverseDynamicsCNN(C, H, W, action_dim).to(device)
model.load_state_dict(torch.load("inverse_dynamics_cnn_final.pt"))