In [91]:
import os
import sys
import datetime
import functools
import math
import time
from typing import Optional

if "PyTorch_VAE" not in sys.path:
    sys.path.append("PyTorch_VAE")

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import trange

from PyTorch_VAE import models
from diffusion_policy.common.pytorch_util import compute_conv_output_shape
from diffusion_policy.dataset.pusht_image_dataset import PushTImageDataset


from lsdp_utils.EpisodeDataset import EpisodeDataset, EpisodeDataloaders
from lsdp_utils.VanillaVAE import VanillaVAE
from lsdp_utils.utils import bcolors
from lsdp_utils.utils import plot_losses, plot_samples, normalize_pn1, denormalize_pn1, bcolors

from types import SimpleNamespace

cfg = SimpleNamespace(dataset_path='/home/matteogu/ssd_data/data_diffusion/pusht/pusht_cchi_v7_replay.zarr',
                      # vae_model_path='/home/matteogu/Desktop/prj_deepul/repo_online/lsdp/models/pusht_vae/vae_32_20240403.pt',
                      vae_save_dir='/home/matteogu/ssd_data/diffusion_models/models/vae/',
                      n_obs_history=0, 
                      n_pred_horizon=1,  # just trying to fit the latents to states
                      
                      kld_weight=1e-7,
                      latent_dim=128,
                      hidden_dims=[32, 64, 128, 256, 512],

                      train_split=0.8,

                      batch_size=512,
                      lr=5e-4,  # optimization params

                      epochs=100,
                      device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                      )


# path = "/nas/ucb/ebronstein/lsdp/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr"
# # path = "/home/tsadja/data_diffusion/pusht/pusht_cchi_v7_replay.zarr"
# path = '/home/matteogu/ssd_data/data_diffusion/pusht/pusht_cchi_v7_replay.zarr'

dataset = PushTImageDataset(cfg.dataset_path)
full_dataset = torch.from_numpy(dataset.replay_buffer["img"]).permute(0, 3, 1, 2)

str_hidden = str(cfg.hidden_dims)[1:-1].replace(", ", "_")
name = (f'pusht_vae_klw_{cfg.kld_weight:.2e}_ldim_{cfg.latent_dim}_'
        f'bs_{cfg.batch_size}_epochs_{cfg.epochs}_lr_{cfg.lr}_hdim_{str_hidden}')
save_dir = f'{cfg.vae_save_dir}{name}'
os.makedirs(save_dir, exist_ok=True)
print(f"{bcolors.OKGREEN} ---------------------- {bcolors.ENDC}")
print(f"{bcolors.OKGREEN}   {name}   {bcolors.ENDC}")
print(f"{bcolors.OKGREEN} ---------------------- {bcolors.ENDC}")



In [92]:
dataset = PushTImageDataset(cfg.dataset_path)
full_dataset = torch.from_numpy(dataset.replay_buffer["img"]).permute(0, 3, 1, 2)
N, C, H, W = full_dataset.shape

# Make the state normalizer.
max_state = dataset.replay_buffer["state"].max(axis=0)
min_state = dataset.replay_buffer["state"].min(axis=0)

vae_model = VanillaVAE(
    in_channels=C, in_height=H, in_width=W, latent_dim=cfg.latent_dim
).to(cfg.device)
vae_model.load_state_dict(torch.load(save_dir+'/vae_99.pt'))

def get_latent(x, _vae_model, device):
    x = x / 255.0
    x = 2. * x - 1.
    return _vae_model.encode(torch.from_numpy(x).to(device))[0].detach()

normalize_encoder_input = functools.partial(
    get_latent, _vae_model=vae_model, device=cfg.device
)


state_normalizer = functools.partial(
    normalize_pn1, min_val=min_state, max_val=max_state
)

process_fns = {"state": state_normalizer, "img": normalize_encoder_input}

print("Making datasets and dataloaders.")
train_loader, val_loader = EpisodeDataloaders(dataset=dataset,
                                              include_keys=["state", "img"],  # one key only
                                              process_fns=process_fns,
                                              cfg=cfg,
                                              val_ratio=0.9)  # configuration params


# TRAIN MAPPING LATENTS TO STATE. need to normalize, encoder output

In [95]:
n_latents = len(val_loader.dataset.samples)
# with torch.no_grad():
#     for i, sample in enumerate(val_loader):
#         input, target = sample[1]['img'], sample[1]['state']

all_latents = torch.zeros(n_latents, cfg.latent_dim)
for idx_sample in trange(n_latents):
    all_latents[idx_sample] = val_loader.dataset.samples[idx_sample][1]['img']  # latent
# val_loader.dataset.samples[idx_sample][1]['img']  # latent

min_latents = torch.min(all_latents, axis=0).values.to(cfg.device)
max_latents = torch.max(all_latents, axis=0).values.to(cfg.device)

#between 0 and 1
norm_latents =  (all_latents - min_latents.cpu())/(max_latents.cpu() - min_latents.cpu())

norm_latents.min(), norm_latents.max(), 

In [27]:
sample = next(iter(train_loader))
sample[1]['img'].shape, sample[1]['state'].shape


In [96]:
class LatentsToStateMLP(nn.Module):
    def __init__(
        self, in_dim, out_dim, hidden_dims: list[int]
    ):
        super().__init__()
        
        layers = []
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU())
            in_dim = hidden_dim
        layers.append(nn.Linear(in_dim, out_dim))
        # layers.append(nn.Tanh())
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x = obs_history.flatten(start_dim=1)
        x = self.model(x)
        return x

In [100]:
model = LatentsToStateMLP(in_dim=cfg.latent_dim, 
                          out_dim=5,  # state
                          hidden_dims=[256, 256, 128, 16]).to(cfg.device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)


In [101]:
EPOCHS = 400

train_losses, test_losses = [], []
train_start_time = time.time()
for epoch in trange(EPOCHS):
    total_train_loss = 0
    model.train()
    for i, sample in enumerate(train_loader):
        input, target = sample[1]['img'], sample[1]['state']
        
        norm_input = input #(input - min_latents)/(max_latents - min_latents)
        
        result = model(norm_input)
        loss = criterion(result, target)

        total_train_loss += loss.item()
        train_losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # print(f"Train loss: {total_train_loss / len(train_loader):.4f}")

    total_test_loss = 0
    model.eval()
    with torch.no_grad():
        for i, sample in enumerate(val_loader):
            input, target = sample[1]['img'], sample[1]['state']

            norm_input = input # (input - min_latents)/(max_latents - min_latents)

            result = model(norm_input)
            loss = criterion(result, target)

            total_test_loss += loss.item()
    test_losses.append(total_test_loss / len(val_loader))
    if epoch % 30 == 0:
        print(f"[Val-{epoch}] loss: {test_losses[-1]:.2e}")
print(f'Saving.. {datetime.timedelta(seconds=time.time() - train_start_time)}')


In [102]:
plot_losses(train_losses, test_losses)

In [112]:
for param in model.parameters():
    param.requires_grad = False
    


In [103]:
torch.save(model.state_dict(), 
           os.path.join(f"/home/matteogu/ssd_data/diffusion_models/models"
                        f"/latent_to_state/mlp_128to5.pt"))