In [None]:
import time
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from helpers.split import create_frequency_based_split

### Configuration


In [None]:
ANALYSIS = False

DATASET_FILE_PATH = "dataset.csv"

GRAPH_FOLDER = "graphs"
MODELS = "models"
PREDICTIONS = "predictions"
SUBFOLDER = "baseline"

VERBOSE = True
EPOCHS = 100

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Loading and Preprocessing


In [None]:
# Load the dataset
df = pd.read_csv(DATASET_FILE_PATH)

In [None]:
# Selected features and targets
features = ["freq", "vb", "vc", "DEV_GEOM_L", "NUM_OF_TRANS_RF"]
targets = [
    "S_deemb(1,1)_real",
    "S_deemb(1,1)_imag",
    "S_deemb(1,2)_real",
    "S_deemb(1,2)_imag",
    "S_deemb(2,1)_real",
    "S_deemb(2,1)_imag",
    "S_deemb(2,2)_real",
    "S_deemb(2,2)_imag",
]

In [None]:
if ANALYSIS:
    print("Checking for null values in features:")
    feature_nulls = df[features].isnull().sum()
    print(feature_nulls[feature_nulls > 0])  # Only show features with nulls

    print("\nChecking for null values in labels:")
    label_nulls = df[targets].isnull().sum()
    print(label_nulls)

In [None]:
# Filter rows with any null values in features or labels
df_clean = df.dropna(subset=features + targets)

if ANALYSIS:
    print(f"\nOriginal dataset shape: {df.shape}")
    print(f"Cleaned dataset shape: {df_clean.shape}")
    print(f"Removed {df.shape[0] - df_clean.shape[0]} rows with null values")

In [None]:
# Apply improved frequency-based split
train_mask, test_mask = create_frequency_based_split(
    df_clean, test_size=0.2, random_state=42
)


# Create separate dataframes for features and labels
X = df_clean[features].copy()
Y = df_clean[targets].copy()

# Encode categorical features
X["DEV_GEOM_L"] = X["DEV_GEOM_L"].astype("category").cat.codes
X["NUM_OF_TRANS_RF"] = X["NUM_OF_TRANS_RF"].astype("category").cat.codes

# Scale freq, vb, vc
scaler = MinMaxScaler()
X[["freq", "vb", "vc"]] = scaler.fit_transform(X[["freq", "vb", "vc"]])

if ANALYSIS:
    print(f"\nFeature dataset shape: {X.shape}")
    print(f"S-parameter labels shape: {Y.shape}")

    print("\nFeature statistics (first 5 columns):")
    print(X.iloc[:, :5].describe())

    print("\nS-parameter statistics (first 4 columns):")
    print(Y.iloc[:, :4].describe())

    print("\nFeature and label separation complete!")

In [None]:
# Split dataset
Y = df_clean[targets]
X_train, X_test = X[train_mask], X[test_mask]
Y_train, Y_test = Y[train_mask], Y[test_mask]

In [None]:
# Optional Y-scaler for targets (for improved convergence)
y_scaler = StandardScaler()
Y_train_scaled = y_scaler.fit_transform(Y_train)
Y_test_scaled = y_scaler.transform(Y_test)

### Model


In [None]:
hyperparams = {
    "hidden_sizes": [384, 768, 1536],
    "dropout_rate": 0.1,
    "activation": "gelu",
    "lr": 0.002,
    "epochs": 300,
    "patience": 40,
    "batch_size": 512,
    "lr_scheduler_type": "reduce_on_plateau",
}

