In [1]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent
sys.path.append(str(ROOT))

import torch
import mlflow
import datetime
import logging
import yaml

from dataclasses import dataclass, field
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import (
    BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryMatthewsCorrCoef,
    MulticlassAccuracy, MulticlassAUROC, MulticlassF1Score)
from pathlib import Path

from src.datasets.dual_input import DualInputSequenceDataset
from src.models.gru import GRUModel
from src.data.pipeline import IngestionPipeline
from src.train.gru import train_gru
from src.utils.utils import CustomReduceLROnPlateau, collate_with_macro, TrainConfig, FocalLoss

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

def load_yaml_file(path):
    with open(path) as stream:
        try:
            config_dict=yaml.safe_load(stream)
            return config_dict
        except yaml.YAMLError as e:
            TypeError(f"Config file could not be loaded: {e}")
    
def train_model_from_config(cfg: TrainConfig) -> GRUModel:
    """Main training function"""
    return train(
        company_path = Path("../" + cfg.firm_data),
        macro_paths = [str(id) for id in cfg.macro_data],
        bankruptcy_col = str(cfg.bankruptcy_col),
        company_col=str(cfg.company_col),
        revenue_cap=int(cfg.revenue_cap),
        metrics=cfg.get_metrics().to(cfg.device),
        device=str(cfg.device),
        num_layers=int(cfg.num_classes),
        hidden_size=int(cfg.hidden_size),
        output_size=1,
        epochs=int(cfg.epochs),
        lr=float(cfg.lr),
        train_fract=float(cfg.train_fract),
        dropout=int(cfg.dropout),
        alpha=float(cfg.alpha),
        gamma=float(cfg.gamma),
        scheduler_factor=float(cfg.scheduler_factor),
        scheduler_patience=int(cfg.scheduler_patience),
        stopping_patience=int(cfg.stopping_patience),
        decay_ih=float(cfg.decay_ih),
        decay_hh=float(cfg.decay_hh),
        decay_other=float(cfg.decay_other),
        seed=int(cfg.seed)
    )

def train(
    company_path: str,
    macro_paths: list[str],
    bankruptcy_col: str,
    company_col: str,
    revenue_cap: int,
    metrics: list[Metric],
    seed: int,
    num_layers: int = 2,
    hidden_size: int = 64,
    output_size: int = 1,
    epochs: int = 50,
    lr: float = 1e-3,
    train_fract: float = 0.8,
    dropout: float = 0.2,
    alpha: float = 0.9,
    gamma: float = 2.0,
    scheduler_factor: float = 0.85,
    scheduler_patience: int = 50,
    stopping_patience: int = 10,
    stopping_window: int = 5,
    min_lr: float = 0.0,
    decay_ih:float = 1e-5,
    decay_hh:float = 1e-5,
    decay_other:float = 1e-5,
    device: str="cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
):  
    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.INFO)
    
    ingestion = IngestionPipeline(
        company_path=company_path,
        macro_paths=macro_paths,
        company_col=company_col,
        bankruptcy_col=bankruptcy_col,
        revenue_cap=revenue_cap
    )
    
    ingestion.run()
    X, M_past, M_future, y = ingestion.get_tensors()
    
    dataset = DualInputSequenceDataset(
        firm_tensor = X,
        macro_past_tensor = M_past,
        macro_future_tensor = M_future,
        labels = y
    )
    
    train_ds, val_ds, seed = dataset.stratified_split(train_fract)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_with_macro)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, collate_fn=collate_with_macro)

    logger.info(f"Device: {device}")
    
    metrics.to(device)
    train_ds = train_ds.to_device(device)
    val_ds = val_ds.to_device(device)
    
    pos_weight = dataset.pos_weight()
    
    firm_input_size, macro_input_size, _ = dataset.input_dims()
    firm_input_size = firm_input_size[-1]
    macro_input_size = macro_input_size[-2]
    
    mlflow.set_tracking_uri('http://127.0.0.1:8080')
    mlflow.set_experiment('bankruptcy-predictions')
    
    with mlflow.start_run():
        mlflow.set_tag("model_type", "gru")
        mlflow.log_param("seed", seed)
        model = GRUModel(
            firm_input_size=firm_input_size,
            macro_input_size=macro_input_size,
            hidden_size=hidden_size,
            output_size=output_size,
            num_layers=num_layers,
            dropout=dropout
        )
        
        model = model.to(device)
        
        # loss_fn = FocalLoss(alpha=alpha, gamma=gamma, reduction="mean")
        loss_fn = BCEWithLogitsLoss(pos_weight = pos_weight)
        
        # Logging hyperparameters
        mlflow.log_param("hidden_size", hidden_size)
        mlflow.log_param("output_size", output_size)
        mlflow.log_param("num_layers", num_layers)
        mlflow.log_param("dropout", dropout)
        mlflow.log_param("lr", lr)

        ih_params = []
        hh_params = []
        other_params = []

        for name, param in model.named_parameters():
            if 'weight_ih' in name:
                ih_params.append(param)
            elif 'weight_hh' in name:
                hh_params.append(param)
            else:
                other_params.append(param)
        
        optimizer = Adam([
                {'params': ih_params, 'weight_decay': decay_ih},
                {'params': hh_params, 'weight_decay': decay_hh},
                {'params': other_params, 'weight_decay': decay_other},
            ], lr=lr
        )
        scheduler=ReduceLROnPlateau(
            optimizer=optimizer,
            mode="min",
            factor=scheduler_factor,
            patience=scheduler_patience,
            min_lr=min_lr
        )
        
        train_gru(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            scheduler=scheduler,
            stopping_patience=stopping_patience,
            stopping_window=stopping_window,
            device=device,
            epochs=epochs,
            metrics=metrics
        )
        
        model_name = f"GRUModel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
        mlflow.pytorch.log_model(model, model_name)
        torch.save(obj = model.state_dict(), f = f"../models/{model_name}.pth")
        print(f"Model saved: {model_name}")
    
    return model

