## SWIN Transformer Plant Disease Detector Model

### Import Libraries

In [71]:
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import json
import keras_cv
import os
from sklearn.metrics import classification_report, confusion_matrix

### Data Preprocessing

In [124]:
data_directory = "PlantVillage"
seed_value = 27
class_names = ['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn_(maize)___Common_rust_',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Corn_(maize)___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_spider_mite',
 'Tomato___Target_Spot',
 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
 'Tomato___Tomato_mosaic_virus',
 'Tomato___healthy']

#### Hyperparameters

In [125]:
number_of_classes = len(class_names)
image_dimension = 32 # 224
window_size = 2
shift_size = 1

input_shape = (image_dimension, image_dimension, 3)
image_size = (image_dimension, image_dimension)
patch_size = (window_size, window_size)

dropout_rate = 0.03
number_of_heads = 8
embedding_dimension = 64
number_of_MLP = 256

qkv_bias = True

number_of_patches_x = input_shape[0] // patch_size[0]
number_of_patches_y = input_shape[1] // patch_size[1]

batch_size = 32
learning_rate = 0.0001
number_of_epochs = 10

validation_split = .03
weight_decay = 0.0001
label_smoothing = 0.1

#### Window Functions

In [126]:
 def window_partition(x, window_size):
    _, height, width, channels = x.shape
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.keras.ops.reshape(
        x,
        (
            -1,
            number_of_patches_y,
            window_size,
            number_of_patches_x,
            window_size,
            channels,
        ),
    )
    x = tf.keras.ops.transpose(x, (0, 1, 3, 2, 4, 5))
    windows = tf.keras.ops.reshape(x, (-1, window_size, window_size, channels))
    return windows


def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.keras.ops.reshape(
        windows,
        (
            -1,
            number_of_patches_x,
            number_of_patches_y,
            window_size,
            window_size,
            channels,
        ),
    )
    x = tf.keras.ops.transpose(x, (0, 1, 3, 2, 4, 5))
    x = tf.keras.ops.reshape(x, (-1, height, width, channels))
    return x

#### Self Attention

