In [1]:
import mlflow
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchmetrics import Metric

import datetime

from src.data.pipeline import IngestionPipeline
from src.datasets.dual_input import DualInputSequenceDataset
from src.models.gru import EnsembleGRU
from src.train import train_model
from src.utils.utils import collate_with_macro, CustomReduceLROnPlateau, FocalLoss

import logging

import mlflow
from mlflow.tracking import MlflowClient

def get_best_models(n_models: int, metric: str = "val_matthews"):
    mlflow.set_tracking_uri("http://127.0.0.1:8080")
    client = MlflowClient()
    
    raw_list = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])
    runs = raw_list.sort_values(by = "metrics." + metric, ascending = False)
    runs.reset_index(inplace=True)
    top_runs = runs[:n_models]["run_id"]
    
    model_paths = []
    
    for run in top_runs:
        artifacts = client.list_artifacts(run)
        
        for artifact in artifacts:
            if artifact.is_dir and artifact.path.startswith("GRUModel_"):
                model_name = artifact.path
                model_uri = f"runs:/{run}/{model_name}"
                model_paths.append(model_uri)

    models = []

    for model_uri in model_paths:
        model = mlflow.pytorch.load_model(model_uri=model_uri)
        models.append(model)
        
    return models

def train(
    company_path: str,
    macro_paths: list[str],
    bankruptcy_col: str,
    company_col: str,
    revenue_cap: int,
    n_models: int,
    metrics: list[Metric],
    seed: int,
    device: str,
    num_layers: int = 2,
    hidden_sizes: list[int] = 16,
    epochs: int = 50,
    lr: float = 1e-2,
    train_fract: float = 0.8,
    threshold: float = 0.5,
    dropout: float = 0.3,
    scheduler_factor: float = 0.5,
    scheduler_patience: int = 5,
    min_lr: float = 1e-5
):
    logger = logging.getLogger(__name__)
    logging.basicConfig(level = logging.INFO)
    
    ingestion = IngestionPipeline(
        company_path=company_path,
        macro_paths=macro_paths,
        company_col=company_col,
        bankruptcy_col=bankruptcy_col,
        revenue_cap=revenue_cap
    )
    
    ingestion.run()
    X, M, y = ingestion.get_tensors()
    
    dataset = DualInputSequenceDataset(
        firm_tensor = X,
        macro_tensor = M,
        labels = y
    )
    
    train_ds, val_ds, seed = dataset.stratified_split(train_fract=train_fract)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_with_macro)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=True, collate_fn=collate_with_macro)

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # Mac hardware acceleration
    logger.info(f"Device: {device}")
    
    pos_weight = dataset.pos_weight()
    
    mlflow.set_tracking_uri("http://127.0.0.1:8080")
    mlflow.set_experiment("bankruptcy-predictions")
    with mlflow.start_run():
        mlflow.set_tag("model_type", "ensemble")
        mlflow.log_param("seed", seed)
        
        models = get_best_models(n_models=n_models)
        for model in models:
            model.to(device)
        
        ensemble_model = EnsembleGRU(
            models=models, hidden_sizes=hidden_sizes, threshold=threshold, dropout=dropout
        )
        ensemble_model = ensemble_model.to(device)
        
        # loss_fn = FocalLoss(alpha=0.9, gamma=2.5)
        loss_fn = BCEWithLogitsLoss(pos_weight=pos_weight)

        # Logging hyperparameters
        mlflow.log_param("hidden_size", hidden_sizes)
        mlflow.log_param("num_layers", len(hidden_sizes))
        mlflow.log_param("dropout", dropout)
        mlflow.log_param("lr", lr)
        
        optimizer = Adam(ensemble_model.parameters(), lr = lr)
        scheduler=CustomReduceLROnPlateau(
            optimizer=optimizer,
            factor=scheduler_factor,
            patience=scheduler_patience,
            min_lr=min_lr
        )
        train_model(
            model=ensemble_model,
            train_loader=train_loader,
            val_loader=val_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            epochs=epochs,
            metrics=metrics,
            stopping_patience=20,
            stopping_window=10
        )
        
        mlflow.pytorch.log_model(ensemble_model, f"model_{datetime.datetime.now()}")
        torch.save(obj = ensemble_model.state_dict(), f = f"model_{datetime.datetime.now()}.pth")
        
    return model

