In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import comet_ml
import torch
from torchvision import transforms
from rsl_depth_completion.diffusion.utils import set_seed

from pathlib import Path
import shutil
import numpy as np
from transformers import CLIPModel, CLIPProcessor

from kbnet import data_utils
import yaml
import argparse
from utils import plot_sample
import tensorflow as tf
from pathlib import Path
from torchvision.utils import save_image
from tqdm.auto import tqdm
from PIL import Image


%matplotlib inline

# config


In [None]:
input_channels = 1
timesteps = 300
# timesteps = 30

seed = 100
set_seed(seed)

torch.backends.cudnn.benchmark = True

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device='cpu'

from config import tmpdir
is_cluster = os.path.exists('/cluster')
if is_cluster:
    if not os.path.exists(f'{tmpdir}/cluster'):
        !tar -xvf /cluster/project/rsl/kzaitsev/datasets.tar -C $TMPDIR

# model


In [None]:
extractor_model_ref = "openai/clip-vit-base-patch32"
extractor_model = CLIPModel.from_pretrained(extractor_model_ref)
extractor_processor = CLIPProcessor.from_pretrained(extractor_model_ref)

In [None]:
from config import path_to_project_dir, base_kitti_dataset_dir

ds_config_str = open(
    f"{path_to_project_dir}/rsl_depth_completion/configs/data/kitti_custom.yaml"
).read()
ds_config_str = ds_config_str.replace("${data_dir}", base_kitti_dataset_dir)
ds_config = argparse.Namespace(**yaml.safe_load(ds_config_str)["ds_config"])
ds_config.use_pose = "photo" in ds_config.train_mode
ds_config.result = ds_config.result_dir
ds_config.use_rgb = ("rgb" in ds_config.input) or ds_config.use_pose
ds_config.use_d = "d" in ds_config.input
ds_config.use_g = "g" in ds_config.input
val_image_paths = data_utils.read_paths(ds_config.val_image_path)
val_sparse_depth_paths = data_utils.read_paths(ds_config.val_sparse_depth_path)
val_intrinsics_paths = data_utils.read_paths(ds_config.val_intrinsics_path)
val_ground_truth_paths = data_utils.read_paths(ds_config.val_ground_truth_path)

In [None]:
import cv2
from rsl_depth_completion.conditional_diffusion import utils
from rsl_depth_completion.data.kitti.kitti_dataset import CustomKittiDCDataset


