In [1]:
import albumentations as A
import cv2
import torch
import sys

from albumentations.pytorch.transforms import ToTensorV2
from effdet import default_detection_model_configs, load_checkpoint, load_pretrained, EfficientDet, \
                   DetBenchTrain
from effdet.config.model_config import efficientdet_model_param_dict
from omegaconf import OmegaConf
from pathlib import Path

In [2]:
sys.path.append('../src')

In [3]:
from data.dataset import WheatDataset
from data.utils import collate

In [4]:
DATA_DIR = Path('/media/dmitry/data/global-wheat-detection')

In [5]:
def get_efficientdet_config(model_name='tf_efficientdet_d1'):
    """Get the default config for EfficientDet based on model name."""
    config = default_detection_model_configs()
    config.num_classes = 1
    model_config = efficientdet_model_param_dict[model_name]
    return OmegaConf.merge(config, OmegaConf.create(model_config))

In [6]:
def create_model(model_name, bench_task='', pretrained=False, checkpoint_path='', 
                 checkpoint_ema=False, **kwargs):    
    config = get_efficientdet_config(model_name)

    pretrained_backbone = kwargs.pop('pretrained_backbone', True)
    if pretrained or checkpoint_path:
        pretrained_backbone = False  # no point in loading backbone weights

    redundant_bias = kwargs.pop('redundant_bias', None)
    if redundant_bias is not None:
        # override config if set to something
        config.redundant_bias = redundant_bias

    model = EfficientDet(config, pretrained_backbone=pretrained_backbone, **kwargs)

    # FIXME handle different head classes / anchors and re-init of necessary layers w/ pretrained load

    if checkpoint_path:
        load_checkpoint(model, checkpoint_path, use_ema=checkpoint_ema)
    elif pretrained:
        load_pretrained(model, config.url)

    # wrap model in task specific bench if set
    if bench_task == 'train':
        model = DetBenchTrain(model, config)
    elif bench_task == 'predict':
        model = DetBenchPredict(model, config)
    return model

In [7]:
get_efficientdet_config('tf_efficientdet_d1')

{'name': 'tf_efficientdet_d1', 'backbone_name': 'tf_efficientnet_b1', 'backbone_args': {'drop_path_rate': 0.2}, 'image_size': 640, 'num_classes': 1, 'min_level': 3, 'max_level': 7, 'num_levels': 5, 'num_scales': 3, 'aspect_ratios': [[1.0, 1.0], [1.4, 0.7], [0.7, 1.4]], 'anchor_scale': 4.0, 'pad_type': 'same', 'act_type': 'swish', 'box_class_repeats': 3, 'fpn_cell_repeats': 4, 'fpn_channels': 88, 'separable_conv': True, 'apply_bn_for_resampling': True, 'conv_after_downsample': False, 'conv_bn_relu_pattern': False, 'use_native_resize_op': False, 'pooling_type': None, 'redundant_bias': True, 'fpn_name': None, 'fpn_config': None, 'fpn_drop_path_rate': 0.0, 'alpha': 0.25, 'gamma': 1.5, 'delta': 0.1, 'box_loss_weight': 50.0, 'url': 'https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d1-4c7ebaf2.pth'}

In [8]:
device = torch.device('cuda:1')

model = create_model(
    'tf_efficientdet_d1', 
    bench_task='train',
    pretrained=False,
    pretrained_backbone=True,
    redundant_bias=None,
    checkpoint_path=''
).to(device)

In [9]:
# N = 5
# B = 5  # boxes per image

# H, W = 640, 640

# base = torch.tensor([0, 0, H, W])[None, None, :]
# scale = torch.cat([torch.zeros(N, B, 2), torch.rand(N, B, 2)], dim=2)
# shift = (torch.rand(N, B, 2) * torch.tensor([H, W])).repeat(1, 1, 2)


# dummy_boxes = base * scale + shift
# dummy_boxes[:, [0, 2]] = dummy_boxes[:, [0, 2]].clamp_(0, H)
# dummy_boxes[:, [1, 3]] = dummy_boxes[:, [1, 3]].clamp_(0, W)
# dummy_boxes = dummy_boxes.to(device)

# dummy_boxes.shape

In [10]:
# x = torch.randn(5, 3, 640, 640).to(device)
# cls = torch.zeros(N, B).to(device)

# target = dict(
#     bbox=dummy_boxes,
#     cls=cls  
# )

# model(x, target)

In [11]:
image_dir = DATA_DIR/'train'
csv_path = DATA_DIR/'train.csv'

tfms = [
    A.Flip(),
    A.RandomRotate90(),
    A.Resize(640, 640, interpolation=cv2.INTER_AREA),
    # we can send byte tensors to GPU and convert byte -> float there
    ToTensorV2()
]
tfms = A.Compose(tfms, bbox_params=A.BboxParams('pascal_voc'))

ds = WheatDataset(image_dir, csv_path, transforms=tfms)

Parsing bboxes...: 100%|██████████| 24/24 [00:00<00:00, 149.70it/s]


In [12]:
dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False, collate_fn=collate)

In [13]:
for images, bboxes in dl:
    images = images.to(device).float()
    boxes, cls = [], []
    
    for b in bboxes:
        c = torch.ones(len(b), device=device)
        b = b.to(device).float()
        boxes.append(b)
        cls.append(c)
    
    target = dict(bbox=boxes, cls=cls)
    out = model(images, target)
    break
    
out

{'loss': tensor(3.1527, device='cuda:1', grad_fn=<AddBackward0>),
 'class_loss': tensor(1.1541, device='cuda:1', grad_fn=<SumBackward1>),
 'box_loss': tensor(0.0400, device='cuda:1', grad_fn=<SumBackward1>)}