In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import xgboost as xgb
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset

### Configuration


In [None]:
ANALYSIS = False
GRAPH_FOLDER = "graphs"
MODELS = "models"
PREDICTIONS = "predictions"
SUBFOLDER = "feature_extraction"

### Preprocessing


In [None]:
# Helper function to identify frequency-related features
def identify_frequency_features(X_columns):
    """Identify frequency-related features in the dataset."""
    freq_features = [
        i
        for i, col in enumerate(X_columns)
        if "freq" in col.lower() or "band" in col.lower()
    ]
    other_features = [i for i in range(len(X_columns)) if i not in freq_features]

    print(
        f"Identified {len(freq_features)} frequency-related features and {len(other_features)} other features"
    )
    return freq_features, other_features

In [None]:
# Step 1: Load the dataset
file_path = "dataset.csv"
df = pd.read_csv(file_path)

target_cols = ["gm", "Cmu", "Cpi", "Zout_real", "Zin_real", "Zout_imag", "Zin_imag"]
nan_heavy_cols = ["MAG", "MSG"]  # Columns with too many NaN values TODO: investigate
exclude_columns = (
    target_cols
    + nan_heavy_cols
    + [
        "TIMEDATE",
        "OPERATOR",
        "REMARKS",
        "TECHNO",
        "LOT",
        "WAFER",
        "CHIP",
        "MODULE",
        "DEV_NAME",
        "S(1,1)_real",
        "S(1,1)_imag",
        "S(1,2)_real",
        "S(1,2)_imag",
        "S(2,1)_real",
        "S(2,1)_imag",
        "S(2,2)_real",
        "S(2,2)_imag",
        "S_deemb(1,1)_real",
        "S_deemb(1,1)_imag",
        "S_deemb(1,2)_real",
        "S_deemb(1,2)_imag",
        "S_deemb(2,1)_real",
        "S_deemb(2,1)_imag",
        "S_deemb(2,2)_real",
        "S_deemb(2,2)_imag",
    ]
)

all_cols = df.columns.tolist()
X_cols = [
    col
    for col in all_cols
    if col not in exclude_columns and pd.api.types.is_numeric_dtype(df[col])
]

# Split input and targets
X = df[X_cols]
Y = df[target_cols]

# Train-test split
X_train, X_test, Y_train, Y_test = train_test_split(
    X, Y, test_size=0.2, random_state=42
)

# Normalize X (standardization)
x_scaler = StandardScaler()
X_train_scaled = x_scaler.fit_transform(X_train)
X_test_scaled = x_scaler.transform(X_test)

freq_idx, other_idx = identify_frequency_features(X_train.columns)

# Convert to PyTorch tensors
X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train.values, dtype=torch.float32)
Y_test_tensor = torch.tensor(Y_test.values, dtype=torch.float32)

In [None]:
if ANALYSIS:
    print(df.columns)
    print(df.shape)

### Analysis


In [None]:
def get_feature_importance():
    # Target variables
    targets = ["gm", "Cpi", "Cmu", "Zin_real", "Zin_imag", "Zout_real", "Zout_imag"]

    # Drop non-numeric or identifier columns that shouldn't be features
    exclude_columns = [
        "TIMEDATE",
        "OPERATOR",
        "REMARKS",
        "TECHNO",
        "LOT",
        "WAFER",
        "CHIP",
        "MODULE",
        "DEV_NAME",
        "S(1,1)_real",
        "S(1,1)_imag",
        "S(1,2)_real",
        "S(1,2)_imag",
        "S(2,1)_real",
        "S(2,1)_imag",
        "S(2,2)_real",
        "S(2,2)_imag",
        "S_deemb(1,1)_real",
        "S_deemb(1,1)_imag",
        "S_deemb(1,2)_real",
        "S_deemb(1,2)_imag",
        "S_deemb(2,1)_real",
        "S_deemb(2,1)_imag",
        "S_deemb(2,2)_real",
        "S_deemb(2,2)_imag",
    ]
    all_numeric = [
        col
        for col in df.columns
        if col not in exclude_columns and pd.api.types.is_numeric_dtype(df[col])
    ]

    # Directory to save plots
    output_dir = Path(GRAPH_FOLDER) / SUBFOLDER
    output_dir.mkdir(parents=True, exist_ok=True)

    # Prepare data (remove rows with missing target values)
    df_clean = df.dropna(subset=targets)
    X_raw = df_clean[all_numeric]

    # Impute missing values
    imputer = SimpleImputer(strategy="mean")
    X_imputed = imputer.fit_transform(X_raw)

    # Normalize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_imputed)

    # Reconstruct feature names after transformation
    X_df = pd.DataFrame(X_scaled, columns=all_numeric)

    # Loop over each target, using all other numeric columns (including other targets) as features
    for target in targets:
        print(f"\n🔍 Feature Importance for Target: {target}")

        feature_cols = [col for col in X_df.columns if col != target]
        X_target = X_df[feature_cols]
        y_target = X_df[target]

        X_train, X_test, y_train, y_test = train_test_split(
            X_target, y_target, test_size=0.2, random_state=42
        )

        # Train XGBoost model
        model = xgb.XGBRegressor(
            tree_method="gpu_hist",
            predictor="gpu_predictor",
            n_estimators=100,
            max_depth=6,
            learning_rate=0.1,
            random_state=42,
        )
        model.fit(X_train, y_train)

        importances = model.feature_importances_
        sorted_idx = importances.argsort()[::-1]
        sorted_features = [feature_cols[i] for i in sorted_idx]

        # Plot
        plt.figure(figsize=(10, 6))
        plt.barh(range(len(sorted_idx)), importances[sorted_idx], align="center")
        plt.yticks(range(len(sorted_idx)), sorted_features)
        plt.xlabel("Feature Importance")
        plt.title(f"Feature Importance for Predicting {target}")
        plt.gca().invert_yaxis()
        plt.tight_layout()

        file_path = os.path.join(output_dir, f"{target}_importance.png")
        plt.savefig(file_path)
        plt.close()


