In [1]:
import os
import random
import math

import numpy as np
import pandas as pd
from PIL import Image, ImageDraw


import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

import torchvision
from torchvision import datasets, models, transforms
from torchvision.models.detection.retinanet import RetinaNet
import  torchvision.transforms.functional as F

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import xml.etree.ElementTree as ET
import collections
from torchvision.datasets.voc import VisionDataset

from functions import *
from functions_torch import *



In [2]:
params = {}
params['target_size']=(2000,1500)
params['batch_size'] = 1
params['lr'] = 0.001

voc_root = '/app/host/lacmus/dataset/full_lacmus_ds'

In [3]:
# Reworked class from pytorch (see https://pytorch.org/vision/0.8/_modules/torchvision/datasets/voc.html#VOCDetection)

class LADDDataSET(torchvision.datasets.VisionDataset):
    def __init__(
            self,
            root: str,
            image_set: str,
            transforms: Optional[Callable] = None):     
        super(LADDDataSET, self).__init__(root, transforms=transforms)
        self.image_set = image_set

        voc_root = root
        image_dir = os.path.join(voc_root, 'JPEGImages')
        annotation_dir = os.path.join(voc_root, 'Annotations')

        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.')

        splits_dir = os.path.join(voc_root, 'ImageSets/Main')
        split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
        assert (len(self.images) == len(self.annotations))
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        img = Image.open(self.images[index]).convert('RGB')
        description = LADDDataSET.parse_voc_xml(
            ET.parse(self.annotations[index]).getroot())

        # get bounding box coordinates 
        num_objs = len(description['annotation']['object'])
        boxes = []
        for l in description['annotation']['object']:
            bb = l['bndbox']
            boxes.append([int(bb['xmin']), int(bb['ymin']), int(bb['xmax']), int(bb['ymax'])])

        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)         # there is only one class            
        target["labels"] = labels = torch.ones((num_objs,), dtype=torch.int64)
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target
    

    def __len__(self) -> int:
        return len(self.images)

    @staticmethod
    def parse_voc_xml(node: ET.Element) -> Dict[str, Any]:
        voc_dict: Dict[str, Any] = {}
        children = list(node)
        if children:
            def_dic: Dict[str, Any] = collections.defaultdict(list)
            for dc in map(LADDDataSET.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            if node.tag == 'annotation':
                def_dic['object'] = [def_dic['object']]
            voc_dict = {
                node.tag:
                    {ind: v[0] if len(v) == 1 else v
                     for ind, v in def_dic.items()}
            }
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict

In [4]:
# Pytorch implemenation of retinanet doesn't supports train on Images without any objects (which, probably need to be fixed)
# see https://github.com/pytorch/vision/blob/master/torchvision/models/detection/retinanet.py#L475
# As a temporary solution, yet, we just filtering out empty images

splits_dir = os.path.join(voc_root, 'ImageSets/Main') 
annotation_dir = os.path.join(voc_root, 'Annotations')

with open(os.path.join(splits_dir,'train.txt'), "r") as f:
    file_names = [x.strip() for x in f.readlines()]

non_empty = []
for a in file_names:
    description = LADDDataSET.parse_voc_xml(
        ET.parse(os.path.join(annotation_dir, a + ".xml")).getroot()
    )
    num_objs = len(description['annotation']['object'])
    if num_objs > 0:
        non_empty.append(a+'\n')
        
with open(os.path.join(splits_dir,'train_non_empty.txt'), "w") as f:
    f.writelines(non_empty)

print('Total images '+str(len(file_names)), ' non empty: '+str(len(non_empty)))
                                                
                                    
        

Total images 1220  non empty: 1180


In [5]:
# # test DS
# im_idx = 99

# dataset = LADDDataSET('/app/host/lacmus/dataset/full_lacmus_ds','test',get_transform(train=True,target_size=params['target_size'])) 
# (image,target) = dataset[im_idx] 
# im = F.to_pil_image(image)
# draw = ImageDraw.Draw(im)

# for bb in target['boxes']:
#     draw.line([(bb[0], bb[1]), (bb[0], bb[3]), (bb[2], bb[3]),
#                (bb[2], bb[1]), (bb[0], bb[1])], width=4, fill=(255, 0, 0))

# im.show()

In [6]:
dataset_train = LADDDataSET(voc_root,'train_non_empty',get_transform(train=True,target_size=params['target_size'])) 
dataset_val = LADDDataSET(voc_root,'val',get_transform(train=False,target_size=params['target_size'])) 

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset_train, batch_size=params['batch_size'], shuffle=True, num_workers=4
     ,collate_fn=collate_fn
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val, batch_size=1, shuffle=False, num_workers=4
     ,collate_fn=collate_fn
)


In [7]:
model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=False, num_classes=2, pretrained_backbone=False, 
                                                           min_size=params['target_size'][0], max_size = params['target_size'][1])
model.load_state_dict(torch.load('/app/host/lacmus/weights/resnet50_SDD_epoch_8.pth'), strict=False)

# the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9, weight_decay=0.0005) #lr 0.001 -> 0.005
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

In [None]:
for epoch in range(10):

    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=100)
    print ("Train done, evaluating.")
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    inference_res = evaluate(model,data_loader_val)
    print('Inference done, computing mAp : ')
    print(evaluate_res(inference_res, iou_threshold = 0.5, score_threshold = 0.05))    
    print(evaluate_res(inference_res, iou_threshold = 0.6, score_threshold = 0.05))
    print('Epoch Done')

Epoch: [0]  [   0/1180]  eta: 0:25:37  lr: 0.000002  loss: 0.9846 (0.9846)  classification: 0.4986 (0.4986)  bbox_regression: 0.4859 (0.4859)  time: 1.3034  data: 0.8602  max mem: 2243
Epoch: [0]  [ 100/1180]  eta: 0:08:11  lr: 0.000102  loss: 1.1960 (1.5684)  classification: 0.5574 (0.7180)  bbox_regression: 0.6622 (0.8504)  time: 0.4471  data: 0.0094  max mem: 2532
Epoch: [0]  [ 200/1180]  eta: 0:07:23  lr: 0.000202  loss: 1.1922 (1.3587)  classification: 0.4271 (0.6163)  bbox_regression: 0.6627 (0.7424)  time: 0.4540  data: 0.0114  max mem: 2532
Epoch: [0]  [ 300/1180]  eta: 0:06:41  lr: 0.000302  loss: 0.8619 (1.2302)  classification: 0.3754 (0.5530)  bbox_regression: 0.4774 (0.6772)  time: 0.4637  data: 0.0118  max mem: 2532
Epoch: [0]  [ 400/1180]  eta: 0:05:57  lr: 0.000402  loss: 0.9241 (1.1860)  classification: 0.3910 (0.5190)  bbox_regression: 0.5610 (0.6670)  time: 0.4643  data: 0.0120  max mem: 2532
Epoch: [0]  [ 500/1180]  eta: 0:05:12  lr: 0.000501  loss: 0.8291 (1.1342) 

In [None]:
# After first epoch 'scores' in predictions was empty, which was reason for 0 mAp. If it will stay the same over several epoch - need to debug