<a href="https://colab.research.google.com/github/benihime91/retinanet_pet_detector/blob/master/nbs/04_template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Load Google Drive:**

In [None]:
# Run this cell to mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

**setup**

In [None]:
# What GPU do we have ?
! nvidia-smi

In [None]:
# Ensure colab doesn't disconnect
%%javascript
function ClickConnect(){
console.log("Working");
document.querySelector("colab-toolbar-button#connect").click()
}setInterval(ClickConnect,60000)

In [None]:
# install dependencies
! pip install pytorch-lightning wandb --quiet
! pip install git+https://github.com/albumentations-team/albumentations --quiet

**Before running this cell make sure that the data is present in `GoogleDrive` at `Data/oxford-iiit-pet.tgz` .  
The Data can be downloaded at [here](https://www.robots.ox.ac.uk/~vgg/data/pets/).  
Running this cell will extract the data and save it to the `/content/oxford-iiit-pet` .**

In [None]:
# unzip the data assuming the `The Oxford-IIIT Pet Dataset` is present as /content/drive/My\ Drive/Data/oxford-iiit-pet.tgz
# to download the dataset go to this link:
# https://www.robots.ox.ac.uk/~vgg/data/pets/
!tar xf /content/drive/My\ Drive/Data/oxford-iiit-pet.tgz -C /content/ 

**Clone the retinanet repo:**

In [None]:
# Clone the RetinaNet Repo:
! git clone https://github.com/benihime91/pytorch_retinanet.git

**Instantiate wandb :**  
**If using `wandb` to track logs uncomment the cell and run it.**

In [None]:
# use wandb to track experiments : Comment this if not using wandb logger
! wanbd login # a74f67fd5fae293e301ea8b6710ee0241f595a63

**required imports:**

In [None]:
import warnings
import os
import sys
from typing import *
import time
import argparse

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import pandas as pd
import numpy as np
import re

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

# PyTorch Imports
import torch
from torch import nn
from torch.optim import *
from torch.utils.data import Dataset, DataLoader

# Lightning import
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (EarlyStopping, ModelCheckpoint, LearningRateLogger,)

# Import some usefull utilities from the RetinaNet Repo:
from pytorch_retinanet.src.models import Retinanet
from pytorch_retinanet.src.utils.eval_utils import CocoEvaluator
from pytorch_retinanet.src.utils.eval_utils import get_coco_api_from_dataset
from pytorch_retinanet.src.utils.general_utils import collate_fn, xml_to_csv
from pytorch_retinanet import DetectionDataset, Visualizer

pl.seed_everything(42) # change this seed number to get different results
pd.set_option("display.max_colwidth", None)

**Preprocess the data:** 

In [None]:
annot_dir = '/content/oxford-iiit-pet/annotations/xmls' # folder where the annotations are stored
img_dir = '/content/oxford-iiit-pet/images' # folder where the training Images are stored

# Create pandas DataFrame from the xmls
df = xml_to_csv(annot_dir)
df.head(5)

In [None]:
# regex to extract the class names from the filenames of the csv file
pat = r"/([^/]+)_\d+.jpg$"
pat = re.compile(pat)


def get_classes(df : pd.DataFrame) -> pd.DataFrame:
    "creates labels for the Images from given filenames"
    # Extract the label
    df["class"] = [pat.search(fname).group(1).lower() for fname in df.filename]
    return df


def preprare_data(img_dir: str, data: Union[str, pd.DataFrame]) -> Union[pd.DataFrame, LabelEncoder]:
    "preprocess the given data and returns a pandas dataframe"
    if isinstance(data, str):
        df = pd.read_csv(data)
    else:
        df = data
    # modify filename to point to the image path
    df["filename"] = [os.path.join(img_dir, idx) for idx in df.filename.values]
    # get labels from the filename
    df = get_classes(df)
    # encode the labels: convert labels to integers
    le = LabelEncoder()
    int_cls = le.fit(df["class"].unique())
    df["labels"] = le.transform(df["class"])
    return df, le


def create_label_dict(dataframe: pd.DataFrame, encoder: LabelEncoder) -> Dict[int, str]:
    "Creates a label dictionary from the given dataframe `labels`"
    names = list(dataframe.labels.unique())
    names.sort()
    # Create the label dictionary
    label_dict = {idx: clas for idx, clas in zip(
        names, list(encoder.inverse_transform(names)))}
    return label_dict


In [None]:
df , le = preprare_data(img_dir, df)
df.head()

In [None]:
# Grab the label dictionary
label_dict = create_label_dict(df, le)
label_dict

**utility function to display image with bounding boxes:**

In [None]:
# Instantiate the visualizer
viz = Visualizer(class_names=label_dict)

# Function to display a random Image from the dataset
def display_random_image(dataframe: pd.DataFrame) -> None:
    "displays a radom Image from given dataframe"
    n = np.random.randint(0, len(dataframe))
    fname = df["filename"][n]
    boxes = df.loc[df["filename"] == fname][["xmin", "ymin", "xmax", "ymax"]].values
    labels = df.loc[df["filename"] == fname]["labels"].values
    viz.draw_bboxes(fname, boxes=boxes, classes=labels, figsize=(10, 10))

**Display image from the data:**

In [None]:
# Display some random Images from the Dataset for sanity check
display_random_image(df)

In [None]:
# Helper function to split a given DataFrame
def create_splits(df: pd.DataFrame, split_sz: float = 0.3) -> Tuple[pd.DataFrame, pd.DataFrame]:
    "Split given DataFrame into `split_sz`"
    
    # Grab the Unique Image Idxs from the Filename
    unique_ids = list(df.filename.unique())
    # Split the Unique Image Idxs into Train & valid Datasets
    train_ids, val_ids = train_test_split(
        unique_ids, shuffle=True, random_state=42, test_size=split_sz
    )

    # Create Splits on the DataFrame
    df["split"] = 0

    for i, idx in enumerate(df.filename.values):
        if idx in set(train_ids):
            df["split"][i] = "train"
        elif idx in set(val_ids):
            df["split"][i] = "val"

    # Split the DataFrame into Train and Valid DataFrames
    df_trn, df_val = df.loc[df["split"] == "train"], df.loc[df["split"] == "val"]

    df_trn, df_val = df_trn.reset_index(drop=True), df_val.reset_index(drop=True)

    # drop the extra redundent column
    df_trn.drop(columns=["split"], inplace=True)
    df_val.drop(columns=["split"], inplace=True)

    return df_trn, df_val

**Create spilts in the DataFrame to get `train`, `validation` & `test` sets**:

In [None]:
# Create train and validation splits from the dataframe
df_trn, df_val = create_splits(df, split_sz=0.3)
df_val, df_test = create_splits(df_val, split_sz=0.5)

print('Num examples in train dataset :', len(df_trn.filename.unique()))
print('Num examples in train dataset :', len(df_val.filename.unique()))
print('Num examples in train dataset :', len(df_test.filename.unique()))

In [None]:
# Peek at the train dataset for sanity check
df_trn.head(3)

In [None]:
# Peek at the validation dataset for sanity check
df_val.head(3)

In [None]:
df_test.head(3)

**sanity check**:

In [None]:
# display random image from the train, valid 
# & test datasets for sanity check
display_random_image(dataframe=df_trn)

In [None]:
display_random_image(dataframe=df_val)

In [None]:
display_random_image(dataframe=df_test)

**Instantiate image transformations:**

We use `albumentations` for image transformations. Check [albumentations docs](https://albumentations.ai/docs/examples/example_bboxes/) for API reference & list of transformations

In [None]:
def get_tfms() -> Dict[str, A.Compose]:
    "Returns a dictionary contatining transformations for train & valid/test datasets"
    
    # train transformations : [Modify this to add Transformations to train dataset] 
    trn_tfms = [
        A.HorizontalFlip(p=0.5),
        A.ToGray(p=0.2),
        A.RandomBrightnessContrast(p=0.5),
        A.CLAHE(p=0.5),
        A.ToFloat(max_value=255.0, always_apply=True),
        ToTensorV2(always_apply=True),
    ]

    # validation transformations : [Transformations to the validation dataset]
    val_tfms = [
        A.ToFloat(max_value=255.0, always_apply=True),
        ToTensorV2(always_apply=True),
    ]

    # test transformations : [Transformations to the test dataset]
    tst_tfms = [
        A.ToFloat(max_value=255.0, always_apply=True),
        ToTensorV2(always_apply=True),                
    ]

    # transforms dictionary :
    transforms = {
        "train": A.Compose(trn_tfms, bbox_params=A.BboxParams(format="pascal_voc", label_fields=["class_labels"]),),
        "valid": A.Compose(val_tfms, bbox_params=A.BboxParams(format="pascal_voc", label_fields=["class_labels"]),),
        "test" : A.Compose(tst_tfms, bbox_params=A.BboxParams(format="pascal_voc", label_fields=["class_labels"]),),
    }
    
    return transforms

**Create `pl.LightningModule` instance :**

In [None]:
# Create pl.LightningModule instance

# ========
# INFO :
# ========
# The hparams config file should contain the following :
# ========
# 1. optimizer : torch.optim.Optimizer -> Optimizer for the model
# 2. scheduler : Union[torch.optim.lr_scheduler, None] -> Scheduler for the Optimizer

# 3. trn_df    : pandas.DataFrame -> train dataframe
# 4. trn_tfms  : A.Compose -> albumentation transformations to apply to the training dataset
# 5. trn_bs    : int -> train batch_size

# 6. val_df    : pandas.DataFrame -> validation dataframe
# 7. val_tfms  : A.Compose -> albumentation transformations to apply to the validation dataset
# 8. val_bs    : int -> validation batch_size

# 9.  test_df  : pandas.DataFrame -> test dataframe
# 10. test_tfms: A.Compose -> albumentation transformations to apply to the test dataset
# 11. test_bs  : int -> test batch_size

# 12. iou_types: List -> for coco evaluation set it to ["bbox"].

class DetectionModel(pl.LightningModule):
    def __init__(self,model: nn.Module, hparams: argparse.Namespace) -> None:
        super(DetectionModel, self).__init__()
        self.model = model
        self.hparams = hparams

    @property
    def num_batches(self) -> List:
        "returns a list containing the number of batches in train, val & test dataloaders"
        return [len(self.train_dataloader()), len(self.val_dataloader()), len(self.test_dataloader())]

    ######## Configure Optimizer & Schedulers #############
    def configure_optimizers(self, *args, **kwargs):
        "instatiate optimizer & scheduler(s)" 
        # optimizer
        optimizer = self.hparams.optimizer
        # scheduler
        scheduler = self.hparams.scheduler
        
        if scheduler is not None:
            return [optimizer], [scheduler]
        else:
            return [optimizer]

    ############# Forward Pass of the Model ##############
    def forward(self, xb, *args, **kwargs):
        "forward step"
        return self.model(xb)

    ############# Train ##############
    def train_dataloader(self, *args, **kwargs):
        "instantiate train dataloader" 
        trn_ds = DetectionDataset(self.hparams.trn_df, self.hparams.trn_tfms)
        trn_dl = DataLoader(
            train_ds, batch_size=self.hparams.trn_bs, shuffle=True, collate_fn=collate_fn, pin_memory=True,
            )
        
        return trn_dl

    def training_step(self, batch, batch_idx, *args, **kwargs):
        "one training step"
        images, targets, _ = batch
        targets = [{k: v for k, v in t.items()} for t in targets]
        loss_dict = self.model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        return {"loss": losses, "log": loss_dict, "progress_bar": loss_dict}

    ############# Validation ##############
    def val_dataloader(self, *args, **kwargs):
        "instatiate validation dataloader"
        val_ds = DetectionDataset(self.hparams.val_df, self.hparams.val_tfms)
        loader = DataLoader(val_ds, batch_size=self.hparams.val_bs, shuffle=False, collate_fn=collate_fn,)
        # instantiate coco_api to track metrics
        coco = get_coco_api_from_dataset(loader.dataset)
        self.coco_evaluator = CocoEvaluator(coco, self.hparams.iou_types)
        return loader

    def validation_step(self, batch, batch_idx, *args, **kwargs):
        "one validation step"
        images, targets, _ = batch
        targets = [{k: v for k, v in t.items()} for t in targets]
        outputs = self.model(images, targets)
        res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
        self.coco_evaluator.update(res)
        return {}

    def validation_epoch_end(self, outputs, *args, **kwargs):
        self.coco_evaluator.accumulate()
        self.coco_evaluator.summarize()
        metric = self.coco_evaluator.coco_eval["bbox"].stats[0]
        metric = torch.as_tensor(metric)
        logs = {"mAP": metric}
        return {"mAP": metric, "log": logs, "progress_bar": logs,}
    
    ############# Test ##############
    def test_dataloader(self, *args, **kwargs):
        "instatiate validation dataloader"
        test_ds = DetectionDataset(self.hparams.test_df, self.hparams.test_tfms)
        loader = DataLoader(test_ds, batch_size=self.hparams.test_bs, shuffle=False, collate_fn=collate_fn,)
        # instantiate coco_api to track metrics
        coco = get_coco_api_from_dataset(loader.dataset)
        self.test_evaluator = CocoEvaluator(coco, self.hparams.iou_types)
        return loader

    def test_step(self, batch, batch_idx, *args, **kwargs):
        "one test step"
        images, targets, _ = batch
        targets = [{k: v for k, v in t.items()} for t in targets]
        outputs = self.model(images, targets)
        res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
        self.test_evaluator.update(res)
        return {}
    
    def test_epoch_end(self, outputs, *args, **kwargs):
        self.test_evaluator.accumulate()
        self.test_evaluator.summarize()
        metric = self.test_evaluator.coco_eval["bbox"].stats[0]
        metric = torch.as_tensor(metric)
        logs = {"mAP": metric}
        return {"mAP": metric, "log": logs, "progress_bar": logs,}

**Specify Configs: :**

**Configs for `DetectionModel` :**


In [None]:
# Specify Patametrs for the DetectionModel:

# load in the RetinaNet model
NUM_CLASSES = 37  # Oxford-IIIT Pets Dataset has 37 classes
BACKBONE = 'resnet18' # backbone for RetinaNet Model
model = Retinanet(num_classes=NUM_CLASSES, backbone_kind=BACKBONE, **kwargs)

# instantiate optimizer
LR = 1e-03 # learning rate for Optimizer
MOMENTUM = 0.9 # Momentum for the Optimizer
WEIGHT_DECAY = 0.0001 # Weight Decay for Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = SGD(params, LR, weight_decay=WEIGHT_DECAY, momentum=MOMENTUM)

# Instantiate scheduler
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
# convert scheduler to lightning format
INTERVAL = "step" # scheduler interval wether after each 'step' for each 'epoch'
scheduler = {"scheduler": scheduler, "interval": INTERVAL , "frequency": 1,}

# Instantiate Transforms:
transforms = get_tfms()

# Train dataset Parametrs:
trn_df = df_trn # dataframe
trn_tfms = transforms['train'] # transformations for the train dataset
trn_bs = 32 # batch_size

# Valid dataset parametrs:
val_df = df_val # dataframe
val_tfms = transforms['valid'] # transformations for the validation dataset
val_bs = 32 # batch_size

# Test dataset parametrs:
test_df = df_test # dataframe
test_tfms = transforms['test'] # transformations for the test dataset
test_bs = 32 # batch_size

# set iou types:
iou_types = ['bbox']


# Create arguments:
hparams = {
    'optimizer': optimizer,
    'scheduler': scheduler,
    'trn_df' : trn_df,
    'trn_tfms': trn_tfms,
    'trn_bs': trn_bs,
    'val_df' : trn_df,
    'val_tfms': trn_tfms,
    'val_bs': trn_bs,
    'test_df' : trn_df,
    'test_tfms': trn_tfms,
    'test_bs': trn_bs,
    'iou_types': iou_types,
}

# Convert dictionary to args
hparams= argparse.Namespace(**hparams)

In [None]:
# sanity check
print(hparams)

**Configs for `LightningTrainer`:**

In [None]:
# Create configs for lighntning trainer

# Wandb logger: assuming wandb is set-up [Optional]
wb_name = f"{time.strftime('%d-%m-||-%I.%M.%S%-p')}" # change the run name here
wb_p = "retinanet-oxford-pets" # change the project name here
wb_logger = WandbLogger(name=wb_name, project=wb_p, anonymous="allow",)

# learning_rate logger:
lr_logger = LearningRateLogger(logging_interval="step")

# set callbacks & loggers:
logger=[wb_logger]
callbacks=[lr_logger]

# checkpoint callback
fname = "/content/drive/My Drive/pascal_checkpoints" 
os.makedirs(fname, exist_ok=True)
checkpoint_callback = ModelCheckpoint(fname, mode="max", monitor="mAP", save_top_k=1,)

# early stopping callback
early_stop_callback = EarlyStopping(mode="max", monitor="mAP", patience=5,)

check_val_every_n_epoch=5 # Validaiton Check Interval
gpus=1  # gpus to use
precision=16 # precision
max_epochs=NUM_EPOCHS # Total number of Epochs

# Cconvert trainer flags into a dictionary
trainer_config = {
    'logger': logger,
    'callbacks': callbacks,
    'checkpoint_callback': checkpoint_callback,
    'early_stop_callback' : early_stop_callback,
    'gpus': gpus,
    'precision': precision,
    'max_epochs': max_epochs,
    'check_val_every_n_epoch': check_val_every_n_epoch,
}

# Convert dictionary to args
trainer_config= argparse.Namespace(**trainer_config)

In [None]:
print(trainer_config)

**Grab the model & the trainer:**

In [None]:
retinanet = DetectionModel(model, hparams)

trainer = pl.Trainer(trainer_config)

**Train model:**

In [None]:
trainer.fit(retinanet)

**Evaluate:**

In [None]:
# Test model on the Test DataLoader
# NB: Best weights are automatically loaded
trainer.test()

**save trained weights**

In [None]:
fname = '/content/drive/My Drive/resnet18-pets-ver0.0.1.pth'
torch.save(retinanet.model.state_dict(), fname, _use_new_zipfile_serialization=False)

**Finetune (Optional):** 

**set-up finetune parameters:**

In [None]:
# Set up new Parameters
LR = 1e-05
NUM_EPOCHS = 22

retinanet.model.requires_grad_(True)
params = [p for p in retinanet.model.parameters() if p.requires_grad]
# Instantiate Optimizer
optimizer = optim.AdamW(params, weight_decay=1e-02)

# instantiate scheduler
scheduler = {
    "scheduler": optim.lr_scheduler.OneCycleLR(optimizer, LR, epochs=NUM_EPOCHS, steps_per_epoch=len(conf_dict.train_dl)),
    "interval": "step",
    "frequency": 1,
    }


**Update the configuration:**

In [None]:
# Change the parameters of the conf_dict
conf_dict.optimizer = optimizer
conf_dict.scheduler = scheduler

**Update trainer & model:**

In [None]:
# TODO: make it more effective

# Reinstantiate model
retinanet_2 = DetectionModel(retinanet.model, conf_dict)
retinanet_2.model.load_state_dict(torch.load(fname))

# Reinstantiate trainer
trainer_2 = get_trainer(check_val_every_n_epoch=5, gpus=1, precision=16, gradient_clip_val=0.1, max_epochs=NUM_EPOCHS,)

**Train**:

In [None]:
trainer_2.fit(retinanet_2)

**Evaluate on test data**:

In [None]:
trainer_2.test()

**save model weights:**

In [None]:
fname = '/content/drive/My Drive/resnet18-pets-ver0.0.2.pth'
torch.save(retinanet.model.state_dict(), fname, _use_new_zipfile_serialization=False)

**Inference**: 

**Load in a `torch` model to do inference:**  

**Model weights can be loaded in 2 ways either load the weights trained above in that case set fname to be the `path` to where the `state_dict` is saved.**

In [None]:
# Instantiate Torch Model for Inference
m = Retinanet(num_classes=37, backbone_kind='resnet18')

# Load in the pretrained model weights from weights file
fname = None
m.load_state_dict(torch.load(fname))

**To load model weights from already trained weights that are available [here](https://github.com/benihime91/retinanet_pet_detector/releases).**  

**Uncomment the cell given below.**  
**Copy the `url` of the weights file and set the url to be the one copied.**
**If using `GPU` set device to be `gpu` else set it to be `cpu`**

In [None]:
# url = None
# device = 'gpu'

# # load model state_dict from url
# state_dict = torch.hub.load_state_dict_from_url(url, map_location=device)
# m.load_state_dict(state_dict)

**import & helper functions for inference:**

In [None]:
from google.colab import files
from torchvision.ops.boxes import batched_nms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
test_tfms = A.Compose([A.ToFloat(max_value=255.0, always_apply=True), ToTensorV2(always_apply=True),])

In [None]:
@torch.no_grad()
def get_preds(
    model: Union[nn.Module, pl.LightningModule],
    path: str,
    threshold: float,
    iou_threshold: float,
    device: torch.device,
) -> Tuple[List, List, List]:
    "Get predictions on image"
    
    model.to(device)

    # Load the imag
    img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
    # Process the image
    img = test_tfms(image=img)["image"]
    img = img.to(device)
    # Generate predictions
    model.eval()
    pred = model([img])

    # Gather the bbox, scores & labels from the preds
    pred_boxes = pred[0]["boxes"]  # Bounding boxes
    pred_class = pred[0]["labels"]  # predicted class labels
    pred_score = pred[0]["scores"]  # predicted scores
    # Get list of index with score greater than threshold.
    mask = pred_score > threshold
    # Filter predictions
    boxes = pred_boxes[mask]
    clas = pred_class[mask]
    scores = pred_score[mask]

    # do NMS
    keep_idxs = batched_nms(boxes, scores, clas, iou_threshold)
    boxes = list(boxes[keep_idxs].cpu().numpy())
    clas = list(clas[keep_idxs].cpu().numpy())
    scores = list(scores[keep_idxs].cpu().numpy())
    return boxes, clas, scores


def object_detection_api(
    model: Union[nn.Module, pl.LightningModule],
    device: torch.device,
    img_path: str = None,
    score_threshold: float = 0.5,
    iou_threshold: float = 0.2,
) -> None:
    "Draw bbox predictions on given image at img_pth"
    if img_path is None:
        uploaded = files.upload()
        img_path = list(uploaded.keys())[0]
    print("[INFO] Generating predictions ....")
    bb, cls, sc = get_preds(model, img_path, score_threshold, iou_threshold, device,)
    print("[INFO] Creating bbox on the image .... ")
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    viz.draw_bboxes(img, boxes=bb, classes=cls, scores=sc)

**inference on test images:**

In [None]:
idx = 10 # index of the test_image
object_detection_api(m, device=device, score_threshold=0.7, iou_threshold=0.2, img_path=df_test["filename"][idx],)

**inference on user given images:**

In [None]:
# Inference on User Images
object_detection_api(m, device=device, score_threshold=0.7, iou_threshold=0.2,)