In [None]:
import torch
import numpy as np
from data.garment_dataset import GarmentDataset
from torch.utils.data import DataLoader, random_split
import json
from torch import nn
import torch.nn.functional as F
import os
from accelerate import Accelerator
from tqdm.auto import tqdm
from accelerate import notebook_launcher
from diffusers import AutoencoderKL
from pprint import pprint
from utils.uv_tools import apply_displacement
from utils.visualization import imshow
import wandb, random
from losses.losses import uv_loss
from utils.geometry import l2dist

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
torch.multiprocessing.set_start_method('spawn')
torch.random.manual_seed(0)
random.seed(0)

In [None]:
config_file = "/home/adumouli/Bureau/garment-diffusion/configs/config.json"
with open(config_file) as f:
    config = json.load(f)

In [None]:
# # RESET
# vae_orig = AutoencoderKL.from_pretrained(config["vae"]["folder"]) # we use this one
# # vae_orig = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix") # we use this one
# # vae_orig = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
# # vae_orig = AutoencoderKL.from_pretrained("ostris/OpenFLUX.1", subfolder='vae')
# # vae_orig = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder='vae')

# # vae_orig = AutoencoderKL(
# #     in_channels=3,
# #     out_channels=3,
# #     down_block_types=[
# #         "DownEncoderBlock2D",
# #         "DownEncoderBlock2D",
# #         "DownEncoderBlock2D",
# #         "DownEncoderBlock2D"
# #     ],
# #     up_block_types=[
# #         "UpDecoderBlock2D",
# #         "UpDecoderBlock2D",
# #         "UpDecoderBlock2D",
# #         "UpDecoderBlock2D"
# #     ],
# #     block_out_channels=[64, 128, 256, 256],
# #     layers_per_block=2,
# #     act_fn='silu',
# #     latent_channels=16,
# #     norm_num_groups=32,
# #     sample_size=config["model"]["image_size"],
# #     scaling_factor=0.3611,
# #     shift_factor=0,
# #     force_upcast=False,
# #     use_quant_conv=False,
# #     use_post_quant_conv=False,
# #     mid_block_add_attention=True
# # )
# print("VAE network size: " + str(vae_orig.num_parameters()/1e6))
# pprint(vae_orig.config)

# # vae_orig.save_pretrained(config["vae"]["folder"])

In [None]:
@torch.no_grad()
def eval(vae, dataset):
    mask = ~GarmentDataset(config, device='cuda').mask
    vae_orig = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", subfolder='vae').to(vae.device)

    img = dataset[500]['vdm'].to(vae.device)
    encode = vae.encode(img.permute(2,0,1)[None,...]).latent_dist.mode()
    decode = vae.decode(encode).sample[0].permute(1,2,0)

    dist = torch.abs(decode-img)
    dist[mask]*=0
    print(dist.mean())
    print(dist.max())

    encode = vae_orig.encode(img.permute(2,0,1)[None,...]).latent_dist.mode()
    decode2 = vae_orig.decode(encode).sample[0].permute(1,2,0)
    
    dist2 = torch.abs(decode2-img)
    dist2[mask]*=0
    print(dist2.mean())
    print(dist2.max())

    dist3 = torch.abs(decode-decode2)
    print(dist3.mean())
    print(dist3.max())

    # imshow(torch.cat((decode, img, decode2), 1))
    # imshow(torch.cat((dist, dist2, dist3), 1)*10)

