In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# import torch, traceback
# orig = getattr(torch._C, "_cuda_init", None)
# def _wrapped_cuda_init():
#     print("=== torch._C._cuda_init called ===")
#     traceback.print_stack(limit=8)
#     if orig:
#         orig()
# torch._C._cuda_init = _wrapped_cuda_init


from spadio import SPADFolder, SPADData  # noqa
from spadclean import GenerateTestData, SPADHotpixelTool  # noqa
from pathlib import Path
from utils import clean_hotpixels
from inference import cpu_inference
from metadata import TrainData, ModelConfig, TrainConfig, load_config
from dataset import (
    BernoulliDataset3D,
    ValidationDataset3D,
    PairedDataset,
    BinomDataset3D,
    N2NDataset3D,
)  # noqa
from spadgapmodels import SPADGAP
import torch
import torch.utils.data as dt
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    EarlyStopping,
    DeviceStatsMonitor,
)
from lightning.pytorch.loggers import TensorBoardLogger
from tifffile import imwrite, imread
import numpy as np
import logging
import sys
import shutil
from tqdm.auto import tqdm
import dask.array as da

In [None]:
def prob_from_dcr(dcr_rate_hz, fps):
    """
    Convert a dark count rate in Hz to a per-frame probability of a dark count photon.
    """
    return 1 - np.exp(-dcr_rate_hz / fps)

def thin_frames_uniform(frames, keep_prob, dcr_prob=None, seed=None):
    """
    Thin binary frames uniformly with probability keep_prob. Also adds dark
    count photons to lower the SNR.

    This is an expensive operation on SPAD data, so dask is used for
    multiprocessing.

    Args:
        dcr_prob: dark count photon probability (not the rate itself)
    """
    T, H, W = frames.shape
    # convert to a dask array with automatic chunking and apply a lazy random mask
    frames = da.from_array(frames, chunks=(400, H, W))
    rs = da.random.RandomState(seed)
    mask = rs.random_sample(frames.shape, chunks=frames.chunks) < keep_prob
    if dcr_prob is not None:
        dcr_photons = rs.binomial(1, dcr_prob, size=frames.shape, chunks=frames.chunks).astype("uint8")
    else:
        dcr_photons = da.zeros(frames.shape, chunks=frames.chunks, dtype="uint8")
    frames = (frames.astype("uint8") & mask.astype("uint8")) | dcr_photons
    return frames.compute()

In [None]:
configure_path = Path("./config.yml")
config = load_config(path=configure_path)  # CLI argument

# logging.basicConfig(
#     # filename=config["PATH"]["logger"],
#     level=logging.DEBUG,
#     format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
#     stream=sys.stdout,
# )

data_type = config["PATH"]["data_type"]
if data_type not in ["raw", "processed"]:
    raise ValueError("Data type must be RAW or CLEAN")

dir_path = Path(config["PATH"]["dir_path"])
num_of_files = config["PATH"]["num_of_files"]
data_dir = Path(config["PATH"]["data_dir"])
data_path = config["PATH"]["data_path"]
data_file = config["PATH"]["data_file"]
ground_truth_path = config["PATH"]["ground_truth_path"]
ground_truth_file = config["PATH"]["ground_truth_file"]
model_path = Path(config["PATH"]["model_path"])

data_path = data_dir / data_file if data_path == "" else Path(data_path)
ground_truth_path = (
    data_dir / ground_truth_file if ground_truth_path == "" else Path(ground_truth_path)
)
if data_type == "raw":
    try:
        if dir_path.is_dir():
            input_folder = SPADFolder(dir_path)
            input = input_folder.spadstack[:num_of_files]
            data = input.process(clean_hotpixels)
        else:
            input_folder = SPADData(dir_path)
            input = input_folder.data
            data = clean_hotpixels(input)
    except FileNotFoundError:
        logging.error("Folder not found")
        # sys.exit(1)
    del input
elif data_type == "processed":
    try:
        data = imread(data_path)
        ground_truth_file = imread(ground_truth_path)
    except FileNotFoundError:
        logging.error("File not found")
        # sys.exit(1)

data = data[:40000]
keep_prob = config["PATH"]["thin"]
if keep_prob < 1.0:
    data = thin_frames_uniform(data, keep_prob=keep_prob, seed=42)
