Connected to base (Python 3.12.9)

In [None]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent
sys.path.append(str(ROOT))

from src.data.pipeline import IngestionPipeline
from src.datasets.dual_input import DualInputSequenceDataset
from src.models.tft import TFTModel
from src.utils.utils import TrainConfig

In [None]:
import torch
import mlflow
import datetime
import logging
import yaml

from dataclasses import dataclass, field
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import (
    BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryMatthewsCorrCoef,
    MulticlassAccuracy, MulticlassAUROC, MulticlassF1Score)
from pathlib import Path

from src.datasets.dual_input import DualInputSequenceDataset
from src.models.gru import GRUModel
from src.data.pipeline import IngestionPipeline
from src.train.gru import train_model
from src.utils.utils import CustomReduceLROnPlateau, collate_with_macro, TrainConfig, FocalLoss

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

def load_yaml_file(path):
    with open(path) as stream:
        try:
            config_dict=yaml.safe_load(stream)
            return config_dict
        except yaml.YAMLError as e:
            TypeError(f"Config file could not be loaded: {e}")

In [None]:
config_dict = load_yaml_file("../config/model_config.yml")
cfg = TrainConfig(**config_dict)

company_data_path = Path("../" + cfg.firm_data)
macro_data_path = [str(id) for id in cfg.macro_data]
bankruptcy_col = str(cfg.bankruptcy_col)
company_col=str(cfg.company_col)
revenue_cap=int(cfg.revenue_cap)
metrics=cfg.get_metrics().to(cfg.device)
device=str(cfg.device)
num_layers=int(cfg.num_classes)
hidden_size=16
output_size=1
epochs=int(cfg.epochs)
lr=float(cfg.lr)
train_fract=float(cfg.train_fract)
dropout=int(cfg.dropout)
scheduler_factor=float(cfg.scheduler_factor)
scheduler_patience=int(cfg.scheduler_patience)
decay_ih=float(cfg.decay_ih)
decay_hh=float(cfg.decay_hh)
decay_other=float(cfg.decay_other)
seed=int(cfg.seed)

ingestion = IngestionPipeline(
    company_path=company_data_path,
    macro_paths=macro_data_path,
    company_col=company_col,
    bankruptcy_col=bankruptcy_col,
    revenue_cap=revenue_cap
)

In [None]:
ingestion.run()

In [None]:
X, M, y = ingestion.get_tensors()

In [None]:
dataset = DualInputSequenceDataset(X, M, y)

In [None]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size = 64, shuffle=True)

In [None]:
from src.train.gru import train_one_epoch

In [None]:
X.shape

In [None]:
M.shape

In [None]:
tft = TFTModel(static_input_dim=0, company_input_dim=X.shape[-1], macro_input_dim=M.shape[-1], decoder_input_dim=8)

In [None]:
X = X.to(device)
M = M.to(device)

In [None]:
tft = tft.to(device)

In [None]:
static_inputs = torch.zeros((X.shape[0], X.shape[1], 8), device=X.device)
logits, weights = tft.forward(X, M, decoder_inputs=static_inputs, static_inputs=static_inputs)

In [None]:
from torch.utils.data import DataLoader

In [None]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent
sys.path.append(str(ROOT))

import torch
import mlflow
import datetime
import logging
import yaml
import gc
import time


from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader
from torchmetrics import Metric

from pathlib import Path

from src.datasets.dual_input import DualInputSequenceDataset
from src.models.tft import TFTModel
from src.data.pipeline import IngestionPipeline
from src.train.tft import train_tft
from src.utils.utils import TrainConfig, FocalLoss

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

def load_yaml_file(path):
    with open(path) as stream:
        try:
            config_dict=yaml.safe_load(stream)
            return config_dict
        except yaml.YAMLError as e:
            TypeError(f"Config file could not be loaded: {e}")
    
