# Setup

In [1]:
"""
Description:
------------
This script initializes a full PyTorch training environment with:

- Modular imports for data, logging, MLflow, ONNX, and more
- Logger setup via Loguru with structured file logging
- MLflow experiment initialization and run lifecycle management
- Device detection (CUDA or CPU) and runtime settings
- Optional support for ONNX runtime inference
- Color-coded console output (colorama) for easier debugging
- Preloaded `requirements.txt` for reproducibility tracking

This script is intended to be imported into downstream training or inference modules.
"""

# --- Basic Utilities ---
import datetime as dt
import math
import os
import re
import sys
import time
import warnings
import json
import functools
from pathlib import Path

# --- Data Manipulation ---
import numpy as np
import pandas as pd

# --- Validation ---
from pydantic import BaseModel, Field, ValidationError

# --- Visualization ---
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Latex, Math

# --- Console Formatting ---
from colorama import Fore, Style

# --- Logging Framework ---
from loguru import logger
import uuid

# --- MLflow for Experiment Tracking ---
import mlflow
import mlflow.pytorch
import mlflow.sklearn
from mlflow.data import from_pandas
from mlflow.models.signature import infer_signature

# --- ONNX Runtime (Optional) ---
import onnx
import onnxruntime as ort

# Add the below onnx code to FastAPI for deployment
# @app.on_event("startup")
# def load_model():
#     global session
#     session = ort.InferenceSession(
#         "model.onnx",
#         providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
#     )


# --- Scientific Computation ---
import scipy as sp
import sympy as sym
from plotly.data import experiment

# --- Scikit-Learn ---
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer

from sklearn.model_selection import (
    train_test_split, GridSearchCV, StratifiedKFold, cross_val_score
)

from sklearn.preprocessing import (
    StandardScaler, MinMaxScaler, RobustScaler,
    OneHotEncoder, OrdinalEncoder
)

from sklearn.linear_model import (
    LogisticRegression, ElasticNet, Ridge, Lasso
)

from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import (
    confusion_matrix, ConfusionMatrixDisplay,
    classification_report, accuracy_score,
    mean_squared_error, mean_absolute_error, r2_score
)

# --- PyTorch Core ---
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchvision import datasets, transforms

# --- Joblib for Model Serialization ---
import joblib

# --- Optional: Config Frameworks (commented) ---
# from omegaconf import OmegaConf
# import hydra

# --- Secrets Management ---
sys.path.append("/mnt/git/github/gabemcwilliams/common-components/security")
from vault_mgr import *

os.environ["MLFLOW_TRACKING_INSECURE_TLS"] = "false"

# --- PyTorch Runtime Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")
print(f"[INFO] PyTorch version: {torch.__version__}")
torch.set_num_threads(32)

  from .autonotebook import tqdm as notebook_tqdm


[INFO] Using device: cuda
[INFO] PyTorch version: 2.6.0+cu124




## Logger

In [8]:
# --- MLflow Experiment Setup ---
experiment_name = "torch_custom_datasets"
mlflow.set_experiment(experiment_name)

# --- Store requirements.txt as obj for logging ref
with open('/mnt/git/requirements.txt') as f:
    requirements = f.read().splitlines()

# --- Init MLflow Run ---

# kill unclosed run if exists
if mlflow.active_run():
    mlflow.end_run()

run = mlflow.start_run()

# --- Logging Initialization ---

# Create log directory for the experiment
log_dir = Path("/mnt/mls/logs") / experiment_name
log_dir.mkdir(parents=True, exist_ok=True)

# Define log file path
log_file = log_dir / f"train_{experiment_name}.log"
timestamp = str(dt.datetime.now(dt.timezone.utc).strftime("%Y_%m_%d_%H%M%S"))
run_id = run.info.run_id

logger.remove()

logger = logger.bind(run_id=run_id)
logger = logger.bind(run_id=run_id, timestamp=timestamp)

# File logging with daily rotation and compression
logger.add(
    str(log_file),  # Path to the log file (must be a string, not a Path object)

    format="{extra[timestamp]} | {level} | {extra[run_id]} | {name}:{function}:{line} | {message}",
    # Format of each log line:
    # {time}     - Timestamp (YYYY-MM-DD HH:MM:SS.mmm)
    # {level}    - Log level (INFO, DEBUG, etc.)
    # {name}     - Module name (e.g., 'helper' from 'helper.py')
    # {function} - Name of the function that emitted the log
    # {line}     - Line number where logger was called
    # {message}  - The actual message logged

    level="DEBUG",  # Minimum log level for this handler (includes INFO, WARNING, ERROR, etc.)
    rotation="00:00",  # Automatically rotate log file when "00:00" for daily rotation
    retention="14 days",  # Keep rotated log files for 14 days, then delete them automatically
    compression="zip",  # Compress rotated log files using ZIP format to save space
    enqueue=True,  # Use multiprocessing-safe queue to log from multiple threads/processes safely
    backtrace=True,  # Show full stack trace (even outside of the except block) if an error occurs
    diagnose=False,  # Disable automatic introspection of variables in the traceback (safer for production)
    mode="a",  # Open the log file in append mode (ensures logs aren’t overwritten on script restart)
    filter=lambda record: record["extra"].get("run_id") == run_id  # Optional, if multiple run_ids used
)


def log_time(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        logger.info(f"{func.__name__} took {end_time - start_time:.4f} seconds")
        return result

    return wrapper


logger.info("Logger initialized (mlflow.start_run()).")