In [None]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent
sys.path.append(str(ROOT))

import mlflow
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchmetrics import Metric

import datetime

from src.data.pipeline import IngestionPipeline
from src.datasets.dual_input import DualInputSequenceDataset
from src.models.gru import EnsembleGRU
from src.train import train_model
from src.utils.utils import collate_with_macro, CustomReduceLROnPlateau, FocalLoss

import logging

import mlflow
from mlflow.tracking import MlflowClient

def get_best_models(n_models: int, metric: str = "best_f1"):
    mlflow.set_tracking_uri("http://127.0.0.1:8080")
    client = MlflowClient()
    
    raw_list = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])
    runs = raw_list.sort_values(by = "metrics." + metric, ascending = False)
    runs.reset_index(inplace=True)
    run_ids = runs["run_id"]
    
    models = []
    current_pool = 0
    
    for run in run_ids:
        try:
            artifacts = client.list_artifacts(run)
            for artifact in artifacts:
                if artifact.is_dir and artifact.path.startswith("GRUModel_"):
                    model_name = artifact.path
                    model_uri = f"runs:/{run}/{model_name}"
                    print(f"Loading model: {model_uri}")
                    model = mlflow.pytorch.load_model(model_uri=model_uri)
                    models.append(model)
                    
                    current_pool += 1
                    if current_pool == n_models:
                        print(f"Retrieved {len(models)} models.")
                        return models
                    
        except Exception as e:
            print(f"Skipping run {run} due to error: {e}")
    
    print(f"Retrieved {len(models)} models.")
    return models

def train(
    company_path: str,
    macro_paths: list[str],
    bankruptcy_col: str,
    company_col: str,
    revenue_cap: int,
    n_models: int,
    metrics: list[Metric],
    seed: int,
    device: str,
    num_layers: int = 2,
    hidden_sizes: list[int] = 16,
    epochs: int = 50,
    lr: float = 1e-2,
    train_fract: float = 0.8,
    threshold: float = 0.5,
    dropout: float = 0.3,
    scheduler_factor: float = 0.5,
    scheduler_patience: int = 5,
    min_lr: float = 1e-5
):
    logger = logging.getLogger(__name__)
    logging.basicConfig(level = logging.INFO)
    
    ingestion = IngestionPipeline(
        company_path=company_path,
        macro_paths=macro_paths,
        company_col=company_col,
        bankruptcy_col=bankruptcy_col,
        revenue_cap=revenue_cap
    )
    
    ingestion.run()
    X, M, y = ingestion.get_tensors()
    
    dataset = DualInputSequenceDataset(
        firm_tensor = X,
        macro_tensor = M,
        labels = y
    )
    
    train_ds, val_ds, seed = dataset.stratified_split(train_fract=train_fract)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_with_macro)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=True, collate_fn=collate_with_macro)

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # Mac hardware acceleration
    logger.info(f"Device: {device}")
    
    pos_weight = dataset.pos_weight()
    
    mlflow.set_tracking_uri("http://127.0.0.1:8080")
    mlflow.set_experiment("bankruptcy-predictions")
    with mlflow.start_run():
        mlflow.set_tag("model_type", "ensemble")
        mlflow.log_param("seed", seed)
        
        models = get_best_models(n_models=n_models)
        for model in models:
            model.to(device)
        
        ensemble_model = EnsembleGRU(
            models=models, hidden_sizes=hidden_sizes, threshold=threshold, dropout=dropout
        )
        ensemble_model = ensemble_model.to(device)
        
        # loss_fn = FocalLoss(alpha=0.9, gamma=2.5)
        loss_fn = BCEWithLogitsLoss(pos_weight=pos_weight)

        # Logging hyperparameters
        mlflow.log_param("hidden_size", hidden_sizes)
        mlflow.log_param("num_layers", len(hidden_sizes))
        mlflow.log_param("dropout", dropout)
        mlflow.log_param("lr", lr)
        
        optimizer = Adam(ensemble_model.parameters(), lr = lr)
        scheduler=CustomReduceLROnPlateau(
            optimizer=optimizer,
            factor=scheduler_factor,
            patience=scheduler_patience,
            min_lr=min_lr
        )
        train_model(
            model=ensemble_model,
            train_loader=train_loader,
            val_loader=val_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            epochs=epochs,
            metrics=metrics,
            stopping_patience=20,
            stopping_window=10
        )
        
        model_name = f"EnsembleModel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
        mlflow.pytorch.log_model(model, model_name)
        torch.save(obj = model.state_dict(), f = f"../models/{model_name}.pth")
        print(f"Model saved: {model_name}")
        
    return model

In [4]:
import yaml
from pathlib import Path

from src.utils.utils import TrainConfig

def load_yaml_file(path):
    with open(path) as stream:
        try:
            config_dict=yaml.safe_load(stream)
            return config_dict
        except yaml.YAMLError as e:
            TypeError(f"Config file could not be loaded: {e}")
    