In [2]:
config_dict = load_yaml_file("../config/gru_config.yml")
cfg = TrainConfig(**config_dict)

company_data_path = Path("../" + cfg.firm_data)
macro_data_path = [str(id) for id in cfg.macro_data]
bankruptcy_col = str(cfg.bankruptcy_col)
company_col=str(cfg.company_col)
revenue_cap=int(cfg.revenue_cap)
metrics=cfg.get_metrics().to(cfg.device)
device=str(cfg.device)
num_layers=int(cfg.num_classes)
hidden_size=16
output_size=1
epochs=int(cfg.epochs)
lr=float(cfg.lr)
train_fract=float(cfg.train_fract)
dropout=int(cfg.dropout)
scheduler_factor=float(cfg.scheduler_factor)
scheduler_patience=int(cfg.scheduler_patience)
decay_ih=float(cfg.decay_ih)
decay_hh=float(cfg.decay_hh)
decay_other=float(cfg.decay_other)
seed=int(cfg.seed)

ingestion = IngestionPipeline(
    company_path=company_data_path,
    macro_paths=macro_data_path,
    company_col=company_col,
    bankruptcy_col=bankruptcy_col,
    revenue_cap=revenue_cap
)

In [3]:
import gc
import time
import traceback

while True:
    try:
        model = train_model_from_config(cfg)
        del model
        gc.collect()
        torch.mps.empty_cache()
    except:
        logging.error("Training failed.", exc_info=True)
        time.sleep(3)
    

