## SWIN Transformer Plant Disease Detector Model

### Import Libraries

In [71]:
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import json
import keras_cv
import os
from sklearn.metrics import classification_report, confusion_matrix

### Data Preprocessing

In [124]:
data_directory = "PlantVillage"
seed_value = 27
class_names = ['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn_(maize)___Common_rust_',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Corn_(maize)___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_spider_mite',
 'Tomato___Target_Spot',
 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
 'Tomato___Tomato_mosaic_virus',
 'Tomato___healthy']

#### Hyperparameters

In [125]:
number_of_classes = len(class_names)
image_dimension = 32 # 224
window_size = 2
shift_size = 1

input_shape = (image_dimension, image_dimension, 3)
image_size = (image_dimension, image_dimension)
patch_size = (window_size, window_size)

dropout_rate = 0.03
number_of_heads = 8
embedding_dimension = 64
number_of_MLP = 256

qkv_bias = True

number_of_patches_x = input_shape[0] // patch_size[0]
number_of_patches_y = input_shape[1] // patch_size[1]

batch_size = 32
learning_rate = 0.0001
number_of_epochs = 10

validation_split = .03
weight_decay = 0.0001
label_smoothing = 0.1

### Training/Validation/Test Data Split

#### Training Images

In [130]:
training_set = tf.keras.utils.image_dataset_from_directory(
    data_directory,
    labels="inferred",
    label_mode="categorical",
    class_names=None,
    color_mode="rgb",
    batch_size=batch_size,
    image_size=image_size,
    shuffle=True,
    validation_split=0.3,
    subset="training",
    interpolation="bilinear",
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    verbose=True,
    seed=seed_value,
)

Found 41276 files belonging to 16 classes.
Using 28894 files for training.


#### Validation Images

In [131]:
validation_set = tf.keras.utils.image_dataset_from_directory(
    data_directory,
    labels="inferred",
    label_mode="categorical",
    class_names=None,
    color_mode="rgb",
    batch_size=32,
    image_size=image_size,
    shuffle=True,
    validation_split=0.3,
    subset="validation",
    interpolation="bilinear",
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    verbose=True,
    seed=seed_value,
)

Found 41276 files belonging to 16 classes.
Using 12382 files for validation.


#### Test Images

In [132]:
number_of_validation_batches = tf.data.experimental.cardinality(validation_set)
test_set = validation_set.skip((number_of_validation_batches * 2) // 3)
validation_set = validation_set.take((number_of_validation_batches * 2) // 3)

#### Data Augmentation

In [133]:
augment_data = tf.keras.Sequential([
    tf.keras.layers.Rescaling(1./255),
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(0.2),
    tf.keras.layers.RandomZoom(0.1),
])

In [134]:
AUTOTUNE = tf.data.AUTOTUNE

training_set = training_set.map(lambda x, y: (augment_data(x, training=True), y))
training_set = training_set.prefetch(buffer_size=AUTOTUNE)
validation_set = validation_set.map(lambda x, y: (x / 255.0, y)).prefetch(buffer_size=AUTOTUNE)
test_set = test_set.map(lambda x, y: (x / 255.0, y)).prefetch(buffer_size=AUTOTUNE)

### Building the Model

In [135]:
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout

### Compiling the Model

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
model.summary()

### Training the Model

In [None]:
training_history = model.fit(x=training_set, validation_data=validation_set, epochs=10)

### Evaluate the Model

### Save the Model

In [None]:
model.save("swin_model_trained.keras")

### Recording the training history 

In [None]:
with open("training_hist.json", "w") as f:
    json.dump(training_history.history, f)

### Metrics Evaluation and Visualization

#### Accuracy

#### Classification Report

#### Confusion Matrix