In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import sys
import os
import json
from fuseformer_poetry.model.fuseformer import FuseFormer
from fuseformer_poetry.callbacks.eval_metrics import EvalMetrics
from fuseformer_poetry.data.data_module import CusDataModule
from lightning.pytorch.loggers import MLFlowLogger
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.tuner import Tuner

In [3]:
config_path = os.path.join(Path().resolve().parent, 'fuseformer_poetry','config.json')
config = json.load(open(config_path))


In [4]:
dmodule = CusDataModule(config)
model = FuseFormer(config)

In [5]:
dmodule.setup(stage='fit')

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename="{epoch}-{step}-{train_gen_loss:.2f}-{train_dis_loss:.2f}-{val_gen_loss:.2f}-{val_dis_loss:.2f}",
    monitor="val_gen_loss",
    mode="min",
    save_last=True
)
es_callback = EarlyStopping(monitor="val_gen_loss", mode="min")
eval_callback = EvalMetrics()

In [7]:
mlflow_logger = MLFlowLogger(experiment_name="fuseformer",run_name='test_3', tracking_uri="http://127.0.0.1:8080")

In [None]:
trainer = L.Trainer(callbacks=[eval_callback, checkpoint_callback, es_callback], accelerator='cpu', logger=mlflow_logger, max_epochs=100, log_every_n_steps=10)

In [None]:
tuner = Tuner(trainer)
# Auto-scale batch size with binary search
tuner.scale_batch_size(model, datamodule=dmodule)


In [None]:
trainer.fit(model, dmodule)

In [None]:
trainer.validate(model, dmodule)