if ANALYSIS:
    get_feature_importance()

### Model


In [None]:
# Config
chain_targets = [
    "gm",
    "Cmu",
    "Cpi",
    "Zout_real",
    "Zin_real",
    "Zout_imag",
    "Zin_imag",
]  # Derived from analysis
batch_size = 2048
num_workers = 12
pin_memory = False
persistent_workers = True
epochs = 100
learning_rate = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

if num_workers > 0:
    mp.set_start_method("spawn", force=True)
    persistent_workers = True

if device != "cpu":
    pin_memory = False

In [None]:
torch.cuda.reset_peak_memory_stats()

In [None]:
# Define Frequency-Aware Neural Network
class FrequencyAwareNetwork(nn.Module):
    def __init__(
        self,
        freq_features,
        other_features,
        # hidden_sizes=[128, 256, 512],
        hidden_sizes=[64, 128, 256],
        dropout_rate=0.2,
        activation="silu",
    ):
        super().__init__()
        self.freq_indices = None
        self.other_indices = None
        act_fn = {"silu": nn.SiLU(), "relu": nn.ReLU(), "gelu": nn.GELU()}[activation]

        def build_branch(in_size):
            layers = []
            for h in hidden_sizes[:2]:
                layers += [
                    nn.Linear(in_size, h),
                    act_fn,
                    nn.BatchNorm1d(h),
                    nn.Dropout(dropout_rate),
                ]
                in_size = h
            return nn.Sequential(*layers)

        self.freq_branch = build_branch(freq_features)
        self.other_branch = build_branch(other_features)

        comb_input = hidden_sizes[1] * 2
        combined = []
        for h in hidden_sizes[2:]:
            combined += [
                nn.Linear(comb_input, h),
                act_fn,
                nn.BatchNorm1d(h),
                nn.Dropout(dropout_rate),
            ]
            comb_input = h
        combined.append(nn.Linear(comb_input, 1))
        self.combined = nn.Sequential(*combined)

    def forward(self, x):
        if self.freq_indices is None or self.other_indices is None:
            raise ValueError("Feature indices not set.")
        f = x[:, self.freq_indices]
        o = x[:, self.other_indices]
        return self.combined(
            torch.cat([self.freq_branch(f), self.other_branch(o)], dim=1)
        )

    def set_feature_indices(self, freq_idx, other_idx):
        self.freq_indices = freq_idx
        self.other_indices = other_idx


# Define Chained Predictor
class ChainedPredictor(nn.Module):
    def __init__(
        self, model_fn, chain_order, freq_size, base_other_size, device="cuda"
    ):
        super().__init__()
        self.models = nn.ModuleDict()
        self.chain_order = chain_order
        self.device = device
        self.feature_indices = {}
        self.freq_size = freq_size
        self.base_other_size = base_other_size

        for i, name in enumerate(chain_order):
            extra_inputs = i  # each new prediction is appended to inputs
            self.models[name] = model_fn(freq_size, base_other_size + extra_inputs).to(
                device
            )

    def set_feature_indices(self, name, freq_idx, other_idx):
        self.models[name].set_feature_indices(freq_idx, other_idx)
        self.feature_indices[name] = (freq_idx, other_idx)

    def forward(self, x):
        outputs = {}
        x = x.to(self.device)
        for name in self.chain_order:
            out = self.models[name](x)
            outputs[name] = out
            x = torch.cat([x, out], dim=1)
        return outputs

