In [None]:
!pip install -q timm
!pip install -q effdet
!pip install -q py-cpuinfo

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch import optim
from torch.cuda.amp import autocast, GradScaler
from torch.optim import lr_scheduler
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import random
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import effdet
import timm
import cpuinfo
import gc
import wandb
from warnings import simplefilter


simplefilter("ignore")
wandb.login()

In [None]:
def get_device_name(device):
    def get_cpu():
        cpu_info = cpuinfo.get_cpu_info()
        device_name = cpu_info["brand_raw"]
        device_type = "CPU"
        
        return device_name, device_type
    
        
    if "cuda" in device:
        if torch.cuda.is_available():
            device_name = torch.cuda.get_device_name(device)
            device_type = "GPU"
        else:
            device_name, device_type = get_cpu()
    else:
        device_name, device_type = get_cpu()
        
    name = f"{device_name} ({device_type})"
    return name


def transform_bounding_boxes(bounding_boxes, source_format="pascal_voc", target_format="pascal_voc"):
    transformed_bounding_boxes = []
    for bounding_box in bounding_boxes:
        transformed_bounding_box = transform_bounding_box(bounding_box, source_format=source_format, target_format=target_format)
        transformed_bounding_boxes.append(transformed_bounding_box)
        
    transformed_bounding_boxes = np.array(transformed_bounding_boxes)
    return transformed_bounding_boxes
        

def transform_bounding_box(bounding_box, source_format="pascal_voc", target_format="pascal_voc"):
    methods = {
        "pascal_voc": from_pascal_voc,
        "coco": from_coco,
        "yolo": from_yolo,
    }
    
    from_method = methods.get(source_format, from_pascal_voc)
        
    transformed_bounding_box = from_method(bounding_box=bounding_box, target_format=target_format)
        
    return transformed_bounding_box
        

def from_pascal_voc(bounding_box, target_format="pascal_voc"):
    x_min, y_min, x_max, y_max = bounding_box
        
    width = x_max - x_min
    height = y_max - y_min
        
    half_width = width / 2
    half_height = height / 2
        
    if target_format == "coco":
        formated_bounding_box = [x_min, y_min, width, height]
            
    elif target_format == "yolo":
        x_center = x_max / 2
        y_center = y_max / 2
            
        formated_bounding_box = [x_center, y_center, width, height]
            
    else:
        formated_bounding_box = bounding_box
            
    formated_bounding_box = np.array(formated_bounding_box).round()
        
    return formated_bounding_box
        
def from_coco(bounding_box, target_format="pascal_voc"):
    x_min, y_min, width, height = bounding_box 
        
    x_max = x_min + width
    y_max = y_min + height
        
    if target_format == "pascal_voc":
        formated_bounding_box = [x_min, y_min, x_max, y_max]
            
    elif target_format == "yolo":
        x_center = x_max / 2
        y_center = y_max / 2
            
        formated_bounding_box = [x_center, y_center, width, height]
            
    else:
        formated_bounding_box = bounding_box
            
    formated_bounding_box = np.array(formated_bounding_box).round()
        
    return formated_bounding_box
    

def from_yolo(bounding_box, target_format="pascal_voc"):
    x_center, y_center, width, height = bounding_box
        
    half_width = width / 2
    half_height = height / 2
        
    x_max = x_center + half_width
    x_min = x_center - half_width
    y_max = y_center + half_height
    y_min = y_center - half_height
        
    if target_format == "pascal_voc":
        formated_bounding_box = [x_min, y_min, x_max, y_max]
            
    elif target_format == "coco":
        formated_bounding_box = [x_min, y_min, width, height]
        
    else:
        formated_bounding_box = bounding_box
            
    return formated_bounding_box


def draw_bboxes(image, bboxes, source_format="pascal_voc", color=(0, 255, 255), thickness=1):
    methods = {
        "pascal_voc": from_pascal_voc,
        "coco": from_coco,
        "yolo": from_yolo,
    }
    
    image_with_bboxes = image.copy()
    for bbox in bboxes:
        from_method = methods.get(source_format, from_pascal_voc)
        bbox = from_method(bounding_box=bbox, target_format="pascal_voc")
        x_min, y_min, x_max, y_max = bbox.round().astype(int)
        image_with_bboxes = cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
        
    return image_with_bboxes