In [127]:
class WindowAttention(tf.keras.layers.Layer):
    def __init__(
        self,
        dimension,
        window_size,
        number_of_heads,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dimension = dimension
        self.window_size = window_size
        self.number_of_heads = number_of_heads
        self.scale = (dimension // number_of_heads) ** -0.5
        self.qkv = tf.keras.layers.Dense(dimension * 3, use_bias=qkv_bias)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.proj = tf.keras.layers.Dense(dimension)

        number_of_window_elements = (2 * self.window_size[0] - 1) * (
            2 * self.window_size[1] - 1
        )
        self.relative_position_bias_table = self.add_weight(
            shape=(number_of_window_elements, self.number_of_heads),
            initializer=tf.keras.initializers.Zeros(),
            trainable=True,
        )
        coords_h = np.arange(self.window_size[0])
        coords_w = np.arange(self.window_size[1])
        coords_matrix = np.meshgrid(coords_h, coords_w, indexing="ij")
        coords = np.stack(coords_matrix)
        coords_flatten = coords.reshape(2, -1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.transpose([1, 2, 0])
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)

        self.relative_position_index = tf.keras.Variable(
            initializer=relative_position_index,
            shape=relative_position_index.shape,
            dtype="int",
            trainable=False,
        )

    def call(self, x, mask=None):
        _, size, channels = x.shape
        head_dimensions = channels // self.number_of_heads
        x_qkv = self.qkv(x)
        x_qkv = tf.keras.ops.reshape(x_qkv, (-1, size, 3, self.number_of_heads, head_dimensions))
        x_qkv = tf.keras.ops.transpose(x_qkv, (2, 0, 3, 1, 4))
        q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
        q = q * self.scale
        k = tf.keras.ops.transpose(k, (0, 1, 3, 2))
        attn = q @ k

        num_window_elements = self.window_size[0] * self.window_size[1]
        relative_position_index_flat = tf.keras.ops.reshape(self.relative_position_index, (-1,))
        relative_position_bias = tf.keras.ops.take(
            self.relative_position_bias_table,
            relative_position_index_flat,
            axis=0,
        )
        relative_position_bias = tf.keras.ops.reshape(
            relative_position_bias,
            (num_window_elements, num_window_elements, -1),
        )
        relative_position_bias = tf.keras.ops.transpose(relative_position_bias, (2, 0, 1))
        attn = attn + tf.keras.ops.expand_dims(relative_position_bias, axis=0)

        if mask is not None:
            nW = mask.shape[0]
            mask_float = tf.keras.ops.cast(
                tf.keras.ops.expand_dims(tf.keras.ops.expand_dims(mask, axis=1), axis=0),
                "float32",
            )
            attn = tf.keras.ops.reshape(attn, (-1, nW, self.number_of_heads, size, size)) + mask_float
            attn = tf.keras.ops.reshape(attn, (-1, self.num_of_heads, size, size))
            attn = tf.keras.activations.softmax(attn, axis=-1)
        else:
            attn = tf.keras.activations.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        x_qkv = attn @ v
        x_qkv = tf.keras.ops.transpose(x_qkv, (0, 2, 1, 3))
        x_qkv = tf.keras.ops.reshape(x_qkv, (-1, size, channels))
        x_qkv = self.proj(x_qkv)
        x_qkv = self.dropout(x_qkv)
        return x_qkv

#### Swin Transformer

In [128]:
class SwinTransformer(tf.keras.layers.Layer):
    def __init__(
        self,
        dimension,
        number_of_patches,
        number_of_heads,
        window_size=7,
        shift_size=0,
        number_of_MLP=1024,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.dimension = dimension  # number of input dimension
        self.number_of_patches = number_of_patches  # number of embedded patches
        self.number_of_heads = number_of_heads  # number of attention heads
        self.window_size = window_size  # size of window
        self.shift_size = shift_size  # size of window shift
        self.number_of_MLP = number_of_MLP  # number of MLP nodes

        self.norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(
            dimension,
            window_size=(self.window_size, self.window_size),
            number_of_heads=number_of_heads,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate,
        )
        self.drop_path = tf.keras.layers.Dropout(dropout_rate)
        self.norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)

        self.mlp = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(number_of_MLP),
                tf.keras.layers.Activation(tf.keras.activations.gelu),
                tf.keras.layers.Dropout(dropout_rate),
                tf.keras.layers.Dense(dimension),
                tf.keras.layers.Dropout(dropout_rate),
            ]
        )

        if min(self.number_of_patches) < self.window_size:
            self.shift_size = 0
            self.window_size = min(self.number_of_patches)

    def build(self, input_shape):
        if self.shift_size == 0:
            self.attn_mask = None
        else:
            height, width = self.number_of_patches
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            mask_array = np.zeros((1, height, width, 1))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    mask_array[:, h, w, :] = count
                    count += 1
            mask_array = tf.keras.ops.convert_to_tensor(mask_array)

            # mask array to windows
            mask_windows = window_partition(mask_array, self.window_size)
            mask_windows = tf.keras.ops.reshape(
                mask_windows, [-1, self.window_size * self.window_size]
            )
            attn_mask = tf.keras.ops.expand_dims(mask_windows, axis=1) - tf.keras.ops.expand_dims(
                mask_windows, axis=2
            )
            attn_mask = tf.keras.ops.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = tf.keras.ops.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = tf.keras.Variable(
                initializer=attn_mask,
                shape=attn_mask.shape,
                dtype=attn_mask.dtype,
                trainable=False,
            )

    def call(self, x, training=False):
        height, width = self.number_of_patches
        _, num_patches_before, channels = x.shape
        x_skip = x
        x = self.norm1(x)
        x = tf.keras.ops.reshape(x, (-1, height, width, channels))
        if self.shift_size > 0:
            shifted_x = tf.keras.ops.roll(
                x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
            )
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = tf.keras.ops.reshape(
            x_windows, (-1, self.window_size * self.window_size, channels)
        )
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows = tf.keras.ops.reshape(
            attn_windows,
            (-1, self.window_size, self.window_size, channels),
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, height, width, channels
        )
        if self.shift_size > 0:
            x = tf.keras.ops.roll(
                shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
            )
        else:
            x = shifted_x

        x = tf.keras.ops.reshape(x, (-1, height * width, channels))
        x = self.drop_path(x, training=training)
        x = x_skip + x
        x_skip = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x_skip + x
        return x

#### Patch Embed / Extract

In [129]:
def patch_extract(images):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=(1, patch_size[0], patch_size[1], 1),
        strides=(1, patch_size[0], patch_size[1], 1),
        rates=(1, 1, 1, 1),
        padding="VALID",
    )
    patch_dimension = patches.shape[-1]
    patch_number = patches.shape[1]
    return tf.reshape(patches, (batch_size, patch_number * patch_number, patch_dimension))


class PatchEmbedding(tf.keras.layers.Layer):
    def __init__(self, number_of_patches, embedding_dimension, **kwargs):
        super().__init__(**kwargs)
        self.number_of_patches = number_of_patches
        self.proj = tf.keras.layers.Dense(embedding_dimension)
        self.pos_embed = tf.keras.layers.Embedding(input_dim=number_of_patches, output_dim=embedding_dimension)

    def call(self, patch):
        pos = tf.keras.ops.arange(start=0, stop=self.number_of_patches)
        return self.proj(patch) + self.pos_embed(pos)


