In [1]:
import os
import mlflow
import torch
import random
import numpy as np
from tqdm import tqdm
import sys

sys.path.append('./')
from src.utils.data_utils import read_dataset
from src.trainers.AETrainer import AETrainer

DATA_PATH = './data'
FILE_NAME_TRAIN = 'BP_safety_network_master_NN_train.csv'
FILE_NAME_TEST = 'BP_safety_network_master_NN_test.csv'
RANDOM_SEED = 42

mlflow.set_tracking_uri('http://127.0.0.1:8080')

In [2]:
experiment = mlflow.set_experiment('01_DAE')

In [3]:
with mlflow.start_run(log_system_metrics=True) as run:
    # Seed random generators to ensure deterministic experiments
    random.seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    torch.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Define PyTorch device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Read and log train, validation and test datasets
    X_train, y_train, _ = read_dataset(DATA_PATH, FILE_NAME_TRAIN, targets='rs_crashes_2324', device=device)
    X_test, y_test, non_accident_dim = read_dataset(DATA_PATH, FILE_NAME_TEST, targets='rs_crashes_2324', device=device)
    
    # Specify and log training parameters
    params = {
        'inp_dim': non_accident_dim,
        'noise_factor': 0.5,
        'enc_dim': 4,
        'learning_rate': 1e-2,
        'weight_decay': 1e-8
    }
    mlflow.log_params(params.copy())

    # Define, train and evaluate model
    trainer = AETrainer(**params)
    trainer.train(X_train, y_train, X_test, y_test, 1500)
    trainer.evaluate(X_test, y_test, 'test')

2025/10/02 12:43:02 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
100%|██████████████████████████████████████████████████████████████████████████████| 1500/1500 [01:33<00:00, 16.08it/s]
2025/10/02 12:44:59 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/10/02 12:44:59 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


🏃 View run popular-slug-721 at: http://127.0.0.1:8080/#/experiments/775881003945806667/runs/e5a5b25fc25f4435bd9177664c21ff9a
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/775881003945806667
