In [None]:
import torch
import pandas as pd
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_percentage_error, median_absolute_error
import scipy.stats as stats

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from core.dataset import DatasetLoader
from core.model import Model
from core.config import config
from core.logger import Logger

pd.set_option('display.max_columns', 20)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print(f"Computational Device: {device}")
print(f"Target Variables: {config.columns.target_cols}")

In [None]:
project_root = os.getcwd()
data_path = os.path.join(project_root, config.paths.train_data)
runs_dir = os.path.join(project_root, "Runs")
available_runs = sorted(os.listdir(runs_dir))

for run in reversed(available_runs):
    checkpoint_path = os.path.join(runs_dir, run, "checkpoints", "best_model.pth")
    if os.path.exists(checkpoint_path):
        latest_run = run
        break

ema_checkpoint_path = os.path.join(runs_dir, latest_run, "checkpoints", "best_model_ema.pth")
use_ema = os.path.exists(ema_checkpoint_path)

print(f"Experiment Run: {latest_run}")
print(f"Checkpoint Path: {checkpoint_path}")
print(f"EMA Model: {'Available' if use_ema else 'Not Available'}")

In [None]:
model_path = checkpoint_path
checkpoint = torch.load(model_path, map_location=device, weights_only=False)

if 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif 'ema_state_dict' in checkpoint:
    state_dict = checkpoint['ema_state_dict']
else:
    state_dict = checkpoint

ckpt_embedding_dims = checkpoint.get('embedding_dimensions', None)

In [None]:
data_module = DatasetLoader(data_path, embedding_dimensions=ckpt_embedding_dims)
train_loader, validation_loader, test_loader = data_module.dataloader_pipeline()

target_scaler = data_module.target_scalers[config.columns.target_col_name]
continuous_scalers  = data_module.continuous_scalers

embedding_dims = data_module.embedding_dimensions
model = Model(embedding_dimensions=embedding_dims, num_continuous=len(data_module.continuous_columns))

model.load_state_dict(state_dict)
model.eval()


In [None]:
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model variant:        {'EMA' if use_ema else 'Standard'}")
print(f"Total parameters:     {n_params:,}")
print(f"Trainable parameters: {n_trainable:,}")

if 'epoch' in checkpoint:
    print(f"Training epochs:      {checkpoint['epoch']}")
if 'best_val_auc' in checkpoint:
    print(f"Best validation AUC:  {checkpoint['best_val_auc']:.4f}")

In [None]:
all_preds, all_targets, all_lengths = [], [], []

with torch.no_grad():
    for x_cat, x_cont, y, lengths in test_loader:   
        preds = model(x_cat, x_cont, lengths)

        all_preds.append(preds.cpu())
        all_targets.append(y.cpu())
        all_lengths.extend(lengths.cpu().numpy())

preds_tensor = torch.cat(all_preds)
targets_tensor = torch.cat(all_targets)
seq_lengths = np.array(all_lengths)

den_targets = np.expm1(target_scaler.inverse_transform(targets_tensor.reshape(-1, 1)))
den_preds   = np.expm1(target_scaler.inverse_transform(preds_tensor.reshape(-1, 1)))

den_preds   = np.clip(den_preds, 0, None)
den_targets = np.clip(den_targets, 0, None)


mae = np.mean(np.abs(den_preds - den_targets))
rmse = np.sqrt(np.mean((den_preds - den_targets) ** 2))
ss_res = np.sum((den_targets - den_preds) ** 2)
ss_tot = np.sum((den_targets - np.mean(den_targets)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot != 0 else float('nan')


In [None]:
print(f"Sequence length - Min: {seq_lengths.min()}, Max: {seq_lengths.max()}, Mean: {seq_lengths.mean():.1f}")
print(f"\nTarget Distribution:")

print(f"\nRegression Metrics:")
print(f"  MAE:  {mae:.4f}")
print(f"  RMSE: {rmse:.4f}")
print(f"  RÂ²:   {r2:.4f}")

In [None]:
errors = np.ravel(den_preds - den_targets)

plt.figure(figsize=(8, 5))
plt.hist(errors, bins='auto', color='skyblue', edgecolor='black')
plt.title('Histogram of Prediction Errors (Residuals)')
plt.xlabel('Prediction Error')
plt.ylabel('Frequency')
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].scatter(seq_lengths, errors, alpha=0.5, color='royalblue', edgecolor='k')
axes[0].set_title('Error vs. Sequence Length')
axes[0].set_xlabel('Sequence Length')
axes[0].set_ylabel('Prediction Error')
axes[0].grid(True, linestyle='--', alpha=0.6)

axes[1].scatter(targets_tensor.numpy(), errors, alpha=0.5, color='darkorange', edgecolor='k')
axes[1].set_title('Error vs. Target Value')
axes[1].set_xlabel('True Target Value')
axes[1].set_ylabel('Prediction Error')
axes[1].grid(True, linestyle='--', alpha=0.6)

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(7, 6))
plt.scatter(den_targets, den_preds, alpha=0.5, color='mediumseagreen', edgecolor='k')
plt.plot([den_targets.min(), den_targets.max()], [den_targets.min(), den_targets.max()], 'r--', lw=2, label='Ideal Fit')
plt.title('Predicted vs. True Values')
plt.xlabel('True Target Value')
plt.ylabel('Predicted Value')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(7, 5))
plt.scatter(den_preds, errors, alpha=0.5, color='slateblue', edgecolor='k')
plt.axhline(0, color='red', linestyle='--', lw=2)
plt.title('Residuals vs. Predicted Values')
plt.xlabel('Predicted Value')
plt.ylabel('Residual (Error)')
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(6, 6))
stats.probplot(errors, dist="norm", plot=plt)
plt.title('QQ-plot of Residuals')
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

In [None]:
sorted_abs_errors = np.sort(np.abs(errors))
cum_abs_error = np.cumsum(sorted_abs_errors)
plt.figure(figsize=(8, 5))
plt.plot(np.arange(1, len(cum_abs_error)+1), cum_abs_error, color='teal')
plt.title('Cumulative Absolute Error')
plt.xlabel('Sample (sorted by error)')
plt.ylabel('Cumulative Absolute Error')
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()