# Feature Extraction Tests

In [None]:
import os
import pathlib
from dataclasses import dataclass
import tensorflow as tf
import pandas as pd
import wandb

from bcd.model.network.shainnet import ShainNetConfig, ShainNetFactory
from bcd.model.repo import ModelRepo
from bcd.model.base import DenseNet
from bcd.model.experiment import Experiment

pd.set_option('display.max_rows',999)

## Preliminaries

## Configuration

In [None]:
# Project Parameters
%env "WANDB_NOTEBOOK_NAME" "test_feature_extraction.ipynb"
datasets = {"Development":{"name": "CBIS-DDSM_10","directory": "data/image/1_final/training_10/training/"},
            "Stage": {"name": "CBIS-DDSM_30","directory": "data/image/1_final/training_30/training/"},
            "Production": {"name": "CBIS-DDSM","directory": "data/image/1_final/training/training/"},
            }
mode = "Development"
project = f"Breast-Cancer-Detection-{mode}" 

# Experiment Parameters
force = False  # Whether to retrain if the model and config already exists.

# Training Config

loss = "binary_crossentropy"
epochs = 5 # Maximum number of epochs to train, subject to any early stopping callback.
learning_rate = 1e-4
training_config = {"loss": loss, "epochs": epochs, "learning_rate": learning_rate}

# Network Config
activation = "sigmoid" # Network configuration common to all networks.
    
# Dataset params
dataset = datasets[mode]["name"]
batch_size = 64 if mode == "Production" else 32
input_shape = (224,224,3)
output_shape = 1
train_dir = pathlib.Path(datasets[mode]["directory"]).with_suffix('') 
dataset_config = {"dataset": dataset, "batch_size": batch_size, "input_shape": input_shape, "output_shape": output_shape}

# Checkpoint Config
ckpt_monitor = "val_accuracy"
ckpt_verbose = 1
ckpt_save_best_only = True
ckpt_save_weights_only = False
ckpt_mode = "auto"
checkpoint_config = {"monitor": ckpt_monitor, "verbose": ckpt_verbose, "save_best_only": ckpt_save_best_only, "mode": ckpt_mode, "save_weights_only": ckpt_save_weights_only}

# Early stop parameters 
es_min_delta = 0.0001
es_monitor = "val_loss"  # Monitor validation loss for early stopping
es_patience = 10  # The number of epochs for which lack of improvement is tolerated 
es_restore_best_weights = True  # Returns the best weights rather than the weights at the last epoch.
es_verbose = 1
early_stop_config = {"min_delta": es_min_delta, "monitor": es_monitor, "patience": es_patience, 
                     "restore_best_weights": es_restore_best_weights, "verbose": es_verbose}

# Reduce LR on Plateau Parameters
rlr_monitor = "val_loss"
rlr_factor = 0.5
rlr_patience = 3
rlr_verbose = 1
rlr_mode = "auto"
rlr_min_delta = 1e-4
rlr_min_lr=1e-10
learning_rate_schedule_config = {"monitor": rlr_monitor, "factor": rlr_factor, "patience": rlr_patience, "verbose": 
                        rlr_verbose, "mode": rlr_mode, "min_delta": rlr_min_delta, "min_lr": rlr_min_lr}



## Experiment Config

In [None]:
config = {
    "project": project,       
    "mode": mode,
    "dataset": dataset_config,
    "training": training_config,
    "checkpoint": checkpoint_config,
    "early_stop": early_stop_config,
    "learning_rate_schedule": learning_rate_schedule_config,
    
}

## Load Data

In [None]:
# Training DataSet 
train_dir = pathlib.Path(train_dir).with_suffix('') 
train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    labels="inferred",
    color_mode="rgb",
    image_size=(224,224),
    shuffle=True,
    validation_split=0.2,
    subset='training',
    interpolation="bilinear",
    seed=123,
    batch_size=batch_size)

# Validation DataSet (10%)
val_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    labels="inferred",
    color_mode="rgb",
    image_size=(224,224),
    shuffle=True,
    validation_split=0.2,
    subset='validation',
    interpolation="bilinear",
    seed=123,
    batch_size=batch_size)

## Callbacks

In [None]:
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor=es_monitor, 
                                                       min_delta=es_min_delta,
                                                       patience=es_patience, 
                                                       restore_best_weights=es_restore_best_weights,
                                                       verbose=es_verbose)

reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor=rlr_monitor,
                                                          factor=rlr_factor,
                                                          patience=rlr_patience,
                                                          verbose=rlr_verbose,
                                                          mode=rlr_mode,
                                                          min_delta=rlr_min_delta,
                                                          min_lr=rlr_min_lr)
callbacks = [early_stop_callback, reduce_lr_callback]

## Dependencies

In [None]:
repo = ModelRepo(mode = mode, project=project)
optimizer=tf.keras.optimizers.Adam
metrics = ['accuracy', tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]

## Build Model

In [None]:
network_config = ShainNetConfig(activation=activation)
factory = ShainNetFactory(config=network_config, input_shape=input_shape, output_shape=output_shape, activation=activation)
network = factory.create(base_model=DenseNet())


## Build Experiment

In [None]:
experiment = Experiment(network=network, config=config, optimizer=optimizer, repo=repo, callbacks=callbacks, metrics=metrics, force=force)
experiment.run(train_ds=train_ds, val_ds=val_ds)

## Check Repository

In [None]:
assert repo.exists(name="ShainNet_DenseNet", config=config)