In [1]:
import gc
import itertools
import os
import shutil
from pathlib import Path
import matplotlib.pyplot as plt

import comet_ml
import tensorflow as tf
import torch
import torch.optim as optim
from load_data import load_data
from model import init_model
from rsl_depth_completion.conditional_diffusion.config import cfg as cfg_cls
from rsl_depth_completion.conditional_diffusion.custom_trainer import ImagenTrainer
from rsl_depth_completion.conditional_diffusion.train import train
from rsl_depth_completion.conditional_diffusion.utils import (
    dict2mdtable,
    log_params_to_exp,
)
from rsl_depth_completion.diffusion.utils import set_seed

torch.backends.cudnn.benchmark = True





  0%|          | 0/2 [00:00<?, ?it/s]

In [2]:

cfg = cfg_cls(path=cfg_cls.default_file)

set_seed(cfg.seed)

if cfg.is_cluster:
    if not os.path.exists(f"{cfg.tmpdir}/cluster"):
        os.system(
            f"tar -xvf /cluster/project/rsl/kzaitsev/dataset.tar -C {cfg.tmpdir}"
        )

logdir = Path("./logs") if not cfg.is_cluster else Path(cfg.cluster_logdir)
if cfg.do_overfit:
    logdir = logdir / "standalone_trainer"
else:
    logdir = logdir / "train"

# shutil.rmtree(logdir, ignore_errors=True)

best_params = {
    "kitti": {
        "use_text_embed": False,
        "use_cond_image": True,
        "use_rgb_as_cond_image": False,
    },
    "mnist": {
        "use_text_embed": True,
        "use_cond_image": False,
        "use_rgb_as_cond_image": True,
    },
}

ds_kwargs = best_params[cfg.ds_name]

ds_kwargs["use_rgb_as_text_embed"] = not ds_kwargs["use_rgb_as_cond_image"]
ds_kwargs["include_sdm_and_rgb_in_sample"] = True
ds_kwargs["do_crop"] = True
print(ds_kwargs)

ds, train_dataloader, val_dataloader = load_data(
    ds_name=cfg.ds_name, do_overfit=cfg.do_overfit, cfg=cfg, **ds_kwargs
)

experiment = comet_ml.Experiment(
    api_key="W5npcWDiWeNPoB2OYkQvwQD0C",
    project_name="rsl_depth_completion",
    auto_metric_logging=True,
    auto_param_logging=True,
    auto_histogram_tensorboard_logging=True,
    log_env_details=True,
    log_env_host=False,
    log_env_gpu=True,
    log_env_cpu=True,
    disabled=cfg.disabled,
)

experiment.log_asset("model.py", copy_to_tmp=False)
experiment.log_asset("load_data.py", copy_to_tmp=False)
experiment.log_asset("train.py", copy_to_tmp=False)
experiment.log_asset("config.py", copy_to_tmp=False)

log_params_to_exp(experiment, ds_kwargs, "dataset")
log_params_to_exp(experiment, cfg.params(), "base_config")

print(
    "Number of train samples",
    len(train_dataloader) * train_dataloader.batch_size,
)

unets, model = init_model(experiment, ds_kwargs, cfg)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(
    "Number of parameters in model",
    num_params,
)

input_name = "interp_sdm"

if ds_kwargs["use_cond_image"]:
    if ds_kwargs["use_rgb_as_cond_image"]:
        img_cond = "rgb"
    else:
        img_cond = "sdm"
else:
    img_cond = "none"

if ds_kwargs["use_text_embed"]:
    if ds_kwargs["use_rgb_as_text_embed"]:
        text_cond = "rgb"
    else:
        text_cond = "sdm"
else:
    text_cond = "none"

cond = f"{img_cond=}_{text_cond=}"
exp_dir = f"{input_name=}/{cond=}/{cfg.lr=}_{cfg.timesteps=}"

exp_dir = f"{len(os.listdir(logdir)) + 1:03d}" if os.path.isdir(logdir) else "001"
train_logdir = logdir / exp_dir / cond
train_logdir.mkdir(parents=True, exist_ok=True)
# train_writer = tf.summary.create_file_writer(str(train_logdir))

trainer = ImagenTrainer(
    model,
    use_lion=False,
    lr=cfg.lr,
    max_grad_norm=1.0,
    fp16=cfg.fp16,
    use_ema=False,
    accelerate_log_with="tensorboard",
    accelerate_logging_dir="logs",
)
trainer.accelerator.init_trackers("train_example")

{'use_text_embed': False, 'use_cond_image': True, 'use_rgb_as_cond_image': False, 'use_rgb_as_text_embed': True, 'include_sdm_and_rgb_in_sample': True, 'do_crop': True}




Number of train samples 1
The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/
Number of parameters in model 16231893


In [3]:
trainer.load('/media/master/wext/msc_studies/second_semester/research_project/project/rsl_depth_completion/rsl_depth_completion/conditional_diffusion/models/odd_atoll_8115-model-last.pt')
print(1)

Failed loading state dict. Trying partial load
layer unets.0.null_text_embed(torch.Size([1, 256, 64]) different than target: torch.Size([1, 256, 32])
layer unets.0.null_text_hidden(torch.Size([1, 256]) different than target: torch.Size([1, 128])
layer unets.0.init_conv.convs.0.weight(torch.Size([32, 2, 3, 3]) different than target: torch.Size([16, 2, 3, 3])
layer unets.0.init_conv.convs.0.bias(torch.Size([32]) different than target: torch.Size([16])
layer unets.0.init_conv.convs.1.weight(torch.Size([16, 2, 7, 7]) different than target: torch.Size([8, 2, 7, 7])
layer unets.0.init_conv.convs.1.bias(torch.Size([16]) different than target: torch.Size([8])
layer unets.0.init_conv.convs.2.weight(torch.Size([16, 2, 15, 15]) different than target: torch.Size([8, 2, 15, 15])
layer unets.0.init_conv.convs.2.bias(torch.Size([16]) different than target: torch.Size([8])
layer unets.0.to_time_hiddens.1.weight(torch.Size([256, 17]) different than target: torch.Size([128, 17])
layer unets.0.to_time_hi

: 

In [None]:
eval_batch = ds.eval_batch

In [4]:
eval_batch = ds.eval_batch
if cfg.do_sample:
    eval_text_embeds = (
        eval_batch["text_embed"]
        if "text_embed" in eval_batch
        else None
    )
    eval_cond_images = (
        eval_batch["cond_image"]
        if "cond_image" in eval_batch
        else None
    )

    samples = trainer.sample(
        text_embeds=eval_text_embeds,
        cond_images=eval_cond_images,
        cond_scale=cfg.cond_scale,
        batch_size=cfg.batch_size,
        stop_at_unet_number=None,
        return_all_unet_outputs=True,
    )

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]

In [6]:
def compare_diff_cond_scales(trainer):
    cond_scales = list(range(0, 14, 2))
    fix, axs = plt.subplots(1, len(cond_scales), figsize=(20, 5))
    for i, cond_scale in enumerate(cond_scales):
        samples = trainer.sample(
            text_embeds=eval_text_embeds,
            cond_images=eval_cond_images,
            cond_scale=cond_scale,
            batch_size=cfg.batch_size,
            stop_at_unet_number=None,
            return_all_unet_outputs=True,
        )
        axs[i].imshow(samples[0][0].permute(1, 2, 0).cpu().numpy())
        axs[i].set_title(f"cond_scale={cond_scale}")
        if i!=0:
            axs[i].set_yticklabels([])
            axs[i].set_yticks([])
            axs[i].set_xticklabels([])
            axs[i].set_xticks([])
    plt.show()