# FossilNET Classification

In this notebook, we are going to update a pre-trained ResNet network to target the FossilNET image dataset.


## What is FossilNET?

FossilNET is an image dataset collected and curated by [Matt Hall](https://github.com/kwinkunks). the dataset is mde available under [CC-0 License](../datsasets/fossilnet/fossilnet-copyright-info.md)

The dataset consists of 3000 128x128 color images acrosss 10 classes. The full dataset is available for experimentation. 

In the tutorial we just target 4 classes:

|  dinosaurs |  fishes  |  forams  |  trilobites
|:---:|:---:|:---:|:---:
| ![dino](../datasets/fossilnet/tvt_split/4/train/dinosaurs/00949.png) | ![fish](../datasets/fossilnet/tvt_split/4/train/fishes/01603.png) | ![forams](../datasets/fossilnet/tvt_split/4/train/forams/01923.png) | ![trilobites](../datasets/fossilnet/tvt_split/4/train/trilobites/02866.png)| | 






### Load Dependencies

We load the usual deps and also load [PyTorch](https://pytorch.org/docs/stable/index.html) and the [TorchVision](https://pytorch.org/docs/stable/torchvision/index.html) helper library to get access to pretrained models, dataloaders & transformers for image problems

In [None]:
%load_ext autoreload
%autoreload 2

from dependencies import *

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms

## Loading Data

FossilNET is nicely organised on disk into train / val / test folders each of which contains a subfolder for each class with the appropraite images therein.

This means we cna use a torchvision.dataset.ImageFolder to take care of loading and inject torchvision.transforms to pre-process and augment the dataset on the fly.

So create a function to set that up for the train and val splits and return dataloaders ready for use

In [None]:
def get_data_loaders(fossilnet_path,
                     batch_size=16,
                     augment_flip=True,
                     use_grayscale=True):
    
    #
    # We define an array (pipeline) of transformers that we then use Compose to present to the dataset
    #
    txs = []
    
    if use_grayscale:
        # convert to gray but maintain 3 channels for resnet
        txs.append(transforms.Grayscale(3))
    
    if augment_flip:
        txs.extend([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip()
        ])
    
    txs.append(transforms.ToTensor())
    
    if use_grayscale:
        txs.append(transforms.Normalize(0.0, 1.0))

    
    #
    # Use the torchvision ImageFolder Dataset class
    #
    train_dataset = datasets.ImageFolder(
                            root=path.join(fossilnet_path, 'train'),
                            transform=transforms.Compose(txs)
                        )
    
    #
    # Setup a DataLoader to get batches of images for training
    #
    train_loader = DataLoader(train_dataset, 
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=3,
                              pin_memory=True)
    
    
    #
    # Setup a DataSet and Loader for the test data. This time without shuffle or
    # augmentations enabled
    #
    val_txs = []
    
    if use_grayscale:
        val_txs.extend([
            transforms.Grayscale(3),
            transforms.ToTensor(),
            transforms.Normalize(0.0, 1.0)
        ])
    else:
        val_txs.append(transforms.ToTensor())
    
    val_dataset = datasets.ImageFolder(
                            root=path.join(fossilnet_path, 'test'),
                            transform=transforms.Compose(val_txs)
                        )
    
    val_loader = DataLoader(test_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=True)
    
    return train_loader, val_loader
            

### Define a Neural Network - We'll use ResNet18

We setup a simple pytorch module, load the weights and reset the last layer only, which we will retrain for our targets

[About ResNet18 Architecture](https://www.researchgate.net/figure/ResNet-18-Architecture_tbl1_322476121)

In [None]:
# cue cool name
class FossilResNet(nn.Module):
    
    def __init__(self, num_outputs=10):
        super(FossilResNet, self).__init__()
        
        # this will pull the weights down to a local cache on first execution
        self.model_conv = torchvision.models.resnet18(pretrained=True)
        
        # we turn of gradients on all layers, so the optimiser will ignore them during backward
        for param in self.model_conv.parameters():
            param.requires_grad = False

        # we replace the last layer with a freshly initialised one, targeting the correct number of outputs,
        # which will be optimised
        num_ftrs = self.model_conv.fc.in_features
        self.model_conv.fc = nn.Linear(num_ftrs, num_outputs)
        
    def forward(self, x):
        return self.model_conv(x)
        

## Train & Validate functions

We have seen train and test/validate functions a few times now, where we just loop over the datasets, calculate losses and metrics and return.

This time train & Validate are called once per epoch and return f1_score over all examples

In [None]:
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score

def train(model, optimizer, train_loader, device=None):
    device = device or torch.device("cpu")

    #
    # Accumulate labels and predicitons manually over all batches
    #
    y_all = []
    y_class = []
    
    # iterate over all training batches
    model.train()
    for X, y in tqdm(train_loader, desc="Training..."):

        # send data to the gpu
        X, y = X.to(device), y.to(device)
        
        # zero gradients from last step
        optimizer.zero_grad()
        
        # run forward pass
        y_pred = model(X)
        
        # compute the loss
        loss = F.nll_loss(y_pred, y)
        
        # backpropagate
        loss.backward()
        
        # step the optimiser
        optimizer.step()
        
        # keep hold of target and compute y_class for metrics
        y_all.extend(y.tolist())   
        _, c = torch.max(y_pred, 1)
        y_class.extend(c.tolist())
        
    #
    # Compute f1 on all examples
    #
    return f1_score(y_all, y_class, average='micro')

In [None]:
def validate(model, data_loader, device=None):
    device = device or torch.device("cpu")
    
    #
    # Accumulate labels and predicitons manually over all batches
    #
    y_all = []
    y_class = []
        
    model.eval()
    with torch.no_grad():
        for X, y in tqdm(data_loader, desc="Testing..."):
            X, y = X.to(device), y.to(device)
            
            y_pred = model(X).cpu()
            
            # keep hold of target and compute y_class for metrics
            y_all.extend(y.tolist())   
            _, c = torch.max(y_pred, 1)
            y_class.extend(c.tolist())
            
    #
    # Compute f1 on all examples
    #

    return f1_score(y_all, y_class, average='micro')

## Create the Trainable Class

We create a ray Trainable wrapper class as before.


In [None]:
from os import path

class FossilTrainable(tune.Trainable):
    
    def _setup(self, config):
        # detect if cuda is availalbe as ray will assign GPUs if available and configured
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.train_loader, self.test_loader = get_data_loaders(
            #
            # This path needs to be right for you local system
            #
            path.expanduser('~/dev/transform-2020-ray-wip/datasets/fossilnet/tvt_split/4'),
            batch_size=int(config.get("batch_size", 16)),
            augment_flip=config.get("augment_flip", True),
            use_grayscale=config.get("use_grayscale", True)
        )
        
        #
        # Create the network
        #
        self.model = FossilResNet(num_outputs=4).to(self.device)
        
        #
        # Setup the optimiser
        #
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=config.get("lr", 0.01),
            weight_decay=config.get("weight_decay", 1e-5)
        )

        #
        # Use Trainable state to keep track of best scores
        #
        self._best_train_f1_score = 0.
        self._best_val_f1_score = 0.
        
    def _train(self):
        train_f1_score = train(self.model,
                               self.optimizer,
                               self.train_loader,
                               device=self.device)
        
        val_f1_score = validate(self.model,
                                self.test_loader,
                                self.device)
        
        if (train_f1_score > self._best_train_f1_score):
            self._best_train_f1_score = train_f1_score
        
        if (val_f1_score > self._best_val_f1_score):
            self._best_val_f1_score = val_f1_score
        
        #
        # Really we should return losses here too and we
        # are free to extend the return dict with anything we want to track
        #
        return dict(
            train_f1_score=train_f1_score,
            best_train_f1_score = self._best_train_f1_score,
            val_f1_score=val_f1_score,
            best_val_f1_score=self._best_val_f1_score
        )

    def _save(self, checkpoint_dir):
        checkpoint_path = path.join(checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return checkpoint_path
    
    def _restore(self, checkpoint_path):
        self.model.load_state_dict(torch.load(checkpoint_path))

### Check for Cuda

In [None]:
print('CUDA Available') if torch.cuda.is_available() else print('CPU Only')

### Start Ray

In [None]:
ray.shutdown()
ray.init(num_cpus=3, num_gpus=1, include_webui=True)

In [None]:
#
# Potential bug in tensorboard logging for tune.choice() working around that here
# 
def _choice(items):
    return items[np.random.randint(len(items))]


#
# Setup our Parameter Optimisation Space
#
config = dict(
    lr=tune.uniform(1e-3, 1e-1),
    weight_decay=tune.loguniform(1e-7, 1e-3),
    batch_size=tune.sample_from(lambda x: _choice([8, 16, 32, 64])),
    augment_flip=tune.sample_from(lambda x: _choice([True,False])),
    use_grayscale=tune.sample_from(lambda x:_choice([True,False]))
)

#
# Before commiting to a huge run, run training 1 iteration with N (10?) samples to dry run through different
# hyperparameter options
#
# Then set this to False and tune for real
#
smoke_test = True

analysis = tune.run(
    FossilTrainable,
    local_dir="~/ray_results/torch_fossilnet",
    resources_per_trial={
        "cpu": 3,
        "gpu": 1
    },
    num_samples=10 if smoke_test else 50,
    checkpoint_at_end=True,
    keep_checkpoints_num=5,
    checkpoint_freq=10,
    stop={
        "train_f1_score": 0.95,
        "training_iteration": 1 if smoke_test else 200,
    },
    config=config
)

print("Best config is:", analysis.get_best_config(metric="best_val_f1_score"))

In [None]:
import ray
ray.shutdown()


## Next

Head over to EC2 and check results of a longer run