In [None]:
from spadio import SPADFolder, SPADData  # noqa
from spadclean import GenerateTestData, SPADHotpixelTool  # noqa
from pathlib import Path
from utils import clean_hotpixels
from inference import cpu_inference
from metadata import TrainData, ModelConfig, TrainConfig, load_config
from dataset import (
    BernoulliDataset3D,
    ValidationDataset3D,
    BinomDataset3D,
    N2NDataset3D,
)  # noqa
from spadgapmodels import SPADGAP
import torch
import torch.utils.data as dt
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    EarlyStopping,
    DeviceStatsMonitor,
)
from lightning.pytorch.loggers import TensorBoardLogger
from tifffile import imwrite, imread
import numpy as np
import logging
import sys
import shutil


default_root_dir = Path("../models/20240413_073549_00999999_1000x32x256x256_skip=0_l=10_d=5_sf=32_ds=2at10_f=10.0_z=2_g=8_sd=0_b=tri_a=gelu_b=4_e=500_p=32")
configure_path = default_root_dir / "config.yml"
config = load_config(path=configure_path)  # CLI argument

data_type = config["PATH"]["data_type"]
if data_type not in ["raw", "processed"]:
    raise ValueError("Data type must be RAW or CLEAN")

dir_path = Path(config["PATH"]["dir_path"])
num_of_files = config["PATH"]["num_of_files"]
data_dir = Path(config["PATH"]["data_dir"])
data_path = config["PATH"]["data_path"]
data_file = config["PATH"]["data_file"]
ground_truth_path = config["PATH"]["ground_truth_path"]
ground_truth_file = config["PATH"]["ground_truth_file"]
model_path = Path(config["PATH"]["model_path"])

data_path = data_dir / data_file if data_path == "" else Path(data_path)
ground_truth_path = (
    data_dir / ground_truth_file if ground_truth_path == "" else Path(ground_truth_path)
)
if data_type == "raw":
    try:
        input_folder = SPADFolder(dir_path)
    except FileNotFoundError:
        logging.error("Folder not found")
        # sys.exit(1)
    input = input_folder.spadstack[:num_of_files]
    data = input.process(clean_hotpixels)
    del input
elif data_type == "processed":
    try:
        data = imread(data_path)
        ground_truth_file = imread(ground_truth_path)
    except FileNotFoundError:
        logging.error("File not found")
        # sys.exit(1)


data_config = TrainData.from_config(config["DATA"], data.astype(np.float32))
model_config = ModelConfig.from_config(config["MODEL"])
train_config = TrainConfig.from_config(config["TRAINING"], data_config, model_config)
val_data_config = TrainData.from_config_validation(
    config["DATA"], data.astype(np.float32)
)
print(train_config.metadata())

In [None]:
train_data = BernoulliDataset3D.from_dataclass(data_config)
val_data = ValidationDataset3D.from_dataclass(val_data_config)

loader_config = {
    "batch_size": train_config.batch_size,
    "shuffle": train_config.shuffle,
    "pin_memory": train_config.pin_memory,
    "drop_last": train_config.drop_last,
    "num_workers": train_config.num_workers,
    "persistent_workers": True,
}

train_loader = dt.DataLoader(train_data, **loader_config)
loader_config["shuffle"] = False
val_loader = dt.DataLoader(val_data, **loader_config)

test_name = train_config.name
ckpt_path = default_root_dir / "final_model.ckpt"
if not ckpt_path.exists():
    ckpt_path = default_root_dir / "version_0" / "checkpoints"
    check_points = list(ckpt_path.glob("*.ckpt"))
    check_points.sort()
    ckpt_path = check_points[-1]
    print(f"Using checkpoint: {ckpt_path}")

model = SPADGAP.load_from_checkpoint(ckpt_path)
model.train()

logger = TensorBoardLogger(save_dir=model_path, name=test_name)

trainer = pl.Trainer(
    default_root_dir=default_root_dir,
    accelerator="cuda",
    gradient_clip_val=1,
    precision=train_config.precision,  # type: ignore
    devices=[train_config.device_number],
    max_epochs=train_config.epochs,
    callbacks=[
        ModelCheckpoint(
            save_weights_only=True,
            mode="min",
            monitor="val_loss",
            save_top_k=2,
        ),
        LearningRateMonitor("epoch"),
        # EarlyStopping("val_loss", patience=25),
        # DeviceStatsMonitor(),
    ],
    logger=logger,  # type: ignore
    profiler="simple",
    limit_val_batches=20,
    enable_model_summary=True,
    enable_checkpointing=True,
)
print(f"input_size: {tuple(next(iter(train_loader))[1].shape)}")
print(f"file: {test_name}")

