<a href="https://colab.research.google.com/github/benihime91/pytorch_retinanet/blob/master/001_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# What GPU do we have ?
!nvidia-smi

In [None]:
# Ensure colab doesn't disconnect
%%javascript
function ClickConnect(){
console.log("Working");
document.querySelector("colab-toolbar-button#connect").click()
}setInterval(ClickConnect,60000)

## **Initial-Setup**:

In [None]:
# install dependencies
! pip install pytorch-lightning wandb omegaconf --quiet
! pip install git+https://github.com/albumentations-team/albumentations --quiet

In [None]:
# use wandb to track experiments : Comment this if not using wandb logger
! wandb login 'a74f67fd5fae293e301ea8b6710ee0241f595a63'

In [None]:
# mount google drive
from google.colab import drive
drive.mount("/content/drive")

In [None]:
# Grab the Data
! unzip -qq /content/drive/My\ Drive/Pascal\ 2007\ Data/pascal_voc_2007_test.zip
! unzip -qq /content/drive/My\ Drive/Pascal\ 2007\ Data/pascal_voc_2007_train_val.zip
# Clone the RetinaNet Repo
! git clone https://github.com/benihime91/pytorch_retinanet.git

In [None]:
import warnings
warnings.filterwarnings('ignore')
import os
os.chdir("/content/pytorch_retinanet")

%load_ext autoreload
%autoreload 2
%matplotlib inline

## **Data-Set Visulaization** :

In [None]:
# Utilities to load in the Pascal dataset
import pandas as pd
from utils.pascal import get_pascal, generate_pascal_category_names
from utils.pascal.pascal_transforms import compose_transforms
import albumentations as A

pd.set_option("display.max_colwidth", None)

# compute transoformations
tfms = compose_transforms()

# paths to the voc dataset images and annotations
test_ann_pth = "/content/pascal_voc_2007_test/Annotations/"
test_im_pth = "/content/pascal_voc_2007_test/Images/"

train_ann_pth = "/content/pascal_voc_2007_train_val/Annotations/"
train_im_pth = "/content/pascal_voc_2007_train_val/Images/"

# generate csv files for the train and test datasets
trn_ds  = get_pascal(train_ann_pth, train_im_pth, "train", transforms=tfms)
test_ds = get_pascal(test_ann_pth, test_im_pth,  "test",  transforms=tfms)

In [None]:
df = pd.read_csv("pascal_train.csv")
df.head()

In [None]:
PASCAL_INSTANCE_CATEGORY_NAMES = generate_pascal_category_names(df)
PASCAL_INSTANCE_CATEGORY_NAMES

In [None]:
from utils import visualize_boxes_and_labels_on_image_array
from utils import collate_fn
from torch.utils.data import DataLoader

dl = DataLoader(trn_ds, collate_fn=collate_fn, batch_size=5)
bs = next(iter(dl)) # grab one batch
image, target, idx = bs # unpack batch

In [None]:
im = visualize_boxes_and_labels_on_image_array(
    image=image[1].permute(1, 2, 0).numpy(),
    boxes=target[1]['boxes'].numpy(),
    scores=None,
    classes=target[1]['labels'].numpy(),
    label_map=PASCAL_INSTANCE_CATEGORY_NAMES,
)

im

In [None]:
im = visualize_boxes_and_labels_on_image_array(
    image=image[3].permute(1, 2, 0).numpy(),
    boxes=target[3]['boxes'].numpy(),
    scores=None,
    classes=target[3]['labels'].numpy(),
    label_map=PASCAL_INSTANCE_CATEGORY_NAMES,
)

im

## **Training the RetinaNet model** :


In [None]:
# directory contents
! ls

In [None]:
from omegaconf import OmegaConf, DictConfig
import time

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import *
from pytorch_lightning.callbacks import *

from model import RetinaNetModel, LogCallback

# seed so that results are reproducible
pl.seed_everything(123)

In [None]:
# ========================================================================= #
# MODIFICATION OF THE CONFIG FILE TO FIX PATHS AND DATSET-ARGUEMENTS :
# ========================================================================= #
# Paths to the Images and the Annotations
test_ann_pth  = "/content/pascal_voc_2007_test/Annotations/"
test_im_pth   = "/content/pascal_voc_2007_test/Images/"

train_ann_pth = "/content/pascal_voc_2007_train_val/Annotations/"
train_im_pth  = "/content/pascal_voc_2007_train_val/Images/"

# Paths to the hparam file for LightningModule
hparams       = OmegaConf.load("/content/pytorch_retinanet/hparams.yaml")

