In [1]:
import os
import sys
import datetime
import functools
import math
import time
from typing import Optional

if "PyTorch_VAE" not in sys.path:
    sys.path.append("PyTorch_VAE")

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import trange

from PyTorch_VAE import models
from diffusion_policy.common.pytorch_util import compute_conv_output_shape
from diffusion_policy.dataset.pusht_image_dataset import PushTImageDataset

from lsdp_utils.VanillaVAE import VanillaVAE
from lsdp_utils.utils import bcolors

from types import SimpleNamespace


In [3]:

cfg = SimpleNamespace(dataset_path='/home/matteogu/ssd_data/data_diffusion/pusht/pusht_cchi_v7_replay.zarr',
                      # vae_model_path='/home/matteogu/Desktop/prj_deepul/repo_online/lsdp/models/pusht_vae/vae_32_20240403.pt',
                      vae_save_dir='/home/matteogu/ssd_data/diffusion_models/models/diffusion/vae/',

                      latent_dim=512,
                      hidden_dims=[32, 64, 128, 256, 512],
                      batch_size=128,
                      down_dims=[1024, 2048],  # 512, 1024,
                      train_split=0.8,
                      lr=1e-3,  # optimization params
                      kld_weight=1e-7,
                      epochs=100,
                      device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                      )


In [4]:

dataset = PushTImageDataset(cfg.dataset_path)
full_dataset = torch.from_numpy(dataset.replay_buffer["img"]).permute(0, 3, 1, 2)

str_hidden = str(cfg.hidden_dims)[1:-1].replace(", ", "_")
name = (f'pusht_vae_klw_{cfg.kld_weight:.2e}_ldim_{cfg.latent_dim}_'
        f'bs_{cfg.batch_size}_epochs_{cfg.epochs}_lr_{cfg.lr}_kld_{cfg.kld_weight:.2f}_hdim_{str_hidden}')
save_dir = f'{cfg.vae_save_dir}{name}'
os.makedirs(save_dir)
print(f"{bcolors.OKGREEN} ---------------------- {bcolors.ENDC}")
print(f"{bcolors.OKGREEN}   {name}   {bcolors.ENDC}")
print(f"{bcolors.OKGREEN} ---------------------- {bcolors.ENDC}")


def normalize(data):
    data /= 255.0
    data = 2 * data - 1
    return data


def unnormalize(data):
    data = (data + 1) / 2
    data *= 255
    return data


full_dataset = normalize(full_dataset)
N, C, H, W = full_dataset.shape

train_size = int(cfg.train_split * N)
val_size = N - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset.to(cfg.device), batch_size=cfg.batch_size, shuffle=True)


In [7]:
val_dataset.dataset.to(cfg.device)

In [6]:
from vae.pusht_vae import VanillaVAE  # same VAE of lsdp_utils/VanillaVAE 

In [9]:
epochs = 20
train_losses, val_losses = [], []
train_recons_losses, train_kld_losses = [], []
val_recons_losses, val_kld_losses = [], []

kld_weight = 1e-7

for epoch in trange(epochs):
    total_train_loss = 0
    model.train()
    for i, x in enumerate(train_loader):
        x = x.to(device)
        result = model(x)
        loss = model.loss_function(*result, M_N=kld_weight)  # ["loss"]
         # {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}
        total_train_loss += loss['loss'].item()
        
        train_losses.append(loss['loss'].item())
        train_recons_losses.append(loss['Reconstruction_Loss'].item())
        train_kld_losses.append(loss['KLD'].item())
        
        optimizer.zero_grad()
        loss['loss'].backward()
        optimizer.step()
        
    print(f"Train loss: {total_train_loss / len(train_loader):.4f}")

    total_val_loss = 0
    model.eval()
    with torch.no_grad():
        for i, x in enumerate(val_loader):
            x = x.to(device)
            result = model(x)
            loss = model.loss_function(*result, M_N=kld_weight)
            
            val_recons_losses.append(loss['Reconstruction_Loss'].item())
            val_kld_losses.append(loss['KLD'].item())
        
        
            total_val_loss += loss["loss"].item()
    val_losses.append(total_val_loss / len(val_loader))
    print(f"Validation loss: {val_losses[-1]:.4f}")


