In [1]:
from enum import Enum
from pathlib import Path
from typing import TypedDict

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

from helpers.evals import evaluate_model_with_scaling
from helpers.features import process_dataset
from helpers.loaders import prepare_data_for_interpolation
from helpers.models import FrequencyAwareNetwork
from helpers.trainers import train_frequency_aware_model

### Config


In [2]:
ANALYSIS = False
VERBOSE = True

DATASET_FILE_PATH = "dataset.csv"

GRAPH_FOLDER = "graphs"
MODELS = "models"
PREDICTIONS = "predictions"
SUBFOLDER = "baseline"

### Data


In [3]:
df = pd.read_csv(DATASET_FILE_PATH)

(
    X_train,
    Y_train,
    X_test,
    Y_test,
    voltage_scaler,
    freq_scaler,
    freq_idx,
    other_idx,
) = process_dataset(df, test_size=0.2, random_state=42)

Identified 18 frequency-related features and 13 other features


### Training


In [4]:
model_dir = Path(MODELS) / SUBFOLDER
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class SchedulerTypes(str, Enum):
    REDUCE_ON_PLATEAU = "reduce_on_plateau"
    STEP = "step"
    COSINE_ANNEALING = "cosine_annealing"
    ONE_CYCLE = "one_cycle"
    EXPONENTIAL = "exponential"
    NONE = "none"


class ActivationTypes(str, Enum):
    GELU = "gelu"
    RELU = "relu"
    SILU = "silu"


class Hyperparameters(TypedDict):
    hidden_sizes: list[int]
    dropout_rate: float
    learning_rate: float
    activation: ActivationTypes
    lr_scheduler_type: SchedulerTypes
    epochs: int
    patience: int
    batch_size: int
    scale_y: bool


class ModelDict(TypedDict):
    model_name: str
    labels: tuple[str, str]
    hparams: Hyperparameters


In [6]:
s11_params: Hyperparameters = {
    "hidden_sizes": [256, 512, 1024, 512],
    "learning_rate": 0.001,
    "dropout_rate": 0.1,
    "batch_size": 1024,
    "epochs": 200,
    "patience": 30,
    "lr_scheduler_type": SchedulerTypes.REDUCE_ON_PLATEAU,
    "activation": ActivationTypes.GELU,
    "scale_y": False,
}
s12_params: Hyperparameters = {
    "hidden_sizes": [384, 768, 1536, 768, 384],
    "learning_rate": 0.002,
    "dropout_rate": 0.1,
    "batch_size": 1024,
    "epochs": 300,
    "patience": 40,
    "lr_scheduler_type": SchedulerTypes.REDUCE_ON_PLATEAU,
    "activation": ActivationTypes.GELU,
    "scale_y": True,
}

s21_params: Hyperparameters = {
    "hidden_sizes": [1024, 2048, 2048, 1024],
    "learning_rate": 0.002,
    "dropout_rate": 0.1,
    "batch_size": 1024,
    "epochs": 200,
    "patience": 30,
    "lr_scheduler_type": SchedulerTypes.REDUCE_ON_PLATEAU,
    "activation": ActivationTypes.GELU,
    "scale_y": False,
}

s22_params: Hyperparameters = {
    "hidden_sizes": [1024, 1536, 2048, 1536, 1024],
    "learning_rate": 0.002,
    "dropout_rate": 0.1,
    "batch_size": 1024,
    "epochs": 200,
    "patience": 30,
    "lr_scheduler_type": SchedulerTypes.REDUCE_ON_PLATEAU,
    "activation": ActivationTypes.GELU,
    "scale_y": False,
}

In [7]:
models_to_train: list[ModelDict] = [
    {
        "model_name": "s11",
        "labels": ("S_deemb(1,1)_real", "S_deemb(1,1)_imag"),
        "hparams": s11_params,
    },
    {
        "model_name": "s12",
        "labels": ("S_deemb(1,2)_real", "S_deemb(1,2)_imag"),
        "hparams": s12_params,
    },
    {
        "model_name": "s21",
        "labels": ("S_deemb(2,1)_real", "S_deemb(2,1)_imag"),
        "hparams": s21_params,
    },
    {
        "model_name": "s22",
        "labels": ("S_deemb(2,2)_real", "S_deemb(2,2)_imag"),
        "hparams": s22_params,
    },
]

In [8]:
results = {}

for model_to_train in models_to_train:
    real_label, imag_label = model_to_train["labels"]
    label_pair = [real_label, imag_label]
    y_train_pair = Y_train[label_pair]
    y_test_pair = Y_test[label_pair]

    print(f"Training {model_to_train['model_name']} model")

    X_train_tensor, Y_train_tensor, X_test_tensor, Y_test_tensor, loader, y_scaler = (
        prepare_data_for_interpolation(
            X_train,
            y_train_pair,
            X_test,
            y_test_pair,
            batch_size=model_to_train["hparams"]["batch_size"],
            scale_y=model_to_train["hparams"]["scale_y"],
        )
    )

    model = FrequencyAwareNetwork(
        len(freq_idx),
        len(other_idx),
        model_to_train["hparams"]["hidden_sizes"],
        model_to_train["hparams"]["dropout_rate"],
        model_to_train["hparams"]["activation"],
    )
    model.set_feature_indices(freq_idx, other_idx)

    optimizer = optim.Adam(
        model.parameters(), lr=model_to_train["hparams"]["learning_rate"]
    )
    criterion = nn.MSELoss()
    trained_model = train_frequency_aware_model(
        model,
        loader,
        X_test_tensor,
        Y_test_tensor,
        criterion,
        optimizer,
        device,
        epochs=model_to_train["hparams"]["epochs"],
        patience=model_to_train["hparams"]["patience"],
        scheduler_str=model_to_train["hparams"]["lr_scheduler_type"],
    )

    metrics = evaluate_model_with_scaling(
        trained_model,
        X_test_tensor,
        Y_test,
        label_pair,
        device,
        y_scaler,
    )
    results[model_to_train["model_name"]] = metrics

Training s11 model


Training Epochs:   8%|▊         | 15/200 [00:09<02:00,  1.54it/s, Epoch=15, Val Loss=0.007247, Best=0.006822, LR=0.001]


KeyboardInterrupt: 

### Results


In [None]:
for model_to_train in models_to_train:
    model_name = model_to_train["model_name"]

    metrics, avg_metrics, predictions_original = results[model_name]

    print("--" * 20)
    print("--" * 20)
    print(f"Performance metrics for {model_name}:")
    for component, metric in metrics.items():
        print(f"\n\t{component}:")
        print(f"\tRMSE: {metric['rmse']:.6f}")
        print(f"\tR²: {metric['r2']:.6f}")
        print(f"\tMAE: {metric['mae']:.6f}")
        if "smape" in metric:
            print(f"\tSMAPE: {metric['smape']:.2f}%")
        else:
            print(f"\tMAPE: {metric['mape']:.2f}%")

    print(f"\nAverage metrics for {model_name}:")
    print(f"\n\tR²: {avg_metrics['r2']:.6f}")
    print(f"\tRMSE: {avg_metrics['rmse']:.6f}")
    print(f"\tMAE: {avg_metrics['mae']:.6f}")
    if "smape" in avg_metrics:
        print(f"\tSMAPE: {avg_metrics['smape']:.2f}%")
    else:
        print(f"\tMAPE: {avg_metrics['mape']:.2f}%")

print("--" * 20)
print("--" * 20)