# Light Schrödinger Bridge Night to Day Domain Transform

This notebooks trains the LightSB model using the AutoencoderKL latents we generated. It will perform a grid search over the n_potentials, is_diagonal and epsilon parameters to get the best performing model.

Next, it performs a transformation on the night validation set. Finally, we decode these latents back into images using the encoder and save them separately to the data folder.

## Imports

In [27]:
import os, sys

# if colab, mount drive and get the git repo
if 'google.colab' in sys.modules:
    from google.colab import drive
    print(os.getcwd())
    drive.mount('/content/drive')
    !git clone --recurse-submodules https://github.com/jsluijter02/LightSB_YOLO

    # Append LightSB_YOLO path
    sys.path.append(os.path.join(os.getcwd(), 'LightSB_YOLO'))

    ## TODO: 

# otherwise local path append
else:
    sys.path.append(os.path.dirname(os.getcwd()))

In [28]:
import numpy as np
np.random.seed(0)
from argparse import Namespace
import copy
import torch
import diffusers
from datetime import date

from scripts.utils import dirs
dirs.add_LIGHTSB_to_PATH()

from scripts.models.autoencoderkl import AutoencoderKL_BDD
from scripts.models.lightsb import LightSB_BDD
from scripts.evals.FID import latent_FID_score, image_FID_score
from scripts.utils import img
from scripts.utils.device import get_device

## Encoder Model

In [29]:
encoder = AutoencoderKL_BDD()

## Load Data

In [30]:
data_dir = dirs.get_data_dir()
latent_dir = os.path.join(data_dir, "encodings")

# Only save Val_night file names -> others are redundant
train_day_latents, _  = encoder.load_latents(os.path.join(latent_dir, "train_day.npz"))
print(train_day_latents.shape)

train_night_latents, _ = encoder.load_latents(os.path.join(latent_dir, "train_night.npz"))
print("train_night_latents shape: ", train_night_latents.shape)

val_day_latents, _ = encoder.load_latents(os.path.join(latent_dir, "val_day.npz"))
print("val_day_latents shape: ", val_day_latents.shape)

val_night_latents, val_night_filenames = encoder.load_latents(os.path.join(latent_dir, "val_night.npz"))
print("val_night_latents shape: ", val_night_latents.shape)
print("val_night_filenames length: ", len(val_night_filenames))

np_data = {"train_day": train_day_latents, 
           "train_night": train_night_latents, 
           "val_day": val_day_latents, 
           "val_night": val_night_latents}

(36800, 4096)
train_night_latents shape:  (28028, 4096)


KeyboardInterrupt: 

## Load Light Schrödinger Bridge model

In [None]:
# sb_config = LightSB_BDD.standard_config()
# sb = LightSB_BDD(sb_config, np_data=np_data)

## Train SB with a Grid Search + Val Set Evaluation
To get a sense of parameter's effectiveness on the transformation, this step performs a grid search and saves the best parameters and state dictionary.

Evaluation method: FID on latents.

Due to computational and time constraints, it is not possible to transform ALL images six times and produce FID-scores, so a latent-based proxy is taken to determine final LightSB model parameters. FID metric on images will be taken of the final image set.

This is different from actual FID metric, as this takes the metric on features learned from a CNN.

Additionally, downstream, the mAP of the YOLOPX algorithm will determine this method's succesfullness as a preprocessing step.

### Grid Search Parameters

In [None]:
epsilons = [0.01, 0.1, 0.5]
n_potentials = [10, 20]
is_diagonal = ... # Doesn't work on mac, but it 
max_steps = [1000, 10000, 50000]

### Grid Search Setup

In [None]:
# args = Namespace()

# models = {}
# fid_scores = {}
# sample_indices = img.sample_indices(len(sb.X_test), how_many=5)

# best = {
#     "name": None,
#     "fid": None,
#     "state_dict": None,
#     "epsilon" : None,
#     "max_steps": None,
#     "n_potentials": None
# }

# date = date.today().strftime('%Y%m%d')
# runs_path = os.path.join(dirs.get_base_dir(), "runs", "LightSB_GridSearch", date)
# os.makedirs(runs_path, exist_ok=True)

### Grid Search Loop + Internal Val Set Eval

In [None]:
# TODO: Split train 50/50 for grid search

# for eps in epsilons:
#     for potential in n_potentials:
#         for steps in max_steps:
#             print("Started process for: ", steps, " ", eps, " ", potential)
#             # to reduce training time, reduce steps to 1000
#             args.MAX_STEPS = steps
#             args.EPSILON = eps
#             args.N_POTENTIALS = potential

#             print("Reloading model")
#             sb.update_config(args=args)
#             sb.reload_model()
#             print("Reloaded model")