In [None]:
def vae_finetune(config):

    eval_generator = torch.Generator(device='cpu').manual_seed(0)

    vae = AutoencoderKL.from_pretrained(config["vae"]["folder"]).to('cuda')

    loss = dict()
    dataset = GarmentDataset(config, device='cuda')
    # dataset = VAETrainingDataset(dataset)
    train_dataset, eval_dataset = random_split(dataset, [0.95, .05], generator=eval_generator)

    mask = ~dataset.mask.expand(config["vae"]["batch_size"], -1, -1)

    train_dataloader = DataLoader(train_dataset, batch_size=config["vae"]["batch_size"], drop_last=True, num_workers=4, shuffle=True, persistent_workers=True, pin_memory=False)

    optimizer = torch.optim.AdamW(vae.parameters(), lr=config["vae"]["lr"] * config["vae"]["batch_size"])

    accelerator = Accelerator(
        mixed_precision=config['mixed_precision'],
        gradient_accumulation_steps=1,
        log_with="wandb",
        project_dir=os.path.join(config["vae"]["folder"], "logs")
    )

    if accelerator.is_main_process:
        os.makedirs(config["vae"]["folder"], exist_ok=True)
        accelerator.init_trackers("vae",
            config={"config": config["vae"]}
            )
    
    optimizer, train_dataloader, vae, vae.encoder, vae.decoder = accelerator.prepare(
        optimizer, train_dataloader, vae, vae.encoder, vae.decoder
    )

    vae.requires_grad_(True)
    vae.train()

    global_step = 0
    progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
    
    # Now you train the model
    for epoch in range(config["vae"]["max_epochs"]):
        progress_bar.reset()
        progress_bar.set_description(f"Epoch {epoch}")
        for step, batch in enumerate(train_dataloader):
            
            with accelerator.accumulate(vae):
                
                optimizer.zero_grad(set_to_none=True)

                clean_images = batch['vdm'].to(accelerator.device)
                clean_vertices = batch['cloth_vertices'].to(accelerator.device)

                encode = vae.encode(clean_images.permute(0,3,1,2)).latent_dist
                decode = vae.decode(encode.sample()).sample.permute(0,2,3,1)

                # reconstruction loss
                rec_loss = F.mse_loss(decode.float(), clean_images.float(), reduction='none')
                rec_loss[mask] *= 0.1
                loss['rec_loss'] = rec_loss.mean()

                decoded_vertices = apply_displacement(dataset.template, decode)
                mesh_loss = F.mse_loss(decoded_vertices.float(), clean_vertices.float(), reduction='none')
                loss['mesh_loss'] = mesh_loss.mean()
                
                loss['uv_loss'] = uv_loss(decode.float(), dataset, decoded_vertices)
                loss['uv_loss'][mask] *= 0.1
                loss['uv_loss'] = loss['uv_loss'].mean()

                # Make the KL divergence loss (effectively regularize the latents)
                loss['kl_loss'] = encode.kl().mean()

                # Combine the losses
                loss['total_loss'] = loss['rec_loss'] + loss['mesh_loss'] + config["vae"]["kl_loss"] * loss['kl_loss'] + loss['uv_loss'] * config["vae"]["uv_loss"]

                accelerator.backward(loss['total_loss'])

                accelerator.clip_grad_norm_(vae.parameters(), 1.0)

                optimizer.step()

            progress_bar.update(1)

            for k in loss.keys():
                elem = loss[k]
                if isinstance(elem, torch.Tensor):
                    loss[k] = elem.detach().item()

            logs = loss.copy()

            logs['epoch'] = epoch
            accelerator.log(logs, step=global_step)
            progress_bar.set_postfix(**loss)

            global_step += 1

            # After each epoch you optionally evaluate and save the model
            if accelerator.is_main_process and step%config["vae"]["eval_freq"] == 0:
                
                eval_vae = accelerator.unwrap_model(vae)
                eval_vae.save_pretrained(config["vae"]["folder"])

                # eval
                with torch.inference_mode():

                    indices = torch.randint(high=len(eval_dataset), size=(config["validation"]["batch_size"],), generator=eval_generator)
                    ground_truth = torch.stack([eval_dataset[i]['vdm'] for i in indices])
                    ground_truth = ground_truth.to(accelerator.device).permute(0,3,1,2)
                    gt_mesh = torch.stack([eval_dataset[i]['cloth_vertices'] for i in indices])
                    gt_mesh = gt_mesh.to(accelerator.device)

                    encode = eval_vae.encode(ground_truth).latent_dist.mode()
                    decode = eval_vae.decode(encode).sample

                    dist = torch.abs(decode-ground_truth)

                    decoded_vertices = apply_displacement(dataset.template, decode.permute(0,2,3,1))
                    mesh_loss = F.mse_loss(decoded_vertices.float(), gt_mesh.float(), reduction='none')

                    eval = dict()
                    eval['eval_mesh_mean'] = mesh_loss.mean().detach().item()
                    eval['eval_mesh_max'] = mesh_loss.max().detach().item()
                    eval['eval_dist_mean'] = dist.mean().detach().item()
                    eval['eval_dist_max'] = dist.max().detach().item()

                    accelerator.log(eval, step=global_step)

                del eval_vae

In [None]:
# args = (config, )
# notebook_launcher(vae_finetune, args, num_processes=1)

