In [None]:
import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import tqdm

from data.dataset_utils import get_dataset, get_transforms
from detection_models.detection_model_utils import get_model

# Inference Notebook

## Load dataset

We first load the dataset. We can either supply a path to a folder containing [`images`, `labels`] subfolders (e.g., `04_model_input` + `mode="val" #"train"`) or to one of the raw `caseX` folders in `01_raw`, containing only images. In the latter case, the dataset will load the corresponding labels from `03_primary/labels` folder. 

Furthermore, we filter out positive or negative frames. The video sequences for the first 13 cases (the ones containing both positive and negative frames) are not "continuous", in the sense that positive and negative frames are not interleaved, but rather belong to separate clips/moments in the colonoscopy.

In [None]:
negative = False # whether to take negative or positive frames

bbox_format = 'pascal_voc'
resize = 512

dataset = get_dataset(
    '', 
    '/home/thuynh/data/01_raw/case63', 
    get_transforms('val', params={'transforms': {'resize': resize, 'min_area': 900, 'min_visibility': 0.25}, 'format': bbox_format}, normalize=True),
    bbox_format
)

# only take the selected frames from the selected case and sort them sequentially
if not negative:
    dataset.images_list = sorted([x for x in dataset.images_list if 'Negative' not in x], key=lambda x: (int(x[-20:-18].strip('a')), int(x[-8:-4])))
else: # select only the first 1200 negative frames for that case, to have a 20FPS video and have a representative sample
    dataset.images_list = sorted([x for x in dataset.images_list if 'Negative' in x][:1200], key=lambda x: (int(x.split('Negative_')[-1][0]), int(x[-8:-4])) if x.split('Negative_')[-1][0].isdigit() else (0, int(x[-8:-4])))

# Also update labels_list
dataset.labels_list = [x[:-4]+'.json' for x in dataset.images_list]

## Load model

We load one trained model. We set `image_mean = (0,0,0)` and `image_std = (1,1,1)` because the `dataset.visualize()` method performs image normalization internally when provided with a model for inference.

Instead, we set the confidence score threshold for accepting a prediction to `0.5`. 

In [None]:
# Dataset specifications
ckpt_path = '/home/thuynh/torchvision_tutorial/runs/2022-05-11_18_56_13_2173/model_9.pth' 

model_name = 'fasterrcnn_mobilenet_v3_large_fpn'

kwargs = {"trainable_backbone_layers": 5, "min_size": resize, "max_size": resize, 'image_mean': (0., 0., 0.), 'image_std': (1., 1., 1.), 'box_score_thresh': 0.5}

num_classes = 7

pretrained = True
pretrained_backbone = True

In [None]:
# Load model
if 'rcnn' in model_name:
    model = torchvision.models.detection.__dict__[model_name](
                pretrained=pretrained, pretrained_backbone=pretrained_backbone, **kwargs
            )

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
else:
    model = torchvision.models.detection.__dict__[model_name](
                pretrained=pretrained, pretrained_backbone=pretrained_backbone, num_classes=num_classes, **kwargs
            )

# Load checkpoint
ckpt = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(ckpt['model'])
model = model.eval()

## Inference pass

Next, we perform the inference pass using the `dataset.visualize()` method with the `model` argument specified. 

Internally, the `visualize` method loads all specified images (random ones otherwise) and corresponding labels. Then, performs the inference pass on the specified model and proceeds to plots the ground-truth and predicted (if any) bounding boxes onto the **original** image (n.b.: model input size and original image size do not correspond usually).

In [None]:
# Perform inference + drawing
images, grid = dataset.visualize(images_list=dataset.images_list, model=model, resize=resize)

In [None]:
# Plot in notebook (only for few images)
# plt.figure(figsize=(10,40))
# plt.axis('off')
# plt.imshow(np.array(grid.permute(1,2,0)))
# plt.show()

## Save sequence

Finally, we save the inferenced sequences as a video using `opencv`. 

In [None]:
case_no = dataset.images_path.split('/')[-2] # extract from filepath

# Define FPS to make video fit within 60 seconds or at least 10FPS 
FPS = max(10.0, round(len(images)/60, 1))

suffix = '_negative' if negative else '_positive'

# Save video
fourcc = cv2.VideoWriter_fourcc(*'XVID')

height, width, channels = images[0].shape
video_name = f'/home/thuynh/data/07_reporting/video/{model_name}/{case_no}{suffix}.avi'

if not os.path.exists(os.path.split(video_name)[0]):
    os.makedirs(os.path.split(video_name)[0])

video_writer = cv2.VideoWriter(video_name, fourcc, fps=FPS, frameSize=(width, height))

for img in tqdm.tqdm(images, total=len(images), desc=f'Creating video for {case_no}...'):
    video_writer.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

cv2.destroyAllWindows()
video_writer.release()

print(f'Video for "{case_no}" saved to: {video_name}.')