#             print("Training model")
#             sb.train()
#             print("Trained model")

#             print("Transforming Validation Latents")
#             transformed = sb.transform(sb.X_test)

#             fid = latent_FID_score(transformed, sb.Y_test)
#             print("FID-score on latents: ", fid)

#             state_dict = copy.deepcopy(sb.model.state_dict())

#             save = {
#                 "fid": fid,
#                 "state_dict": state_dict,
#                 "max_steps": steps,
#                 "epsilon": eps,
#                 "n_potentials": potential
#             }

#             model_name = f'{fid}_{eps}_{potential}_{steps}.pt'

#             torch.save(save, os.path.join(runs_path, model_name))

#             # Save highest FID score's params
#             if fid > best["fid"]:
#                 best["fid"] = fid
#                 best["epsilon"] = eps
#                 best["max_steps"] = steps
#                 best["n_potentials"] = potential
#                 best["name"] = model_name
#                 best["state_dict"] = state_dict

#             sample_imgs = diffusers.utils.pt_to_pil(encoder.decode_latents(transformed[sample_indices]))
#             img.plot_samples(sample_imgs, title=f'Transformed Val Set Images with Params: Steps: {steps}, Epsilon: {eps}, N_potentials: {potential}. FID: {fid}')

In [None]:
# print(f'Best model: {best["name"]}, FID: {best["fid"]}')

## Transform Val Images with Best Model

In [None]:
best_cfg = LightSB_BDD.standard_config()
best_cfg.MODEL.EPSILON = 0.1#best["epsilon"]
best_cfg.MODEL.N_POTENTIALS = 20#best["n_potentials"]
best_cfg.MAX_STEPS = 10000#best["max_steps"]

best_sb = LightSB_BDD(config=best_cfg, np_data=np_data)

# best_sb.load_state_dict(best["state_dict"])

In [None]:
best_sb.train()

100%|██████████| 10000/10000 [02:40<00:00, 62.30it/s]


In [None]:
best_transformed = best_sb.transform(best_sb.X_test)

In [None]:
np.save(os.path.join(dirs.get_data_dir(), "LightSB_transformed.npy"), best_transformed.cpu())

In [None]:
best_transformed = torch.as_tensor(np.load(os.path.join(dirs.get_data_dir(), "LightSB_transformed.npy")), device=get_device())
decoded = encoder.decode_latents(best_transformed, batch_size=16)

latents shape:  torch.Size([3929, 4096])
reshaped latents shape:  torch.Size([3929, 4, 32, 32])


100%|██████████| 246/246 [08:33<00:00,  2.09s/it]

decoded latents shape:  torch.Size([3929, 3, 256, 256])





In [None]:
np.save(os.path.join(dirs.get_data_dir(), "LightSB_decoded.npy"), decoded.cpu())

## Save Images to Folder

In [None]:
decoded = torch.as_tensor(np.load(os.path.join(dirs.get_data_dir(), "LightSB_decoded.npy")), device=get_device())

In [None]:
## TODO: Reverse to original aspect ratio

encoder.save_imgs(decoded, filenames=val_night_filenames, folder_name="LightSB_Images", split="val")

3929it [00:18, 208.12it/s]


## Get ACTUAL Val Night FID Metric

In [None]:
## Day / Day FID
## Night / Day FID
## Day / Fake Day FID

## LightSB GT Box Check

In [None]:
indices = img.sample_indices(len(val_night_filenames), how_many=5)
print(indices)

[3291 1004 1254  704  399]


In [None]:
# --- GT bbox overlay on SB images (YOLOPX loader-space, one batch, with tiny-box filter) ---

import os
import cv2
import numpy as np
import torchvision.transforms as transforms

from scripts.utils import dirs
from scripts.dataset import bdd
from scripts.models.yolo import YOLOPX_BDD

import models.YOLOPX.lib.dataset as dataset
from models.YOLOPX.lib.utils.utils import DataLoaderX
from models.YOLOPX.lib.utils.plot import plot_one_box  # if your file is plots.py, change import

# 0) init
yolo = YOLOPX_BDD()

sb_root = os.path.join(dirs.get_data_dir(), "LightSB_Images")
assert os.path.exists(sb_root), f"SB folder not found: {sb_root}"

out_dir = os.path.join(dirs.get_data_dir(), "sb_gt_overlays")
os.makedirs(out_dir, exist_ok=True)

# 1) build val dataset (avoid train aug)
transf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_set = bdd.get_bdd_dataset(yolo.config, is_train=False, skip=True, transform=transf)
val_set.db = bdd.get_db(yolo.config, is_train=False, timeofday="night")