INFO:src.data.loaders:Reading file: ../data/4941.xlsx
INFO:src.data.loaders:Dropping high-revenue outliers...
ERROR:root:Training failed.
Traceback (most recent call last):
  File "/Users/guillaumedecina-halmi/Documents/202504 Bankruptcy prediction on restaurants/src/data/loaders.py", line 92, in _clean
    df[self.bankruptcy_col] = pd.to_datetime(
                              ^^^^^^^^^^^^^^^
  File "/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/pandas/core/tools/datetimes.py", line 1067, in to_datetime
    values = convert_listlike(arg._values, format)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/pandas/core/tools/datetimes.py", line 407, in _convert_listlike_datetimes
    return _to_datetime_with_unit(arg, unit, name, utc, errors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/pandas/co

Past: torch.Size([3, 1]), Future: torch.Size([3, 1])
Data sent to device: mps
Data sent to device: mps


Epoch 23/50 | Loss: 0.22356 | ACCURACY: 0.99571 | AUROC: 0.99852 | F1: 0.81818 | MATTHEWS: 0.81601 | LOSS: 0.22356 | LR: 0.00042:  46%|████▌     | 23/50 [00:32<00:38,  1.42s/it]

Early stopping at epoch 23, restoring model from epoch 22



INFO:src.data.loaders:Reading file: ../data/4941.xlsx


Model saved: GRUModel_20250804_194519
🏃 View run fun-horse-204 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/0f1b954130ae4dbbaa2321edd2c31a40
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


INFO:src.data.loaders:Dropping high-revenue outliers...
ERROR:root:Training failed.
Traceback (most recent call last):
  File "/Users/guillaumedecina-halmi/Documents/202504 Bankruptcy prediction on restaurants/src/data/loaders.py", line 92, in _clean
    df[self.bankruptcy_col] = pd.to_datetime(
                              ^^^^^^^^^^^^^^^
  File "/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/pandas/core/tools/datetimes.py", line 1067, in to_datetime
    values = convert_listlike(arg._values, format)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/pandas/core/tools/datetimes.py", line 407, in _convert_listlike_datetimes
    return _to_datetime_with_unit(arg, unit, name, utc, errors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guillaumedecina-halmi/miniforge3/lib/python3.12/site-packages/pandas/core/tools/datetimes.py", line 512, in _to_datetime_with

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import datetime
import logging
import mlflow
import yaml

from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchmetrics import MetricCollection
from tqdm import tqdm
from pathlib import Path
from src.data.pipeline import IngestionPipeline
from src.datasets.dual_input import DualInputSequenceDataset
from src.models.tft import TFTModel
from src.utils.utils import CustomReduceLROnPlateau, TrainConfig, collate_with_macro

def load_yaml_file(path):
    with open(path) as stream:
        try:
            config_dict=yaml.safe_load(stream)
            return config_dict
        except yaml.YAMLError as e:
            TypeError(f"Config file could not be loaded: {e}")


def train_tft(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_epochs: int,
    metrics: MetricCollection,
    criterion: nn.Module = nn.MSELoss(),
    optimizer_cls=torch.optim.Adam,
    lr: float = 1e-3,
    device: str = "cuda" if torch.cuda.is_available() else "mps" 
        if torch.mps.is_available() else "cpu",
    early_stopping_patience: int = 10,
    scheduler_cls=None,
    scheduler_kwargs=None,
    log_fn=None,
):
    
    model = model.to(device)
    optimizer=optimizer_cls(model.parameters(), lr=lr)
    scheduler=scheduler_cls(optimizer, **scheduler_kwargs) if scheduler_cls else None
    
    best_val_loss = float("inf")
    patience_counter = 0
    history = {"train_loss": [], "val_loss": []}
    
    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} / {num_epochs} — Training"):
            firm_x = batch["firm_seq"].to(device)
            macro_x = batch["macro_seq"].to(device)
            decoder_x = macro_x[:, -12:, :] # Condition on the last year of macro data
            y = batch["label"].to(device)
            
            optimizer.zero_grad()
            y_hat = model(firm_x, macro_x, decoder_x)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
            
        train_loss = sum(train_losses) / len(train_losses)
        computed_metrics = {
            name: metric.compute().item() for name, metric in metrics.items()
        }
        computed_metrics["loss"] = train_loss
        for name, value in metrics.items():
            mlflow.log_metric(f"train_{name}", value, step = epoch)
        history["train_loss"].append(train_loss)
        
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} / {num_epochs} — Validation"):
                firm_x = batch["firm_seq"].to(device)
                macro_x = batch["macro_seq"].to(device)
                decoder_x = macro_x[:, -12:, :] # Condition on the last year of macro data
                y_hat = model(firm_x, macro_x, decoder_x)
                
                val_loss = criterion(y_hat, y)
                val_losses.append(val_loss.item())
        
        val_loss = sum(val_losses) / len(val_losses)
        computed_metrics = {
            name: metric.compute().item() for name, metric in metrics.items()
        }
        for name, value in metrics.items():
            mlflow.log_metric(f"val_{name}", value, step = epoch)
        history["val_loss"].append(val_loss)
        
        if log_fn:
            log_fn(epoch=epoch, train_loss=train_loss, val_loss=val_loss)
        
        print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_tft_model.pt")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping triggered.")
                break

        if scheduler:
            scheduler.step(val_loss)

    model_name = f"model_{datetime.datetime.now()}"
    mlflow.pytorch.log_model(model, model_name)
    torch.save(obj = model.state_dict(), f = f"models/{model_name}.pth")
    print(f"Model saved: {model_name}")
    
    return model, history


def main():
    config_dict = load_yaml_file("config/model_config.yml")
    cfg = TrainConfig(**config_dict)

    company_path = Path(cfg.firm_data)
    macro_paths = [Path(path) for path in cfg.macro_data]
    bankruptcy_col = str(cfg.bankruptcy_col)
    company_col=str(cfg.company_col)
    revenue_cap=int(cfg.revenue_cap)
    metrics=cfg.get_metrics().to(cfg.device)
    device=str(cfg.device)
    num_layers=int(cfg.num_classes)
    hidden_size=int(cfg.hidden_size)
    output_size=1
    epochs=int(cfg.epochs)
    lr=float(cfg.lr)
    train_fract=float(cfg.train_fract)
    dropout=int(cfg.dropout)
    scheduler_factor=float(cfg.scheduler_factor)
    scheduler_patience=int(cfg.scheduler_patience)
    decay_ih=float(cfg.decay_ih)
    decay_hh=float(cfg.decay_hh)
    decay_other=float(cfg.decay_other)
    seed=int(cfg.seed)
    
    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.INFO)
    
    ingestion = IngestionPipeline(
        company_path=company_path,
        macro_paths=macro_paths,
        company_col=company_col,
        bankruptcy_col=bankruptcy_col,
        revenue_cap=revenue_cap
    )
    
    ingestion.run()
    
    X, M, y = ingestion.get_tensors()
    
    dataset = DualInputSequenceDataset(
        firm_tensor = X,
        macro_tensor = M,
        labels = y
    )
    
    train_ds, val_ds, seed = dataset.stratified_split(train_fract)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_with_macro)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, collate_fn=collate_with_macro)

    logger.info(f"Device: {device}")
    
    firm_input_size, macro_input_size = dataset.input_dims()

    optimizer = Adam
    scheduler=CustomReduceLROnPlateau
    scheduler_kwargs = {
        "mode": "min",
        "factor": scheduler_factor,
        "patience": scheduler_patience,
        "min_lr": 0.0
    }
    
    pos_weight = dataset.pos_weight()
    loss_fn = BCEWithLogitsLoss(pos_weight=pos_weight)
    
    mlflow.set_tracking_uri('http://127.0.0.1:8080')
    mlflow.set_experiment('bankruptcy-predictions')
    
    # Logging hyperparameters
    with mlflow.start_run():
        mlflow.log_param("hidden_size", hidden_size)
        mlflow.log_param("output_size", output_size)
        mlflow.log_param("num_layers", num_layers)
        mlflow.log_param("dropout", dropout)
        mlflow.log_param("lr", lr)
        mlflow.log_param("seed", seed)
        
        model = TFTModel(
            static_input_dim=firm_input_size,
            encoder_input_dims=[1] * macro_input_size,
            decoder_input_dims=[1] * macro_input_size,
            hidden_dim=hidden_size,
            attention_heads=4,
            dropout=dropout
        )
        
        train_tft(
            model,
            train_loader=train_loader,
            val_loader=val_loader,
            num_epochs=epochs,
            criterion=loss_fn,
            optimizer_cls=optimizer,
            lr=lr,
            scheduler_cls=scheduler,
            scheduler_kwargs=scheduler_kwargs,
            metrics=metrics
        )

In [None]:
main()

In [None]:
mlflow.end_run()