This is a binary classification transfer learning template with wandb support. The code and dataset is uploaded to wandb as artifacts while configuration parameters and other important metrics are also logged to your wandb account which you can see in the web.

## Importing Packages

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import os
import numpy as np
import pandas
import matplotlib.pyplot as plt
import wandb
from wandb.keras import WandbCallback
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

In [None]:
from tensorflow.keras.applications.efficientnet_v2 import EfficientNetV2S
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint

## Parameters Selection

In [None]:
# Parameters
input_shape = (140,80,3) #input shape for rgb images, incase of grayscale provide (h,w,1)
num_classes = 2 # number of classes i.e. folders in your dataset
color_mode = "rgb" if input_shape[-1]==3 else "grayscale"

In [None]:
#Hyper parameters
batch_size=64
fc_layer_neuron = 2048
dropout_rate = 0.3
epochs = 10
optimizer = "adam" #can be adam or sgd
learning_rate = 0.001
starting_layer = "flatten" #can be flatten or globalavgpool

## Splitting Dataset into Train/Val/Test directories

In [None]:
# #splitting dataset
# import splitfolders
# splitfolders.ratio('raw_data', output="dataset", seed=1337, ratio=(.8, 0.1,0.1)) 

## Initialize wandb

In [None]:
#provide your wandb data to launch the run
# save_code flag save the notebook to wand
wandb.init(project="test", entity="user-name", notes="test run", save_code=True) 

## Loading Dataset

In [None]:
main_dir = "dataset"
train_dir = os.path.join(main_dir, "train")
val_dir = os.path.join(main_dir, "val")
test_dir = os.path.join(main_dir, "test")

### log data artifacts

In [None]:
# raw data artifact
raw_data = wandb.Artifact("raw_data", type="Dataset")
raw_data.add_dir("raw_data")
wandb.log_artifact(raw_data)

In [None]:
# split data artifact
split_data = wandb.Artifact("splitdata_80_10_10", type="Dataset")
split_data.add_dir(main_dir)
wandb.log_artifact(split_data)

### Dataloader

In [None]:
train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=5, 
                               horizontal_flip=True, brightness_range=[0.8,1], 
                               zoom_range=0.1, shear_range=0.1)
test_datagen = ImageDataGenerator(rescale=1./255)

In [None]:
train_generator = train_datagen.flow_from_directory(
                  directory=train_dir,
                  target_size=input_shape[:2], # resize to this size
                  color_mode=color_mode, # for coloured images
                  batch_size=batch_size, # number of images to extract from folder for every batch
                  class_mode="binary", # classes to predict
                  seed=42) # to make the result reproducible

val_generator = test_datagen.flow_from_directory(directory=val_dir,
                                                target_size=input_shape[:2],
                                                batch_size=batch_size,
                                                color_mode=color_mode,
                                                class_mode="binary")

test_generator = test_datagen.flow_from_directory(directory=test_dir,
                                                target_size=input_shape[:2],
                                                batch_size=batch_size,
                                                color_mode=color_mode,
                                                class_mode="binary",
                                                 shuffle=False)

In [None]:
classes = list(train_generator.class_indices.keys())

### Visualizing some of the loaded samples

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(15,15))

for i in range(4):

    # convert to unsigned integers for plotting
    image, label = next(train_generator)
    image = image * 255
    image = image.astype('uint8')
    # changing size from (1, 200, 200, 3) to (200, 200, 3) for plotting the image
    image = np.squeeze(image)

    # plot raw pixel data
    ax[i].imshow(image[0])
    ax[i].set_title(classes[int(label[0])])
    ax[i].axis('off')

## Model Development

In [None]:

starting_layer_after_conv = Flatten() if starting_layer=="flatten" else GlobalAveragePooling2D

base_model = EfficientNetV2S(weights="imagenet", input_shape=input_shape, include_top=False)
base_model.trainable = False


x = Flatten()(base_model.layers[-1].output)
x = Dropout(dropout_rate)(x)
x = Dense(fc_layer_neuron, activation="relu")(x)
x = Dropout(dropout_rate)(x)
x = Dense(1, activation="sigmoid")(x)
model = Model(base_model.input,x)

In [None]:
model.summary()

In [None]:
optim = Adam(learning_rate=learning_rate) if optimizer=="adam" else SGD(learning_rate=learning_rate)
loss = "binary_crossentropy"
metric = "accuracy"

model.compile(loss=loss, optimizer=optim, metrics=metric)

In [None]:
# checkpoint callback
checkpoint_filepath = "checkpoint/model.ckpt"
model_checkpoint_callback = ModelCheckpoint(filepath=checkpoint_filepath, save_weights_only=True,
    monitor='val_loss', mode='min',save_best_only=True)

In [None]:
steps_per_epochs = len(train_generator)//batch_size
model.fit(train_generator, epochs=epochs, 
          steps_per_epoch=steps_per_epochs, 
          validation_data=val_generator, callbacks=[WandbCallback(),model_checkpoint_callback])

## Model Evaluation

In [None]:
model.load_weights(checkpoint_filepath)

In [None]:
test_loss, test_acc = model.evaluate(test_generator)
print(f'Test loss: {test_loss:.4f}\nTest accuracy: {test_acc*100:.4f}%')

In [None]:
#logging wandb
wandb.config.update({"input_shape": input_shape,
                     "num_classes": num_classes,
                     "classes": classes,
                     "color_mode":color_mode,
                     "batch_size": batch_size,
                     "epochs": epochs,
                     "fc_layer_neuron":fc_layer_neuron,
                     "dropout_rate":dropout_rate,
                     "optimizer": optimizer,
                     "learning_rate": learning_rate,
                     "starting_layer": starting_layer})

wandb.log({"test_loss": test_loss, "test_acc": test_acc})

### confusion matrix and classification report

In [None]:
Y_pred = model.predict(test_generator)
Y_pred = np.where(Y_pred>0.6,1,0).reshape(Y_pred.shape[0],)

In [None]:
# for wandb cm log
cm = wandb.plot.confusion_matrix(
    y_true=test_generator.classes,
    preds=Y_pred,
    class_names=classes)

wandb.log({"conf_mat": cm})

In [None]:
# for local cm visualization
cfn_mtx = confusion_matrix(test_generator.classes, Y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cfn_mtx, 
                              display_labels=classes)
disp.plot()

plt.show()

In [None]:
report = classification_report(test_generator.classes, Y_pred, target_names=classes, output_dict=True)
print(report)

In [None]:
for key_ in list(report.keys())[:-3]:
    wandb.log({key_: report[key_]})

In [None]:
wandb.finish()