# Building the Model

In [None]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import PIL
from PIL import Image
import random
from sklearn.model_selection import train_test_split
import time
import torch
from torch import nn
import torch.multiprocessing as mp
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms as T
from torchvision.io import read_image
import torchvision.models as models
from torchvision.transforms import Lambda
from tqdm.notebook import tqdm

%matplotlib inline

Now we have our `image_labels.csv` linking to the images in `/domino/datasets/local/serengeti-small-dataset`

In [None]:
img_dir = 'data'
#to use all of the species clases read in image_labels'csv
annotations_file = 'labels_reduced_classes.csv'

labels = pd.read_csv(annotations_file)
species = sorted(labels['question__species'].unique())
species_to_idx = dict(zip(species,range(len(species))))
idx_to_species = {v: k for k, v in species_to_idx.items()}

We can set the number of images to use for each class

In [None]:
#num_samples (per class) must be less than 1000
num_samples = 200

small_dfs = []
for animal in species:
    small_df = labels.loc[labels['question__species'] == animal]
    small_df = small_df[['image_name','question__species']].sample(num_samples, random_state=42)
    small_dfs.append(small_df)

reduced_images = pd.concat(small_dfs, ignore_index=True)
reduced_images.to_csv('reduced_images.csv', index=False)


#change annotations file location
annotations_file = 'reduced_images.csv'
labels = pd.read_csv(annotations_file)

labels['question__species'].value_counts()

We can run the following code to see the counts of classes and the dictionary `species_to_idx` mapping species to class indices

In [None]:
with pd.option_context('display.max_rows', 999):
    print(labels['question__species'].value_counts())   
print('---\nspecies_to_idx: ' + str(species_to_idx))

## Building a Custom Dataset

To define your own, you need to inherit from the predefined `Dataset` class and implement three methods:
* `__init__`
* `__len__` so that `len(dataset)` returns the size of the dataset
* `__getitem__` such that `dataset[i]` can be used to get `i`th sample

In [None]:
class SnapshotSerengetiDataset(Dataset):
    def __init__(self, annotations_file, img_dir, class_dict, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.class2index = class_dict

    def __len__(self):
        return len(self.img_labels)
        
    def __getitem__(self, idx):
        image_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(image_path)
        image = torch.mul(image, 1/255.) # scale to [0, 1]
        label = self.class2index[self.img_labels.iloc[idx, 1]]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

Here we have the image transforms discussed in the `2-Load-Data-PyTorch.ipynb` notebook as well as a transform to one-hot encode the species index.

In [None]:
train_transform = T.Compose([T.Resize(256),
                             T.RandomRotation(30),
                             T.RandomHorizontalFlip(),
                             T.CenterCrop(224),
                             T.ConvertImageDtype(torch.float32),
                             T.Normalize(mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])])

val_transform = T.Compose([T.Resize(256),
                           T.CenterCrop(224),
                           T.ConvertImageDtype(torch.float32),
                           T.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])])

