In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.core.decorators import auto_move_data
import timm

from effdet import EfficientDet, DetBenchTrain, get_efficientdet_config
from effdet.config.model_config import efficientdet_model_param_dict
from effdet.efficientdet import HeadNet

In [6]:
import cv2
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from ensemble_boxes import ensemble_boxes_wbf

# Build model & load weights

In [2]:
# Function to build effdet model
def create_model(
    num_classes=3, image_size=512, architecture="tf_efficientnetv2_l"
):
    efficientdet_model_param_dict[architecture] = dict(
        name=architecture,
        backbone_name=architecture,
        backbone_args=dict(drop_path_rate=0.2),
        num_classes=num_classes,
        url='', )
    
    config = get_efficientdet_config(architecture)
    config.update({'num_classes': num_classes})
    config.update({'image_size': (image_size, image_size)})
        
    net = EfficientDet(config, pretrained_backbone=True)
    net.class_net = HeadNet(
        config, num_outputs=config.num_classes
    )
    return DetBenchTrain(net, config)

In [9]:
class EfficientDetModel(pl.LightningModule):
    def __init__(
        self,
        num_classes=3,
        img_size=512,
        prediction_confidence_threshold=0.2,
        learning_rate=0.0002,
        wbf_iou_threshold=0.44,
        model_architecture='tf_efficientnetv2_l',
    ):
        super(EfficientDetModel, self).__init__()
        self.model = create_model(
            num_classes, img_size, architecture=model_architecture
        )
        
    @auto_move_data
    def forward(self, images, targets):
        return self.model(images, targets)

In [10]:
model = EfficientDetModel()
model = model.load_from_checkpoint('weights/effdet_l.ckpt')

model.eval();

# Run image for inference