# Final Inference Notebook with TTA

This solution/module/whatever makes use of the following libraries/modules:

EfficientDet-PyTorch (https://github.com/rwightman/efficientdet-pytorch) licensed under Apache 2.0, Copyright Ross Wightman, License: third_party/effdet/LICENSE

PyTorch (https://github.com/pytorch/pytorch) 

## Introduction

This notebook aims to provide a working training script for the updated version of Effdet created by [@rwrightman](https://www.kaggle.com/rwightman). I realised the training code released by [@shonenkov](https://www.kaggle.com/shonenkov) does not work for the updated Effdet, hence, i made some changes accordingly.

#### Main Changes:
* **DetBenchTrain** forward() now takes in a dictionary object with key:value pair {bbox:, cls:, img_size:, img_scale:} as argument
* **DetBenchPredict** forward() takes in (image, img_size, img_scale) as argumemnt

(Im not sure what is the significance of img_scale and img_size, would appreciate if anyone can advise on this)

### Download and Import Dependencies

In [1]:
!pip install albumentations
!pip install --no-deps '../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl' > /dev/null
!pip install --no-deps '../input/timm-0130/timm-0.1.30-py3-none-any.whl' > /dev/null
import sys
sys.path.insert(0,'../input/effidetpytorch/effidetpytorch') #add packages to system path to allow import
sys.path.insert(0,'../input/torch-img-model')
sys.path.insert(0,'../input/omegaconf')



In [2]:
import numpy as np 
import pandas as pd 
import torch
import os
from glob import glob
import random
from tqdm.notebook import tqdm
import cv2
import albumentations as A
from torch.utils.data import Dataset,DataLoader
from sklearn.model_selection import StratifiedKFold
from albumentations.pytorch.transforms import ToTensorV2
import matplotlib.pyplot as plt
import matplotlib.patches as patches

### Data Preprocessing

In [3]:
IMG_SIZE = 1024

### Model Instantiation

In [4]:
import effdet
from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain
from effdet.efficientdet import HeadNet

#def get_net():
#    config = get_efficientdet_config('tf_efficientdet_d7')
#    net = EfficientDet(config, pretrained_backbone=False)
#    checkpoint = torch.load('../input/efficientdet-model/eff_det_models/tf_efficientdet_d5-ef44aea8.pth') #d3-d7 ('efficientdet_model' folder) 
#    net.load_state_dict(checkpoint)
#    config.num_classes = 1
#    config.image_size = IMG_SIZE
#    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
#    return DetBenchTrain(net, config)

# net = get_net()

### Run training

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

## Inference
For the inference, i have referred to [@shonenkov](https://www.kaggle.com/shonenkov) insightful use of Weighted Box Fusion(WBF) that uses information from all boxes to fix the overlapping bounding boxes issues as well as the Test Time Augmentation(TTA) template that he has kindly provided [here](https://www.kaggle.com/shonenkov/wbf-over-tta-single-model-efficientdet). Do head over and read about it!

In [6]:
sys.path.insert(0, "../input/weightedboxfusion")

import gc
from effdet import DetBenchPredict
from ensemble_boxes import *

TEST_PATH = "../input/global-wheat-detection/test/"

### Test Dataset

In [7]:
class WheatData(Dataset):
    def __init__(self, img_ids, transform=None):
        self.img_ids = img_ids
        self.transform = transform
        
    def __getitem__(self, index):
        img_id = self.img_ids[index]
        image = cv2.imread(f'{TEST_PATH}/{img_id}.jpg', cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image = image /255.0
        
        if self.transform:
            sample = {'image' : image}
            sample = self.transform(**sample)
            image = sample['image']
        
        target = {}
        target['img_scale'] = torch.tensor([1.])
            
        return image, img_id, target
        
    def __len__(self) -> int: #annotate parameters with their expected type
        return self.img_ids.shape[0]

In [8]:
def valid_transform():
    return A.Compose([
            A.Resize(height=IMG_SIZE, width=IMG_SIZE, p=1.0),
            ToTensorV2(p=1.0)], 
            p=1.0)


def collate_fn(batch):
    return tuple(zip(*batch))


test_dataset = WheatData(
    img_ids=np.array([path.split('/')[-1][:-4] for path in glob(f'{TEST_PATH}/*.jpg')]),
    transform=valid_transform())

test_loader = DataLoader(test_dataset,
                         batch_size = 4,
                         shuffle = False,
                         drop_last = False,
                         collate_fn = collate_fn) 

### Load saved model

In [9]:
from glob import glob

### All the weights

In [10]:
#WEIGHTS = glob("/kaggle/input/*/*.bin")
WEIGHTS  = [
#     '/kaggle/input/0723-effdet7valid-f0/best-checkpoint-043epoch.bin',
#             '/kaggle/input/0804-effidetd7/f1_ep39_loss_0.404.bin',
#             '/kaggle/input/0725-effidetd7/f2_ep_39_loss_0.39553.bin',
#             '/kaggle/input/0725-effidetd7/f3_ep_33_loss_0.40588.bin',
#             '/kaggle/input/0725-effidetd7/f4_ep_41_loss_0.39407.bin',
            '/kaggle/input/0803-ramen-dx-3cups/f0-0387-ep64.bin',
            '/kaggle/input/0803-ramen-dx-3cups/f1-0387-ep42.bin',
            '/kaggle/input/0804ramendxf2/f2_ep31_loss_0.3883.bin']

In [11]:
def get_pretrained_tag(path):
    if "ramen" in path:
        return 'tf_efficientdet_d7x'
    else:
        return 'tf_efficientdet_d7'

In [12]:
class NET(torch.nn.Module):
    def __init__(self,model):
        super().__init__()
        self.model = model
        
    def forward(self,*args,**kwargs):
        return self.model(*args,**kwargs)

In [13]:
def load_net(checkpoint_path,):
    tag = get_pretrained_tag(checkpoint_path)
    
    
    if tag == 'tf_efficientdet_d7x':
        print(f'[7DX] from {checkpoint_path}')
        config = get_efficientdet_config(tag)
        model = EfficientDet(config, pretrained_backbone=False)
    
        config.num_classes = 1
        config.image_size = IMG_SIZE
        model.class_net = HeadNet(config, num_outputs=config.num_classes, 
                                norm_kwargs=dict(eps=.001, momentum=.01))
    
        net = NET(model)
    
        checkpoint = torch.load(checkpoint_path)
        #net.load_state_dict(checkpoint['model_state_dict'])
    
        import collections
        apex_model = collections.OrderedDict()
        for key, value in checkpoint['model_state_dict'].items():
            k = 'model.'+ key
            apex_model[k] = value
            #if 'anchor_labeler.anchors.boxes' in key:
            #    pass
            #elif 'model.' in key:
            #    k = key[5:]
            #    apex_model[k] = value
            #
            #else:
            #    apex_model[key] = value
        net.load_state_dict(apex_model)
    
        del checkpoint
        del apex_model
        gc.collect()
    
        
    else:
        print(f'[7D] from {checkpoint_path}')
        config = get_efficientdet_config(tag)
        model = EfficientDet(config, pretrained_backbone=False)
    
        config.num_classes = 1
        config.image_size = IMG_SIZE
        model.class_net = HeadNet(config, num_outputs=config.num_classes, 
                            norm_kwargs=dict(eps=.001, momentum=.01))
    
        net = NET(model)
        checkpoint = torch.load(checkpoint_path)
        
        import collections
        apex_model = collections.OrderedDict()
        for key, value in checkpoint['model'].items():
            if 'anchor_labeler.anchors.boxes' in key:
                pass
            else:
                apex_model[key] = value
        net.load_state_dict(apex_model)
    
        del checkpoint
        del apex_model
        gc.collect()
    
    net = DetBenchPredict(net, config)
    
    net.eval()
    net = net.cuda()
    return net.cuda()

# load
Nets = list(load_net(weight) for weight in WEIGHTS)

[7DX] from /kaggle/input/0803-ramen-dx-3cups/f0-0387-ep64.bin
[7DX] from /kaggle/input/0803-ramen-dx-3cups/f1-0387-ep42.bin
[7DX] from /kaggle/input/0804ramendxf2/f2_ep31_loss_0.3883.bin


### Testtime inference

In [14]:
class BaseWheatTTA:
    """ author: @shonenkov """
    image_size = IMG_SIZE

    def augment(self, image):
        raise NotImplementedError
    
    def batch_augment(self, images):
        raise NotImplementedError
    
    def deaugment_boxes(self, boxes):
        raise NotImplementedError

class TTAHorizontalFlip(BaseWheatTTA):
    """ author: @shonenkov """

    def augment(self, image):
        return image.flip(1)
    
    def batch_augment(self, images):
        return images.flip(2)
    
    def deaugment_boxes(self, boxes):
        boxes[:, [1,3]] = self.image_size - boxes[:, [3,1]]
        return boxes

class TTAVerticalFlip(BaseWheatTTA):
    """ author: @shonenkov """
    
    def augment(self, image):
        return image.flip(2)
    
    def batch_augment(self, images):
        return images.flip(3)
    
    def deaugment_boxes(self, boxes):
        boxes[:, [0,2]] = self.image_size - boxes[:, [2,0]]
        return boxes
    
class TTARotate90(BaseWheatTTA):
    """ author: @shonenkov """
    
    def augment(self, image):
        return torch.rot90(image, 1, (1, 2))

    def batch_augment(self, images):
        return torch.rot90(images, 1, (2, 3))
    
    def deaugment_boxes(self, boxes):
        res_boxes = boxes.copy()
        res_boxes[:, [0,2]] = self.image_size - boxes[:, [1,3]]
        res_boxes[:, [1,3]] = boxes[:, [2,0]]
        return res_boxes

class TTACompose(BaseWheatTTA):
    """ author: @shonenkov """
    def __init__(self, transforms):
        self.transforms = transforms
        
    def augment(self, image):
        for transform in self.transforms:
            image = transform.augment(image)
        return image
    
    def batch_augment(self, images):
        for transform in self.transforms:
            images = transform.batch_augment(images)
        return images
    
    def prepare_boxes(self, boxes):
        result_boxes = boxes.copy()
        result_boxes[:,0] = np.min(boxes[:, [0,2]], axis=1)
        result_boxes[:,2] = np.max(boxes[:, [0,2]], axis=1)
        result_boxes[:,1] = np.min(boxes[:, [1,3]], axis=1)
        result_boxes[:,3] = np.max(boxes[:, [1,3]], axis=1)
        return result_boxes
    
    def deaugment_boxes(self, boxes):
        for transform in self.transforms[::-1]:
            boxes = transform.deaugment_boxes(boxes)
        return self.prepare_boxes(boxes)

### Weighted Box Fusion

In [15]:
def run_wbf(predictions, image_index, image_size=IMG_SIZE, iou_thr=0.44, 
            skip_box_thr=0.43, weights=None):
    boxes = [(prediction[image_index]['boxes']/(image_size-1)).tolist() for prediction in predictions]
    scores = [prediction[image_index]['scores'].tolist() for prediction in predictions]
    labels = [np.ones(prediction[image_index]['scores'].shape[0]).astype(int).tolist() for prediction in predictions]
    boxes, scores, labels = ensemble_boxes_wbf.weighted_boxes_fusion(boxes, scores, labels, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
    boxes = boxes*(image_size-1)
    return boxes, scores, labels

### Running inference

In [16]:
def format_prediction_string(boxes, scores):
    pred_strings = []
    for j in zip(scores, boxes):
        pred_strings.append("{0:.4f} {1} {2} {3} {4}".format(j[0], j[1][0], j[1][1], 
                                                             j[1][2], j[1][3]))
    return " ".join(pred_strings)

In [17]:
from itertools import product

tta_transforms = []
for tta_combination in product([TTAHorizontalFlip(), None], 
                               [TTAVerticalFlip(), None],
                               [TTARotate90(), None]):
    tta_transforms.append(TTACompose([tta_transform for tta_transform in tta_combination if tta_transform]))

In [18]:
# WBF over TTA
def predict_single(net,tta = True):
    def predict_func(images, target, score_thres=0.5):
        with torch.no_grad():
            prediction = []
            images = torch.stack(images).to(device).float()
            img_scale = torch.tensor([1]*images.shape[0]).float().cuda()
            img_size = torch.tensor([(IMG_SIZE, IMG_SIZE) for target in targets]).to(device)
    
            '''
    
            Within the forward function of the DetBenchPredict class, it takes in 3 arguments (image, image_scale, image_size)
            The return object is as follows: 
            detections = torch.cat([boxes, scores, classes.float()], dim=1) 
            where the first 4 col will be the bboxes, 5th col the scores
            Find out more at https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/bench.py
    
            '''
            if tta:
                for tta_transform in tta_transforms:
                    result = []
                    det = net(tta_transform.batch_augment(images.clone()),
                              img_scales = img_scale,
                              img_size = img_size)
        
                    for i in range(images.shape[0]):
                        boxes = det[i].detach().cpu().numpy()[:,:4]    
                        scores = det[i].detach().cpu().numpy()[:,4]
                        indexes = np.where(scores > score_thres)[0]
                        boxes = boxes[indexes]
                        boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
                        boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
                        boxes = tta_transform.deaugment_boxes(boxes.copy())
                        result.append({
                            'boxes': boxes,
                            'scores': scores[indexes],
                        })
        
                    prediction.append(result)
            else:
                result = []
                det = net(images.clone(),
                              img_scales = img_scale,
                              img_size = img_size)
        
                for i in range(images.shape[0]):
                    boxes = det[i].detach().cpu().numpy()[:,:4]    
                    scores = det[i].detach().cpu().numpy()[:,4]
                    indexes = np.where(scores > score_thres)[0]
                    boxes = boxes[indexes]
                    boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
                    boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
                    result.append({
                            'boxes': boxes,
                            'scores': scores[indexes],
                        })
        
                prediction.append(result)
    
        return prediction
    return predict_func

In [19]:
from itertools import chain

In [20]:
def predict(images, target,tta=False, score_thres=0.5):
    preds = []
    for net in Nets:
        preds.append(predict_single(net,tta=tta)(images, target, score_thres))
    return list(chain(*preds))

In [21]:
firstnet = Nets[0]

In [22]:
%%time
results = []
for images, image_ids, targets in test_loader:
    predictions = predict(images, targets,tta = True)
    for i, image in enumerate(images):
        boxes, scores, labels = run_wbf(predictions, image_index=i)
        boxes = (boxes*1024/1024).astype(np.int32).clip(min=0, max=1023)
        image_id = image_ids[i]
        
        boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
        boxes[:, 3] = boxes[:, 3] - boxes[:, 1]

        result = {
            'image_id': image_id,
            'PredictionString': format_prediction_string(boxes, scores)
        }
        results.append(result)



CPU times: user 33.7 s, sys: 15.5 s, total: 49.2 s
Wall time: 50.1 s


In [23]:
test_df = pd.DataFrame(results, columns=['image_id', 'PredictionString'])
test_df.to_csv('submission.csv', index=False)
test_df.head(20)

Unnamed: 0,image_id,PredictionString
0,348a992bb,0.8578 732 222 141 88 0.8136 598 444 122 98 0....
1,796707dd7,0.8602 709 823 110 104 0.8564 895 331 113 94 0...
2,aac893a91,0.8652 559 532 121 188 0.8521 28 451 103 158 0...
3,f5a1f0358,0.8969 688 204 113 93 0.8917 943 435 79 185 0....
4,cb8d261a3,0.8743 753 489 120 91 0.8529 311 167 101 201 0...
5,cc3532ff6,0.9557 772 828 164 161 0.8581 909 124 112 96 0...
6,51f1be19e,0.8374 608 81 160 180 0.8209 839 265 136 205 0...
7,51b3e36ab,0.8760 495 359 315 131 0.8742 870 287 152 142 ...
8,53f253011,0.8863 14 33 145 110 0.8830 621 100 119 146 0....
9,2fd875eaa,0.8874 106 584 141 85 0.8742 730 155 83 88 0.8...
