In [1]:
import os
import sys
from typing import Optional

if "PyTorch_VAE" not in sys.path:
    sys.path.append("PyTorch_VAE")

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import trange

from PyTorch_VAE import models
from diffusion_policy.common.pytorch_util import compute_conv_output_shape
from diffusion_policy.dataset.pusht_image_dataset import PushTImageDataset
from vae.pusht_vae import VanillaVAE

In [2]:
path = "/nas/ucb/ebronstein/lsdp/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
# path = "/home/tsadja/data_diffusion/pusht/pusht_cchi_v7_replay.zarr"
path = '/home/matteogu/ssd_data/data_diffusion/pusht/pusht_cchi_v7_replay.zarr'

dataset = PushTImageDataset(path)
full_dataset = torch.from_numpy(dataset.replay_buffer["img"]).permute(0, 3, 1, 2)

In [3]:
def normalize(data):
    data /= 255.0
    data = 2 * data - 1
    return data


def unnormalize(data):
    data = (data + 1) / 2
    data *= 255
    return data

In [4]:
class VanillaVAE(models.VanillaVAE):

    def __init__(
        self,
        in_channels: int,
        in_height: int,
        in_width: int,
        latent_dim: int,
        hidden_dims: Optional[list] = None,
        **kwargs
    ) -> None:
        models.BaseVAE.__init__(self)

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        kernel_size = 3
        stride = 2
        padding = 1
        dilation = 1
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels,
                        out_channels=h_dim,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                        dilation=dilation,
                    ),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU(),
                )
            )
            in_channels = h_dim

        self.conv_out_shape = compute_conv_output_shape(
            H=in_height,
            W=in_width,
            padding=padding,
            stride=stride,
            kernel_size=kernel_size,
            dilation=dilation,
            num_layers=len(hidden_dims),
            last_hidden_dim=hidden_dims[-1],
        )
        conv_out_size = np.prod(self.conv_out_shape)

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(conv_out_size, latent_dim)
        self.fc_var = nn.Linear(conv_out_size, latent_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, conv_out_size)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        hidden_dims[i],
                        hidden_dims[i + 1],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=1,
                    ),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU(),
                )
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(
                hidden_dims[-1],
                hidden_dims[-1],
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
            ),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
            nn.Tanh(),
        )

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, *self.conv_out_shape)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

In [4]:
from vae.pusht_vae import VanillaVAE

In [5]:
full_dataset = normalize(full_dataset)
N, C, H, W = full_dataset.shape
train_split = 0.8
train_size = int(train_split * N)
val_size = N - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)

In [6]:
latent_dim = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VanillaVAE(in_channels=3, in_height=H, in_width=W, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
epochs = 20
train_losses, val_losses = [], []

for epoch in trange(epochs):
    total_train_loss = 0
    model.train()
    for i, x in enumerate(train_loader):
        x = x.to(device)
        result = model(x)
        loss = model.loss_function(*result, M_N=1e-6)["loss"]
        # loss = loss['loss']
        total_train_loss += loss.item()
        train_losses.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Train loss: {total_train_loss / len(train_loader):.4f}")

    total_val_loss = 0
    model.eval()
    with torch.no_grad():
        for i, x in enumerate(val_loader):
            x = x.to(device)
            result = model(x)
            loss = model.loss_function(*result, M_N=1e-6)["loss"]
            total_val_loss += loss.item()
    val_losses.append(total_val_loss / len(val_loader))
    print(f"Validation loss: {val_losses[-1]:.4f}")


In [None]:
def show_reconstructions(model: VanillaVAE, val_loader: torch.utils.data.DataLoader, save_fig: bool = False):
    val_data = next(iter(val_loader))
    num_samples = 5
    val_data = val_data.to(device)
    result = model(val_data)
    recon = result[0]
    recon = unnormalize(recon)
    val_data = unnormalize(val_data)

    fig, ax = plt.subplots(2, num_samples, figsize=(num_samples*2, 6))
    # fig.set_size_inches(10, 10)
    for ii in range(num_samples):
        ax[0, ii].imshow(
            val_data[ii].permute(1, 2, 0).cpu().detach().numpy().astype(np.uint8)
        )
        ax[1, ii].imshow(
            recon[ii].permute(1, 2, 0).cpu().detach().numpy().astype(np.uint8)
        )
        ax[0, ii].axis('off')
        ax[1, ii].axis('off')
        
    # plt.suptitle("Reconstructions")
    ax[0, 0].set_title('Ground Truth')
    ax[1, 0].set_title('Reconstruction')
    plt.tight_layout()
    if save_fig: plt.savefig(f'figs/pusht_vae/reconstructions_{latent_dim}.png')
    plt.show()
    
show_reconstructions(model, val_loader, save_fig=True)

In [None]:
# plt.plot(train_losses)
# plt.plot(val_losses)
def plot_losses(train_losses, test_losses):
    # Plot train and test losses.
    plt.figure(figsize=(12, 6))
    plt.plot(train_losses, label="Train Loss")
    plt.semilogy(
        np.linspace(0, len(train_losses), len(test_losses)),
        test_losses,
        label="Test Loss",
    )
    # Remove outliers for better visualization
    # plt.ylim(0, 0.01)
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.title(f'[Latent {latent_dim}] Final Test loss: {test_losses[-1]:.4f}')
    plt.tight_layout()
    plt.show()
    
plot_losses(train_losses, val_losses)

In [None]:
with open(f'{save_dir}/losses/losses_{latent_dim}_{now}.npy', 'rb') as f:
    train_losses_l = np.load(f)
    val_losses_l = np.load(f)


# Plots for the report

In [None]:
import os 
loss_path = 'models/pusht_vae/losses/'
exps = os.listdir(loss_path)
exps.sort()
plt.figure(figsize=(12, 6))
for exp in exps:
    with open(f'{loss_path}{exp}', 'rb') as f:
        _train_losses_ = np.load(f)
        _val_losses = np.load(f)
    
    _latent_dim = int(exp.split('_')[1])
    plt.semilogy(_train_losses_, label=f"[{_latent_dim}] Train Loss")    
    plt.semilogy(
        np.linspace(0, len(_train_losses_), len(_val_losses)),
        _val_losses,
        label=f"[{_latent_dim}] Test Loss, final_value: {_val_losses[-1]:.4f}",
    )
    # Remove outliers for better visualization
    # plt.ylim(0, 0.01)
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.title(f'Test loss vs Latent Dimension')
plt.tight_layout()
plt.savefig('Loss_vs_latent.png')
plt.show()

In [7]:
latent_dim = 32
# Load the VAE
model = VanillaVAE(in_channels=3, in_height=H, in_width=W, latent_dim=latent_dim).to(device)
save_dir = "models/pusht_vae"
model.load_state_dict(torch.load(os.path.join(save_dir, "vae_32_20240403.pt")))

<All keys matched successfully>

In [11]:
# Encode the full dataset
model.eval()
with torch.no_grad():
    mu, log_var = model.encode(full_dataset.to(device))
    mu = mu.cpu().detach().numpy()
    log_var = log_var.cpu().detach().numpy()

In [12]:
mu.shape, log_var.shape

((25650, 32), (25650, 32))