In [126]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
import matplotlib.pyplot as plt
import numpy as np 
import math 
import os
import csv
import pandas as pd
from PIL import Image
from torchvision.io import read_image
import torchvision.transforms as transforms

In [60]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using:', device)

Using: cuda


In [75]:
# Automating the generation of a csv file with all directories, image names and their lables
# ---directory----image_name----binary_category----category

with open('./lables.csv', 'w') as file:
    writer = csv.writer(file)
    writer.writerow(["directory", "image_name", "binary_cat", "multi_cat"])
    
    categories = {
        "healthy": 0,
        "unhealthy" : 1,
        "Anthracnose" : 2,
        "Bacterial Blight": 3,
        "Black Spot": 4,
        "Citrus Canker": 5, 
        "Citrus Hindu Mite":6,
        "Citrus Leafminer":7,
        "Curl Leaf":8,
        "Deficiency":9,
        "Dry Leaf":10,
        "Greening":11,
        "Melanose":12,
        "Sooty Mould":13,
        "Spider Mites":14 
    }
    
    for path, subdirs, files in os.walk('./dataset'):
        if "./dataset/healthy" in path:
            for img in files:
                writer.writerow([path, img, categories["healthy"], categories["healthy"]])
        if "./dataset/unhealthy" in path:
            for img in files:
                # path[20:] cut out the unnecessary directory snince we only needthe diseasw name, which is the latter part of the directory
                multi = path[20:]
                writer.writerow([path, img, categories["unhealthy"], categories[multi]])

f.close()

In [114]:
class CitrusImageDataset(Dataset):
    """
    A custom dataset class for Citrus leaf image dataset.
    - Resizes images to a specified width and height. 
    - Imaplements methods to get dataset items and dataset length (as reccomended by pytorch documentation)
    - Adds an argument for common transformation for the image dataset

    Args:
        annotations_file: takes in the cvs file that contains lables of images 
        img_dir: direcotry in which images are stored
        transform: to modify the features, needed to manipulate data to make it suitable for training
        target_transform: to modify the labels - that accept callables containing the transformation logic
        
    """
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.directory_list = img_dir
        self.data_paths = []
        self.transform = transform
        self.target_transform = target_transform

        for directory in img_dir: # healthy or unhealthy dirs 
            for filename in os.listdir(directory):
                if ".jpg" in filename:
                    self.data_paths.append(os.path.join(directory, filename))
                    #print("dircoties:" + os.path.join(directory, filename))
                else:
                    for f in os.listdir("./dataset/unhealthy/"): # since unhealthy folder has subfolders with images
                        path = os.path.join("./dataset/unhealthy/", f)
                        for ff in os.listdir(path):
                            self.data_paths.append(os.path.join(path, ff))
                            #print("dircoties:" + os.path.join(path, ff))
        
        

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        image = Image.open(data_path)
        binary_label = self.img_labels.iloc[idx, 2]
        multi_cat_label = self.img_labels.iloc[idx, -1]
        
        if self.transform:
            image = self.transform(image)
        return image, binary_label, multi_cat_label

In [124]:
# Defining transformations - using industry standard, can try to customize later
"""
Normalize does: image = (image - mean)/std;  mean for all 3 channels: [0.485, 0.456, 0.406], std for all 3 channels: [0.229, 0.224, 0.225] 
"""
img_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((224,224)),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet's mean and std
])
labels_dir = "./lables.csv"
img_dir = ["./dataset/healthy", "./dataset/unhealthy"]
dataset = CitrusImageDataset(labels_dir, img_dir, img_transforms)
#print(len(dataset))


In [129]:
batch_size = 100
val_size = 2000
train_size = 38000
test_size = 3017
train_data, test_data, val_data = random_split(dataset, [train_size,test_size, val_size])
train_dl = DataLoader(train_data, batch_size, shuffle = True, num_workers = 4, pin_memory = True)
val_dl = DataLoader(val_data, batch_size*2, num_workers = 4, pin_memory = True)

In [130]:
def show_batch(dl):
    """Plot images grid of single batch"""
    for images, labels in dl:
        fig,ax = plt.subplots(figsize = (16,12))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(make_grid(images,nrow=16).permute(1,2,0))
        break
        
show_batch(train_dl)

IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/data/dataset.py", line 399, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/data/dataset.py", line 399, in <listcomp>
    return [self.dataset[self.indices[idx]] for idx in indices]
            ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_229/3592318900.py", line 42, in __getitem__
    binary_label = self.img_labels.iloc[idx, 2]
                   ~~~~~~~~~~~~~~~~~~~~^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/pandas/core/indexing.py", line 1183, in __getitem__
    return self.obj._get_value(*key, takeable=self._takeable)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/pandas/core/frame.py", line 4212, in _get_value
    return series._values[index]
           ~~~~~~~~~~~~~~^^^^^^^
IndexError: index 38954 is out of bounds for axis 0 with size 3513
