In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
from lib.reproduction import major_oxides
import mlflow
import numpy as np
import datetime
import os
os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras


torch.manual_seed(42)
np.random.seed(42)


In [10]:
print(keras.__version__)

3.2.1


In [11]:
import torch.nn as nn
import torch.optim as optim

# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [12]:
from keras import layers, optimizers

def build_model(input_dim, output_dim):
    model = keras.models.Sequential()
    model.add(layers.Input(shape=(input_dim,)))
    model.add(layers.Dense(1024, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(128, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(output_dim))  # No activation, linear output

    optimizer = optimizers.Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='mse', metrics=['root_mean_squared_error', 'mae'])

    return model

In [13]:
INPUT_DIM = 6144
OUTPUT_DIM = 1

In [14]:
drop_cols = major_oxides + ["ID", "Sample Name"]
target_cols = major_oxides

In [15]:
from lib.cross_validation import (
    get_cross_validation_metrics,
)
from lib.metrics import rmse_metric, std_dev_metric
from functools import partial
from lib.deep_learning_utils import get_preprocess_fn, MLFlowCallback
from experiments.optuna_run import get_data


early_stopping_callback = partial(
    keras.callbacks.EarlyStopping, monitor="val_loss", patience=25, restore_best_weights=True
)

mlflow.set_experiment(f'ANN_{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}')


args = {
    "epochs": 1000,
    "batch_size": 32,
}

for target in major_oxides:
    folds, train, test = get_data(target,)

    with mlflow.start_run(run_name=f"ANN_{target}"):
        # == CROSS VALIDATION ==
        cv_metrics = []
        for cv_train_data, cv_test_data in folds:
            model = build_model(INPUT_DIM, OUTPUT_DIM)

            preprocess_fn = get_preprocess_fn([target], drop_cols)
            X_train, y_train = preprocess_fn(cv_train_data)
            X_test, y_test = preprocess_fn(cv_test_data)

            model.fit(
                X_train, y_train, **args, callbacks=[early_stopping_callback()]
            )  # don't want to use mlflow callback here
            y_pred = model.predict(X_test)

            rmse = rmse_metric(y_test, y_pred)
            std_dev = std_dev_metric(y_test, y_pred)
            cv_metrics.append([rmse, std_dev])

        mlflow.log_metrics(get_cross_validation_metrics(cv_metrics).as_dict())

        # == TRAIN ON ALL DATA ==
        model = build_model(INPUT_DIM, OUTPUT_DIM)
        preprocess_fn = get_preprocess_fn([target], drop_cols)

        X_train, y_train = preprocess_fn(train)
        X_test, y_test = preprocess_fn(test)

        model.fit(
            X_train,
            y_train,
            **args,
            callbacks=[MLFlowCallback(), early_stopping_callback()],
        )
        y_pred = model.predict(X_test)

        std_dev = std_dev_metric(y_test, y_pred)
        rmse = rmse_metric(y_test, y_pred)
        mlflow.log_metrics({"rmse": rmse, "std_dev": std_dev})

        mlflow.log_params(args)

2024/05/24 14:11:43 INFO mlflow.tracking.fluent: Experiment with name 'ANN_20240524-141143' does not exist. Creating a new experiment.


Epoch 1/1000
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 25ms/step - loss: 3173.8875 - mae: 55.3699 - root_mean_squared_error: 56.3326 - val_loss: 5123.6445 - val_mae: 70.1403 - val_root_mean_squared_error: 71.5796
Epoch 2/1000
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 21ms/step - loss: 2975.9541 - mae: 54.1060 - root_mean_squared_error: 54.5504 - val_loss: 3636.9929 - val_mae: 59.5249 - val_root_mean_squared_error: 60.3075
Epoch 3/1000
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 17ms/step - loss: 2727.5586 - mae: 51.8660 - root_mean_squared_error: 52.2253 - val_loss: 2358.3076 - val_mae: 48.2423 - val_root_mean_squared_error: 48.5624
Epoch 4/1000
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 19ms/step - loss: 2358.6008 - mae: 48.2187 - root_mean_squared_error: 48.5629 - val_loss: 2011.9658 - val_mae: 44.5906 - val_root_mean_squared_error: 44.8549
Epoch 5/1000
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37

KeyboardInterrupt: 