In [1]:
import gc
import itertools
import os
import shutil
from pathlib import Path
import matplotlib.pyplot as plt

import comet_ml
import tensorflow as tf
import torch
import torch.optim as optim
from load_data import load_data
from model import init_model
from rsl_depth_completion.conditional_diffusion.config import cfg as cfg_cls
from rsl_depth_completion.conditional_diffusion.custom_trainer import ImagenTrainer
from rsl_depth_completion.conditional_diffusion.utils import (
    dict2mdtable,
    log_params_to_exp,
)
from rsl_depth_completion.diffusion.utils import set_seed
from rsl_depth_completion.conditional_diffusion.pipeline_utils import (
    get_ds_kwargs,
    setup_train_pipeline,
)
from rsl_depth_completion.conditional_diffusion.data_utils import (
    update_eval_batch_file,
    fill_eval_batch_with_coco,
)

from rsl_depth_completion.conditional_diffusion.pipeline_utils import create_tracking_exp

torch.backends.cudnn.benchmark = True

In [2]:



logdir_name = "debug"

out_dir = f"/tmp/{logdir_name}" 
os.makedirs(out_dir, exist_ok=True)
# logdir_name = "standalone_trainer"
cfg, train_logdir = setup_train_pipeline(logdir_name, use_ssl=False)
cfg.disabled = True
cfg.input_res = 128
cfg.unets_output_res = [64,128]
cfg.use_triplet_loss=False

experiment = create_tracking_exp(cfg)

ds_kwargs = get_ds_kwargs(cfg)

ds, train_dataloader, val_dataloader = load_data(
    ds_name=cfg.ds_name, do_overfit=cfg.do_overfit, cfg=cfg, **ds_kwargs
)
x=ds[0]
eval_batch = ds.eval_batch
batch=next(iter(train_dataloader))
# eval_batch = torch.utils.data.default_collate([ds[210],ds[40]])
# eval_batch = {k:v[:cfg.batch_size] for k,v in torch.load("eval_batch_rand_sdm.pt")[cfg.input_res].items()}

print(x['cond_img'].shape)



torch.Size([1, 128, 128])


In [3]:
torch.load("eval_batch.pt").keys()

dict_keys([256, 64, 128, 160])

In [4]:
cfg.use_super_res = True
unets, model = init_model(experiment, ds_kwargs, cfg) 
trainer_kwargs = dict(
    imagen=model,
    use_lion=False,
    lr=0.5,
    max_grad_norm=1.0,
    fp16=cfg.fp16,
    use_ema=False,
    accelerate_log_with="comet_ml",
    accelerate_project_dir="logs",
)
trainer = ImagenTrainer(**trainer_kwargs)

The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/


In [5]:
if cfg.do_overfit:
    batch = eval_batch
is_multi_unet_training = (trainer.num_unets) > 1
images = batch["input_img"]
if "text_embed" in batch:
    text_embeds = batch["text_embed"]
else:
    text_embeds = None
if "cond_img" in batch:
    cond_images = batch["cond_img"]
else:
    cond_images = None

validity_map_depth = torch.where(
    batch["sdm"] > 0, torch.ones_like(batch["sdm"]), batch["sdm"]
).bool()

def step_unet_i(images, text_embeds, cond_images, validity_map_depth, i, trainer):
    loss = trainer(
            images=images,
            text_embeds=text_embeds,
            cond_images=cond_images,
            unet_number=i,
            max_batch_size=cfg.max_batch_size,
            validity_map_depth=validity_map_depth
            if i == (trainer.num_unets) and cfg.use_validity_map_depth
            else None,
        )
    trainer.update(unet_number=i)
    return loss

i=1
trainer = ImagenTrainer(**trainer_kwargs)
for epoch in range(1, 5):
    loss = step_unet_i(images, text_embeds, cond_images, validity_map_depth, i, trainer)
    print(f"unet {i}\tloss: {loss}")
ckpt_path1 = f"{out_dir}/checkpoint_{i}.pt"
trainer.save(ckpt_path1)

i=2
trainer = ImagenTrainer(**trainer_kwargs)
for epoch in range(1, 5):
    loss = step_unet_i(images, text_embeds, cond_images, validity_map_depth, i, trainer)
    print(f"unet {i}\tloss: {loss}")

ckpt_path2 = f"{out_dir}/checkpoint_{i}.pt"
trainer.save(ckpt_path2)

unet 1	loss: 0.6485889703035355
unet 1	loss: 32.736530780792236
unet 1	loss: 68.83356475830078
unet 1	loss: 24.02565371990204
checkpoint saved to /tmp/debug/checkpoint_1.pt
unet 2	loss: 1.0122721195220947
unet 2	loss: 25.605119466781616
unet 2	loss: 52.253578186035156
unet 2	loss: 65.95992279052734
checkpoint saved to /tmp/debug/checkpoint_2.pt


In [6]:
from rsl_depth_completion.conditional_diffusion.train_imagen_loop import sample
samples_lowres = sample(
                        cfg,
                    trainer,
                    experiment,
                    out_dir,
                    eval_batch,
                    1,
                    start_at_unet_number=1,
                    start_image_or_video=None,
                    stop_at_unet_number=1,
)

ensure unet 1 is trained. this trainer instance hasn't trained it
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]

In [8]:
samples_sr_256 = sample(
                        cfg,
                    trainer,
                    experiment,
                    out_dir,
                    eval_batch,
                    2,
                    start_at_unet_number=2,
                    start_image_or_video=samples_lowres[0],
                    stop_at_unet_number=2,
)

ensure unet 1 is trained. this trainer instance hasn't trained it
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]

In [9]:
samples_lowres[0].shape, samples_sr_256[0].shape

(torch.Size([2, 1, 64, 64]), torch.Size([2, 1, 128, 128]))