# ShainNet 

In [1]:
import os
import pathlib
import tensorflow as tf
import pandas as pd
import wandb

from bcd.model.network.base import NetworkConfig
from bcd.model.network.shainnet import ShainNetConfig, ShainNetFactory
from bcd.model.repo import ExperimentRepo
from bcd.model.base import *
from bcd.model.experiment import FeatureExtractionExperiment
from bcd.model.config import *

pd.set_option('display.max_rows',999)

## Configuration

In [2]:
# Experiment Parameters
%env "WANDB_NOTEBOOK_NAME" "shainnet.ipynb"
mode = "Stage"
force = False  # Whether to retrain if the model and config already exists.
base_models = [DenseNet(), EfficientNet(), Inception(), InceptionResNet(), MobileNet(), ResNet(), Xception()]

env: "WANDB_NOTEBOOK_NAME"="shainnet.ipynb"


In [3]:
def create_config(network_config: NetworkConfig):
    project_config = ProjectConfig(mode=mode)

    train_config = TrainConfig(epochs=50, learning_rate=1e-4)    

    dataset_config = DatasetConfig(mode=mode)

    checkpoint_config = CheckPointConfig(monitor="val_accuracy", verbose=1, save_best_only=True, save_weights_only=False, mode="auto")

    early_stop_config = EarlyStopConfig(min_delta=1e-4, monitor="val_loss", patience=10, restore_best_weights=True, verbose=1)

    learning_rate_schedule_config = LearningRateScheduleConfig(min_delta=1e-4, monitor="val_loss", factor=0.5, patience=3, restore_best_weights=True, verbose=1, mode="auto")

    config = Config(project=project_config, 
                    dataset=dataset_config, 
                    train=train_config, 
                    network=network_config, 
                    checkpoint=checkpoint_config, 
                    early_stop=early_stop_config, 
                    learning_rate_schedule=learning_rate_schedule_config)
    return config

network_config = ShainNetConfig(activation="sigmoid",input_shape=(224,224,3), output_shape=1, dense1=1024, dropout1=0.5, dense2=1024, dropout2=0.3, dense3=512, dense4=128)
config = create_config(network_config=network_config)

## Load Data

In [4]:
train_dir = pathlib.Path(config.dataset.train_dir).with_suffix('') 
train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    labels="inferred",
    color_mode="rgb",
    image_size=(224,224),
    shuffle=True,
    validation_split=0.2,
    subset='training',
    interpolation="bilinear",
    seed=123,
    batch_size=config.dataset.batch_size)

# Validation DataSet (10%)
val_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    labels="inferred",
    color_mode="rgb",
    image_size=(224,224),
    shuffle=True,
    validation_split=0.2,
    subset='validation',
    interpolation="bilinear",
    seed=123,
    batch_size=config.dataset.batch_size)

Found 816 files belonging to 2 classes.
Using 653 files for training.
Found 816 files belonging to 2 classes.
Using 163 files for validation.


## Callbacks

In [5]:
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor=config.early_stop.monitor, 
                                                       min_delta=config.early_stop.min_delta,
                                                       patience=config.early_stop.patience, 
                                                       restore_best_weights=config.early_stop.restore_best_weights,
                                                       verbose=config.early_stop.verbose)

reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor=config.learning_rate_schedule.monitor,
                                                          factor=config.learning_rate_schedule.factor,
                                                          patience=config.learning_rate_schedule.patience,
                                                          verbose=config.learning_rate_schedule.verbose,
                                                          mode=config.learning_rate_schedule.mode,
                                                          min_delta=config.learning_rate_schedule.min_delta,
                                                          min_lr=config.learning_rate_schedule.min_lr)
callbacks = [early_stop_callback, reduce_lr_callback]

## Dependencies

In [6]:
repo = ExperimentRepo(mode = mode, project=config.project.name)
optimizer=tf.keras.optimizers.Adam
metrics = ['accuracy', tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]

## Build Factory

In [7]:
factory = ShainNetFactory(config=config.network)

## Run Experiments

In [None]:
for base_model in base_models:
    network = factory.create(base_model=base_model)
    experiment = FeatureExtractionExperiment(network=network, config=config, optimizer=optimizer, repo=repo, callbacks=callbacks, metrics=metrics, tags=[network.architecture, base_model.name], force=False)
    experiment.run(train_ds=train_ds, val_ds=val_ds)
