<a href="https://colab.research.google.com/github/benihime91/pytorch_retinanet/blob/master/001_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Introduction**

In this notebook, we implement [PyTorch RetinaNet](https://github.com/benihime91) for custom dataset. 

We will take the following steps to implement PyTorch RetinaNet on our custom data:
* Install PyTorch RetinaNet along with required dependencies.
* Download Custom Dataset.
* Write Training Configuation yaml file .
* Train  Detection Model .
* Use Trained PyTorch RetinaNet Object Detection For Inference on Test Images.


### **Setting up Colab :**

In [None]:
# What GPU do we have ?
!nvidia-smi

In [None]:
# Ensure colab doesn't disconnect
%%javascript
function ClickConnect(){
console.log("Working");
document.querySelector("colab-toolbar-button#connect").click()
}setInterval(ClickConnect,60000)

# **Install Pytorch Retinanet and Dependencies** :

In [None]:
# Clone the RetinaNet Repo
!git clone https://github.com/benihime91/pytorch_retinanet.git
# install dependencies
!pip install pytorch-lightning omegaconf wandb --quiet
!pip install git+https://github.com/albumentations-team/albumentations --quiet
!echo "[   OK   ] Installed all depedencies "

In [None]:
#Update sys path to enclude the pytorch RetinaNet modules
import warnings
import os
import sys

warnings.filterwarnings('ignore')
sys.path.append("/content/pytorch_retinanet/")
%load_ext autoreload
%autoreload 2
%matplotlib inline

!echo "[   OK   ] Setup Done "

# **Prepare Pytorch Retinanet Object Detection Training Data**

We will use the **[BCCD Dataset](https://public.roboflow.com/object-detection/bccd)** from RoboFlow. There are 364 images across three classes.


To train on the custom dataset the data needs to be in either **csv** or **pascal-voc** format . Roboflow makes it easier to generate the datasets. We can directly download the datsets in required format.

We will download the dataset in **Pascal-VOC** format and then use in-built methods available in PyTorch Retinanet to convert our data into **csv** format.

In [None]:
#Downloading data from Roboflow
#UPDATE THIS LINK - get our data from Roboflow
%cd /content
!curl -L "https://public.roboflow.com/ds/vYOavmoHyM?key=byxE8x2t11" > roboflow.zip; unzip -qq roboflow.zip; rm roboflow.zip

In [None]:
#Set up paths 

#Path to where the Images are stored
TRAIN_IMAGE_PATH = "/content/train/"
VALID_IMAGE_PATH = "/content/valid/"
TEST_IMAGE_PATH  = "/content/test/"
#Path to where annotations are stored
TRAIN_ANNOT_PATH = "/content/train/"
VALID_ANNOT_PATH = "/content/valid/"
TEST_ANNOT_PATH  = "/content/test/"

In [None]:
import pandas as pd
from utils.pascal import convert_annotations_to_df
from PIL import Image
import cv2
import numpy as np

pd.set_option("display.max_colwidth", None)
np.random.seed(123)

## **Generate csv file from XML Annotations:**

In [None]:
#convert xml files to pandas DataFrames
train_df = convert_annotations_to_df(TRAIN_ANNOT_PATH, TRAIN_IMAGE_PATH, image_set="train")
valid_df = convert_annotations_to_df(VALID_ANNOT_PATH, VALID_IMAGE_PATH, image_set="test")
test_df  = convert_annotations_to_df(TEST_ANNOT_PATH, TEST_IMAGE_PATH, image_set="test")

!echo "[   DONE  ] DataFrame(s) Generated !"


def remove_invalid_annots(df):
    """
    Removes annotaitons where xmax, ymax < xmin,ymin
    from the given dataframe
    """
    df = df[df.xmax > df.xmin]
    df = df[df.ymax > df.ymin]
    df.reset_index(inplace=True, drop=True)
    return df

# removing annotations that are not valid annotations
!echo "[   INFO  ] Removing invalid annotations ... "
train_df = remove_invalid_annots(train_df)
valid_df = remove_invalid_annots(valid_df)
test_df = remove_invalid_annots(test_df)
!echo "[   DONE  ] OK !"

### **CSV Files are as follows :**

In [None]:
train_df.head()

In [None]:
valid_df.head()

In [None]:
test_df.head()

In [None]:
#Paths where to save the generated dataframes
TRAIN_CSV = "/content/train_data.csv"
VALID_CSV = "/content/valid_data.csv"
TEST_CSV = "/content/test_data.csv"

# #Save the dataframes to memory
train_df.to_csv(TRAIN_CSV, index=False)
valid_df.to_csv(VALID_CSV, index=False)
test_df.to_csv(TEST_CSV, index=False)

In [None]:
train_df = pd.read_csv(TRAIN_CSV)
valid_df = pd.read_csv(VALID_CSV)
test_df=pd.read_csv(TEST_CSV)

# **View Images from the Dataset** :

We can use the fn `visualize_boxes_and_labels_on_image_array` from the RetinaNet repo to visualize images and the bounding boxes over them. To use this function we need to first create a Label Map, which is a list that contains all the classes at index corresponding to the integer labels .

##**Let's now generate the Label Map which is used for visualization:**

In [None]:
from utils.pascal import generate_pascal_category_names

LABEL_MAP = generate_pascal_category_names(train_df)
LABEL_MAP

##**Plot images with Bounding boxes over them**:

In [None]:
from utils import visualize_boxes_and_labels_on_image_array
import matplotlib.pyplot as plt

In [None]:
def grab_bbs_(dataframe, index:int):
    """
    Takes in a Pandas DataFrame and a index number
    Returns filename of the image and all the bounding boxes and class_labels
    corresponding the image that is at the given index
    """
    assert index <= len(dataframe), f"[  ERROR  ] Invalid index for dataframe with len: {len(dataframe)}"
    fname = dataframe.filename[index]
    locs = dataframe.loc[dataframe.filename == fname]
    bbs  = locs[["xmin", "ymin", "xmax", "ymax"]].values
    cls  = locs["labels"].values
    return fname, bbs, cls

### **Image from Train Data:**

In [None]:
#grab image, boxe and class target
image, boxes, clas = grab_bbs_(train_df, index=0)

#load and normalize the image
image = Image.open(image)
image = np.array(image) / 255.

#draw boxes over the image
image = visualize_boxes_and_labels_on_image_array(
                image=image,
                boxes=boxes, 
                scores=None, 
                classes=clas,
                label_map=LABEL_MAP,
)

# # images are very large so-takes along time to 
# # render . Therefore, plot the images using 
# # matplotlib with a smaller figure size
# plt.figure(figsize=(15,15))
# plt.axis("off")
# plt.imshow(image);
image

### **Image from Validation data:**

In [None]:
#grab image, boxe and class target
image, boxes, clas = grab_bbs_(valid_df, index=50)

#load and normalize the image
image = Image.open(image)
image = np.array(image) / 255.

#draw boxes over the image
image = visualize_boxes_and_labels_on_image_array(
                image=image,
                boxes=boxes, 
                scores=None, 
                classes=clas,
                label_map=LABEL_MAP,
)

# # images are very large so-takes along time to 
# # render . Therefore, plot the images using 
# # matplotlib with a smaller figure size
# plt.figure(figsize=(15,15))
# plt.axis("off")
# plt.imshow(image);
image

###**Image from Test Data:**

In [None]:
#grab image, boxe and class target
image, boxes, clas = grab_bbs_(test_df, index=100)

#load and normalize the image
image = Image.open(image)
image = np.array(image) / 255.

#draw boxes over the image
image = visualize_boxes_and_labels_on_image_array(
                image=image,
                boxes=boxes, 
                scores=None, 
                classes=clas,
                label_map=LABEL_MAP,
)

# # images are very large so-takes along time to 
# # render . Therefore, plot the images using 
# # matplotlib with a smaller figure size
# plt.figure(figsize=(15,15))
# plt.axis("off")
# plt.imshow(image);
image

#**Configure Custom PyTorch RetianNet Object Detection Training Configuration** :


In [None]:
!cat /content/pytorch_retinanet/hparams.yaml

In [None]:
from omegaconf import OmegaConf

#load in the hparams.ymal file using Omegaconf
hparams = OmegaConf.load("/content/pytorch_retinanet/hparams.yaml")

# ========================================================================= #
# MODIFICATION OF THE CONFIG FILE TO FIX PATHS AND DATSET-ARGUEMENTS :
# ========================================================================= #
#specify kind of data to use
hparams.dataset.kind = "csv"
#Paths to the csv files
hparams.dataset.trn_paths   = TRAIN_CSV
hparams.dataset.valid_paths = VALID_CSV
hparams.dataset.test_paths  = TEST_CSV
# number of classes in dataset excluding the 
# "__background__" class
hparams.model.num_classes = len(LABEL_MAP) - 1

#Changing optimizer paramters, 
hparams.optimizer = {
    "class_name": "torch.optim.SGD", 
    "params": {
        "lr": 0.003,
        "momentum": 0.9,
        "weight_decay" : 0.001,
        },
    }

print(OmegaConf.to_yaml(hparams))

#**Instantiate Lightning-Module and Lightning-Trainer**

In [None]:
import time
import pytorch_lightning as pl

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import *
from pytorch_lightning.callbacks import *
from model import RetinaNetModel, LogCallback

# seed so that results are reproducible
pl.seed_everything(123)

## **Load in the Lighning-Trainer :**

In [None]:
# ============================================================ #
# INSTANTIATE LIGHTNING-TRAINER with CALLBACKS :
# ============================================================ #
# NOTE: 
# For a list of whole trainer specific arguments see : 
# https://pytorch-lightning.readthedocs.io/en/latest/trainer.html

#Logger
#check: https://pytorch-lightning.readthedocs.io/en/latest/loggers.html,for all supported loggers
#we will use wandb logger

#wandb API-KEY
!wandb login "a74f67fd5fae293e301ea8b6710ee0241f595a63"
#Wandb project name
PNAME = "bloood-cell-roboflow"
LOGGER = WandbLogger(name=f"{time.time()}", anonymous=True, project=PNAME)

#Learning-rate Logger to log the learning-rate to the logger
LR_LOGGER = LearningRateLogger(logging_interval="step")

#Model Checkpoint Callback, this callback will save checkpoints 
#each time our val loss decreases
fname =f"/content/checkpoints/"
os.makedirs(fname, exist_ok=True)
CHECKPOINT_CALLBACK = ModelCheckpoint(fname, mode="min", monitor="val_loss", save_top_k=3,)

#callback for early-stopping
EARLY_STOPPING_CALLBACK = EarlyStopping(mode="min", monitor="val_loss", patience=20, verbose=True)

#instantiate LightningTrainer
trainer = Trainer(precision=16, 
                  max_epochs=100,
                  gpus=1, 
                  logger=[LOGGER],
                  early_stop_callback=EARLY_STOPPING_CALLBACK, 
                  checkpoint_callback=CHECKPOINT_CALLBACK,
                  callbacks=[LogCallback(), LR_LOGGER], 
                  weights_summary=None,
                  terminate_on_nan=True, 
                  benchmark=True,
                  );

##**Load in the Lighning-Module using the hparams file modified above & Start Train :**

In [None]:
# Instantiate lightning-module
litModel = RetinaNetModel(hparams=hparams)

In [None]:
!echo "[   START   ] START TRAINING ... "

#TIP: if using wandb logger go to wandb dashboard to view live logs
trainer.fit(litModel)

!echo "[    END    ] TRAINING COMPLETE ! "

#**Evaluating the trained-model using COCO-API Metrics** : 

In [None]:
# Evaluations results on the test/ validation dataset(if test dataset is not given)
# using COCO API
!echo "[   START   ] START EVALUATION OF MODEL ON TEST IMAGES USING COCO-API ... "
trainer.test(litModel)
!echo "[    END    ] DONE EVALUATING MODEL ON TEST IMAGES USING COCO-API ! "

# **Export the model weights** :

In [None]:
import torch

PATH = f"/content/trained_weights.pth"
torch.save(litModel.model.state_dict(), PATH)

#**Load PyTorch Model from the trained Lightning-Module weights :**

In [None]:
from retinanet import Retinanet

state_dict = torch.load(PATH)

MODEL = Retinanet(num_classes=hparams.model.num_classes, backbone_kind=hparams.model.backbone_kind)
MODEL.load_state_dict(state_dict)
MODEL.to("cuda:0");

# **Run Inference on Test Images with Custom PyTroch Object Detector**

In [None]:
from PIL import Image
import numpy as np
import cv2

import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils import visualize_boxes_and_labels_on_image_array

@torch.no_grad()
def get_preds(path, threshold=0.6,):
    """
    Generates predictions on the given image from the given path.
    """
    image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
    
    INFER_TRANSFORMS = A.Compose([A.ToFloat(max_value=255.0, always_apply=True),
                                  ToTensorV2(always_apply=True)
                                  ])
    
    TENSOR_IMAGE = INFER_TRANSFORMS(image=image)["image"].to("cuda:0")
    PREDICTIONS = MODEL.predict([TENSOR_IMAGE])
    #print(PREDICTIONS[0])
    return PREDICTIONS[0]

def detect(image_path, threshold=0.6):
    """
    Generate detections on the image that is present in 
    the given image path

    Args:
        image_path: Path to the input Image
        threshold: Score threshold to filter predictions
        nms_threshold: NMS threshold

    Returns: a PIL image containg the original Image and
             bounding boxes draw over it.
    """
    
    # visualize_boxes_and_labels_on_image_array function
    # expects the pixels values of the image to be in 
    # range [0,1] so be divide the loaded image by 255.0
    # to noramlize the co-ordinates
    # load the image as numpy array
    image = Image.open(image_path)
    image = np.array(image) / 255.
    # Generate predictions for the given image
    preds = get_preds(image_path, threshold,)
    # print(preds)
    # Filter predictions
    boxes, labels, scores = preds["boxes"], preds["labels"], preds["scores"]
    mask = scores > threshold
    boxes = boxes[mask]
    labels = labels[mask]
    scores = scores[mask]
    return boxes.cpu().numpy(), labels.cpu().numpy(), scores.cpu().numpy()

In [None]:
IMAGE, REAL_BOXES, REAL_LABELS = grab_bbs_(test_df, index=200)
IMAGE = Image.open(IMAGE)
IMAGE = np.array(IMAGE) / 255.

#draw boxes over the image
REAL_IMAGE = visualize_boxes_and_labels_on_image_array(
                image=IMAGE,
                boxes=REAL_BOXES, 
                scores=None, 
                classes=REAL_LABELS,
                label_map=LABEL_MAP,
)

REAL_IMAGE

PATH = test_df.filename[200]
THRESHOLD = 0.7

PRED_BOXES, PRED_LABELS, PRED_SCORES = detect(PATH, THRESHOLD,)

PRED_IMAGE = visualize_boxes_and_labels_on_image_array(
                image=IMAGE, 
                boxes=PRED_BOXES, 
                scores=PRED_SCORES,
                classes= PRED_LABELS,
                label_map=LABEL_MAP,
                )

In [None]:
!echo "[ INFERENCE ] ORIGINAL IMAGE : "

REAL_IMAGE

In [None]:
!echo "[ INFERENCE ] PREDICTIONS : "

PRED_IMAGE