In [1]:
import os
os.environ['KERAS_BACKEND'] = 'torch'
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
from keras import layers, models
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2 as transforms


In [2]:
import keras
keras.backend.backend()

'torch'

In [3]:
print('CUDA available:', torch.cuda.is_available())

CUDA available: True


In [4]:
classes = {}

for dir_name in os.listdir("dataset/disease-classification"):
    crop, class_name = dir_name.split("___")
    classes[crop] = classes.get(crop, []) + [class_name]


chosen_classes = {}
for crop, class_names in classes.items():
    if 'healthy' in class_names and len(class_names) > 1:
        chosen_classes[crop] = class_names
    else:
        print(crop, "doesnt meet criteria:", class_names)

print()
for crop, class_names in chosen_classes.items():
    print(crop, len(class_names), class_names)


Squash doesnt meet criteria: ['Powdery_mildew']
Soybean doesnt meet criteria: ['healthy']
Blueberry doesnt meet criteria: ['healthy']
Raspberry doesnt meet criteria: ['healthy']
Orange doesnt meet criteria: ['Haunglongbing_(Citrus_greening)']

Cherry_(including_sour) 2 ['healthy', 'Powdery_mildew']
Tomato 10 ['Bacterial_spot', 'Tomato_mosaic_virus', 'Late_blight', 'Target_Spot', 'Tomato_Yellow_Leaf_Curl_Virus', 'Leaf_Mold', 'Early_blight', 'healthy', 'Septoria_leaf_spot', 'Spider_mites Two-spotted_spider_mite']
Strawberry 2 ['Leaf_scorch', 'healthy']
Corn_(maize) 4 ['Common_rust_', 'Northern_Leaf_Blight', 'healthy', 'Cercospora_leaf_spot Gray_leaf_spot']
Apple 4 ['Apple_scab', 'Black_rot', 'healthy', 'Cedar_apple_rust']
Grape 4 ['Black_rot', 'Leaf_blight_(Isariopsis_Leaf_Spot)', 'healthy', 'Esca_(Black_Measles)']
Potato 3 ['healthy', 'Late_blight', 'Early_blight']
Pepper,_bell 2 ['healthy', 'Bacterial_spot']
Peach 2 ['healthy', 'Bacterial_spot']


In [5]:

def make_model(crop):
    model = models.Sequential(
        name=f"model_for_{crop}",
        layers=[
            layers.Input(shape=(3, 256, 256)),
            # Block 1
            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            # Block 2
            layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            # Block 3
            layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            # Flatten
            layers.Flatten(),
            # Dense
            layers.Dense(512, activation="relu"),
            layers.Dense(128, activation="relu"),
            layers.Dense(1, activation="sigmoid"),
        ]
    )
    model.compile(
        optimizer="adam",
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )

    # model.summary()
    return model

In [6]:

def make_dataloader_factory(crop):
    def is_valid_file(filepath):
        return f'{crop}___' in filepath

    dataset = ImageFolder(
        root="dataset/disease-classification/",
        transform=transforms.Compose([
            transforms.ToImage(),
            transforms.ToDtype(torch.uint8),
            transforms.RandomHorizontalFlip(),

            transforms.RandomPosterize(bits=4, p=0.3),
            transforms.RandomPosterize(bits=5, p=0.3),
            transforms.RandomAffine(degrees=180, translate=(0.1, 0.5), scale=(0.6, 1.4)),
            transforms.RandomAutocontrast(),
            transforms.RandomAdjustSharpness(sharpness_factor=0.6, p=0.2),
            transforms.RandomAdjustSharpness(sharpness_factor=1.4, p=0.2),

            transforms.ToDtype(torch.float32, scale=True),
        ]),
        is_valid_file=is_valid_file,
        allow_empty=True,
    )

    print(f"Number of samples for {crop}: {len(dataset.samples)}")

    def factory(batch_size):
        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=8,
        )
        return dataloader

    return factory, len(dataset.samples)

In [7]:

def dataloader_wrapper(factory, batch_size, repeat):
    dataloader = factory(batch_size)
    for _ in range(repeat):
        for imgs, labels in dataloader:
            labels[labels == dataloader.dataset.class_to_idx[f"{crop}___healthy"]] = 0.0
            labels[labels != dataloader.dataset.class_to_idx[f"{crop}___healthy"]] = 1.0
            yield imgs, labels


def train_model(crop, batch_size=32, dataset_repeat=5, epochs=3):
    model = make_model(crop)
    dataloader_factory, num_samples = make_dataloader_factory(crop)

    steps_per_epoch = num_samples * dataset_repeat // batch_size // epochs

    model.fit(
        dataloader_wrapper(dataloader_factory, batch_size, dataset_repeat),
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
    )
    return model

In [8]:
crop_to_filename_dict = {
    'Apple': 'apple',
    'Cherry_(including_sour)': 'cherry',
    'Corn_(maize)': 'corn',
    'Grape': 'grape',
    'Peach': 'peach',
    'Pepper,_bell': 'pepper-bell',
    'Potato': 'potato',
    'Strawberry': 'strawberry',
    'Tomato': 'tomato',
}

for crop, filename in crop_to_filename_dict.items():
    m = train_model(crop, dataset_repeat=3, epochs=1)
    m.save(f"saved-models/disease-classification/{filename}.keras")

    # Delete Model
    del m

    # Clear GPU Memory
    keras.backend.clear_session()

    # Remove cached tensors
    torch.cuda.empty_cache()

Number of samples for Apple: 7771
[1m728/728[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 136ms/step - accuracy: 0.9951 - loss: 0.0068


2024-12-17 01:26:45.075464: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1734379005.090059  609712 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734379005.094293  609712 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-17 01:26:45.110677: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Number of samples for Cherry_(including_sour): 3509
[1m328/328[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 133ms/step - accuracy: 0.9879 - loss: 0.0135
Number of samples for Corn_(maize): 7316
[1m685/685[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m92s[0m 134ms/step - accuracy: 0.9913 - loss: 0.0073
Number of samples for Grape: 7222
[1m677/677[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 137ms/step - accuracy: 0.9899 - loss: 0.0078
Number of samples for Peach: 3566
[1m334/334[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 135ms/step - accuracy: 0.9845 - loss: 0.0133
Number of samples for Pepper,_bell: 3901
[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 134ms/step - accuracy: 0.9828 - loss: 0.0129
Number of samples for Potato: 5702
[1m534/534[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 132ms/step - accuracy: 0.9980 - loss: 0.0087
Number of samples for Strawberry: 3598
[1m337/337[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[