In [None]:
def decoder_finetune(config):

    eval_generator = torch.Generator(device='cpu').manual_seed(0)

    vae = AutoencoderKL.from_pretrained(config["vae"]["folder"])

    loss = dict()
    dataset = GarmentDataset(config, device='cpu')
    train_dataset, eval_dataset = random_split(dataset, [0.95, .05], generator=eval_generator)

    mask = ~dataset.mask.expand(config["vae"]["batch_size"], -1, -1)

    train_dataloader = DataLoader(train_dataset, batch_size=config["vae"]["batch_size"], drop_last=True, num_workers=4, shuffle=True, persistent_workers=True, pin_memory=True)
   
    optimizer = torch.optim.AdamW(vae.decoder.parameters(), lr=config["vae"]["lr"] * config["vae"]["batch_size"])

    accelerator = Accelerator(
        mixed_precision=config['mixed_precision'],
        gradient_accumulation_steps=1,
        log_with="wandb",
        project_dir=os.path.join(config["vae"]["folder"], "logs")
    )

    if accelerator.is_main_process:
        os.makedirs(config["vae"]["folder"], exist_ok=True)
        accelerator.init_trackers("vae",
            config={"config": config["vae"]}
            )
    

    vae.decoder.requires_grad_(True)
    vae.decoder.train()
    vae.encoder.requires_grad_(False)
    vae.encoder.eval()
    
    vae.to(accelerator.device)

    vae.encode = torch.compile(vae.encode, mode="max-autotune", fullgraph=True, dynamic=True)

    optimizer, train_dataloader, vae.decoder = accelerator.prepare(
        optimizer, train_dataloader, vae.decoder
    )

    global_step = 0
    progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
    
    # Now you train the model
    for epoch in range(config["vae"]["max_epochs"]):
        progress_bar.reset()
        progress_bar.set_description(f"Epoch {epoch}")
        for step, batch in enumerate(train_dataloader):
            
            optimizer.zero_grad(set_to_none=True)

            clean_images = batch['vdm'].to(accelerator.device)
            clean_vertices = batch['cloth_vertices'].to(accelerator.device)

            with torch.no_grad():
                encode = vae.encode(clean_images.permute(0,3,1,2)).latent_dist

            with accelerator.accumulate(vae.decoder):

                decode = vae.decode(encode.sample()).sample.permute(0,2,3,1)

                # reconstruction loss
                rec_loss = F.mse_loss(decode.float(), clean_images.float(), reduction='none')
                rec_loss[mask] *= 0.1
                loss['rec_loss'] = rec_loss.sum()

                decoded_vertices = apply_displacement(dataset.template, decode)
                mesh_loss = l2dist(decoded_vertices.float(), clean_vertices.float())
                loss['mesh_loss'] = mesh_loss.sum()

                # loss['uv_loss'] = uv_loss(decode.float(), dataset, decoded_vertices)
                # loss['uv_loss'][mask] *= 0.1
                # loss['uv_loss'] = loss['uv_loss'].sum()

                # Combine the losses
                loss['total_loss'] = loss['rec_loss'] + loss['mesh_loss'] * config["vae"]["mesh_loss"]# + loss['uv_loss'] * config["vae"]["uv_loss"]

                accelerator.backward(loss['total_loss'])

                accelerator.clip_grad_norm_(vae.decoder.parameters(), 1.0)

                optimizer.step()

            progress_bar.update(1)

            for k in loss.keys():
                elem = loss[k]
                if isinstance(elem, torch.Tensor):
                    loss[k] = elem.detach().item()

            logs = loss.copy()

            logs['epoch'] = epoch
            accelerator.log(logs, step=global_step)
            progress_bar.set_postfix(**loss)

            global_step += 1

            # After each epoch you optionally sample some demo images with evaluate() and save the model
            if accelerator.is_main_process and step%config["vae"]["eval_freq"] == 0:
                
                eval_vae = accelerator.unwrap_model(vae)
                eval_vae.save_pretrained(config["vae"]["folder"]+'2')

                # eval
                with torch.inference_mode():

                    indices = torch.randint(high=len(eval_dataset), size=(config["validation"]["batch_size"],), generator=eval_generator)
                    ground_truth = torch.stack([eval_dataset[i]['vdm'] for i in indices])
                    ground_truth = ground_truth.to(accelerator.device).permute(0,3,1,2)
                    gt_mesh = torch.stack([eval_dataset[i]['cloth_vertices'] for i in indices])
                    gt_mesh = gt_mesh.to(accelerator.device)

                    encode = vae.encode(ground_truth).latent_dist.mode()
                    decode = vae.decode(encode).sample

                    dist = torch.abs(decode-ground_truth)

                    decoded_vertices = apply_displacement(dataset.template, decode.permute(0,2,3,1))
                    mesh_loss = F.mse_loss(decoded_vertices.float(), gt_mesh.float(), reduction='none')

                    eval = dict()
                    eval['eval_mesh_mean'] = mesh_loss.mean().detach().item()
                    eval['eval_mesh_max'] = mesh_loss.max().detach().item()
                    eval['eval_dist_mean'] = dist.mean().detach().item()
                    eval['eval_dist_max'] = dist.max().detach().item()

                    accelerator.log(eval, step=global_step)

            if accelerator.is_main_process and step%config["vae"]["save_image_freq"] == 0:

                # eval
                with torch.inference_mode():
                    
                    indice = random.randint(0, len(eval_dataset))
                    ground_truth = eval_dataset[indice]['vdm'][None, ...]
                    ground_truth = ground_truth.to(accelerator.device).permute(0,3,1,2)

                    encode = vae.encode(ground_truth).latent_dist.mode()
                    decode = vae.decode(encode).sample
                    
                    examples = []
                    image = wandb.Image(ground_truth[0], caption=f"Ground truth")
                    examples.append(image)
                    image = wandb.Image(decode[0], caption="Decoded")
                    examples.append(image)

                    accelerator.log({"Images": examples}, step=global_step)
                    print(indice)

                    decode = decode[0].permute(1,2,0)
                    ground_truth = ground_truth[0].permute(1,2,0)
                    dist = (decode-ground_truth)

                    imshow(torch.cat((decode, ground_truth), 1))
                    imshow(dist)


In [None]:
args = (config, )
notebook_launcher(decoder_finetune, args, num_processes=1)

In [None]:
eval(vae = AutoencoderKL.from_pretrained(config["vae"]["folder"]).to('cuda'), dataset = GarmentDataset(config, device='cuda'))