def save_checkpoint(model, optimizer, epoch=None, loss=None, path="checkpoint.pth"):
    checkpoint = {
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "epoch": epoch,
        "loss": loss
    }
    
    torch.save(checkpoint, path=path)
    
    return checkpoint


def create_model(model_name="tf_efficientdet_d0", num_classes=1, pretrained=True, image_size=(512, 512), checkpoint_path=None, mode="train"):
    config = effdet.get_efficientdet_config(model_name)
    config.image_size = image_size
    config.num_classes = num_classes
    config.norm_kwargs=dict(eps=.001, momentum=.01)

    model = effdet.EfficientDet(config, pretrained_backbone=pretrained)
    model.class_net = effdet.efficientdet.HeadNet(config, num_outputs=config.num_classes)

    if checkpoint_path is not None:
        if torch.cuda.is_available():
            checkpoint = torch.load(checkpoint_path)
        else:
            checkpoint = torch.load(checkpoint_path, map_location="cpu")
        
        model.load_state_dict(checkpoint)            
        print(f"Loaded checkpoint from '{checkpoint_path}'")
    
    
    if mode == "inference":
        model = effdet.DetBenchPredict(model, config)
    else:
        model = effdet.DetBenchTrain(model, config)
        
    return model


def train_one_batch(batch, model, optimizer, scaler=None, clipping_norm=None, inputs_device="cpu", targets_device="cpu"):
    optimizer.zero_grad()
    
    if scaler is not None:
        with autocast():
            inputs, targets = Dataset.collate_batch(batch, inputs_device=inputs_device, targets_device=targets_device)
            outputs = model(inputs, targets)

            batch_loss = outputs["loss"]

        scaler.scale(batch_loss).backward()
        
        if clipping_norm is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clipping_norm)
        
        scaler.step(optimizer)
        scaler.update()
    else:
        inputs, targets =  Dataset.collate_batch(batch, inputs_device=inputs_device, targets_device=targets_device)
        outputs = model(inputs, targets)
        
        batch_loss = outputs["loss"]
        
        batch_loss.backward()
        if clipping_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clipping_norm)
            
        optimizer.step()
        
    return batch_loss


def validate(model, loader, inputs_device="cpu", targets_device="cpu"):
    loss = 0
    with torch.no_grad():
        for batch in loader:
            inputs, targets = Dataset.collate_batch(batch, inputs_device=inputs_device, targets_device=targets_device)
            outputs = model(inputs, targets)

            batch_loss = outputs["loss"]
            loss += batch_loss
        
    loss /= len(loader)
    
    return loss