In [2]:
import yaml
from pathlib import Path

from src.utils.utils import TrainConfig

def load_yaml_file(path):
    with open(path) as stream:
        try:
            config_dict=yaml.safe_load(stream)
            return config_dict
        except yaml.YAMLError as e:
            TypeError(f"Config file could not be loaded: {e}")
    
def train_model_from_config(cfg: TrainConfig) -> EnsembleGRU:
    """Main training function"""
    return train(
        company_path = Path(cfg.firm_data),
        macro_paths = [Path(path) for path in cfg.macro_data],
        bankruptcy_col = cfg.bankruptcy_col,
        company_col=cfg.company_col,
        metrics=cfg.get_metrics().to(cfg.device),
        revenue_cap=3000,
        n_models=5,
        device=cfg.device,
        num_layers=cfg.num_classes,
        hidden_sizes=[cfg.hidden_size] if isinstance(cfg.hidden_size, int) else cfg.hidden_size,
        epochs=cfg.epochs,
        lr=float(cfg.lr),
        train_fract=cfg.train_fract,
        dropout=cfg.dropout,
        scheduler_factor=cfg.scheduler_factor,
        scheduler_patience=cfg.scheduler_patience,
        seed=cfg.seed
    )

In [3]:
for i in range(10):
    with open("config/model_config.yml") as stream:
        config=yaml.safe_load(stream)
        config = TrainConfig(**config)
    train_model_from_config(config)

INFO:src.data.loaders:Reading file: data/demo_data.xlsx
INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/_8tc5upe.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/04locp20.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=1928', 'data', '

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 29/50 | Loss: 1.06419 | ACCURACY: 0.89833 | AUROC: 0.76423 | F1: 0.30435 | MATTHEWS: 0.25053 | LR: 0.00375:  58%|█████▊    | 29/50 [06:06<04:25, 12.63s/it]

Early stopping at epoch 29, restoring model from epoch 28



INFO:src.data.loaders:Reading file: data/demo_data.xlsx


🏃 View run respected-shark-559 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/cd8f05d0336b40e188b4c1618c8ef3ee
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/46zwvvgi.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/z8eug65p.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=11550', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 31/50 | Loss: 1.09560 | ACCURACY: 0.89496 | AUROC: 0.75411 | F1: 0.30303 | MATTHEWS: 0.24825 | LR: 0.00281:  62%|██████▏   | 31/50 [06:30<03:59, 12.59s/it]

Early stopping at epoch 31, restoring model from epoch 30



INFO:src.data.loaders:Reading file: data/demo_data.xlsx


🏃 View run legendary-shoat-447 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/54cb10600a394abd878c16319f4e684f
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/6hhkindj.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/0gn1_q08.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=91378', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 29/50 | Loss: 1.11230 | ACCURACY: 0.89913 | AUROC: 0.73566 | F1: 0.26801 | MATTHEWS: 0.21388 | LR: 0.00281:  58%|█████▊    | 29/50 [04:11<03:02,  8.67s/it]

Early stopping at epoch 29, restoring model from epoch 28



INFO:src.data.loaders:Reading file: data/demo_data.xlsx


🏃 View run receptive-pug-234 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/9da52f7cc4ba4d74b9bca4b31db84152
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/n5kznvyu.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/vlw8qima.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=42416', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 30/50 | Loss: 1.09694 | ACCURACY: 0.90389 | AUROC: 0.73988 | F1: 0.28190 | MATTHEWS: 0.23044 | LR: 0.00281:  60%|██████    | 30/50 [04:30<03:00,  9.01s/it]

Early stopping at epoch 30, restoring model from epoch 29



INFO:src.data.loaders:Reading file: data/demo_data.xlsx