In [None]:
class FrequencyAwareNetwork(nn.Module):
    def __init__(
        self,
        freq_features,
        other_features,
        hidden_sizes,
        dropout_rate,
        activation,
    ):
        super().__init__()

        if activation == "silu":
            activation_fn = nn.SiLU()
        elif activation == "relu":
            activation_fn = nn.ReLU()
        elif activation == "gelu":
            activation_fn = nn.GELU()
        else:
            raise ValueError(f"Unsupported activation function: {activation}")

        # Frequency-specific processing branch
        freq_layers = []
        prev_size = freq_features
        for h_size in hidden_sizes[:2]:  # First two hidden sizes for branches
            freq_layers.append(nn.Linear(prev_size, h_size))
            freq_layers.append(
                activation_fn
            )  # Using SiLU (Swish) activation for better performance
            freq_layers.append(nn.BatchNorm1d(h_size))
            freq_layers.append(nn.Dropout(dropout_rate))
            prev_size = h_size

        self.freq_branch = nn.Sequential(*freq_layers)

        # Other parameters branch
        other_layers = []
        prev_size = other_features
        for h_size in hidden_sizes[:2]:
            other_layers.append(nn.Linear(prev_size, h_size))
            other_layers.append(activation_fn)
            other_layers.append(nn.BatchNorm1d(h_size))
            other_layers.append(nn.Dropout(dropout_rate))
            prev_size = h_size

        self.other_branch = nn.Sequential(*other_layers)

        # Combined processing with residual connections
        combined_layers = []
        prev_size = hidden_sizes[1] * 2  # Output size from both branches combined

        for h_size in hidden_sizes[2:]:
            combined_layers.append(nn.Linear(prev_size, h_size))
            combined_layers.append(activation_fn)
            combined_layers.append(nn.BatchNorm1d(h_size))
            combined_layers.append(nn.Dropout(dropout_rate))
            prev_size = h_size

        # Final output layer for real and imaginary components
        combined_layers.append(nn.Linear(prev_size, 2))

        self.combined = nn.Sequential(*combined_layers)

        # Store feature indices for processing
        self.freq_indices = None
        self.other_indices = None

    def forward(self, x):
        # Split input into frequency and other features
        if self.freq_indices is None or self.other_indices is None:
            raise ValueError(
                "Feature indices not set. Call set_feature_indices() first."
            )

        freq_input = x[:, self.freq_indices]
        other_input = x[:, self.other_indices]

        # Process through branches
        freq_features = self.freq_branch(freq_input)
        other_features = self.other_branch(other_input)

        # Combine and output
        combined = torch.cat([freq_features, other_features], dim=1)
        return self.combined(combined)

    def set_feature_indices(self, freq_indices, other_indices):
        """Set indices for frequency and other features."""
        self.freq_indices = freq_indices
        self.other_indices = other_indices

In [None]:
# Convert to tensors
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train_scaled, dtype=torch.float32)
Y_test_tensor = torch.tensor(Y_test_scaled, dtype=torch.float32)

train_loader = DataLoader(
    TensorDataset(X_train_tensor, Y_train_tensor),
    batch_size=hyperparams["batch_size"],
    shuffle=True,
)

### Training