class PatchMerging(tf.keras.layers.Layer):
    def __init__(self, number_of_patches, embedding_dimension):
        super().__init__()
        self.number_of_patches = number_of_patches
        self.embedding_dimension = embedding_dimension
        self.linear_trans = tf.keras.layers.Dense(2 * embedding_dimension, use_bias=False)

    def call(self, x):
        height, width = self.number_of_patches
        _, _, C = x.shape
        x = tf.keras.ops.reshape(x, (-1, height, width, C))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = tf.keras.ops.concatenate((x0, x1, x2, x3), axis=-1)
        x = tf.keras.ops.reshape(x, (-1, (height // 2) * (width // 2), 4 * C))
        return self.linear_trans(x)

### Training/Validation/Test Data Split

#### Training Images

In [130]:
training_set = tf.keras.utils.image_dataset_from_directory(
    data_directory,
    labels="inferred",
    label_mode="categorical",
    class_names=None,
    color_mode="rgb",
    batch_size=batch_size,
    image_size=image_size,
    shuffle=True,
    validation_split=0.3,
    subset="training",
    interpolation="bilinear",
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    verbose=True,
    seed=seed_value,
)

Found 41276 files belonging to 16 classes.
Using 28894 files for training.


#### Validation Images

In [131]:
validation_set = tf.keras.utils.image_dataset_from_directory(
    data_directory,
    labels="inferred",
    label_mode="categorical",
    class_names=None,
    color_mode="rgb",
    batch_size=32,
    image_size=image_size,
    shuffle=True,
    validation_split=0.3,
    subset="validation",
    interpolation="bilinear",
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    verbose=True,
    seed=seed_value,
)

Found 41276 files belonging to 16 classes.
Using 12382 files for validation.


#### Test Images

In [132]:
number_of_validation_batches = tf.data.experimental.cardinality(validation_set)
test_set = validation_set.skip((number_of_validation_batches * 2) // 3)
validation_set = validation_set.take((number_of_validation_batches * 2) // 3)

#### Data Augmentation

In [133]:
augment_data = tf.keras.Sequential([
    tf.keras.layers.Rescaling(1./255),
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(0.2),
    tf.keras.layers.RandomZoom(0.1),
])

In [134]:
AUTOTUNE = tf.data.AUTOTUNE

training_set = training_set.map(lambda x, y: (augment_data(x, training=True), y))
training_set = training_set.prefetch(buffer_size=AUTOTUNE)
validation_set = validation_set.map(lambda x, y: (x / 255.0, y)).prefetch(buffer_size=AUTOTUNE)
test_set = test_set.map(lambda x, y: (x / 255.0, y)).prefetch(buffer_size=AUTOTUNE)

### Building the Model

In [135]:
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout

In [136]:
input = tf.keras.layers.Input(shape=(256, 12))
print(number_of_patches_x, number_of_patches_y)
x = PatchEmbedding(number_of_patches_x * number_of_patches_y, embedding_dimension)(input)
x = SwinTransformer(
    dimension=embedding_dimension,
    number_of_patches=(number_of_patches_x, number_of_patches_y),
    number_of_heads=number_of_heads,
    window_size=window_size,
    shift_size=0,
    number_of_MLP=number_of_MLP,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)

x = SwinTransformer(
    dimension=embedding_dimension,
    number_of_patches=(number_of_patches_x, number_of_patches_y),
    number_of_heads=number_of_heads,
    window_size=window_size,
    shift_size=shift_size,
    number_of_MLP=number_of_MLP,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)

x = PatchMerging((number_of_patches_x, number_of_patches_y), embedding_dimension=embedding_dimension)(x)
x = tf.keras.layers.GlobalAveragePooling1D()(x)
output = tf.keras.layers.Dense(number_of_classes, activation="softmax")(x)

16 16


InvalidArgumentError: {{function_node __wrapped__Reshape_device_/job:localhost/replica:0/task:0/device:GPU:0}} Input to reshape is a tensor with 256 values, but the requested shape requires a multiple of 1024 [Op:Reshape]

### Compiling the Model

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
model.summary()

### Training the Model

In [None]:
training_history = model.fit(x=training_set, validation_data=validation_set, epochs=10)

### Evaluate the Model

### Save the Model

In [None]:
model.save("swin_model_trained.keras")

### Recording the training history 

In [None]:
with open("training_hist.json", "w") as f:
    json.dump(training_history.history, f)

### Metrics Evaluation and Visualization

#### Accuracy

#### Classification Report

#### Confusion Matrix