# Light Schrödinger Bridge Night to Day Domain Transform

This notebooks trains the LightSB model using the AutoencoderKL latents we generated. It will perform a grid search over the n_potentials, is_diagonal and epsilon parameters to get the best performing model.

Next, it performs a transformation on the night validation set. Finally, we decode these latents back into images using the encoder and save them separately to the data folder.

## Imports

In [1]:
import os, sys

# if colab, mount drive and get the git repo
if 'google.colab' in sys.modules:
    from google.colab import drive
    print(os.getcwd())
    drive.mount('/content/drive')
    !git clone --recurse-submodules https://github.com/jsluijter02/LightSB_YOLO

    # Append LightSB_YOLO path
    sys.path.append(os.path.join(os.getcwd(), 'LightSB_YOLO'))

    ## TODO: 

# otherwise local path append
else:
    sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
import numpy as np
np.random.seed(0)
from argparse import Namespace
import copy
import torch
import diffusers
from datetime import date

from scripts.utils import dirs
dirs.add_LIGHTSB_to_PATH()

from scripts.models.autoencoderkl import AutoencoderKL_BDD
from scripts.models.lightsb import LightSB_BDD
from scripts.evals.FID import latent_FID_score, image_FID_score
from scripts.utils import img
from scripts.utils.device import get_device

  from .autonotebook import tqdm as notebook_tqdm


## Encoder Model

In [3]:
encoder = AutoencoderKL_BDD()

## Load Data

In [4]:
data_dir = dirs.get_data_dir()
latent_dir = os.path.join(data_dir, "encodings")

# Only save Val_night file names -> others are redundant
train_day_latents, _  = encoder.load_latents(os.path.join(latent_dir, "train_day.npz"))
print(train_day_latents.shape)

train_night_latents, _ = encoder.load_latents(os.path.join(latent_dir, "train_night.npz"))
print("train_night_latents shape: ", train_night_latents.shape)

val_day_latents, _ = encoder.load_latents(os.path.join(latent_dir, "val_day.npz"))
print("val_day_latents shape: ", val_day_latents.shape)

val_night_latents, val_night_filenames = encoder.load_latents(os.path.join(latent_dir, "val_night.npz"))
print("val_night_latents shape: ", val_night_latents.shape)
print("val_night_filenames length: ", len(val_night_filenames))

np_data = {"train_day": train_day_latents, 
           "train_night": train_night_latents, 
           "val_day": val_day_latents, 
           "val_night": val_night_latents}

(36800, 4096)
train_night_latents shape:  (28028, 4096)
val_day_latents shape:  (5258, 4096)
val_night_latents shape:  (3929, 4096)
val_night_filenames length:  3929


## Load Light Schrödinger Bridge model

In [5]:
# sb_config = LightSB_BDD.standard_config()
# sb = LightSB_BDD(sb_config, np_data=np_data)

## Train SB with a Grid Search + Val Set Evaluation
To get a sense of parameter's effectiveness on the transformation, this step performs a grid search and saves the best parameters and state dictionary.

Evaluation method: FID on latents.

Due to computational and time constraints, it is not possible to transform ALL images six times and produce FID-scores, so a latent-based proxy is taken to determine final LightSB model parameters. FID metric on images will be taken of the final image set.

This is different from actual FID metric, as this takes the metric on features learned from a CNN.

Additionally, downstream, the mAP of the YOLOPX algorithm will determine this method's succesfullness as a preprocessing step.

### Grid Search Parameters

In [6]:
epsilons = [0.01, 0.1, 0.5]
n_potentials = [10, 20]
is_diagonal = ... # Doesn't work on mac, but it 
max_steps = [1000, 10000, 50000]

### Grid Search Setup

In [7]:
# args = Namespace()

# models = {}
# fid_scores = {}
# sample_indices = img.sample_indices(len(sb.X_test), how_many=5)

# best = {
#     "name": None,
#     "fid": None,
#     "state_dict": None,
#     "epsilon" : None,
#     "max_steps": None,
#     "n_potentials": None
# }

