In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from lib.reproduction import major_oxides
import mlflow
import numpy as np
import datetime
import os
os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras


torch.manual_seed(42)
np.random.seed(42)


In [None]:
print(keras.__version__)

In [None]:
import torch.nn as nn
import torch.optim as optim

# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
from keras import layers, optimizers

def build_model(input_dim, output_dim):
    model = keras.models.Sequential()
    model.add(layers.Input(shape=(input_dim,)))
    model.add(layers.Dense(1024, activation='relu'))
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dense(128, activation='relu'))
    model.add(layers.Dense(output_dim))  # No activation, linear output

    optimizer = optimizers.Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='mse', metrics=['root_mean_squared_error', 'mae'])

    return model

In [None]:
INPUT_DIM = 6144
OUTPUT_DIM = 1

In [None]:
drop_cols = major_oxides + ["ID", "Sample Name"]
target_cols = major_oxides

In [None]:
from lib.cross_validation import (
    get_cross_validation_metrics,
)
from lib.metrics import rmse_metric, std_dev_metric
from functools import partial
from lib.deep_learning_utils import get_preprocess_fn, MLFlowCallback
from experiments.optuna_run import get_data
from lib.norms import Norm3Scaler


early_stopping_callback = partial(
    keras.callbacks.EarlyStopping, monitor="val_loss", patience=25, restore_best_weights=True
)

mlflow.set_experiment(f'ANN_{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}')


args = {
    "epochs": 1000,
    "batch_size": 32,
}


def check_group_overlap(train_set, validation_set, group_column="Sample Name"):
    train_groups = set(train_set[group_column].unique())
    validation_groups = set(validation_set[group_column].unique())
    overlap = train_groups.intersection(validation_groups)
    assert len(overlap) == 0, f"Data leakage detected: {overlap}"


def split_dataset_by_group(dataset, split_ratio: float, group_column="Sample Name"):
    unique_groups = dataset[group_column].unique()
    selected_groups = np.random.choice(unique_groups, size=int(split_ratio * len(unique_groups)), replace=False)
    training_set = dataset[dataset[group_column].isin(selected_groups)]
    remaining_set = dataset[~dataset[group_column].isin(selected_groups)]

    # Check for group overlap
    check_group_overlap(training_set, remaining_set, group_column)

    return training_set, remaining_set


SPLIT_RATIO = 0.2
for target in major_oxides:
    folds, train, test = get_data(target)

    with mlflow.start_run(run_name=f"ANN_{target}"):
        # == CROSS VALIDATION ==
        cv_metrics = []
        for cv_train_data, cv_test_data in folds:
            train_cv, val_cv = split_dataset_by_group(cv_train_data, SPLIT_RATIO)
            check_group_overlap(train_cv, val_cv, "Sample Name")

            model = build_model(INPUT_DIM, OUTPUT_DIM)

            preprocess_fn = get_preprocess_fn([target], drop_cols)

            scaler = Norm3Scaler()
            scaler.fit(train_cv)
            train_cv = scaler.transform(train_cv)
            val_cv = scaler.transform(val_cv)
            cv_test_data = scaler.transform(cv_test_data)

            X_train, y_train = preprocess_fn(train_cv)
            X_val, y_val = preprocess_fn(val_cv)
            X_test, y_test = preprocess_fn(cv_test_data)

            model.fit(
                X_train, y_train, **args, callbacks=[early_stopping_callback()], validation_data=(X_val, y_val)
            )  # don't want to use mlflow callback here
            y_pred = model.predict(X_test)

            rmse = rmse_metric(y_test, y_pred)
            std_dev = std_dev_metric(y_test, y_pred)
            cv_metrics.append([rmse, std_dev])

        mlflow.log_metrics(get_cross_validation_metrics(cv_metrics).as_dict())

        # == TRAIN ON ALL DATA ==
        model = build_model(INPUT_DIM, OUTPUT_DIM)
        preprocess_fn = get_preprocess_fn([target], drop_cols)
        train, validation = split_dataset_by_group(train, SPLIT_RATIO)

        check_group_overlap(train, validation, "Sample Name")

        scaler = Norm3Scaler()
        scaler.fit(train)
        train = scaler.transform(train)
        validation = scaler.transform(validation)
        test = scaler.transform(test)

        X_train, y_train = preprocess_fn(train)
        X_val, y_val = preprocess_fn(validation)
        X_test, y_test = preprocess_fn(test)

        model.fit(
            X_train,
            y_train,
            **args,
            callbacks=[MLFlowCallback(), early_stopping_callback()],
            validation_data=(X_val, y_val),
        )
        y_pred = model.predict(X_test)

        std_dev = std_dev_metric(y_test, y_pred)
        rmse = rmse_metric(y_test, y_pred)
        mlflow.log_metrics({"rmse": rmse, "std_dev": std_dev})

        mlflow.log_params({**args, "target": target})