# modify the haparams file
# pascal 2007 dataset has 20 classes excluding the "__background__" class
hparams.model.num_classes = 20 
hparams.dataset.kind = "pascal"
# Paths to the train and validation/Testing Datasets
hparams.dataset.trn_paths = [train_ann_pth, train_im_pth]
hparams.dataset.valid_paths = [test_ann_pth, test_im_pth]

print(OmegaConf.to_yaml(hparams))

In [None]:
# Instantie lightning-module
litModel = RetinaNetModel(hparams=hparams)

**Lightning Trainer:**

In [None]:
# ============================================================ #
# INSTANTIATE LIGHTNING-TRAINER with CALLBACKS :
# ============================================================ #
# NOTE: 
# For a list of whole trainer specific arguments see : 
# https://pytorch-lightning.readthedocs.io/en/latest/trainer.html

# Wandb logger
# can use any other logger
wb_name = f"[{time.strftime('%m/%d %H:%M:%S')}]"
wb_p = "pascal-2007" 
wb_logger = WandbLogger(name=wb_name, project=wb_p, anonymous="allow",)

# Learning-rate Logger
lr_logger = LearningRateLogger(logging_interval="step")

# Model Checkpoint
fname =f"/content/drive/My Drive/pascal_checkpoints/weights_pascal/"
os.makedirs(fname, exist_ok=True)
checkpoint_callback = ModelCheckpoint(fname, mode="min", monitor="val_loss", save_top_k=3,)

# callback for early-stopping
early_stop_callback = EarlyStopping(mode="min", monitor="val_loss", patience=10,)

trainer = Trainer(precision=16, 
                  num_sanity_val_steps=0,
                  gpus=1, 
                  logger=[wb_logger],
                  early_stop_callback=early_stop_callback, 
                  checkpoint_callback=checkpoint_callback,
                  callbacks=[LogCallback(), lr_logger], 
                  weights_summary=None,
                  terminate_on_nan = True, 
                  deterministic=True,
                  max_epochs=55,
                  )

In [None]:
trainer.fit(litModel)

## **Evaluating the trained-model** : 

In [None]:
# Evaluations results on the test/ validation dataset(if test dataset is not given)
# using COCO API
trainer.test(litModel)

## **Saving the model weights** :

In [None]:
import torch

path = f"/content/drive/My Drive/pascal_weights_{int(time.time())}.pth"
torch.save(litModel.model.state_dict(), path)

## **Loading model weights :**

In [None]:
from retinanet import Retinanet

state_dict = torch.load(path)

MODEL = Retinanet(num_classes=20, backbone_kind="resnet50")
MODEL.load_state_dict(state_dict)

## **Generating Predictions from the Model :**

In [None]:
# These are our classes
PASCAL_INSTANCE_CATEGORY_NAMES

In [None]:
from PIL import Image
import numpy as np
import cv2

import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils import visualize_boxes_and_labels_on_image_array

@torch.no_grad()
def get_preds(path):
    """
    Generates predictions on the given image from the given path.
    """
    image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
    
    INFER_TRANSFORMS = A.Compose([A.ToFloat(max_value=255.0, always_apply=True),
                                  ToTensorV2(always_apply=True)
                                  ])
    
    TENSOR_IMAGE = INFER_TRANSFORMS(image=image)["image"]
    PREDICTIONS = MODEL.predict([TENSOR_IMAGE])
    return PREDICTIONS[0]

def filter_preds(ps, threshold=0.5):
    """
    Filters the predictions using given threshold.
    """
    scores = ps["scores"]
    labels = ps["labels"]
    boxes = ps["boxes"]

    mask = scores > threshold

    scores = scores[mask]
    labels = labels[mask]
    boxes = boxes[mask]
    return scores.numpy(), labels.numpy(), boxes.numpy()


def detect(image_path, threshold=0.5):
    """
    Generate detections on the image that is present in 
    the given image path

    Args:
        image_path: Path to the input Image
        threshold: Score threshold to filter predictions

    Returns: a PIL image containg the original Image and
             bounding boxes draw over it.
    """
    
    # visualize_boxes_and_labels_on_image_array function
    # expects the pixels values of the image to be in 
    # range [0,1] so be divide the loaded image by 255.0
    # to noramlize the co-ordinates
    image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) / 255.0
    
    # Generate predictions for the given image
    preds = get_preds(image_path)
    # Filter predictions
    scores, labels, boxes = filter_preds(preds, threshold)
    # Draw all the bounding boxes over the Image
    im = visualize_boxes_and_labels_on_image_array(
        image,
        boxes,
        labels,
        scores,
        PASCAL_INSTANCE_CATEGORY_NAMES)
    
    return im