In [6]:
# imports
import os
import shutil
import csv
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.io as io
import torchvision.transforms as T
from torch.utils.data import Dataset

# Steps in Preprocessing

1. Read in the whole PlantVillage dataset
2. Set all images to uniform dimensions
3. Apply a number of transformations to the PlantVillage dataset
4. Create two output dataset folders split into two categories, healthy and unhealthy:
   1. Using unaugmented images
   2. Using augmented images
5. Create custom PyTorch datasets corresponding to these folders
6. Repeat steps (1) to (5) with the PlantLeaves and PlantaeK datasets (combine these two datasets)

In [2]:
cur_dir = os.getcwd()
train_src_dir = os.path.join(cur_dir, 'Plant_leave_diseases_dataset_with_augmentation')
train_dst_dir = os.path.join(os.path.join(os.path.join(cur_dir, 'load_dataset'), 'dataset'), 'train')

In [5]:
train_set = []
img_index = 0
img_size = (256, 256) # minimum size for ImageNet is 224x224, but default dataset is mostly 256x256; might tweak
transform = T.Resize(size=img_size)

# iterate through all the subfolders (where each subfolder corresponds to species + healthy/disease)
# 
for subdir in os.listdir(train_src_dir):
    if subdir == 'Background_without_leaves':
        # ignore this case
        pass
    else:
        # iterate through all the files within the subfolder
        superdir = os.path.join(train_src_dir, subdir)
        for src_filename in os.listdir(superdir):
            
            # destination filename
            dst_filename = 'img{}.jpg'.format(img_index)
            
            # print(src_filename)
            img = io.read_image(os.path.join(superdir, src_filename), mode=io.ImageReadMode.RGB)
            img = transform(img)

            # print(img.dtype)

            io.write_jpeg(img, os.path.join(train_dst_dir, dst_filename), quality=90)

            if subdir.endswith('healthy'):
                # 1 to indicate it is healthy
                train_set.append([dst_filename, 1])
            else:
                # 0 to indicate it is diseased
                train_set.append([dst_filename, 0])
            img_index += 1