Connected to base (Python 3.12.9)

In [1]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent
sys.path.append(str(ROOT))

from src.data.pipeline import IngestionPipeline
from src.datasets.dual_input import DualInputSequenceDataset
from src.models.tft import TFTModel
from src.utils.utils import TrainConfig

In [2]:
import torch
import mlflow
import datetime
import logging
import yaml

from dataclasses import dataclass, field
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import (
    BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryMatthewsCorrCoef,
    MulticlassAccuracy, MulticlassAUROC, MulticlassF1Score)
from pathlib import Path

from src.datasets.dual_input import DualInputSequenceDataset
from src.models.gru import GRUModel
from src.data.pipeline import IngestionPipeline
from src.train import train_model
from src.utils.utils import CustomReduceLROnPlateau, collate_with_macro, TrainConfig, FocalLoss

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

def load_yaml_file(path):
    with open(path) as stream:
        try:
            config_dict=yaml.safe_load(stream)
            return config_dict
        except yaml.YAMLError as e:
            TypeError(f"Config file could not be loaded: {e}")

In [3]:
config_dict = load_yaml_file("../config/model_config.yml")
cfg = TrainConfig(**config_dict)

company_data_path = Path("../" + cfg.firm_data)
macro_data_path = [str(id) for id in cfg.macro_data]
bankruptcy_col = str(cfg.bankruptcy_col)
company_col=str(cfg.company_col)
revenue_cap=int(cfg.revenue_cap)
metrics=cfg.get_metrics().to(cfg.device)
device=str(cfg.device)
num_layers=int(cfg.num_classes)
hidden_size=16
output_size=1
epochs=int(cfg.epochs)
lr=float(cfg.lr)
train_fract=float(cfg.train_fract)
dropout=int(cfg.dropout)
scheduler_factor=float(cfg.scheduler_factor)
scheduler_patience=int(cfg.scheduler_patience)
decay_ih=float(cfg.decay_ih)
decay_hh=float(cfg.decay_hh)
decay_other=float(cfg.decay_other)
seed=int(cfg.seed)

ingestion = IngestionPipeline(
    company_path=company_data_path,
    macro_paths=macro_data_path,
    company_col=company_col,
    bankruptcy_col=bankruptcy_col,
    revenue_cap=revenue_cap
)

In [4]:
ingestion.run()

INFO:src.data.loaders:Reading file: ../data/demo_data.xlsx
INFO:src.data.loaders:Dropping high-revenue outliers...
INFO:src.data.loaders:Loading 3 macroeconomic series...
  getattr(self, f"handle_{query_type}")()
INFO:sdmx.client:Request https://www.bdm.insee.fr/series/sdmx/data/SERIES_BDM/010774417
INFO:sdmx.client:with headers {'User-Agent': 'python-requests/2.32.3', 'Accept-Encoding': 'gzip, deflate, br, zstd', 'Accept': 'application/vnd.sdmx.genericdata+xml;version=2.1', 'Connection': 'keep-alive'}
  getattr(self, f"handle_{query_type}")()
INFO:sdmx.client:Request https://www.bdm.insee.fr/series/sdmx/data/SERIES_BDM/001763782
INFO:sdmx.client:with headers {'User-Agent': 'python-requests/2.32.3', 'Accept-Encoding': 'gzip, deflate, br, zstd', 'Accept': 'application/vnd.sdmx.genericdata+xml;version=2.1', 'Connection': 'keep-alive'}
  getattr(self, f"handle_{query_type}")()
INFO:sdmx.client:Request https://www.bdm.insee.fr/series/sdmx/data/SERIES_BDM/001587668
INFO:sdmx.client:with hea

In [5]:
X, M, y = ingestion.get_tensors()

INFO:src.data.tensor_factory:Converting financial series to tensors...
INFO:src.data.tensor_factory:Scaling financial data with RobustScaler...
INFO:src.data.tensor_factory:Shaped financial data tensor: (6296, 3, 4)
INFO:src.data.tensor_factory:Shaped macro data tensor: torch.Size([3, 36])


In [6]:
dataset = DualInputSequenceDataset(X, M, y)

In [7]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size = 64, shuffle=True)

In [8]:
from src.train import train_one_epoch

In [9]:
X.shape

torch.Size([6296, 3, 4])

In [10]:
M.shape

torch.Size([3, 36])

In [11]:
tft = TFTModel(static_input_dim=0, company_input_dim=X.shape[-1], macro_input_dim=M.shape[-1], decoder_input_dim=8)

In [12]:
X.shape

torch.Size([6296, 3, 4])

In [16]:
static_inputs = torch.zeros((X.shape[0], 3, 8), device=X.device)
logits, weights = tft.forward(X, M, decoder_inputs=static_inputs, static_inputs=static_inputs)

In [18]:
logits.mean()

tensor(-0.0471, grad_fn=<MeanBackward0>)