Importing Libraries

In [None]:
from pathlib import Path
from typing import Any, Dict

import os
import numpy as np
from PIL import Image
from pytorch_lightning import Trainer
from torchvision.transforms import ToPILImage

from heat_anomaly.config import get_configurable_parameters
from heat_anomaly.data import get_datamodule
from heat_anomaly.models import get_model
from heat_anomaly.pre_processing.transforms import Denormalize
from heat_anomaly.utils.callbacks import LoadModelCallback, get_callbacks

Configurations for model and the parameters

In [None]:
MODEL = "cflow"
CONFIG_PATH = f"./heat_anomaly/models/{MODEL}/ir_image.yaml"
config = get_configurable_parameters(config_path=CONFIG_PATH)

Loading the data

In [None]:
datamodule = get_datamodule(config)
datamodule.setup()
datamodule.prepare_data()

Starting the training

In [None]:
model = get_model(config)
callbacks = get_callbacks(config)

In [None]:
# start training
trainer = Trainer(**config.trainer, callbacks=callbacks)
trainer.fit(model=model, datamodule=datamodule)

Validation

In [None]:
# load best model from checkpoint before evaluating
load_model_callback = LoadModelCallback(weights_path=trainer.checkpoint_callback.best_model_path)
trainer.callbacks.insert(0, load_model_callback)
trainer.test(model=model, datamodule=datamodule)