🏃 View run bedecked-dove-970 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/6af5cd7cc2bd4330891f7c50a2dbcb89
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/ssuej1ji.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/pmo7z56j.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=44673', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 29/50 | Loss: 1.11340 | ACCURACY: 0.89337 | AUROC: 0.73466 | F1: 0.26337 | MATTHEWS: 0.20659 | LR: 0.00500:  58%|█████▊    | 29/50 [04:19<03:07,  8.94s/it]

Early stopping at epoch 29, restoring model from epoch 28



INFO:src.data.loaders:Reading file: data/demo_data.xlsx


🏃 View run learned-cub-710 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/c5862465fdd6477bbfcbc90c989180ce
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/zgy4k_sh.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/6tkg1los.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=34401', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 34/50 | Loss: 1.12255 | ACCURACY: 0.90409 | AUROC: 0.73003 | F1: 0.25348 | MATTHEWS: 0.20279 | LR: 0.00500:  68%|██████▊   | 34/50 [05:03<02:22,  8.93s/it]

Early stopping at epoch 34, restoring model from epoch 33



INFO:src.data.loaders:Reading file: data/demo_data.xlsx


🏃 View run funny-carp-707 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/de636120bf2144d1bbd56a23306e9417
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/agz8j9ms.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/7cc6nknm.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=64490', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 30/50 | Loss: 1.12473 | ACCURACY: 0.90707 | AUROC: 0.72580 | F1: 0.25949 | MATTHEWS: 0.21109 | LR: 0.00375:  60%|██████    | 30/50 [02:40<01:47,  5.36s/it]

Early stopping at epoch 30, restoring model from epoch 29



INFO:src.data.loaders:Reading file: data/demo_data.xlsx


🏃 View run learned-carp-807 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/dfedcc92bc1d4bc2bfdda0221efaa58c
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/hbuy88xq.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/zg4nr94w.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=73561', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 31/50 | Loss: 1.11762 | ACCURACY: 0.89496 | AUROC: 0.73472 | F1: 0.28223 | MATTHEWS: 0.22656 | LR: 0.00281:  62%|██████▏   | 31/50 [02:48<01:43,  5.43s/it]

Early stopping at epoch 31, restoring model from epoch 30





🏃 View run nimble-wolf-699 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/ba8e782635364432afb00d167980e432
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Reading file: data/demo_data.xlsx
INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/r6_zmp4l.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/v_d_1_h3.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=12886', 'data', 

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 33/50 | Loss: 1.11756 | ACCURACY: 0.89793 | AUROC: 0.73224 | F1: 0.25291 | MATTHEWS: 0.19814 | LR: 0.00281:  66%|██████▌   | 33/50 [02:56<01:30,  5.34s/it]

Early stopping at epoch 33, restoring model from epoch 32



INFO:src.data.loaders:Reading file: data/demo_data.xlsx


🏃 View run adaptable-shrike-89 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/d09609ffc3464c07bea2c0cd2d8af002
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/fmz_6ahk.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpxitkoibe/8naxrqpd.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=13702', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

🏃 View run zealous-moose-704 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/27298398f5634776a36a62b45030630f
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


RuntimeError: stack expects a non-empty TensorList

In [None]:
import mlflow
import mlflow.pytorch
import torch
from src.models.gru import GRUModel

mlflow.set_tracking_uri("http://127.0.0.1:8080")
model_uri = 'runs:/bb5b7d115ac04197a760a5d7aba049e9/GRUModel_20250701_233333'

model = mlflow.pytorch.load_model(model_uri)

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [None]:
mlflow.set_tracking_uri("http://127.0.0.1:8080")
runs = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])