def train_model_from_config(cfg: TrainConfig) -> TFTModel:
    """Main training function"""
    return train(
        company_path = Path("../" + cfg.firm_data),
        macro_paths = [str(id) for id in cfg.macro_data],
        bankruptcy_col = str(cfg.bankruptcy_col),
        company_col=str(cfg.company_col),
        revenue_cap=int(cfg.revenue_cap),
        metrics=cfg.get_metrics().to(cfg.device),
        device=str(cfg.device),
        hidden_size=int(cfg.hidden_size),
        output_size=1,
        epochs=int(cfg.epochs),
        lr=float(cfg.lr),
        train_fract=float(cfg.train_fract),
        dropout=int(cfg.dropout),
        alpha=float(cfg.alpha),
        gamma=float(cfg.alpha),
        scheduler_factor=float(cfg.scheduler_factor),
        scheduler_patience=int(cfg.scheduler_patience),
        stopping_patience=int(cfg.stopping_patience),
        decay_ih=float(cfg.decay_ih),
        decay_hh=float(cfg.decay_hh),
        decay_other=float(cfg.decay_other),
        seed=int(cfg.seed)
    )

def train(
    company_path: str,
    macro_paths: list[str],
    bankruptcy_col: str,
    company_col: str,
    revenue_cap: int,
    metrics: list[Metric],
    seed: int,
    num_layers: int = 2,
    hidden_size: int = 64,
    output_size: int = 1,
    epochs: int = 50,
    lr: float = 1e-3,
    train_fract: float = 0.8,
    dropout: float = 0.2,
    alpha: float = 0.9,
    gamma: float = 2.0,
    scheduler_factor: float = 0.85,
    scheduler_patience: int = 50,
    stopping_patience: int = 10,
    stopping_window: int = 5,
    min_lr: float = 0.0,
    decay_ih:float = 1e-5,
    decay_hh:float = 1e-5,
    decay_other:float = 1e-5,
    device: str="cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
):  
    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.INFO)
    
    ingestion = IngestionPipeline(
        company_path=company_path,
        macro_paths=macro_paths,
        company_col=company_col,
        bankruptcy_col=bankruptcy_col,
        revenue_cap=revenue_cap
    )
    
    ingestion.run()
    X, M_past, M_future, y = ingestion.get_tensors()
    
    dataset = DualInputSequenceDataset(
        firm_tensor = X,
        macro_past_tensor = M_past,
        macro_future_tensor = M_future,
        labels = y
    )
    
    train_ds, val_ds, seed = dataset.stratified_split(train_fract)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

    logger.info(f"Device: {device}")
    
    metrics.to(device)
    train_ds = train_ds.to_device(device)
    val_ds = val_ds.to_device(device)
    
    firm_input_size, macro_past_input_size, macro_future_input_size = dataset.input_dims()
    firm_input_size = firm_input_size[-1]
    macro_past_input_size = macro_past_input_size[-1]
    macro_future_input_size = macro_future_input_size[-1]
    
    mlflow.set_tracking_uri('http://127.0.0.1:8080')
    mlflow.set_experiment('bankruptcy-predictions')
    
    with mlflow.start_run():
        mlflow.log_param("seed", seed)
        mlflow.set_tag("model_type", "tft")
        model = TFTModel(
            static_input_dim=0,
            company_input_dim=firm_input_size,
            macro_input_dim=macro_past_input_size,
            decoder_input_dim=macro_future_input_size,
            hidden_dim=hidden_size,
            attention_heads=8,
            dropout=0.1
        )
        
        model = model.to(device)
        pos_weight = dataset.pos_weight()
        
        loss_fn = BCEWithLogitsLoss(pos_weight=pos_weight)
        
        # Logging hyperparameters
        mlflow.log_param("hidden_size", hidden_size)
        mlflow.log_param("dropout", dropout)
        mlflow.log_param("lr", lr)
        
        optimizer = Adam(model.parameters(), lr=lr)
        scheduler=ReduceLROnPlateau(
            optimizer=optimizer,
            mode="min",
            factor=scheduler_factor,
            patience=scheduler_patience,
            min_lr=min_lr
        )
        
        train_tft(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            scheduler=scheduler,
            stopping_patience=stopping_patience,
            stopping_window=stopping_window,
            device=device,
            epochs=epochs,
            metrics=metrics
        )
        
        model_name = f"TFTModel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
        mlflow.pytorch.log_model(model, model_name)
        torch.save(obj = model.state_dict(), f = f"../models/{model_name}.pth")
        print(f"Model saved: {model_name}")
    
    return model