def train_model_from_config(cfg: TrainConfig) -> EnsembleGRU:
    """Main training function"""
    return train(
        company_path = Path("../" + cfg.firm_data),
        macro_paths = [Path("../" + path) for path in cfg.macro_data],
        bankruptcy_col = cfg.bankruptcy_col,
        company_col=cfg.company_col,
        metrics=cfg.get_metrics().to(cfg.device),
        revenue_cap=3000,
        n_models=7,
        device=cfg.device,
        num_layers=cfg.num_classes,
        hidden_sizes=[cfg.hidden_size] if isinstance(cfg.hidden_size, int) else cfg.hidden_size,
        epochs=cfg.epochs,
        lr=float(cfg.lr),
        train_fract=cfg.train_fract,
        dropout=cfg.dropout,
        scheduler_factor=cfg.scheduler_factor,
        scheduler_patience=cfg.scheduler_patience,
        seed=cfg.seed
    )

In [5]:
import gc

while True:
    with open("../config/ensemble_config.yml") as stream:
        config=yaml.safe_load(stream)
        config = TrainConfig(**config)
    model = train_model_from_config(config)
    
    del model
    gc.collect()
    torch.mps.empty_cache()

INFO:src.data.loaders:Reading file: ../data/demo_data.xlsx


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpn22451t0/1deusn4h.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpn22451t0/fore3s8t.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=56877', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Loading model: runs:/f23b8ed7cd1b422fba1cb1a7289c1b12/GRUModel_20250724_061750


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/f1269821006f477d8424259e3ada789c/GRUModel_20250723_224211


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/2d8ca5d28d58486daf85ed77074301e5/GRUModel_20250724_102840


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/ffbe042b627b4877aa32198a422e8dfa/GRUModel_20250724_065920


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/7d88e80a8aac4cea99184cee4072d3aa/GRUModel_20250723_203657


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/c6eba5f871c04869aa6b249adb9d368d/GRUModel_20250723_223300


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/2a70998b91a6404fa5e9522dee1e45fc/GRUModel_20250723_205917


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Retrieved 7 models.


Epoch 31/50 | Loss: 1.15228 | ACCURACY: 0.62609 | AUROC: 0.71341 | F1: 0.19906 | MATTHEWS: 0.15726 | LR: 0.00281:  62%|██████▏   | 31/50 [15:36<09:33, 30.19s/it]

Early stopping at epoch 31, restoring model from epoch 30



INFO:src.data.loaders:Reading file: ../data/demo_data.xlsx


🏃 View run stylish-gull-327 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/dae896b8cfe248d1a40016a9ef691c00
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpn22451t0/d6pgk1_q.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpn22451t0/2ramvdrv.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=10714', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Loading model: runs:/f23b8ed7cd1b422fba1cb1a7289c1b12/GRUModel_20250724_061750


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/f1269821006f477d8424259e3ada789c/GRUModel_20250723_224211


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/2d8ca5d28d58486daf85ed77074301e5/GRUModel_20250724_102840


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/ffbe042b627b4877aa32198a422e8dfa/GRUModel_20250724_065920


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/7d88e80a8aac4cea99184cee4072d3aa/GRUModel_20250723_203657


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/c6eba5f871c04869aa6b249adb9d368d/GRUModel_20250723_223300


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/2a70998b91a6404fa5e9522dee1e45fc/GRUModel_20250723_205917


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Retrieved 7 models.


Epoch 32/50 | Loss: 1.13057 | ACCURACY: 0.65886 | AUROC: 0.72402 | F1: 0.21481 | MATTHEWS: 0.17957 | LR: 0.00281:  64%|██████▍   | 32/50 [16:22<09:12, 30.72s/it]

Early stopping at epoch 32, restoring model from epoch 31



INFO:src.data.loaders:Reading file: ../data/demo_data.xlsx


🏃 View run capricious-shrimp-71 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/e7bb809f1a2f41f29a49fc95e92b7f78
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpn22451t0/9a028m1a.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpn22451t0/4l2lyfwk.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=47818', 'data', 'file=/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/t

Loading model: runs:/f23b8ed7cd1b422fba1cb1a7289c1b12/GRUModel_20250724_061750


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/f1269821006f477d8424259e3ada789c/GRUModel_20250723_224211


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/2d8ca5d28d58486daf85ed77074301e5/GRUModel_20250724_102840


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/ffbe042b627b4877aa32198a422e8dfa/GRUModel_20250724_065920


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/7d88e80a8aac4cea99184cee4072d3aa/GRUModel_20250723_203657


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/c6eba5f871c04869aa6b249adb9d368d/GRUModel_20250723_223300


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/2a70998b91a6404fa5e9522dee1e45fc/GRUModel_20250723_205917


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Retrieved 7 models.


Epoch 30/50 | Loss: 1.14772 | ACCURACY: 0.67534 | AUROC: 0.71556 | F1: 0.20976 | MATTHEWS: 0.16580 | LR: 0.00281:  60%|██████    | 30/50 [15:14<10:09, 30.48s/it]

Early stopping at epoch 30, restoring model from epoch 29