In [None]:
runs

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.train_matthews,metrics.best_threshold,metrics.best_f1,metrics.train_lr,...,params.hidden_size,params.num_layers,params.output_size,tags.mlflow.log-model.history,tags.mlflow.runName,tags.mlflow.source.type,tags.model_type,tags.mlflow.source.name,tags.mlflow.user,tags.ensemble
0,6cdacabfdbf644019d2c147256818dd5,387584985157093548,FINISHED,mlflow-artifacts:/387584985157093548/6cdacabfd...,2025-07-14 12:24:37.989000+00:00,2025-07-14 12:42:00.901000+00:00,0.071489,0.5,0.340249,0.005625,...,"[64, 32, 16]",2,,"[{""run_id"": ""6cdacabfdbf644019d2c147256818dd5""...",chill-hound-382,LOCAL,ensemble,/Users/guillaumedecina-halmi/miniforge3/lib/py...,guillaumedecina-halmi,
1,805f7855e48b4ac8939d4959d54efb0e,387584985157093548,FINISHED,mlflow-artifacts:/387584985157093548/805f7855e...,2025-07-14 12:10:09.239000+00:00,2025-07-14 12:24:28.802000+00:00,0.067752,0.5,0.303571,0.005625,...,"[64, 32, 16]",2,,"[{""run_id"": ""805f7855e48b4ac8939d4959d54efb0e""...",silent-finch-41,LOCAL,ensemble,/Users/guillaumedecina-halmi/miniforge3/lib/py...,guillaumedecina-halmi,
2,0a7b61714e52426da980dcaaa94a24b7,387584985157093548,FAILED,mlflow-artifacts:/387584985157093548/0a7b61714...,2025-07-14 11:59:04.549000+00:00,2025-07-14 12:09:11.246000+00:00,-0.048402,,,0.010000,...,"[64, 32, 16]",2,,,dapper-crab-958,LOCAL,ensemble,/Users/guillaumedecina-halmi/miniforge3/lib/py...,guillaumedecina-halmi,
3,d788cea214014ceabbab6c592c0cdef4,387584985157093548,FAILED,mlflow-artifacts:/387584985157093548/d788cea21...,2025-07-14 11:56:36.312000+00:00,2025-07-14 11:56:42.762000+00:00,,,,,...,"[64, 32, 16]",2,,,inquisitive-midge-0,LOCAL,ensemble,/Users/guillaumedecina-halmi/miniforge3/lib/py...,guillaumedecina-halmi,
4,8f51c4836b994542b500e3bae383105e,387584985157093548,FAILED,mlflow-artifacts:/387584985157093548/8f51c4836...,2025-07-14 11:55:07.056000+00:00,2025-07-14 11:55:13.686000+00:00,,,,,...,,,,,powerful-asp-432,LOCAL,ensemble,/Users/guillaumedecina-halmi/miniforge3/lib/py...,guillaumedecina-halmi,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1030,a5ef4ae82b974f178db362043a06f378,387584985157093548,FINISHED,mlflow-artifacts:/387584985157093548/a5ef4ae82...,2025-04-03 16:17:27.379000+00:00,2025-04-03 16:17:35.267000+00:00,,,,,...,,,,"[{""run_id"": ""a5ef4ae82b974f178db362043a06f378""...",wise-boar-183,LOCAL,,/Users/guillaumedecina-halmi/Library/Python/3....,guillaumedecina-halmi,
1031,4acbec115384493d846cb0cf25f07b0c,387584985157093548,FAILED,mlflow-artifacts:/387584985157093548/4acbec115...,2025-04-03 16:16:10.017000+00:00,2025-04-03 16:16:10.070000+00:00,,,,,...,,,,,welcoming-stork-436,LOCAL,,/Users/guillaumedecina-halmi/Library/Python/3....,guillaumedecina-halmi,
1032,a318ac43d95d474595373f042d3efc0d,387584985157093548,FAILED,mlflow-artifacts:/387584985157093548/a318ac43d...,2025-04-03 16:15:02.037000+00:00,2025-04-03 16:15:02.096000+00:00,,,,,...,,,,,classy-donkey-176,LOCAL,,/Users/guillaumedecina-halmi/Library/Python/3....,guillaumedecina-halmi,
1033,4e3239ab510142ab9b5dc6c04ace7e95,387584985157093548,FAILED,mlflow-artifacts:/387584985157093548/4e3239ab5...,2025-04-03 16:14:30.895000+00:00,2025-04-03 16:14:33.204000+00:00,,,,,...,,,,,mysterious-midge-268,LOCAL,,/Users/guillaumedecina-halmi/Library/Python/3....,guillaumedecina-halmi,


In [None]:
selected_runs = runs.sort_values(by = "metrics.val_matthews", ascending = False)[:7]
selected_runs.reset_index(inplace=True)