# remap image path to SB by basename
for d in val_set.db:
    d["image"] = os.path.join(sb_root, os.path.basename(d["image"]))

# filter missing files
val_set.db = [d for d in val_set.db if os.path.exists(d["image"])]
print("DB after filtering:", len(val_set.db))

loader = DataLoaderX(
    val_set,
    batch_size=16,
    shuffle=False,
    num_workers=yolo.config.WORKERS,
    collate_fn=dataset.AutoDriveDataset.collate_fn
)

# 2) get one batch
imgs, targets, paths, shapes = next(iter(loader))

# 3) denorm helper (imgs are normalized)
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)[None, None, :]
std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)[None, None, :]

t = targets[0].cpu().numpy()  # per-image padded targets, format: [cls, cx, cy, w, h] in 640-space

# 4) filter to make the visualization readable
MIN_WH = 12
MIN_AREA = 12 * 12

# 5) draw GT boxes on the EXACT loader image (so coords match)
for i in range(len(paths)):
    # tensor -> uint8 BGR image for cv2 drawing
    im = imgs[i].permute(1, 2, 0).cpu().numpy()          # HWC, float
    im = (im * std + mean) * 255.0
    im = np.clip(im, 0, 255).astype(np.uint8)
    im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)

    ti = t[i]  # [max_objs, fields]
    H, W = im.shape[:2]

    for row in ti:
        if row.sum() == 0:
            continue  # padding rows

        cls = int(row[0])
        cx, cy, bw, bh = row[1:5].astype(float)

        # coords are already pixels in 640-space (per your printouts), so no scaling needed

        # xywh -> x1y1x2y2
        x1 = cx - bw / 2
        y1 = cy - bh / 2
        x2 = cx + bw / 2
        y2 = cy + bh / 2

        w = x2 - x1
        h = y2 - y1
        if w < MIN_WH or h < MIN_WH or (w * h) < MIN_AREA:
            continue

        plot_one_box([x1, y1, x2, y2], im, color=[0, 0, 255])

    out_path = os.path.join(out_dir, f"sb_gt_batch0_{i}_{os.path.basename(paths[i])}")
    cv2.imwrite(out_path, im)

print("Saved filtered GT overlays to:", out_dir)

AUTO_RESUME: False
CUDNN:
  BENCHMARK: True
  DETERMINISTIC: False
  ENABLED: True
DATASET:
  CLAHE_CLIPLIMIT: 2.0
  CLAHE_VAL: False
  COLOR_RGB: False
  DATAROOT: /Users/jochem/Documents/GitHub/LightSB_YOLO/data/bdd/images
  DATASET: BddDataset
  DATA_FORMAT: jpg
  FLIP: True
  HSV_H: 0.015
  HSV_S: 0.7
  HSV_V: 0.4
  LABELROOT: /Users/jochem/Documents/GitHub/LightSB_YOLO/data/bdd/det_annotations
  LANEROOT: /Users/jochem/Documents/GitHub/LightSB_YOLO/data/bdd/ll_seg_annotations
  MASKROOT: /Users/jochem/Documents/GitHub/LightSB_YOLO/data/bdd/da_seg_annotations
  ORG_IMG_SIZE: [720, 1280]
  ROT_FACTOR: 10
  SCALE_FACTOR: 0.25
  SELECT_DATA: False
  SHEAR: 0.0
  TEST_SET: val
  TRAIN_SET: train
  TRANSLATE: 0.1
DEBUG: False
GPUS: (0,)
LOG_DIR: runs/
LOSS:
  BOX_GAIN: 0.05
  CLS_GAIN: 0.5
  CLS_POS_WEIGHT: 1.0
  DA_SEG_GAIN: 0.2
  FL_GAMMA: 2.0
  LL_IOU_GAIN: 0.2
  LL_SEG_GAIN: 0.2
  LOSS_NAME: 
  MULTI_HEAD_LAMBDA: None
  OBJ_GAIN: 1.0
  OBJ_POS_WEIGHT: 1.0
  SEG_POS_WEIGHT: 1.0
MODEL

DB after filtering: 3929
Saved filtered GT overlays to: /Users/jochem/Documents/GitHub/LightSB_YOLO/data/sb_gt_overlays


In [None]:
import os
from scripts.utils import dirs

print("data dir:", dirs.get_data_dir())
print("contents:", os.listdir(dirs.get_data_dir())[:50])

data dir: /Users/jochem/Documents/GitHub/LightSB_YOLO/data
contents: ['encodings', '.DS_Store', 'LightSB_transformed.npy', 'weights', 'LightSB_decoded.npy', 'pkl_files', 'CLAHE_images', 'LightSB_Images', 'encodings.zip', 'bdd']