🏃 View run peaceful-owl-246 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/94ca71f322fe4290b0e0d16bdbcdb059
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Reading file: ../data/demo_data.xlsx
INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
INFO:prophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpn22451t0/te0yvthg.json
DEBUG:cmdstanpy:input tempfile: /var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/tmpn22451t0/np3t3z_6.json
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: None
DEBUG:cmdstanpy:CmdStan args: ['/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/prophet/stan_model/prophet_model.bin', 'random', 'seed=61552', 'data

Loading model: runs:/f23b8ed7cd1b422fba1cb1a7289c1b12/GRUModel_20250724_061750


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/f1269821006f477d8424259e3ada789c/GRUModel_20250723_224211


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/2d8ca5d28d58486daf85ed77074301e5/GRUModel_20250724_102840


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/ffbe042b627b4877aa32198a422e8dfa/GRUModel_20250724_065920


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/7d88e80a8aac4cea99184cee4072d3aa/GRUModel_20250723_203657


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/c6eba5f871c04869aa6b249adb9d368d/GRUModel_20250723_223300


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Loading model: runs:/2a70998b91a6404fa5e9522dee1e45fc/GRUModel_20250723_205917


Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Retrieved 7 models.


Epoch 29/50 | Loss: 1.13148 | ACCURACY: 0.65766 | AUROC: 0.72367 | F1: 0.21565 | MATTHEWS: 0.18153 | LR: 0.00375:  58%|█████▊    | 29/50 [14:58<10:50, 30.97s/it]

Early stopping at epoch 29, restoring model from epoch 28



INFO:src.data.loaders:Reading file: ../data/demo_data.xlsx


🏃 View run amazing-seal-389 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/3ee011a5c66542a9b893a2dfbe9f7331
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")
  df["Date"]=pd.to_datetime(df["Date"], errors="coerce")


ValueError: Dataframe has less than 2 non-NaN rows.

In [None]:
import mlflow
import mlflow.pytorch
import torch
from src.models.gru import GRUModel

mlflow.set_tracking_uri("http://127.0.0.1:8080")
model_uri = 'runs:/ded35c0837614ae2ab5feafda113ee9c/GRUModel_20250702_021258'

model = mlflow.pytorch.load_model(model_uri)

In [None]:
model

In [None]:
mlflow.set_tracking_uri("http://127.0.0.1:8080")
runs = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])

In [None]:
runs

In [None]:
selected_runs = runs.sort_values(by = "metrics.val_matthews", ascending = False)[:7]
selected_runs.reset_index(inplace=True)

In [None]:
import mlflow.artifacts


mlflow.artifacts.download_artifacts(run_id="fe05629c8fa04c18a6e37553051db968")

In [None]:
selected_runs.columns

In [None]:
selected_runs["artifact_uri"][1]

In [None]:
run = run_ids[89]

In [None]:
run_ids

In [None]:
mlflow.pytorch.load_model("runs:/bb5b7d115ac04197a760a5d7aba049e9/GRUModel_20250701_233333")

In [None]:
runs.run_id

In [None]:
run_ids

In [None]:
from mlflow.tracking import MlflowClient
client = MlflowClient()
model_paths = []
artifacts = client.list_artifacts(run_ids[89])
for artifact in artifacts:
    if artifact.is_dir and artifact.path.starts_with("GRUModel_"):
        model_name = artifact.path
        model_uri = f"runs/.../{model_name}"
        model_paths.append((run.infomodel_name )

In [None]:
import mlflow
from mlflow.tracking import MlflowClient

def get_best_models(n_models: int = 7, metric: str = "val_matthews"):
    mlflow.set_tracking_uri("http://127.0.0.1:8080")
    client = MlflowClient()
    
    runs = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])
    selected_runs = runs.sort_values(by = "metrics." + metric, ascending = False)[:n_models]["run_id"]
    selected_runs.reset_index(inplace=True)
    
    model_paths = []
    
    for run in selected_runs:
        run_id = run.info.run_id
        artifacts = client.list_artifacts(run_id)
        
        for artifact in artifacts:
            if artifact.is_dir and artifact.path.starts_with("GRUModel_"):
                model_name = artifact.path
                model_uri = f"runs:/{run_id}/{model_name}"
                model_paths.append(model_uri)

    models = []

    for model_uri in model_paths:
        model = mlflow.pytorch.load_model(model_uri = model_uri)
        models.append(model)

In [None]:
models = get_best_models()
models

In [None]:
mlflow.set_tracking_uri("http://127.0.0.1:8080")
runs = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])
selected_runs = runs.sort_values(by = "metrics.val_matthews", ascending = False)[:7]["tags.mlflow.log-model.history"]

In [None]:
import re

In [None]:
re.search(selected_runs[101], r"\"$")

In [None]:
selected_runs[101]

In [None]:
def get_best_models(n_models: int = 7, metric: str = "val_matthews"):
    mlflow.set_tracking_uri("http://127.0.0.1:8080")
    runs = mlflow.search_runs(experiment_names=["bankruptcy-predictions"])
    selected_runs = runs.sort_values(by = "metrics." + metric, ascending = False)[:n_models]["run_id"]
    for run in selected_runs:
        model_uri = "runs:/" + run


In [None]:
run_ids = get_best_models()