# date = date.today().strftime('%Y%m%d')
# runs_path = os.path.join(dirs.get_base_dir(), "runs", "LightSB_GridSearch", date)
# os.makedirs(runs_path, exist_ok=True)

### Grid Search Loop + Internal Val Set Eval

In [8]:
# for eps in epsilons:
#     for potential in n_potentials:
#         for steps in max_steps:
#             print("Started process for: ", steps, " ", eps, " ", potential)
#             # to reduce training time, reduce steps to 1000
#             args.MAX_STEPS = steps
#             args.EPSILON = eps
#             args.N_POTENTIALS = potential

#             print("Reloading model")
#             sb.update_config(args=args)
#             sb.reload_model()
#             print("Reloaded model")

#             print("Training model")
#             sb.train()
#             print("Trained model")

#             print("Transforming Validation Latents")
#             transformed = sb.transform(sb.X_test)

#             fid = latent_FID_score(transformed, sb.Y_test)
#             print("FID-score on latents: ", fid)

#             state_dict = copy.deepcopy(sb.model.state_dict())

#             save = {
#                 "fid": fid,
#                 "state_dict": state_dict,
#                 "max_steps": steps,
#                 "epsilon": eps,
#                 "n_potentials": potential
#             }

#             model_name = f'{fid}_{eps}_{potential}_{steps}.pt'

#             torch.save(save, os.path.join(runs_path, model_name))

#             # Save highest FID score's params
#             if fid > best["fid"]:
#                 best["fid"] = fid
#                 best["epsilon"] = eps
#                 best["max_steps"] = steps
#                 best["n_potentials"] = potential
#                 best["name"] = model_name
#                 best["state_dict"] = state_dict

#             sample_imgs = diffusers.utils.pt_to_pil(encoder.decode_latents(transformed[sample_indices]))
#             img.plot_samples(sample_imgs, title=f'Transformed Val Set Images with Params: Steps: {steps}, Epsilon: {eps}, N_potentials: {potential}. FID: {fid}')

In [9]:
# print(f'Best model: {best["name"]}, FID: {best["fid"]}')

## Transform Val Images with Best Model

In [10]:
best_cfg = LightSB_BDD.standard_config()
best_cfg.MODEL.EPSILON = 0.1#best["epsilon"]
best_cfg.MODEL.N_POTENTIALS = 20#best["n_potentials"]
best_cfg.MAX_STEPS = 10000#best["max_steps"]

best_sb = LightSB_BDD(config=best_cfg, np_data=np_data)

# best_sb.load_state_dict(best["state_dict"])

In [11]:
best_sb.train()

100%|██████████| 10000/10000 [02:40<00:00, 62.30it/s]


In [None]:
best_transformed = best_sb.transform(best_sb.X_test)

In [15]:
np.save(os.path.join(dirs.get_data_dir(), "LightSB_transformed.npy"), best_transformed.cpu())

In [None]:
best_transformed = torch.as_tensor(np.load(os.path.join(dirs.get_data_dir(), "LightSB_transformed.npy")), device=get_device())
decoded = encoder.decode_latents(best_transformed, batch_size=16)

latents shape:  torch.Size([3929, 4096])
reshaped latents shape:  torch.Size([3929, 4, 32, 32])


100%|██████████| 246/246 [08:33<00:00,  2.09s/it]

decoded latents shape:  torch.Size([3929, 3, 256, 256])





In [8]:
np.save(os.path.join(dirs.get_data_dir(), "LightSB_decoded.npy"), decoded.cpu())

## Save Images to Folder

In [5]:
decoded = torch.as_tensor(np.load(os.path.join(dirs.get_data_dir(), "LightSB_decoded.npy")), device=get_device())

In [6]:
encoder.save_imgs(decoded, filenames=val_night_filenames, folder_name="LightSB_Images")

3929it [00:01, 1982.43it/s]


## Get ACTUAL Val Night FID Metric

In [None]:
## Day / Day FID
## Night / Day FID
## Day / Fake Day FID