In [None]:
trainer.fit(model, train_loader, val_loader)
trainer.save_checkpoint(default_root_dir / "final_model.ckpt")

In [None]:
trainer.save_checkpoint(default_root_dir / "final_model.ckpt")

In [None]:
from utils import clear_vram  
clear_vram()

In [None]:
data_backup = data.copy()

In [None]:
n = 10
datalist = []
output_list = []
output_value = []
for i in range(n):
    default_root_dir_b = default_root_dir / f"psnr096/new_inference_{i}"
    default_root_dir_b.mkdir(exist_ok=True, parents=True)  
    input = np.random.binomial(data_backup, 0.96)[3450:3450+512].astype(np.float32)
    datalist.append(input)  
    output = gpu_patch_inference(
        model,
        input,
        initial_patch_depth=48,
        min_overlap=40,
        device=train_config.device_number,
    )
    imwrite(default_root_dir_b / "output.tif", output)
    output_list.append(output)
    ground_truth = ground_truth_file[3450:3450+512].astype(float)
    psnr = group_metrics(input, output, ground_truth, default_root_dir_b, device=train_config.device)
    output_value.append(psnr)
    

In [None]:
output = np.stack(output_list)
output_sum = np.sum(output, axis=0)
default_root_dir_b = default_root_dir / f"psnr096/sum"
psnr_sum = group_metrics(input, output_sum, ground_truth, default_root_dir_b, device=train_config.device)
output_median = np.median(output, axis=0)
default_root_dir_b = default_root_dir / f"psnr096/median"
psnr_median = group_metrics(input, output_median, ground_truth, default_root_dir_b, device=train_config.device)


In [None]:
output_sum_b = output_sum - np.max(output, axis=0)
output_sum_b = output_sum_b - np.min(output, axis=0)

In [None]:
default_root_dir_b = default_root_dir / f"psnr096/sum_no_outliers"
psnr_sum = group_metrics(input, output_sum_b, ground_truth, default_root_dir_b, device=train_config.device)

In [None]:
imwrite(default_root_dir_b / "output.tif", output_sum_b)    

In [None]:
from matplotlib import pyplot as plt
plt.plot(output_value)
plt.title("PSNR vs P")
plt.xlabel("P = n/100+0.9")
plt.ylabel("PSNR")
plt.savefig(default_root_dir / f"psnr096/PSNR_vs_P.png")
plt.savefig(default_root_dir / f"psnr096/PSNR_vs_P.svg" , format='svg', dpi=1200)

In [None]:
default_root_dir

In [None]:
from inference import gpu_patch_inference
default_root_dir_b = default_root_dir / "rand_split"
data = datalist[1][3450:3450+512]
output = gpu_patch_inference(
    model,
    data.astype(np.float32),
    initial_patch_depth=48,
    min_overlap=40,
    device=train_config.device_number,
)
imwrite(default_root_dir_b / "input.tif", data)
imwrite(default_root_dir_b / "inference.tif", output)

In [None]:
outputlist = []

default_root_dir_b = default_root_dir / "10rand_split"
default_root_dir_b.mkdir(parents=True, exist_ok=True)
for i in range(10):
    output = gpu_patch_inference(
        model,
        datalist[i][3450:3450+512].astype(np.float32),
        initial_patch_depth=48,
        min_overlap=40,
        device=train_config.device_number,
    )
    outputlist.append(output)
output = np.stack(outputlist)
output = output.sum(axis=0)
imwrite(default_root_dir_b / "input.tif", data[3450:3450+512])
# imwrite(default_root_dir_b / "input2.tif", data2[3450:3450+512])
imwrite(default_root_dir_b / "inference.tif", output)

In [None]:
from inference import gpu_patch_inference
output = gpu_patch_inference(
    model,
    data[3450:3450+512].astype(np.float32),
    initial_patch_depth=48,
    min_overlap=40,
    device=train_config.device_number,
)
imwrite(default_root_dir / "input.tif", data[3450:3450+512])
imwrite(default_root_dir / "inference.tif", output)

In [None]:
from utils import group_metrics
input = data.astype(float)
image = output
ground_truth = ground_truth_file[3450:3450+512].astype(float)
group_metrics(input, image, ground_truth, default_root_dir_b, device=train_config.device)