In [None]:
# Tracking
stage_predictions_train = {}
stage_predictions_test = {}
trained_models = {}
target_scalers = {}

# Input tensors
X_train_chain = X_train_tensor.clone().cuda()
X_test_chain = X_test_tensor.clone().cuda()

# Directory to save models
model_dir = Path(MODELS) / SUBFOLDER
model_dir.mkdir(parents=True, exist_ok=True)

for i, target in enumerate(chain_targets):
    print(f"\n🔁 Training model for: {target} (Stage {i + 1}/{len(chain_targets)})")

    # Get and scale target values
    y_scaler = StandardScaler()
    y_train_np = Y_train_tensor[:, i].cpu().numpy().reshape(-1, 1)
    y_test_np = Y_test_tensor[:, i].cpu().numpy().reshape(-1, 1)
    y_train_scaled = y_scaler.fit_transform(y_train_np)
    y_test_scaled = y_scaler.transform(y_test_np)
    target_train = torch.FloatTensor(y_train_scaled).to(device)
    target_test = torch.FloatTensor(y_test_scaled).to(device)
    target_scalers[target] = y_scaler

    # Create model
    model = FrequencyAwareNetwork(
        freq_features=len(freq_idx),
        other_features=X_train_chain.shape[1] - len(freq_idx),
        hidden_sizes=[64, 128, 256],
        dropout_rate=0.2,
        activation="silu",
    ).to(device)

    model.set_feature_indices(
        freq_idx=freq_idx,
        other_idx=[j for j in range(X_train_chain.shape[1]) if j not in freq_idx],
    )

    train_dataset = TensorDataset(X_train_chain, target_train)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=persistent_workers,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    best_model_state = None
    best_loss = float("inf")
    patience = 15
    counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for xb, yb in train_loader:
            xb = xb.cuda(non_blocking=True)
            yb = yb.cuda(non_blocking=True)
            optimizer.zero_grad()
            pred = model(xb)
            loss = criterion(pred, yb)

            if torch.isnan(loss):
                print("⚠️ Skipping batch due to NaN loss.")
                continue

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            del xb, yb
            torch.cuda.empty_cache()

        avg_loss = running_loss / max(len(train_loader), 1)

        # Validation
        model.eval()
        with torch.no_grad():
            val_pred = model(X_test_chain)
            val_loss = criterion(val_pred, target_test).item()

        print(
            f"Epoch {epoch + 1:3d}: Train Loss = {avg_loss:.5f} | Val Loss = {val_loss:.5f}"
        )

        # Early stopping logic
        if not torch.isnan(torch.tensor(val_loss)) and val_loss < best_loss:
            best_loss = val_loss
            best_model_state = model.state_dict()
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("🛑 Early stopping.")
                break

    # Load best model state
    if best_model_state:
        model.load_state_dict(best_model_state)
    else:
        print("⚠️ No best state found, using last epoch weights.")

    # Save model and scaler
    trained_models[target] = model
    model_path = model_dir / f"{target}_model.pt"
    torch.save(model.state_dict(), model_path)
    print(f"✅ Saved best model for {target} to {model_path}")

    # Predict for next stage
    model.eval()
    with torch.no_grad():
        train_pred = model(X_train_chain)
        test_pred = model(X_test_chain)

    stage_predictions_train[target] = train_pred
    stage_predictions_test[target] = test_pred

    # Chain: append predictions to input for next stage
    X_train_chain = torch.cat([X_train_chain, train_pred], dim=1)
    X_test_chain = torch.cat([X_test_chain, test_pred], dim=1)

In [None]:
print(torch.cuda.max_memory_allocated() / 1e6, "MB used")

In [None]:
# Save predictions
predictions_dir = Path(PREDICTIONS) / SUBFOLDER
predictions_dir.mkdir(parents=True, exist_ok=True)

train_pred_df = pd.DataFrame(
    {
        target: preds.squeeze().detach().cpu().numpy()
        for target, preds in stage_predictions_train.items()
    }
)
train_pred_df.to_csv(predictions_dir / "train_predictions.csv", index=False)
test_pred_df = pd.DataFrame(
    {
        target: preds.squeeze().detach().cpu().numpy()
        for target, preds in stage_predictions_test.items()
    }
)
test_pred_df.to_csv(predictions_dir / "test_predictions.csv", index=False)
print("✅ Predictions saved.")

In [None]:
# Recreate and load a trained model
# loaded_model = FrequencyAwareNetwork(freq_features=len(freq_idx), other_features=base_other_size + i)
# loaded_model.set_feature_indices(freq_idx, other_idx)
# loaded_model.load_state_dict(torch.load("models/feature_extraction/Zout_real_model.pt"))
# loaded_model.to("cuda")