In [None]:
class Dataset:
    def __init__(self, pathes, bboxes=None, bboxes_format="pascal_voc", bboxes_format_return="pascal_voc", transforms=None):
        self.pathes = pathes
        self.bboxes = bboxes
        self.transforms = transforms
        self.bboxes_format = bboxes_format
        self.bboxes_format_return = bboxes_format_return
        
    def __len__(self):
        return len(self.pathes)
    
    @staticmethod
    def get_bboxes(bboxes_string):
        if bboxes_string != "no_box":
            bboxes = bboxes_string.split(";")
            new_bboxes = []
            for bbox in bboxes:
                new_bbox = bbox.split()
                new_bboxes.append(new_bbox)
        else:
            new_bboxes = []

        new_bboxes = np.asarray(new_bboxes, dtype=np.int32)
        return new_bboxes
    
    def __getitem__(self, index):
        path = self.pathes[index]
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        
        if self.bboxes is not None:
            bboxes = self.bboxes[index]
            bboxes = Dataset.get_bboxes(bboxes)
            labels = np.ones(shape=(bboxes.shape[0], ), dtype=np.int32)
            
            if self.transforms is not None:
                height, width, channels = image.shape
                bboxes = A.normalize_bbox(bboxes, cols=width, rows=height)
                augmented = self.transforms(image=image, bboxes=bboxes, class_labels=labels, class_categories=labels)
                image, bboxes = augmented["image"], augmented["bboxes"]
                bboxes = transform_bounding_boxes(bboxes, source_format=self.bboxes_format, target_format=self.bboxes_format_return)
        
            labels = np.ones(shape=(bboxes.shape[0], ), dtype=np.int32)

            image = torch.tensor(image)
            bboxes = torch.tensor(bboxes)
            labels = torch.tensor(labels)
            
            target = {
                "bboxes": bboxes,
                "labels": labels,
            }
            
            return image, target
        
        return image
    
    
    def show_samples(self, rows=1, columns=1, color=(0, 255, 255), thickness=1, coef=3):
        fig = plt.figure(figsize=(columns*coef, rows*coef))
        n_iterations = rows * columns
        n_samples = len(self)
        
        for i in range(n_iterations):
            index = random.randint(0, n_samples-1)
            path = self.pathes[index]
            image = cv2.imread(path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            if self.bboxes is not None:
                bboxes = self.bboxes[index]
                bboxes = Dataset.get_bboxes(bboxes)
                image = draw_bboxes(image, bboxes, source_format="pascal_voc", color=color, thickness=thickness)
                
            ax = fig.add_subplot(rows, columns, i+1)
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.set_title(f"Sample's index: {index}", loc="left")
            ax.imshow(image)
        
        fig.show()
    
    @staticmethod
    def collate_batch(batch, inputs_device="cpu", targets_device="cpu"):
        if len(batch) == 2:
            inputs, targets = batch
            inputs = inputs.to(inputs_device)
            targets["bbox"] = [bbox.to(targets_device).float() for bbox in targets["bbox"]]
            targets["cls"] = [label.to(targets_device).float() for label in targets["cls"]]
            
            return inputs, targets
        else:
            return batch.to(inputs_device)
        
    @staticmethod
    def collate_fn(batch):
        all_images, all_bboxes, all_labels = [], [], []
        for sample in batch:
            if len(sample) == 2:
                image, target = sample
                all_images.append(image.numpy())
                
                bboxes = target["bboxes"]
                all_bboxes.append(bboxes)
                
                labels = target["labels"]
                all_labels.append(labels)
                
            else:
                image = sample
                all_images.append(image.numpy())
                
        all_images = torch.tensor(all_images, dtype=torch.float32)
        
        if (len(all_bboxes) != 0) and (len(all_labels) != 0):
            all_targets = {
                "bbox": all_bboxes,
                "cls": all_labels,
            }
        
            return all_images, all_targets
        
        return all_images

In [None]:
config = {
    "epochs": 1,
    "device": "cuda",
    "size": (512, 512),
    "lr": 1e-3,
    "batch_size": 4,
    "pin_memory": True,
    "num_workers": 4,
}

In [None]:
train_images_path = "../input/wheat-detection/user_task/images"
train_path = "../input/wheat-detection/user_task/train.csv"
train = pd.read_csv(train_path)
train["image_name"] = train["image_name"].apply(lambda filename: os.path.join(train_images_path, filename))
train.columns = ["path", "bboxes"]

no_boxes = train[train["bboxes"] == "no_box"]
train = train.drop(no_boxes.index, axis=0)

crashed_samples = [267, 564, 695, 847, 870, 1059, 1160, 1286, 1296, 1609, 1618, 1644, 1720, 1764, 1921, 1963, 2088, 2134, 2309, 2327, 2476, 2533, 2613, 2621, 2633, 2653, 2676, 2689, 2701, 2731, 2814, 2816, 2897, 3090, 3368, 3436, 3581, 3587, 3688, 3717, 3778, 3853, 3858, 3974, 4145, 4236, 4381, 4407, 4410, 4444, 4514, 4583]
train = train.drop(crashed_samples, axis=0)

In [None]:
test_size = 0.2
train_data, validation_data, train_targets, validation_targets = train_test_split(train["path"].values, train["bboxes"].values, test_size=test_size)

In [None]:
width, height = config["size"]

bbox_parameters = A.BboxParams(format="pascal_voc", 
                               label_fields=['class_labels', 'class_categories'], 
                               min_area=0, 
                               min_visibility=0)

train_transforms = A.Compose([
    A.Resize(width=width, height=height),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.ShiftScaleRotate(p=0.2),
    A.Normalize(),
    ToTensorV2(),
], bbox_params=bbox_parameters)

train_dataset = Dataset(pathes=train_data, 
                        bboxes=train_targets, 
                        bboxes_format="pascal_voc", 
                        bboxes_format_return="pascal_voc", 
                        transforms=train_transforms)

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=config["batch_size"], 
                          shuffle=True, 
                          num_workers=config["num_workers"], 
                          pin_memory=config["pin_memory"], 
                          drop_last=False, 
                          collate_fn=Dataset.collate_fn)