idx_train = int(data.shape[0] * 0.8)
data_config = TrainData.from_config(config["DATA"], data[:idx_train].astype(np.float32))
model_config = ModelConfig.from_config(config["MODEL"])
train_config = TrainConfig.from_config(config["TRAINING"], data_config, model_config)
val_data_config = TrainData.from_config_validation(
    config["DATA"], (data[idx_train:].astype(np.float32))
)
print(train_config.metadata())

In [None]:
train_data = BernoulliDataset3D.from_dataclass(data_config)
val_data = ValidationDataset3D.from_dataclass(val_data_config)

loader_config = {
    "batch_size": train_config.batch_size,
    "shuffle": train_config.shuffle,
    "pin_memory": train_config.pin_memory,
    "drop_last": train_config.drop_last,
    "num_workers": train_config.num_workers,
    "persistent_workers": True,
}

train_loader = dt.DataLoader(train_data, **loader_config)
loader_config["shuffle"] = False
val_loader = dt.DataLoader(val_data, **loader_config)

test_name = train_config.name
default_root_dir = model_path / test_name
if not default_root_dir.exists():
    default_root_dir.mkdir(parents=True)

model = SPADGAP.from_dataclass(model_config)
model.train()

logger = TensorBoardLogger(save_dir=model_path, name=test_name)

trainer = pl.Trainer(
    default_root_dir=default_root_dir,
    accelerator="gpu",
    gradient_clip_val=1,
    precision=train_config.precision,  # type: ignore
    devices=[0, 1, 2, 3, 4, 5, 6, 7],
    strategy="ddp_find_unused_parameters_true",
    max_epochs=train_config.epochs,
    callbacks=[
        ModelCheckpoint(
            save_weights_only=True,
            mode="min",
            monitor="val_loss",
            save_top_k=2,
        ),
        LearningRateMonitor("epoch"),
        # EarlyStopping("val_loss", patience=25),
        # DeviceStatsMonitor(),
    ],
    logger=logger,  # type: ignore
    profiler="simple",
    limit_val_batches=20,
    enable_model_summary=True,
    enable_checkpointing=True,
)
# print(f"input_size: {tuple(next(iter(train_loader))[0].shape)}")
print(f"file: {test_name}")

In [None]:
# import torchinfo
# torchinfo.summary(model, input_size=(train_config.batch_size, 1, 64, 64, 64))

In [None]:
model.train()
train_config.to_yaml(default_root_dir / "metadata.yml")
shutil.copyfile(configure_path, default_root_dir / "config.yml")
trainer.fit(model, train_loader, val_loader)
trainer.save_checkpoint(default_root_dir / "final_model.ckpt")

In [None]:
# trainer.save_checkpoint(default_root_dir / "final_model.ckpt")

In [None]:
# from utils import clear_vram  
# clear_vram()

In [None]:
# data_split = np.random.binomial(data, 0.9)


In [None]:
# splitted_inference = []
# for i in range(0, 10):
#     data_split = np.random.binomial(data, 0.9)
#     output = gpu_patch_inference(
#         model,
#         data_split[:512].astype(np.float32),
#         initial_patch_depth=48,
#         min_overlap=40,
#         device=train_config.device_number,
#     )
#     splitted_inference.append(output)

In [None]:
from matplotlib import cm
import cv2
from tqdm.auto import tqdm
def get_codec_for_format(format: str):
    """
    Get appropriate fourcc codec string for given video format.
    """
    format = format.lower()
    if format == "mp4":
        return "mp4v"
    elif format == "avi":
        return "FFV1"
    elif format == "mov":
        return "avc1"
    else:
        raise ValueError(f"I haven't added the codec for: {format}")
