In [None]:
import pandas as pd
from models import efficient_net
import tensorflow as tf
import tensorflow.compat.v1 as tfc
from sklearn.utils.class_weight import compute_class_weight

from modeling import predict_from_csv, stacking_from_csv
from src import InputPipeline

%load_ext autoreload
%autoreload 2

## Check GPU for tf

In [None]:
# Some GPU setup
# for documentation about using gpus refer to: https://www.tensorflow.org/install/pip#windows-wsl2

tf.keras.backend.clear_session()

device_name = tf.test.gpu_device_name()
if not device_name:
  raise SystemError('GPU device not found')
try:  # prevent a previous session from being alive
  sess.close() 
except:
  pass

tfc.enable_eager_execution()
gpu_options= tfc.GPUOptions(per_process_gpu_memory_fraction = 0.80)
sess = tfc.InteractiveSession(config=tfc.ConfigProto(gpu_options=gpu_options))

### Optional Stuff
- Here we compute classweights since the data is very imbalanced

In [None]:
train_df = pd.read_csv("../data/train_images_stratified.csv")

In [None]:
class_weights = compute_class_weight(class_weight = 'balanced', classes = train_df["label"].unique(), y=train_df["label"])
class_weights = dict(enumerate(class_weights))

## Configure Parameters

In [None]:
INPUT_SHAPE = (220,220,3)

CONF = {
  "learning_rate": 0.0001,
  "batch_size": 48,
  "epochs": 5,
  "loss_function": "sparse_categorical_crossentropy",
  "metric": "sparse_categorical_accuracy",
}

## Make Input Pipelines

In [None]:
# Input pipeline for subspecies
sub_species_input_pipeline = InputPipeline(splits=(0.85,0.0,0.15), channels=3, batch_size=CONF["batch_size"], size=INPUT_SHAPE[:2])
# This has a stratified split
sub_species_input_pipeline.make_stratified_train_dataset(
  train_ds_path= "../data/train_ds_images_stratified.csv",
  val_ds_path="../data/val_ds_images_stratified.csv",
)

# Input pipeline for species
species_input_pipeline = InputPipeline(splits=(0.85,0.0,0.15), channels=3, batch_size=CONF["batch_size"], size=INPUT_SHAPE[:2])
species_input_pipeline.make_train_datasets(directory="../data/train_images/species_classify")  # This doesnt have a stratified split

## Training a Model

In [None]:
from modeling import train_classifier

In [None]:
# making the species classifier
train_classifier(
  model_name="../classifiers/eff_net_hyptunning_25e",
  input_shape=INPUT_SHAPE,  
  classes_to_classify=200,
  configuration=CONF,
  model=efficient_net, 
  train_dataset=sub_species_input_pipeline.train_dataset,
  validation_dataset=sub_species_input_pipeline.validation_dataset,
)

In [None]:
# making the family classifier
train_classifier(
  model_name="../classifiers/species_efficient_net_classifier_50e",
  input_shape=INPUT_SHAPE,
  classes_to_classify=70,
  configuration=CONF,
  model=efficient_net, 
  train_dataset=species_input_pipeline.train_dataset,
  validation_dataset=species_input_pipeline.validation_dataset,
)

## NN results

In [None]:
import pickle 
import matplotlib.pyplot as plt
import pandas as pd

with open("../classifiers/trainHistoryDict/subspecies_effnet_250_classifier_100e.pkl", 'rb') as file:
    history = pickle.load(file)

history_df = pd.DataFrame(history)

In [None]:
fig, [ax1, ax2] = plt.subplots(1,2, figsize=(16,5))

# Plot for ax1
ax1.plot(history_df["loss"], label='Training Loss')
ax1.plot(history_df["val_loss"], label='Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot for ax2
ax2.plot(history_df["sparse_categorical_accuracy"], label='Training Accuracy')
ax2.plot(history_df["val_sparse_categorical_accuracy"], label='Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

plt.show()

## Predict Stuff

In [None]:
# No label adjustments needed afterwards!!!!!!
predict_from_csv(
  classifier="../classifiers/subspecies_effnet_250_classifier_100e",
  dataset="../data/test_images_path.csv",
  path="../data/test_images",
  size=INPUT_SHAPE,
)

## Stacking Prediction

In [None]:
stacking_from_csv(
    primary_classifier="../classifiers/species_efficient_net_classifier_50e",
    secondary_classifier="../classifiers/eff_net_hyptunning_25e",
    dataset="../data/test_images_path.csv",
    path="../data/test_images",
    size=INPUT_SHAPE,
    weights=(1.0, 0.3),
    mapping="../mapping.pickle",
)