validation_transforms = A.Compose([
    A.Resize(width=width, height=height),
    A.Normalize(),
    ToTensorV2(),
],  bbox_params=bbox_parameters)

validation_dataset = Dataset(pathes=validation_data, 
                             bboxes=validation_targets, 
                             bboxes_format="pascal_voc", 
                             bboxes_format_return="pascal_voc", 
                             transforms=validation_transforms)

validation_loader = DataLoader(dataset=validation_dataset, 
                               batch_size=config["batch_size"], 
                               shuffle=True, 
                               num_workers=config["num_workers"], 
                               pin_memory=config["pin_memory"], 
                               drop_last=False, 
                               collate_fn=Dataset.collate_fn)

In [None]:
train_dataset.show_samples(rows=1, columns=1, thickness=2, coef=5)

In [None]:
model = create_model(model_name="tf_efficientdet_d5", 
                     pretrained=False, 
                     mode="train", 
                     image_size=config["size"],
                     checkpoint_path="../input/wheat2020-checkpoints/effdet_ed5_512_fold4.pth")

optimizer = optim.AdamW(model.parameters(), lr=config["lr"])
scaler = GradScaler()

In [None]:
experiment_name = f"EfficientDetD5"
num_iterations_per_epoch = len(train_loader)
num_iterations = config["epochs"] * num_iterations_per_epoch
device_name = get_device_name(config["device"])
debug_iterations = 50
validation_steps = 500
best_validation_loss = np.inf
passed_iterations = 1

wandb.init(project="Wheat Detection", entity="zzmtsvv", name=experiment_name)
model.to(DEVICE)
print(f"Training on '{device_name}' for {num_iterations} iterations / {config['epochs']} epochs")
for epoch in range(1, config["epochs"]+1):
    model.train()
    for batch in train_loader:
        batch_loss = train_one_batch(batch=batch, 
                                     model=model, 
                                     optimizer=optimizer, 
                                     scaler=scaler, 
                                     clipping_norm=1, 
                                     inputs_device=config["device"], 
                                     targets_device=config["device"])
    
        if passed_iterations % validation_steps == 0:
            validation_loss = validate(model=model, 
                                       loader=validation_loader, 
                                       inputs_device=config["device"], 
                                       targets_device=config["device"])
        
            if validation_loss.item() < best_validation_loss:
                checkpoint_path = f"{experiment_name}_{validation_loss.item()}.pth"
                checkpoint = save_checkpoint(model=model, 
                                             optimizer=optimizer, 
                                             epoch=epoch, 
                                             loss=validation_loss.item(), 
                                             path=checkpoint_path)
                
                best_validation_loss = validation_loss
                print(f"Iteration [{passed_iterations}/{num_iterations}] Saving Checkpoint with Validation Loss: {validation_loss.item()} | Path: {checkpoint_path}")
            
            
            print(f"Iteration [{passed_iterations}/{num_iterations}] Validation Loss: {validation_loss.item()}")
            wandb.log({"validation_loss": validation_loss.item()})
        
        if passed_iterations % debug_iterations == 0:
            print(f"Iteration [{passed_iterations}/{num_iterations}] Train Loss: {batch_loss.item()}")
            
        wandb.log({"train_loss": batch_loss.item()})    
        passed_iterations += 1