In [None]:
import pytorch_lightning as L
from model.gauss_predictor import GaussPredictor
from model.gauss_predictor_with_latent_codes import GaussPredictorWithLatentCodes
from data_processing.data_module import FaceDataModule
from pytorch_lightning.loggers import WandbLogger
import yaml
from datetime import datetime
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import time

config = yaml.safe_load(open('./configs/settings_latent_codes.yaml', 'r'))
face_data_module = FaceDataModule(config=config)
face_data_module.setup('fit')

if config['model']['model_parameters']['latent_codes']['enabled']:
    model_class = GaussPredictorWithLatentCodes
else:
    model_class = GaussPredictor

# load or create model
if config['model']['checkpoint']:
    gauss_predictor = model_class.load_from_checkpoint(config['model']['checkpoint'], config=config, data_module=face_data_module)
else:
    gauss_predictor = model_class(config=config, data_module=face_data_module)

trainer = L.Trainer(
    max_epochs=config['training']['max_epochs'], 
    accelerator="gpu", 
    devices=1, 
    logger=wandb_logger, 
    default_root_dir='./checkpoints/',
    reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=1,
    callbacks=[
        EarlyStopping(monitor="val/loss/rendering", mode="min", patience=50),
        ModelCheckpoint(monitor="val/loss/rendering", mode="min", filename="{epoch:02d}", save_top_k=5)
    ]
)
trainer.fit(gauss_predictor, datamodule=face_data_module)