class MinimagenDatasetCustom(CustomKittiDCDataset):
    def __init__(
        self,
        include_cond_image=False,
        sdm_transform=None,
        do_crop=False,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.include_cond_image = include_cond_image
        self.default_transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        self.sdm_transform = sdm_transform or self.default_transform
        self.max_depth = 80
        self.do_crop = do_crop

    def __getitem__(self, idx):
        items = super().__getitem__(idx)
        if self.do_crop:
            items["d"] = items["d"][:, 50 : 50 + 256, 400 : 400 + 256]
            items["img"] = items["img"][:, 50 : 50 + 256, 400 : 400 + 256]
        sparse_dm = items["d"]
        sparse_dm /= self.max_depth

        interpolated_sparse_dm = torch.from_numpy(
            # utils.infill_sparse_depth(sparse_dm.numpy())
            utils.interpolate_sparse_depth(
                sparse_dm.squeeze().numpy(), do_multiscale=True
            )
        ).unsqueeze(0)

        rgb_image = items["img"]

        rgb_pixel_values = self.extract_img_features(rgb_image)
        sdm_pixel_values = self.extract_img_features(
            cv2.cvtColor(sparse_dm.squeeze().numpy(), cv2.COLOR_GRAY2RGB)
        )
        rgb_embed = extractor_model.get_image_features(pixel_values=rgb_pixel_values)
        sdm_embed = extractor_model.get_image_features(pixel_values=sdm_pixel_values)

        sample = {
            "perturbed_sdm": interpolated_sparse_dm.detach(),
            "rgb_embed": rgb_embed.detach(),
            "sdm_embed": sdm_embed.detach(),
            "rgb_image": (rgb_image / 255).detach(),
            "sparse_dm": (sparse_dm).detach(),
        }
        return sample

    def extract_img_features(self, cond_image):
        return extractor_processor(
            images=torch.stack(
                [
                    torch.from_numpy(np.array(cond_image)),
                ]
            ),
            return_tensors="pt",
        ).pixel_values


ds = MinimagenDatasetCustom(
    ds_config=ds_config,
    image_paths=val_image_paths,
    sparse_depth_paths=val_sparse_depth_paths,
    intrinsics_paths=val_intrinsics_paths,
    ground_truth_paths=val_ground_truth_paths,
    include_cond_image=True,
    do_crop=True,
)
x = ds[0]
plot_sample(x)
x["perturbed_sdm"].shape, x["rgb_embed"].shape, x["sdm_embed"].shape, x[
    "rgb_image"
].shape

In [None]:
ds_subset = torch.utils.data.Subset(
    ds,
    range(0, len(ds) // 2)
    # range(0, 5)
)
train_size = int(0.8 * len(ds_subset))
test_size = len(ds_subset) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(
    ds_subset, [train_size, test_size]
)
is_cluster = os.path.exists("/cluster")
if is_cluster:
    BATCH_SIZE = 16
    NUM_WORKERS = min(20, BATCH_SIZE)
else:
    BATCH_SIZE = 4
    NUM_WORKERS = 0

dl_opts = {
    "batch_size": BATCH_SIZE,
    "num_workers": NUM_WORKERS,
    "drop_last": True,
    # "collate_fn": MinimagenCollator(device),
}
train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts, shuffle=False)

In [None]:
len(train_dataloader), len(valid_dataloader)

In [None]:
from imagen_pytorch import Unet, Imagen

import gc

gc.collect()
torch.cuda.empty_cache()

unet_base_params = dict(
    dim=64,
    dim_mults=[1, 1, 2, 2, 4, 4],
    channels=1,
    channels_out=None,
    text_embed_dim=512,
    num_resnet_blocks=2,
    layer_attns=[False, False, False, False, False, True],
    layer_cross_attns=[False, False, False, False, False, True],
    attn_heads=8,
    lowres_cond=False,
    memory_efficient=False,
    attend_at_middle=False,
    cond_dim=None,
    cond_images_channels=3,
)
imagen_params = dict(
    text_embed_dim=512,
    channels=1,
    timesteps=timesteps,
    loss_type="l2",
    lowres_sample_noise_level=0.2,
    dynamic_thresholding_percentile=0.9,
    only_train_unet_number=None,
    image_sizes=[128],
    text_encoder_name="google/t5-v1_1-base",
    auto_normalize_img=True,
    cond_drop_prob=0.1,
    condition_on_text=True,
)

unet_base = Unet(**unet_base_params)
unets = [unet_base]


imagen = Imagen(unets=unets, **imagen_params)

unet_base.to(device)
imagen.to(device)

imagen = torch.compile(imagen)

print(
    "Number of parameters in model",
    sum(p.numel() for p in imagen.parameters() if p.requires_grad),
)

In [None]:
import torch.optim as optim

lr = 5e-6
optimizer = optim.Adam(imagen.parameters(), lr=lr)

In [None]:
experiment = comet_ml.Experiment(
    api_key="W5npcWDiWeNPoB2OYkQvwQD0C",
    project_name="rsl_depth_completion",
    auto_metric_logging=True,
    auto_param_logging=True,
    auto_histogram_tensorboard_logging=True,
    log_env_details=False,
    log_env_host=False,
)

In [None]:
experiment.log_parameters({f"imagen_{k}": v for k, v in imagen_params.items()})
experiment.log_parameters({f"unet_base_{k}": v for k, v in unet_base_params.items()})

In [None]:
logdir = Path("./logs")
input_name = "interp_sdm"
cond = "rgb+sdm"
exp_dir = f"{input_name=}/{cond=}/{lr=}_{timesteps=}"
train_logdir = logdir / "train" / exp_dir

In [None]:
num_epochs = 21
progress_bar = tqdm(total=num_epochs, disable=False)
out_dir = f"results/{exp_dir}"

os.makedirs(out_dir, exist_ok=True)
train_writer = tf.summary.create_file_writer(str(train_logdir))

start_epoch_scaler = 0
start_epoch = num_epochs * start_epoch_scaler
final_epoch = start_epoch + num_epochs * 2

for epoch in range(start_epoch, final_epoch):
    progress_bar.set_description(f"Epoch {epoch}")
    optimizer.zero_grad()
    running_loss = {"loss": 0}
    with torch.autocast(device.type):
        for batch_idx, batch in tqdm(
            enumerate(train_dataloader), total=len(train_dataloader)
        ):
            perturbed_sdm = batch["perturbed_sdm"].to(device)
            rgb_img = batch["rgb_image"].to(device)
            sdm_embed = batch["sdm_embed"].to(device)
            for i in range(1, 2):
                loss = imagen(
                    perturbed_sdm,
                    text_embeds=sdm_embed,
                    cond_images=rgb_img,
                    unet_number=i,
                )
                loss.backward()
                running_loss["loss"] += loss.item()
            optimizer.step()

            with train_writer.as_default():
                tf.summary.scalar(
                    "batch/loss",
                    loss.item(),
                    step=epoch * len(train_dataloader) + batch_idx,
                )

    with train_writer.as_default():
        tf.summary.scalar("epoch/loss", running_loss["loss"], step=epoch)

    progress_bar.update(1)

    if (epoch - 1) % 5 == 0:
        print(f"Epoch: {epoch}\t{running_loss}")

        progress_bar.set_postfix(**running_loss)

        with torch.no_grad():
            samples = imagen.sample(
                text_embeds=sdm_embed,
                cond_images=rgb_img,
                cond_scale=1.0,
                stop_at_unet_number=1,
                return_all_unet_outputs=True,
            )

        first_sample_in_batch = samples[0][0].permute(1, 2, 0).cpu().detach().numpy()
        plt.imshow(first_sample_in_batch)
        plt.savefig(f"{out_dir}/first_sample_in_batch_{epoch:04d}.png")
        out_path = f"{out_dir}/sample-{epoch}.png"
        save_image(samples[0], str(out_path), nrow=10)
        out_img = Image.open(out_path)
        max_outputs = len(samples[0])
        with train_writer.as_default():
            tf.summary.image(
                "samples",
                np.expand_dims(np.array(out_img), 0),
                max_outputs=max_outputs,
                step=epoch,
            )
            tf.summary.image(
                "interp_sdm",
                batch["perturbed_sdm"].numpy().transpose(0, 2, 3, 1),
                max_outputs=max_outputs,
                step=epoch,
            )
            tf.summary.image(
                "sdm",
                batch["sparse_dm"].numpy().transpose(0, 2, 3, 1),
                max_outputs=max_outputs,
                step=epoch,
            )
            tf.summary.image(
                "rgb",
                batch["rgb_image"].numpy().transpose(0, 2, 3, 1),
                max_outputs=max_outputs,
                step=epoch,
            )

torch.save(imagen.state_dict(), f"{out_dir}/imagen_epoch_{final_epoch}.pt")

In [None]:
experiment.end()