def to_video(frames: np.ndarray, path, res_scale=1.0, playback_fps=None, cmap=None, format=None, maxv=None):
    """
    Saves video frame arrays to a video file or sequence of PNGs. If path has no extension, 
    it is treated as a directory and individual image files are saved.

    Args:
        frames (np.ndarray): (T x H x W x C) (RGB) or (T x H x W) (intensity) video frames.
        path (str or Path): output video file path or directory for image files.
        res_scale (float): resolution scaling factor with nearest neighbor interpolation.
        cmap: ignored if frames are RGB; otherwise, matplotlib colormap name or object.
        format (str or None): video format (e.g., "mp4", "avi"), or image format (e.g., "png");
            if None, inferred from path suffix.
    """
    path = Path(path)
    if cmap is None:
        cmap = "viridis"
    cmap_fn = cm.get_cmap(cmap)
    is_rgb = False
    if frames.ndim == 4:
        if frames.shape[3] == 3:
            is_rgb = True
        else:
            raise ValueError("4D frames array must have shape (T, H, W, 3) for RGB video")
    elif frames.ndim == 3:
        is_rgb = False
    else:
        raise ValueError("frames must be a 3D or 4D numpy array")

    # compute a normalized intensity in [0,1] for colormap input
    if maxv is None:
        maxv = float(np.max(frames))
        if maxv == 0.0:
            maxv = 1.0

    H, W = frames.shape[1], frames.shape[2]
    if res_scale != 1.0:
        out_W = int(W * res_scale)
        out_H = int(H * res_scale)
    else:
        out_W = W
        out_H = H
    # if path is a directory, write individual image files
    is_video_file = path.suffix in [".mp4", ".avi", ".mov", ".mkv"]
    if not is_video_file:
        path.mkdir(parents=True, exist_ok=True)
        if format is None:
            format = "png"
    else:
        if playback_fps is None:
            raise ValueError("playback_fps must be specified if saving a video file")
        if format is None:
            format = path.suffix[1:].lower()
        codec = get_codec_for_format(format)
        fourcc = cv2.VideoWriter_fourcc(*codec)
        vidwriter = cv2.VideoWriter(str(path), fourcc, playback_fps, (out_W, out_H), isColor=True)

    max_frames = len(frames)

    for i in tqdm(range(max_frames), desc="Writing video frames"):
        intensity = np.clip(frames[i], 0, maxv) / maxv  # normalize to [0,1]
        if is_rgb:
            rgb_mapped = (intensity * 255.0).astype(np.uint8)  # (H,W,3) in RGB
        else:
            # apply matplotlib colormap -> returns RGBA in [0,1]
            rgba_mapped = cmap_fn(intensity)  # shape (H,W,4)
            rgb_mapped = (rgba_mapped[..., :3] * 255.0).astype(np.uint8)  # (H,W,3) in RGB
        bgr_mapped = rgb_mapped[..., ::-1]  # convert to BGR for OpenCV
        if res_scale != 1.0:
            bgr_mapped = cv2.resize(bgr_mapped, (out_W, out_H), interpolation=cv2.INTER_NEAREST)
        if is_video_file:
            vidwriter.write(bgr_mapped)
        else:
            frame_path = path / f"frame_{i:05d}.{format}"
            cv2.imwrite(str(frame_path), bgr_mapped)
    if is_video_file:
        vidwriter.release()

In [None]:
from inference import gpu_patch_inference

dataset = "guitar-0.03125"
model = SPADGAP.load_from_checkpoint(f"models/{dataset}/final_model.ckpt")
indata = data[:10000].astype(np.float32)
output = gpu_patch_inference(
    model,
    indata,
    initial_patch_depth=48,
    min_overlap=40,
    device=train_config.device_number,
)
resdir = Path(f"results/{dataset}")
resdir.mkdir(parents=True, exist_ok=True)
imwrite(resdir / "input.tif", indata)
imwrite(resdir / "inference.tif", output)
to_video(output, resdir / "inference.avi",  playback_fps=30, cmap="grey")
gamma = output ** (1/2.2)
gamma14 = output ** (1/4.0)
to_video(gamma, resdir / "inference-gamma.avi",  playback_fps=30, cmap="grey")
to_video(gamma14, resdir / "inference-gamma-1-4.avi",  playback_fps=30, cmap="grey")
to_video(gamma[::100], resdir / "inference-gamma-fps1000.avi",  playback_fps=30, cmap="grey")
to_video(gamma14[::100], resdir / "inference-gamma-1-4-fps1000.avi",  playback_fps=30, cmap="grey")

In [None]:
output = imread(Path("results") / "inference.tif")

In [None]:
import matplotlib.pyplot as plt
plt.plot(np.arange(output.shape[0]) / 100_000, output[:, 50, 450])
plt.xlabel("Time (s)")
# plt.xlim(0, 0.01)

In [None]:
from utils import group_metrics

input = data[:512].astype(float)
image = output
ground_truth = ground_truth_file[:512].astype(float)
group_metrics(input, image, ground_truth, default_root_dir, device=train_config.device)