In [2]:
config_dict = load_yaml_file("../config/tft_config.yml")
cfg = TrainConfig(**config_dict)

company_data_path = Path("../" + cfg.firm_data)
macro_data_path = [str(id) for id in cfg.macro_data]
bankruptcy_col = str(cfg.bankruptcy_col)
company_col=str(cfg.company_col)
revenue_cap=int(cfg.revenue_cap)
metrics=cfg.get_metrics().to(cfg.device)
device=str(cfg.device)
num_layers=int(cfg.num_classes)
hidden_size=16
output_size=1
epochs=int(cfg.epochs)
lr=float(cfg.lr)
train_fract=float(cfg.train_fract)
dropout=int(cfg.dropout)
scheduler_factor=float(cfg.scheduler_factor)
scheduler_patience=int(cfg.scheduler_patience)
decay_ih=float(cfg.decay_ih)
decay_hh=float(cfg.decay_hh)
decay_other=float(cfg.decay_other)
seed=int(cfg.seed)

ingestion = IngestionPipeline(
    company_path=company_data_path,
    macro_paths=macro_data_path,
    company_col=company_col,
    bankruptcy_col=bankruptcy_col,
    revenue_cap=revenue_cap
)

In [3]:
while True:
    try:
        model = train_model_from_config(cfg)
        del model
        gc.collect()
        torch.mps.empty_cache()
    except:
        logging.error("Training failed.", exc_info=True)
        time.sleep(3)

INFO:src.data.loaders:Reading file: ../data/demo_data.xlsx
INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  getattr(self, f"handle_{query_type}")()
INFO:sdmx.client:Request https://www.bdm.insee.fr/series/sdmx/data/SERIES_BDM/010774417
INFO:sdmx.client:with headers {'User-Agent': 'python-requests/2.32.3', 'Accept-Encoding': 'gzip, deflate, br, zstd', 'Accept': 'application/vnd.sdmx.genericdata+xml;version=2.1', 'Connection': 'keep-alive'}
  getattr(self, f"handle_{query_type}")()
INFO:sdmx.client:Request https://www.bdm.insee.fr/series/sdmx/data/SERIES_BDM/001763782
INFO:sdmx.client:with headers {'User-Agent': 'python-requests/2.32.3', 'Accept-Encoding': 'gzip, deflate, br, zstd', 'Accept': 'application/vnd.sdmx.genericdata+xml;version=2.1', 'Connection': 'keep-alive'}
  getattr(self, f"handle_{query_type}")()
INFO:sdmx.client:Request https://www.bdm.insee.fr/series/sdmx/data/SERIES_BDM/001587668
INFO:sdmx.client:with hea

Past: torch.Size([3, 43]), Future: torch.Size([3, 12])
Data sent to device: mps
Data sent to device: mps


Training:   0%|          | 0/100 [00:00<?, ?it/s]
ERROR:root:Training failed.
Traceback (most recent call last):
  File "/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/ipykernel_38104/2159626803.py", line 3, in <module>
    model = train_model_from_config(cfg)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/ipykernel_38104/2536884172.py", line 46, in train_model_from_config
    return train(
           ^^^^^^
  File "/var/folders/h1/hrjhnsw55w3fh7wq8fc7_bcm0000gn/T/ipykernel_38104/2536884172.py", line 169, in train
    train_tft(
  File "/Users/guillaumedecina-halmi/Documents/202504 Bankruptcy prediction on restaurants/src/train/tft.py", line 97, in train_tft
    train_loss, train_metrics = train_one_epoch(
                                ^^^^^^^^^^^^^^^^
  File "/Users/guillaumedecina-halmi/Documents/202504 Bankruptcy prediction on restaurants/src/train/tft.py", line 27, in train_one_epoch
    preds, _ = model(firm_seq, encoder_input

🏃 View run loud-moose-367 at: http://127.0.0.1:8080/#/experiments/387584985157093548/runs/b43e4e533d834655a116e46b3c91ad8b
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/387584985157093548


KeyboardInterrupt: 