target_transform = Lambda(lambda y: torch.zeros(len(species_to_idx), dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

Now we split the data into stratified train and validation sets using scikit-learn's `train_test_split`.

It's good to remember that you can still import and use functions from tools you are familiar with even if you aren't using them elsewhere in a project.

In [None]:
random_seed= 42
val_size = .1

def stratified_split(annotations_file, test_size = 0.2):
    img_labels = pd.read_csv(annotations_file)
    indices = img_labels.index
    labels = img_labels[['question__species']]
    train_indices, test_indices, _, _ = train_test_split(indices, labels, stratify=labels, test_size=test_size, random_state=random_seed)
    return train_indices, test_indices

train_indices, val_indices = stratified_split(annotations_file=annotations_file, test_size=val_size)

dataset_size = len(train_indices) + len(val_indices)

print('Train counts\n---')
print(labels['question__species'].iloc[train_indices].value_counts())
print('\nValidation counts\n---')
print(labels['question__species'].iloc[val_indices].value_counts())

Now that we have our train and validation indices, we can generate an instance of the dataset for each (with the appropriate transforms). We also generate a random sampler based on the indices to ensure that images are shuffled across epochs.

In [None]:
train_dataset = SnapshotSerengetiDataset(annotations_file=annotations_file, img_dir=img_dir, class_dict=species_to_idx, transform=train_transform, target_transform=target_transform)
val_dataset = SnapshotSerengetiDataset(annotations_file=annotations_file, img_dir=img_dir, class_dict=species_to_idx, transform=val_transform, target_transform=target_transform)

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

## Dataloaders

Pytorch dataloaders wrap an iterable around a dataset and make it easier to feed data to a model during training.

The relevant parameter here is batch size.

In [None]:
batch_size = 128

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler)

Printing the tensor sizes and other intermediate data is one of the easiest ways to confirm that we are prepping the data as we expect.

In [None]:
#Printing out tensor sizes
images, labels = next(iter(train_loader))
print(f"Feature batch shape: {images.size()}")
print(f"Labels batch shape: {labels.size()}")
label = labels[0]
print("Label: " + idx_to_species[label.argmax(0).item()])

image = images[0]
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = std * image.permute(1, 2, 0).numpy() + mean
image = np.clip(image, 0, 1)
plt.imshow(image)
plt.show()

## Model Training Loops

One of pytorch's greatest strengths is that you explicitly control what happens at every step of the model training.

The `train_model` function shown here loops through the provided number of epochs to complete the train and validation steps at each.

Depending on what you're trying to have your model do, your training function may be significantly more or less complicated.

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase]):
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.argmax(1))
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

## Importing a model

Although it's good to understand the relevant deep learning model components (e.g. convolutional layers, ReLU) you generally don't need to build an architecture from scratch.

Instead, you can find a model architecture that has already been proven to perform well on a similar problem ane start with that!

To facilitate this, pytorch has a library of popular and state-of-the-art model architectures that you can import into your project. We'll use the resnet50 architecture here, because it has been shown to work well for animal recognition.

## Transfer Learning

Large models can require a lot of data and are expensive to train. One common approach to mitigate this is transfer learning.

Transfer learning is using what was learned for one task to solve a different task.

Here, we use pytorch's pretrained resnet50 (trained on the popular 1000-class ImageNet database). To exclude the pretrained features of the model from training we freeze all but the final layer, which maps the features to the number of expected classes. We also change the final layer to have an output neuron for each species class.

In [None]:
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(species_to_idx)) # Rescale fully-connected layer output size
num_frozen_layers = 9

# Freeze `num_frozen_layers`
layer = 0
for child in model.children():
    layer += 1
    if layer <= num_frozen_layers:
        for param in child.parameters():
            param.requires_grad = False
print('Number of unfrozen layers: ' + str(layer-num_frozen_layers))

You can use the command below to see details about what's happening in each layer of our model.

In [None]:
for child in model.children():
    print(child)

## Time to Train!

Now we're ready to train our model. We first do final formatting of our inputs and define a loss function and optimizer.

In [None]:
dataloaders = {'train': train_loader, 'val': val_loader}
dataset_sizes = {'train': dataset_size*(1-val_size), 'val': dataset_size*val_size}

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=5e-3, momentum=0.9, weight_decay=5e-3)

model = train_model(model, loss_fn, optimizer, 1, num_epochs=1)

## Visualizing the Output

One of the most important things you can do is **look at your data!**

Visualizing samples from your model with the predicted label can give you confidence that it is performing as expected.

In [None]:
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0


    with torch.no_grad():
        for i, (images, labels) in enumerate(dataloaders['val']):
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            figure = plt.figure(figsize=(16, 16))
            for j in range(num_images):
                images_so_far += 1
                figure.add_subplot(num_images//2, 2, images_so_far)
                figure.tight_layout()
                plt.axis('off')
                image = images.cpu().data[j]
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                image = std * image.permute(1, 2, 0).numpy() + mean
                image = np.clip(image, 0, 1)
                plt.title('predicted: {}'.format(idx_to_species[preds[j].item()]))
                plt.imshow(image)
                
                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
            plt.show()
                
        model.train(mode=was_training)


visualize_model(model, num_images=6)