In [1]:
import os
os.environ['KERAS_BACKEND'] = 'torch'
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
from keras import layers, models
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2 as transforms


In [None]:
print('CUDA available:', torch.cuda.is_available())

In [None]:
classes = {}

for dir_name in os.listdir("dataset/disease-classification"):
    crop, class_name = dir_name.split("___")
    classes[crop] = classes.get(crop, []) + [class_name]


chosen_classes = {}
for crop, class_names in classes.items():
    if 'healthy' in class_names and len(class_names) > 1:
        chosen_classes[crop] = class_names
    else:
        print(crop, "doesnt meet criteria:", class_names)

print()
for crop, class_names in chosen_classes.items():
    print(crop, len(class_names), class_names)


In [4]:

def make_model(crop):
    model = models.Sequential(
        name=f"model_for_{crop}",
        layers=[
            layers.Input(shape=(3, 256, 256)),
            # Block 1
            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            # Block 2
            layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            # Block 3
            layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            # Flatten
            layers.Flatten(),
            # Dense
            layers.Dense(512, activation="relu"),
            layers.Dense(128, activation="relu"),
            layers.Dense(1, activation="sigmoid"),
        ]
    )
    model.compile(
        optimizer="adam",
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )

    # model.summary()
    return model

In [5]:

def make_dataloader_factory(crop):
    def is_valid_file(filepath):
        return f'{crop}___' in filepath

    dataset = ImageFolder(
        root="dataset/disease-classification/",
        transform=transforms.Compose([
            transforms.ToImage(),
            transforms.ToDtype(torch.uint8),
            transforms.RandomHorizontalFlip(),

            transforms.RandomPosterize(bits=4, p=0.3),
            transforms.RandomPosterize(bits=5, p=0.3),
            transforms.RandomAffine(degrees=180, translate=(0.1, 0.5), scale=(0.6, 1.4)),
            transforms.RandomAutocontrast(),
            transforms.RandomAdjustSharpness(sharpness_factor=0.6, p=0.2),
            transforms.RandomAdjustSharpness(sharpness_factor=1.4, p=0.2),

            transforms.ToDtype(torch.float32, scale=True),
        ]),
        is_valid_file=is_valid_file,
        allow_empty=True,
    )

    print(f"Number of samples for {crop}: {len(dataset.samples)}")

    def factory(batch_size):
        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=8,
        )
        return dataloader

    return factory, len(dataset.samples)

In [6]:

def dataloader_wrapper(factory, batch_size, repeat):
    dataloader = factory(batch_size)
    for _ in range(repeat):
        for imgs, labels in dataloader:
            labels[labels == dataloader.dataset.class_to_idx[f"{crop}___healthy"]] = 0.0
            labels[labels != dataloader.dataset.class_to_idx[f"{crop}___healthy"]] = 1.0
            yield imgs, labels


def train_model(crop, batch_size=32, dataset_repeat=5, epochs=3):
    model = make_model(crop)
    dataloader_factory, num_samples = make_dataloader_factory(crop)

    steps_per_epoch = num_samples * dataset_repeat // batch_size // epochs

    model.fit(
        dataloader_wrapper(dataloader_factory, batch_size, dataset_repeat),
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
    )
    return model

In [None]:
crop_to_filename_dict = {
    'Apple': 'apple',
    'Cherry_(including_sour)': 'cherry',
    'Corn_(maize)': 'corn',
    'Grape': 'grape',
    'Peach': 'peach',
    'Pepper,_bell': 'pepper-bell',
    'Potato': 'potato',
    'Strawberry': 'strawberry',
    'Tomato': 'tomato',
}

for crop, filename in crop_to_filename_dict.items():
    m = train_model(crop, dataset_repeat=3, epochs=1)
    m.save(f"saved-models/disease-classification/{filename}.keras")