In [None]:
def show_reconstructions(model: VanillaVAE, val_loader: torch.utils.data.DataLoader, save_fig: bool = False):
    val_data = next(iter(val_loader))
    num_samples = 5
    val_data = val_data.to(device)
    result = model(val_data)
    recon = result[0]
    recon = unnormalize(recon)
    val_data = unnormalize(val_data)

    fig, ax = plt.subplots(2, num_samples, figsize=(num_samples*2, 6))
    # fig.set_size_inches(10, 10)
    for ii in range(num_samples):
        ax[0, ii].imshow(
            val_data[ii].permute(1, 2, 0).cpu().detach().numpy().astype(np.uint8)
        )
        ax[1, ii].imshow(
            recon[ii].permute(1, 2, 0).cpu().detach().numpy().astype(np.uint8)
        )
        ax[0, ii].axis('off')
        ax[1, ii].axis('off')
        
    # plt.suptitle("Reconstructions")
    ax[0, 0].set_title('Ground Truth')
    ax[1, 0].set_title('Reconstruction')
    plt.tight_layout()
    if save_fig: plt.savefig(f'figs/pusht_vae/reconstructions_{latent_dim}.png')
    plt.show()
    
show_reconstructions(model, val_loader, save_fig=True)

In [None]:
# plt.plot(train_losses)
# plt.plot(val_losses)
def plot_losses(train_losses, test_losses):
    # Plot train and test losses.
    plt.figure(figsize=(12, 6))
    plt.plot(train_losses, label="Train Loss")
    plt.semilogy(
        np.linspace(0, len(train_losses), len(test_losses)),
        test_losses,
        label="Test Loss",
    )
    # Remove outliers for better visualization
    # plt.ylim(0, 0.01)
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.title(f'[Latent {latent_dim}] Final Test loss: {test_losses[-1]:.4f}')
    plt.tight_layout()
    plt.show()
    
plot_losses(train_losses, val_losses)

In [None]:
with open(f'{save_dir}/losses/losses_{latent_dim}_{now}.npy', 'rb') as f:
    train_losses_l = np.load(f)
    val_losses_l = np.load(f)


# Plots for the report

In [None]:
import os 
loss_path = 'models/pusht_vae/losses/'
exps = os.listdir(loss_path)
exps.sort()
plt.figure(figsize=(12, 6))
for exp in exps:
    with open(f'{loss_path}{exp}', 'rb') as f:
        _train_losses_ = np.load(f)
        _val_losses = np.load(f)
    
    _latent_dim = int(exp.split('_')[1])
    plt.semilogy(_train_losses_, label=f"[{_latent_dim}] Train Loss")    
    plt.semilogy(
        np.linspace(0, len(_train_losses_), len(_val_losses)),
        _val_losses,
        label=f"[{_latent_dim}] Test Loss, final_value: {_val_losses[-1]:.4f}",
    )
    # Remove outliers for better visualization
    # plt.ylim(0, 0.01)
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.title(f'Test loss vs Latent Dimension')
plt.tight_layout()
plt.savefig('Loss_vs_latent.png')
plt.show()

In [7]:
latent_dim = 32
# Load the VAE
model = VanillaVAE(in_channels=3, in_height=H, in_width=W, latent_dim=latent_dim).to(device)
save_dir = "models/pusht_vae"
model.load_state_dict(torch.load(os.path.join(save_dir, "vae_32_20240403.pt")))

In [11]:
# Encode the full dataset
model.eval()
with torch.no_grad():
    mu, log_var = model.encode(full_dataset.to(device))
    mu = mu.cpu().detach().numpy()
    log_var = log_var.cpu().detach().numpy()

In [12]:
mu.shape, log_var.shape