In [None]:
import mlflow.artifacts


mlflow.artifacts.download_artifacts(run_id="fe05629c8fa04c18a6e37553051db968")

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

'/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpux7aooyn/'

In [None]:
selected_runs.columns

Index(['index', 'run_id', 'experiment_id', 'status', 'artifact_uri',
       'start_time', 'end_time', 'metrics.train_matthews',
       'metrics.best_threshold', 'metrics.best_f1', 'metrics.train_lr',
       'metrics.val_auroc', 'metrics.train_accuracy', 'metrics.val_accuracy',
       'metrics.val_matthews', 'metrics.train_f1', 'metrics.val_loss',
       'metrics.train_auroc', 'metrics.val_f1',
       'metrics.val_BinaryMatthewsCorrCoef', 'metrics.val_BinaryF1Score',
       'metrics.val_BinaryRecall', 'metrics.train_BinaryRecall',
       'metrics.val_BinaryPrecision', 'metrics.train_BinaryAveragePrecision',
       'metrics.train_BinaryF1Score', 'metrics.train_BinaryPrecision',
       'metrics.train_BinaryMatthewsCorrCoef',
       'metrics.val_BinaryAveragePrecision', 'metrics.train_BinaryAUROC',
       'metrics.val_BinaryAUROC', 'metrics.auc', 'metrics.lr', 'metrics.f1',
       'metrics.loss', 'params.lr', 'params.seed', 'params.dropout',
       'params.hidden_size', 'params.num_layers'

In [None]:
selected_runs["artifact_uri"][1]

'mlflow-artifacts:/387584985157093548/fe05629c8fa04c18a6e37553051db968/artifacts'

In [None]:
run = run_ids[89]

NameError: name 'run_ids' is not defined

In [None]:
run_ids

In [None]:
mlflow.pytorch.load_model("runs:/bb5b7d115ac04197a760a5d7aba049e9/GRUModel_20250701_233333")

In [None]:
runs.run_id

In [None]:
run_ids

In [None]:
from mlflow.tracking import MlflowClient
client = MlflowClient()
model_paths = []
artifacts = client.list_artifacts(run_ids[89])
for artifact in artifacts:
    if artifact.is_dir and artifact.path.starts_with("GRUModel_"):
        model_name = artifact.path
        model_uri = f"runs/.../{model_name}"
        model_paths.append((run.infomodel_name )

In [None]:
import mlflow
from mlflow.tracking import MlflowClient

def get_best_models(n_models: int = 7, metric: str = "val_matthews"):
    mlflow.set_tracking_uri("http://127.0.0.1:8080")
    client = MlflowClient()
    
    runs = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])
    selected_runs = runs.sort_values(by = "metrics." + metric, ascending = False)[:n_models]["run_id"]
    selected_runs.reset_index(inplace=True)
    
    model_paths = []
    
    for run in selected_runs:
        run_id = run.info.run_id
        artifacts = client.list_artifacts(run_id)
        
        for artifact in artifacts:
            if artifact.is_dir and artifact.path.starts_with("GRUModel_"):
                model_name = artifact.path
                model_uri = f"runs:/{run_id}/{model_name}"
                model_paths.append(model_uri)

    models = []

    for model_uri in model_paths:
        model = mlflow.pytorch.load_model(model_uri = model_uri)
        models.append(model)

In [None]:
models = get_best_models()
models

In [None]:
mlflow.set_tracking_uri("http://127.0.0.1:8080")
runs = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])
selected_runs = runs.sort_values(by = "metrics.val_matthews", ascending = False)[:7]["tags.mlflow.log-model.history"]

In [None]:
import re

In [None]:
re.search(selected_runs[101], r"\"$")

In [None]:
selected_runs[101]

In [None]:
def get_best_models(n_models: int = 7, metric: str = "val_matthews"):
    mlflow.set_tracking_uri("http://127.0.0.1:8080")
    runs = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])
    selected_runs = runs.sort_values(by = "metrics." + metric, ascending = False)[:n_models]["run_id"]
    for run in selected_runs:
        model_uri = "runs:/" + run


In [None]:
run_ids = get_best_models()