In [0]:
import copy
import yaml
import argparse
import optuna
import mlflow
from mlflow.optuna import MlflowStorage
from mlflow.pyspark.optuna.study import MlflowSparkStudy
from optuna.integration import PyTorchLightningPruningCallback
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import MLFlowLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from pyspark.sql import SparkSession
from src.model.model_factory import get_model
from src.utils.callbacks import LogPredictionsCallback
from src.data.optuna_snuplass_datamodule import get_datamodule
from src.utils.get_from_overview import get_split_from_overview

In [0]:
def objective(trial: optuna.Trial, config: dict, model_name: str, tracking_uri: str, experiment_name: str) -> float:
    """
    Kjører Optuna hyperparameter-tuning for en gitt modell.
    Argumenter:
        trial (Trial): Optuna trial
        config (dict): Konfigurasjonen
        model_name (str): Navnet på modellen som skal optimeres
        tracking_uri (str): URI til MLflow-trackeren
        experiment_name (str): Navnet på MLflow-expirementet
    Returnerer:
        float: Verdien til metrikken som optimiseres
    """
    opt_cfg = config['optuna']

    # Hyperparameter tuning
    lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    config['model'][model_name]['lr'] = lr
    trial.set_user_attr('lr', lr)

    batch_size = trial.suggest_categorical('batch_size', [4, 8, 16, 32])
    config['model'][model_name]['batch_size'] = batch_size
    trial.set_user_attr('batch_size', batch_size)

    # Setter opp logger
    mlflow.set_tracking_uri(tracking_uri)
    mlflow.set_experiment(experiment_name)
    mlf_logger = MLFlowLogger(
        experiment_name=experiment_name,
        run_name=f"{model_name}-trial{trial.number}",
        tracking_uri=tracking_uri,
        tags={'model': model_name}
    )

    # Data & modell
    datamodule = get_datamodule(config, model_name)
    model = get_model(model_name, config['model'][model_name])

    early_stop = EarlyStopping(
        monitor=config['training']['monitor'],
        mode=config['training']['monitor_mode'],
        patience=config['training']['early_stopping_patience'],
        verbose=True
    )
    checkpoint = ModelCheckpoint(
        dirpath="/tmp/checkpoints",
        monitor=config['training']['monitor'],
        mode=config['training']['monitor_mode'],
        save_top_k=1,
        filename=f"{{epoch:02d}}-{{{config['training']['monitor']}: .4f}}"
    )
    log_pred = LogPredictionsCallback(**config.get('log_predictions_callback', {}))
    pruning = PyTorchLightningPruningCallback(
        trial,
        monitor=config['training']['monitor'],
    )

    trainer = Trainer(
        logger=mlf_logger,
        default_root_dir="/tmp",
        max_epochs=config['training']['max_epochs'],
        accelerator=config['training']['accelerator'],
        devices=config['training']['devices'],
        precision=config['training']['precision'],
        callbacks=[early_stop, checkpoint, log_pred, pruning],
        log_every_n_steps=10,
        deterministic=True,
        enable_progress_bar=False,
    )

    trainer.fit(model, datamodule=datamodule)
    val_metrics = trainer.validate(model, datamodule=datamodule)[0]
    return val_metrics[opt_cfg['metric_name']]


def main(config_path: str):
    """
    Kjører Optuna hyperparameter-tuning for en gitt modell.
    Argumenter:
        config_path: Sti til YAML-fil med konfigurasjonen
    """
    # Laster inn konfigurasjonen
    with open(config_path) as f:
        base_config = yaml.safe_load(f)

    # Kopierer konfigurasjonen
    config = copy.deepcopy(base_config)

    model_name = config['model_names'][0]
    opt_cfg = config['optuna']
    tracking_uri = config['logging'].get('tracking_uri', 'databricks')

    # Finner eksperimentID
    spark = SparkSession.builder.getOrCreate()
    username = spark.sql("SELECT current_user()").collect()[0][0]
    experiment_name = f"/Users/{username}/{model_name}"
    experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id

    # Henter ut train-, val- og holdout-IDer
    train, val, holdout = get_split_from_overview(
        spark,
        config['data']['spark_catalog'],
        config['data']['spark_schema'],
        config['data']['train']['overview_table'],
        config['data']['train']['id_field'],
        require_mask=True
    )

    # Lagrer train-, val- og holdout-IDer i konfigurasjonen
    config['data']['train_ids'] = train
    config['data']['val_ids'] = val
    config['data']['holdout_ids'] = holdout


    # Setter opp MLflow storage
    mlflow.set_tracking_uri(tracking_uri)
    mlflow.set_experiment(experiment_name)
    mlflow_storage = MlflowStorage(experiment_id=experiment_id)

    # Setter opp Optuna 
    study = MlflowSparkStudy(
        study_name=f"optuna-{model_name}",
        storage=mlflow_storage,
        pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10)
    )

    study._directions = [opt_cfg['direction']]

    # Kjører optimisering parallelt
    study.optimize(
        lambda trial: objective(trial, config, model_name, tracking_uri, experiment_name),
        n_trials=opt_cfg.get('n_trials', 30),
        n_jobs=opt_cfg.get('n_jobs', 1)
    )

    # Logger beste verdi og parametre
    trial = study.best_trial
    print(f"Beste iou: {trial.value}")
    print("Beste parametre: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, required=True, help="Sti til YAML-konfigurasjon"
    )
    args, _ = parser.parse_known_args()
    main(args.config)
