In [None]:
cfg = SimpleNamespace(dataset_path='/home/matteogu/ssd_data/data_diffusion/pusht/pusht_cchi_v7_replay.zarr',
                      # vae_model_path='/home/matteogu/Desktop/prj_deepul/repo_online/lsdp/models/pusht_vae/vae_32_20240403.pt',
                      vae_save_dir='/home/matteogu/ssd_data/diffusion_models/models/vae/',

                      kld_weight=1e-7,
                      latent_dim=128,
                      hidden_dims=[32, 64, 128, 256, 512],

                      train_split=0.8,

                      batch_size=512,
                      lr=5e-4,  # optimization params

                      epochs=100,) 


In [None]:
import collections
import copy
import datetime
import functools
import math
import os
import sys
import time
from typing import Callable, Optional

if "PyTorch_VAE" not in sys.path:
    sys.path.append("PyTorch_VAE")

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from tqdm.notebook import tqdm, trange

import wandb
from diffusion_policy.common.pytorch_util import compute_conv_output_shape
from diffusion_policy.common.sampler import get_val_mask
from diffusion_policy.dataset.pusht_image_dataset import PushTImageDataset
from diffusion_policy.model.diffusion import conditional_unet1d
from ema import EMAHelper

# Custom imports
from PyTorch_VAE import models
from lsdp_utils.Diffusion import Diffusion
from lsdp_utils.VanillaVAE import VanillaVAE
from lsdp_utils.EpisodeDataset import EpisodeDataset, EpisodeDataloaders
from lsdp_utils.utils import plot_losses, plot_samples

# model trained with the mlp to convert latents into states
# pusht_unet1d_img_1024_2048_edim_256obs_0_pred_4_bs_64_lr_0.001_e_300/2024-04-30_12-22-07

In [None]:
import torch
import glob
import os

from types import SimpleNamespace
cfg = SimpleNamespace(dataset_path='/home/matteogu/ssd_data/data_diffusion/pusht/pusht_cchi_v7_replay.zarr',
                      # vae_model_path='/nas/ucb/ebronstein/lsdp/models/pusht_vae/vae_32_20240403.pt',
                      vae_model_path='/home/matteogu/Desktop/prj_deepul/repo_online/lsdp/models/pusht_vae/vae_32_20240403.pt',
                      save_dir='/home/matteogu/ssd_data/diffusion_models/models/diffusion/',
                      batch_size=4096,  # 3.8 Giga for state, better 512 for latents
                      n_obs_history=8,
                      n_pred_horizon=8,
                      down_dims=[256, 512, 1024],
                      diffusion_step_embed_dim=128,  # in the original paper was 256
                      lr=3e-4,  # optimization params
                      epochs=200,
                      device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                      obs_key="img",
                      )

In [None]:

cfg_dict

In [None]:
list_of_files = glob.glob(f'{cfg.save_dir}*')
latest_file = max(list_of_files, key=os.path.getctime)
print(latest_file)
list_of_files = glob.glob(f'{latest_file}/*')
latest_file = max(list_of_files, key=os.path.getctime)
print(latest_file)


In [None]:
#

load_dir2 = f'{cfg.save_dir}pusht_unet1d_img_256_512_1024_obs_8_pred_8/2024-04-29_18-59-25'
# # load_dir = f'{cfg.save_dir}pusht_unet1d_state_256_512_1024_obs_8_pred_8/2024-04-29_19-20-23'
# load_dir1 = f'{cfg.save_dir}pusht_unet1d_img_256_512_1024_2048_edim_128obs_8_pred_8_bs_4096_lr_3e-05_e_200/2024-04-29_19-43-11'

load_dir = f'{latest_file}'
# os.listdir(load_dir)
# diffusion.load(os.path.join(load_dir, "diffusion_model_final.pt"))
train_losses = np.load(os.path.join(load_dir, "train_losses.npy"))
test_losses = np.load(os.path.join(load_dir, "test_losses.npy"))
plot_losses(train_losses, test_losses)

# 
# train_losses2 = np.load(os.path.join(load_dir2, "train_losses.npy"))
# test_losses2 = np.load(os.path.join(load_dir2, "test_losses.npy"))
# 
# plot_losses(train_losses1, test_losses1)
# plot_losses(train_losses2, test_losses2)

In [None]:


assert cfg.obs_key == "img" or cfg.obs_key == "state"


dataset = PushTImageDataset(cfg.dataset_path)
full_dataset = torch.from_numpy(dataset.replay_buffer["img"]).permute(0, 3, 1, 2)
N, C, H, W = full_dataset.shape

# Make the state normalizer.
max_state = dataset.replay_buffer["state"].max(axis=0)
min_state = dataset.replay_buffer["state"].min(axis=0)

if cfg.obs_key == "img":
    # Load VAE.
    latent_dim = 32
    vae_model = VanillaVAE(
        in_channels=3, in_height=H, in_width=W, latent_dim=latent_dim
    ).to(cfg.device)
    vae_model.load_state_dict(torch.load(cfg.vae_model_path))
    cfg.STATE_DIM = latent_dim

    def get_latent(x, vae_model, device):
        x = x / 255.0
        x = 2 * x - 1
        return vae_model.encode(torch.from_numpy(x).to(device))[0].detach()

    normalize_encoder_input = functools.partial(
        get_latent, vae_model=vae_model, device=cfg.device
    )
else:
    cfg.STATE_DIM = 5
    normalize_encoder_input = None

# Make train and val loaders
val_mask = get_val_mask(dataset.replay_buffer.n_episodes, 0.1)
val_idxs = np.where(val_mask)[0]
train_idxs = np.where(~val_mask)[0]

state_normalizer = functools.partial(
    normalize_pn1, min_val=min_state, max_val=max_state
)

process_fns = {"state": state_normalizer, "img": normalize_encoder_input}

print("Making datasets and dataloaders.")
train_loader, val_loader = EpisodeDataloaders(dataset=dataset,
                                              episode_train_idxs=train_idxs,
                                              episode_val_idxs=val_idxs,
                                              include_keys=[cfg.obs_key],  # one key only
                                              process_fns=process_fns,
                                              cfg=cfg)  # configuration params

global_cond_dim = cfg.STATE_DIM * cfg.n_obs_history

diff_model = conditional_unet1d.ConditionalUnet1D(
    input_dim=cfg.STATE_DIM,
    down_dims=cfg.down_dims,
    diffusion_step_embed_dim=cfg.diffusion_step_embed_dim,
    global_cond_dim=global_cond_dim,
).to(cfg.device)

optim_kwargs = dict(lr=cfg.lr)
diffusion = Diffusion(
    train_data=train_loader,
    test_data=val_loader,
    model=diff_model,
    n_epochs=cfg.epochs,
    optim_kwargs=optim_kwargs,
    device=cfg.device,
)


In [None]:
load_dir = f'{cfg.save_dir}pusht_unet1d_img_256_512_1024_obs_8_pred_8'
diffusion.load(os.path.join(load_dir, "diffusion_model_final.pt"))
train_losses = np.load(os.path.join(load_dir, "train_losses.npy"))
test_losses = np.load(os.path.join(load_dir, "test_losses.npy"))

In [None]:
load_dir = "models/diffusion/pusht-1dconv_latent_128_256_512_1024-obs_8-pred_8/2024-04-28_00-43-07"

if load_dir is not None:
    diffusion.load(os.path.join(load_dir, "diffusion_model_final.pt"))
    train_losses = np.load(os.path.join(load_dir, "train_losses.npy"))
    test_losses = np.load(os.path.join(load_dir, "test_losses.npy"))