# ShainNet Feature Extraction

In [1]:
import os
import pathlib
from dataclasses import dataclass
import tensorflow as tf
import pandas as pd
import wandb

from bcd.model.network.shainnet import ShainNetConfig, ShainNetFactory
from bcd.model.repo import ExperimentRepo
from bcd.model.base import *
from bcd.model.experiment import FeatureExtractionExperiment
from bcd.model.config import *

pd.set_option('display.max_rows',999)

## Preliminaries

### Configuration

In [2]:
# Experiment Parameters
%env "WANDB_NOTEBOOK_NAME" "experiments.ipynb"
mode = "Development"
force = False  # Whether to retrain if the model and config already exists.

env: "WANDB_NOTEBOOK_NAME"="experiments.ipynb"


In [3]:

project_config = ProjectConfig(mode=mode)

train_config = TrainConfig(epochs=5, learning_rate=1e-4)

network_config = ShainNetConfig(dense1=1024, dropout1=0.5, dense2=1024, dropout2=0.3, dense3=512, dense4=128)

dataset_config = DatasetConfig(mode=mode)

checkpoint_config = CheckPointConfig(monitor="val_accuracy", verbose=1, save_best_only=True, save_weights_only=False, mode="auto")

early_stop_config = EarlyStopConfig(min_delta=1e-4, monitor="val_loss", patience=10, restore_best_weights=True, verbose=1)

learning_rate_schedule_config = LearningRateScheduleConfig(min_delta=1e-4, monitor="val_loss", factor=0.5, patience=3, restore_best_weights=True, verbose=1, mode="auto")

config = Config(project=project_config, 
                dataset=dataset_config, 
                train=train_config, 
                network=network_config, 
                checkpoint=checkpoint_config, 
                early_stop=early_stop_config, 
                learning_rate_schedule=learning_rate_schedule_config)


### Load Data

In [4]:
train_dir = pathlib.Path(config.dataset.train_dir).with_suffix('') 
train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    labels="inferred",
    color_mode="rgb",
    image_size=(224,224),
    shuffle=True,
    validation_split=0.2,
    subset='training',
    interpolation="bilinear",
    seed=123,
    batch_size=config.dataset.batch_size)

# Validation DataSet (10%)
val_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    labels="inferred",
    color_mode="rgb",
    image_size=(224,224),
    shuffle=True,
    validation_split=0.2,
    subset='validation',
    interpolation="bilinear",
    seed=123,
    batch_size=config.dataset.batch_size)

Found 276 files belonging to 2 classes.
Using 221 files for training.
Found 276 files belonging to 2 classes.
Using 55 files for validation.


### Callbacks

In [5]:
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor=config.early_stop.monitor, 
                                                       min_delta=config.early_stop.min_delta,
                                                       patience=config.early_stop.patience, 
                                                       restore_best_weights=config.early_stop.restore_best_weights,
                                                       verbose=config.early_stop.verbose)

reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor=config.learning_rate_schedule.monitor,
                                                          factor=config.learning_rate_schedule.factor,
                                                          patience=config.learning_rate_schedule.patience,
                                                          verbose=config.learning_rate_schedule.verbose,
                                                          mode=config.learning_rate_schedule.mode,
                                                          min_delta=config.learning_rate_schedule.min_delta,
                                                          min_lr=config.learning_rate_schedule.min_lr)
callbacks = [early_stop_callback, reduce_lr_callback]

### Dependencies

In [6]:
repo = ExperimentRepo(mode = mode, project=config.project.name)
optimizer=tf.keras.optimizers.Adam
metrics = ['accuracy', tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]

## Build Model

In [7]:
factory = ShainNetFactory(config=config.network)
network = factory.create(base_model=DenseNet())


## Run Experiment

In [8]:
experiment = FeatureExtractionExperiment(network=network, config=config, optimizer=optimizer, repo=repo, callbacks=callbacks, metrics=metrics,notes="test", tags=["ShainNet", "DenseNet"], force=False)
experiment.run(train_ds=train_ds, val_ds=val_ds)

                                         ShainNet_DenseNet                                          
# ------------------------------------------------------------------------------------------------ #
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 ShainNet_DenseNet_input_lay  [(None, 224, 224, 3)]    0         
 er (InputLayer)                                                 
                                                                 
 tf.math.truediv (TFOpLambda  (None, 224, 224, 3)      0         
 )                                                               
                                                                 
 tf.nn.bias_add (TFOpLambda)  (None, 224, 224, 3)      0         
                                                                 
 tf.math.truediv_1 (TFOpLamb  (None, 224, 224, 3)      0         
 da)                                                     

# Test Model Exists.

In [9]:
experiment.run(train_ds=train_ds, val_ds=val_ds)

                                         ShainNet_DenseNet                                          
# ------------------------------------------------------------------------------------------------ #
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 ShainNet_DenseNet_input_lay  [(None, 224, 224, 3)]    0         
 er (InputLayer)                                                 
                                                                 
 tf.math.truediv (TFOpLambda  (None, 224, 224, 3)      0         
 )                                                               
                                                                 
 tf.nn.bias_add (TFOpLambda)  (None, 224, 224, 3)      0         
                                                                 
 tf.math.truediv_1 (TFOpLamb  (None, 224, 224, 3)      0         
 da)                                                     