# Load Serengeti images into Pytorch

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import PIL
from PIL import Image
from sklearn.model_selection import train_test_split
import time
%matplotlib inline

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils import data
from torchvision import datasets, transforms as T
import torch.nn as nn
from torchvision.io import read_image
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.transforms import ToTensor, Lambda

In [None]:
img_dir = 'data'
#to use all of the species clases read in image_labels'csv
annotations_file = 'labels_reduced_classes.csv'
labels = pd.read_csv(annotations_file)

#shuffle dataset
labels = labels.sample(frac=1, random_state=42)

print('There are currently {} rows in the dataset'.format(labels.shape[0]))
labels.head()

In [None]:
#create dictionary for species
species = sorted(labels['question__species'].unique())
print('There are {} unique species in this dataset'.format(len(species)))

species_to_idx = dict(zip(species,range(len(species))))
idx_to_species = {v: k for k, v in species_to_idx.items()}
idx_to_species

## PyTorch Datasets
PyTorch provides two data primitives: `DataLoader` and `Dataset`. Datasets store the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples
To define your own, you need to inherit from the predefined Dataset class and implement three methods:

- `__init__`
- `__len__` so that `len(dataset)` returns the size of the dataset
- `__getitem__` such that `dataset[i]` can be used to get `i`th sample

In [None]:
#create a custom class for loading in the data to PyTorch, resizing, and creating the labels
class SnapshotSerengetiDataset(Dataset):
    def __init__(self, annotations_file, img_dir, class_dict, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.class2index = class_dict

    def __len__(self):
        return len(self.img_labels)
        
    def __getitem__(self, idx):
        image_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(image_path)
        image = torch.mul(image, 1/255.) # scale to [0, 1]
        label = self.class2index[self.img_labels.iloc[idx, 1]]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
#display sample of photos with no transforms
no_transforms = SnapshotSerengetiDataset(annotations_file, img_dir, species_to_idx)

figure = plt.figure(figsize=(16, 16))
cols, rows = 3, 3

for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(no_transforms), size=(1,)).item()
    img, species_label = no_transforms[sample_idx]
    label_name = list(species_to_idx.keys())[species_label]
    figure.add_subplot(rows, cols, i)
    figure.tight_layout()
    plt.title(label_name)
    plt.axis("off")
    plt.imshow(img.permute(1, 2, 0))
plt.show()

## PyTorch Transforms

Data does not always come in its final processed form that is required for training machine learning algorithms. We use transforms to perform some manipulation of the data and make it suitable for training.

The `torchvision.transforms` module offers several commonly-used transforms out of the box. You'll get to test these out at the end of this section

In [None]:
#use resize and normalize transforms so they can be used with pre-trained networks, and convert to a tensor
#this will be applied to both the test and validation set in later notebooks
standard_transform = T.Compose([T.Resize(256),
                           T.ConvertImageDtype(torch.float32),
                           T.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])])

In [None]:
#apply transforms
transformed_dataset = SnapshotSerengetiDataset(annotations_file=annotations_file, img_dir=img_dir, class_dict=species_to_idx, transform=standard_transform)

In [None]:
#set up DataLoader
batch_size = 16

transformed_loader = torch.utils.data.DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True)

In [None]:
#function to plot sample of transformed images
def imshow(dataloader, cols=3, rows=3):
    images, species_labels = next(iter(dataloader))
    
    figure = plt.figure(figsize=(16, 16))

    for i in range(1, cols * rows + 1):
        figure.add_subplot(rows, cols, i)
        figure.tight_layout()
        plt.axis("off")
        image = images[i]
        #use since the image is normalized:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image.permute(1, 2, 0).numpy() + mean
        image = np.clip(image, 0, 1)
        species_label = species_labels[i]
        plt.title(idx_to_species[species_label.item()])
        plt.imshow(image)
    plt.show()

In [None]:
#images with transforms - these are what will be used in the validation set
imshow(transformed_loader)

# Activity 

#### Visualize different PyTorch transforms

See this document on PyTorch transforms - https://pytorch.org/vision/stable/transforms.html

Here is a sample of Pytorch transforms to try out:

- Normalize 
- Resize
- Scale
- CenterCrop
- Pad
- RandomCrop
- RandomHorizontalFlip
- RandomVerticalFlip
- RandomResizedCrop
- RandomSizedCrop
- LinearTransformation
- ColorJitter 
- RandomRotation 
- RandomAffine 
- Grayscale
- RandomGrayscale
- RandomPerspective
- RandomErasing
- GaussianBlur
- InterpolationMode
- RandomInvert 
- RandomPosterize
- RandomSolarize
- RandomAdjustSharpness
- RandomAutocontrast
- RandomEqualize

In [None]:
#test out transforms for the training set by adding to the list below

my_transforms = [T.Resize(256),

                 #add transforms here by preceding the transform name with T. For example: 
                 #T.RandomRotation(30),
                 
                 T.ConvertImageDtype(torch.float32),
                 T.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])]

train_transform = T.Compose(my_transforms)
train_dataset = SnapshotSerengetiDataset(annotations_file=annotations_file, img_dir=img_dir, class_dict=species_to_idx, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
         
#images with train transforms
imshow(train_loader)