In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import sys
sys.path.append(str(Path.cwd() / 'src'))

from src.configs import TestConfig
import numpy as np
from tqdm import tqdm
import torch
from dataset import MonoDS
from datetime import datetime
from PIL import Image

config = TestConfig()

# Model Setup
Configure and initialize the model to be tested according to the desired specifications:
1. Model backbone (default is CUT).
2. Use the camera intrinsic (FPA) temperature (default is True).
3. Use the physical model (default is True).

In [None]:
from cycle_gan_model import CycleGANModel
from cut_model import CUTModel
from utils.deep import NetPhase


# config.model = "CUT" # "CycleGan"
# config.thermal.is_fpa_input = True
# config.thermal.is_physical_model = True


backbone = CUTModel if config.model == "CUT" else CycleGANModel
model = backbone(config)

## Load model checkpoint
The model object loads the following types of weights:
- D: the weights of the discriminator.
- G: the weights of the deep generator.
- F: the weights of the contrastive classifier network (for CUT backbone only).
- PW: the weights of the affine layer applied to the physical estimator.
- coefficients: the calibrated coefficients of the physical estimator.

In [None]:
model.set_phase(NetPhase.test)
model.setup(config)
model.load_networks("best")

# Create dataloaders

In [None]:
batch_size = 100
dataset = MonoDS(src_dir=Path("data", "pan", "test"))
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=int(config.data_loader.num_threads),
)
print(f"The number of images for inference = {len(dataloader) * batch_size}")

# Inference
The loop transforms the panchromatic test set to synthetic monochromatic images and saves them either as .npy (original resolution) or .png (uint8) formats according to the user specification of the fmt_of_output variable

In [None]:
fmt_of_output = "png"

path_to_save = Path.cwd() / 'results' / 'transformed' / datetime.now().strftime("%Y%m%d_h%Hm%Ms%S")
path_to_save.mkdir(parents=True, exist_ok=True)

with torch.inference_mode():
    for i, data in enumerate(tqdm(dataloader, desc="Test")):
        model.set_input(data)
        model.forward()

        visuals = model.get_current_visuals()

        for type in ["pan_real", "mono_fake", "mono_phys"]:
            sub_dir = path_to_save / type
            sub_dir.mkdir(parents=True, exist_ok=True)
            cur_vis = visuals[type]
            domain = type.split("_")[0]
            images = model.rec_image(cur_vis, domain=domain, fmt=fmt_of_output)
            if fmt_of_output == "npy":
                for j, image in enumerate(images):
                    np.save(f"{str(sub_dir)}/{i*batch_size + j}.npy", image)
            else:
                for j, image in enumerate(images):
                    pil_img = Image.fromarray(image)
                    pil_img.save(f"{str(sub_dir)}/{i*batch_size + j}.{fmt_of_output}")