In [None]:
def train_model(X_train, X_test, Y_train, Y_test, hyperparameters, selected_features):
    """
    Train frequency-aware models for each S-parameter with conditional scaling.
    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Identify frequency-related features
    freq_indices = [X.columns.get_loc("freq")]
    other_indices = [i for i in range(X.shape[1]) if i not in freq_indices]

    # Store results and models
    models = {}
    all_results = {}
    all_predictions = {}
    scalers = {}  # Store scalers for each model

    # Record start time
    start_time = time.time()

    # Train a model for each S-parameter
    for model_name, components in s_parameter_models.items():
        print(f"\n{'=' * 50}")
        print(f"Training frequency-aware model for {model_name}")
        print(f"{'=' * 50}")

        # Decide whether to scale Y data (only for S12)
        scale_y = model_name == "S12"

        # Prepare data with conditional scaling
        prep_results = prepare_data_for_pytorch_with_scaling(
            X_train,
            Y_train,
            X_test,
            Y_test,
            components,
            hyperparameters["batch_size"],
            scale_y=scale_y,
        )

        if scale_y:
            (
                X_train_tensor,
                Y_train_tensor,
                X_test_tensor,
                Y_test_tensor,
                train_loader,
                y_scaler,
            ) = prep_results
            scalers[model_name] = y_scaler
            print("Applied StandardScaler to Y values for S12")
        else:
            (
                X_train_tensor,
                Y_train_tensor,
                X_test_tensor,
                Y_test_tensor,
                train_loader,
                _,
            ) = prep_results

        # Initialize model
        model = FrequencyAwareNetwork(
            len(freq_indices),
            len(other_indices),
            hyperparameters["hidden_sizes"],
            hyperparameters["dropout_rate"],
            hyperparameters.get("activation", "gelu"),
        )
        model.set_feature_indices(freq_indices, other_indices)

        # Loss and optimizer
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=hyperparameters["learning_rate"])

        # Train model (use your existing train_model function)
        trained_model, train_losses, val_losses = train_model(
            model,
            train_loader,
            X_test_tensor,
            Y_test_tensor,
            criterion,
            optimizer,
            device,
            hyperparameters["epochs"],
            hyperparameters["early_stopping_patience"],
            lr_scheduler_type=hyperparameters.get("lr_scheduler_type", "one_cycle"),
        )

        # Plot learning curves
        plot_learning_curves(train_losses, val_losses, model_name)

        # Evaluate model with proper scaling handling
        metrics, avg_metrics, predictions = evaluate_model_with_scaling(
            trained_model,
            X_test_tensor,
            Y_test_tensor,
            Y_test,
            components,
            device,
            scalers.get(model_name),
        )

        # Plot predictions and error distributions
        plot_predictions(Y_test, predictions, components, model_name)
        plot_error_distribution(Y_test, predictions, components, model_name)

        # Print results
        print(f"\nPerformance metrics for {model_name}:")
        for component, metric in metrics.items():
            print(f"  {component}:")
            print(f"    RMSE: {metric['rmse']:.6f}")
            print(f"    R²: {metric['r2']:.6f}")
            print(f"    MAE: {metric['mae']:.6f}")
            if "smape" in metric:
                print(f"    SMAPE: {metric['smape']:.2f}%")
            else:
                print(f"    MAPE: {metric['mape']:.2f}%")

        print(f"\nAverage metrics for {model_name}:")
        print(f"  R²: {avg_metrics['r2']:.6f}")
        print(f"  RMSE: {avg_metrics['rmse']:.6f}")
        print(f"  MAE: {avg_metrics['mae']:.6f}")
        if "smape" in avg_metrics:
            print(f"  SMAPE: {avg_metrics['smape']:.2f}%")
        else:
            print(f"  MAPE: {avg_metrics['mape']:.2f}%")

        # Store results
        models[model_name] = trained_model
        all_results[model_name] = {
            "component_metrics": metrics,
            "avg_metrics": avg_metrics,
        }
        all_predictions[model_name] = predictions

    # Record total training time
    train_time = time.time() - start_time
    print(f"\nTotal training time: {train_time:.2f} seconds")

    # Save models
    for model_name, model in models.items():
        torch.save(model.state_dict(), f"freq_aware_results/{model_name}_model.pth")

    print("Models and results saved to freq_aware_results/")

    return models, all_results, all_predictions, scalers

In [None]:
# Training loop
best_loss = float("inf")
best_model_state = None
counter = 0
start_time = time.time()

for epoch in range(hyperparams["epochs"]):
    model.train()
    total_loss = 0
    loop = tqdm(
        train_loader, desc=f"Epoch [{epoch + 1}/{hyperparams['epochs']}]", leave=False
    )
    for xb, yb in loop:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}: Avg Loss = {avg_loss:.6f}", end="")

    if scheduler:
        scheduler.step(avg_loss)

    if avg_loss < best_loss:
        best_loss = avg_loss
        best_model_state = model.state_dict()
        counter = 0
    else:
        counter += 1
        if counter >= hyperparams["patience"]:
            print(f"Early stopping at epoch {epoch + 1}")
            break

print(
    f"Training complete in {time.time() - start_time:.2f} seconds. Best loss: {best_loss:.6f}"
)

                                                                              

Epoch 1: Avg Loss = 0.051038

                                                                              

Epoch 2: Avg Loss = 0.051200

                                                                              

Epoch 3: Avg Loss = 0.050686

                                                                              

Epoch 4: Avg Loss = 0.051306

                                                                              

Epoch 5: Avg Loss = 0.049064

                                                                              

Epoch 6: Avg Loss = 0.051155

                                                                              

Epoch 7: Avg Loss = 0.049329

                                                                              

Epoch 8: Avg Loss = 0.049317

                                                                              

Epoch 9: Avg Loss = 0.048813

Epoch [10/300]:  31%|███       | 77/252 [00:00<00:00, 251.72it/s, loss=0.0473]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7faca6389a90>>
Traceback (most recent call last):
  File "/home/w01f/ml4rf/env/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
                                                                               

Epoch 10: Avg Loss = 0.049101

                                                                               

Epoch 11: Avg Loss = 0.049042

                                                                               

Epoch 12: Avg Loss = 0.049999

                                                                               

Epoch 13: Avg Loss = 0.050266

                                                                               

Epoch 14: Avg Loss = 0.047869

                                                                               

Epoch 15: Avg Loss = 0.048905

                                                                               

Epoch 16: Avg Loss = 0.047888

                                                                               

Epoch 17: Avg Loss = 0.049683

                                                                               

Epoch 18: Avg Loss = 0.048878

                                                                               

Epoch 19: Avg Loss = 0.048019

                                                                               

Epoch 20: Avg Loss = 0.048153

Epoch [21/300]:  90%|█████████ | 227/252 [00:00<00:00, 278.51it/s, loss=0.0508]

In [None]:
# Directory to save models
model_dir = Path(MODELS) / SUBFOLDER
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / "baseline_model.pt"

torch.save(model.state_dict(), model_path)

### Evaluation


In [None]:
# Evaluation
with torch.no_grad():
    preds_scaled = model(X_test_tensor.to(device)).cpu().numpy()
    preds = y_scaler.inverse_transform(preds_scaled)

r2 = r2_score(Y_test, preds, multioutput="raw_values")
rmse = np.sqrt(mean_squared_error(Y_test, preds, multioutput="raw_values"))
mae = mean_absolute_error(Y_test, preds, multioutput="raw_values")

for i, name in enumerate(targets):
    print(f"{name}: R²={r2[i]:.4f}, RMSE={rmse[i]:.4f}, MAE={mae[i]:.4f}")