### Environment information from Kaggle
You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## Imports
We'll keep our imports here for easier reference

In [1]:
import pandas as pd  # To easily read csv files (image filenames & label)
import os  # For building paths
import json  # To parse the label map
from torch.utils.data import Dataset, DataLoader  # We'll build a custom dataset & dataloader
import torchvision.transforms  # To transform our images to the shape needed by torch
import torch
from PIL import Image  # Python Image Library for reading our image files

## GPU Use
See if we have an available CUDA-enabled GPU. We'll move data/models to the selected device later on.

In [2]:
if not torch.cuda.is_available():
    print('No CUDA-enabled GPU detected. Please enable GPU acceleration for faster training.')
    device = 'cpu'
else:
    device = torch.device('cuda:0')
    print(f'CUDA-enabled GPU detected. Using device: {device}')

CUDA-enabled GPU detected. Using device: cuda:0


In [3]:
INPUT_PATH = os.path.join(os.sep, 'kaggle', 'input', 'cassava-leaf-disease-classification')
TRAIN_IMAGES_PATH = os.path.join(INPUT_PATH, 'train_images')
TEST_IMAGES_PATH = os.path.join(INPUT_PATH, 'test_images')

# Explore what we have

# train.csv contains image filenames & their class label corresponding to the leaf disease (or healthy) 
train_labels = pd.read_csv(os.path.join(INPUT_PATH, 'train.csv'))
# print(train_labels.head())
# print(train_labels['image_id'][[0, 2]].values)
print(f'Number of training images: {train_labels.shape[0]}')

Number of training images: 21397


## Load label meanings
The file label_num_to_disease_map.json maps our class labels (integers) to the actual meaning (disease name or healthy)


In [4]:
with open(os.path.join(INPUT_PATH, 'label_num_to_disease_map.json')) as json_file:
    LABEL_MAP = json.load(json_file)
    
for k, v in LABEL_MAP.items():
    print(f'{k}: {v}')

0: Cassava Bacterial Blight (CBB)
1: Cassava Brown Streak Disease (CBSD)
2: Cassava Green Mottle (CGM)
3: Cassava Mosaic Disease (CMD)
4: Healthy


# Custom Dataset
We'll need a custom dataset, inheriting from `torch.utils.data.Dataset`, that can load images & labels as needed. We'll go ahead & transform images to tensors in this class as well.

In [5]:
class TrainingDataset(Dataset):
    def __init__(self, train_images_path):
        self.transform = torchvision.transforms.ToTensor()
        self.labels_df = pd.read_csv(os.path.join(INPUT_PATH, 'train.csv'))
        self.images_path = train_images_path
        
    def __len__(self):
        return self.labels_df.shape[0]
    
    def __getitem__(self, idx):            
        img_name = os.path.join(self.images_path, self.labels_df['image_id'].iloc[idx])
        img = self.transform(Image.open(img_name))
        label = self.labels_df['label'][idx]
        
        return img, label

# Test that our dataset works
Here we just instantiate our dataset & confirm that we're getting what we expect from it

In [6]:
train_dataset = TrainingDataset(TRAIN_IMAGES_PATH)

assert len(train_dataset) == train_labels.shape[0]

for i in range(5):
    img, label = train_dataset[i]
    print(f'Image {i}')
    print(f'\ttype: {type(img)}, shape: {img.shape}, label: {label}')

Image 0
	type: <class 'torch.Tensor'>, shape: torch.Size([3, 600, 800]), label: 0
Image 1
	type: <class 'torch.Tensor'>, shape: torch.Size([3, 600, 800]), label: 3
Image 2
	type: <class 'torch.Tensor'>, shape: torch.Size([3, 600, 800]), label: 1
Image 3
	type: <class 'torch.Tensor'>, shape: torch.Size([3, 600, 800]), label: 1
Image 4
	type: <class 'torch.Tensor'>, shape: torch.Size([3, 600, 800]), label: 3


## Use DataLoader
The DataLoader class gives us batching. We'll also need a custom callback to pin memory. Pinned memory is faster when transferring from CPU to GPU which we'll be doing to speed up training.

In [7]:
class ImageBatch:
    def __init__(self, data):
        self.images = torch.stack([datum[0] for datum in data])
        self.labels = [datum[1] for datum in data]
        
    def pin_memory(self):
        self.images = self.images.pin_memory()
        return self
    
        
def collate_wrapper(batch):
    return ImageBatch(batch)

### Test DataLoader
Below we check that our data loader is returning what we expect and that our memory is actually pinned. Note that our memory will not be pinned if our device is a CPU rather than GPU.

In [8]:
dl = DataLoader(train_dataset, batch_size=10, collate_fn=collate_wrapper, pin_memory=True)
for batch_index, image_batch in enumerate(dl):
    if batch_index == 5:
        break
    print(f'Batch {batch_index} size: {len(image_batch.images[0])}. is_pinned: {image_batch.images[0].is_pinned()}')

Batch 0 size: 3. is_pinned: True
Batch 1 size: 3. is_pinned: True
Batch 2 size: 3. is_pinned: True
Batch 3 size: 3. is_pinned: True
Batch 4 size: 3. is_pinned: True
