In [22]:
import os
import math
import torch
import pickle
import torchvision


import numpy as np
import bbox_visualizer as bbv
import matplotlib.pyplot as plt


from PIL import Image
from tqdm import tqdm
from pathlib import Path
from torchvision.datasets.voc import *
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

In [116]:
from torchvision.transforms import Compose
VOC_LABEL = [   'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus','car','cat','chair','cow',
                'diningtable', 'dog','horse','motorbike', 'person','pottedplant', 'sheep','sofa','train','tvmonitor']
VOC_LABEL_PAIR = {VOC_LABEL[i]:i+1 for i in range(len(VOC_LABEL))}


WRONG_FILE_NAMES_IN_TRAIN = ['2009_005069']
WRONG_FILES_NAMES_IN_VAL = ['2008_005245', '2009_000455', '2009_004969', '2011_002644', '2011_002863']
WRONG_FILES_NAMES = WRONG_FILE_NAMES_IN_TRAIN+WRONG_FILES_NAMES_IN_VAL
def remove_wrong_annotated_file(file_names):
    for wrong_file_name in WRONG_FILES_NAMES:
        if wrong_file_name in file_names:
            file_names.remove(wrong_file_name)
    return file_names


class VOC2012_MaskRCNN_InstanceSegmentation_Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root: str,
        image_set: str = "train",
        image_transform: Optional[Callable] = Compose([torchvision.transforms.ToTensor()]),
        label_transform: Optional[Callable] = None,
        box_transform: Optional[Callable] = None,
        mask_transform: Optional[Callable] = None

    ):
        root = Path(root)
        self.image_transform = image_transform
        self.label_transform = label_transform
        self.box_transform = box_transform
        self.mask_transform = mask_transform


        self.image_set = verify_str_arg(image_set, "image_set", ["train", "trainval", "val"])

        split_f = root/"ImageSets"/"Segmentation"/f"{image_set}.txt"
        with open(os.path.join(split_f)) as f:
            file_names = [x.strip() for x in f.readlines()]
            file_names = remove_wrong_annotated_file(file_names)
            
        # dir setting
        image_dir = root/"JPEGImages"
        annotation_dir = root/"Annotations"
        mask_dir = root/"SegmentationObject"

        # file name list setting
        self.images = [image_dir/f"{x}.jpg" for x in file_names]
        self.annotations = [annotation_dir/f"{x}.xml" for x in file_names]
        self.masks = [mask_dir/f"{x}.png" for x in file_names]
        
        assert len(self.images) == len(self.annotations) == len(self.masks)
        
    def __len__(self) -> int:
        return len(self.images)
    
    def __getitem__(self, index:int):
        image = self.get_image(index)
        labels, boxes = self.get_annotation(index)
        masks = self.get_masks(index)

        target = {
            "labels": torch.tensor(labels),
            "boxes" : torch.tensor(boxes),
            "masks" : torch.tensor(masks)
        }
        
        return (image, target)
    
    def get_image(self, index):
        image = Image.open(self.images[index]).convert("RGB")
        if self.image_transform is not None:
            image = self.image_transform(image)
        return image
    
    def get_masks(self, index):
        mask = Image.open(self.masks[index])
        mask = np.array(mask)

        id_list = np.unique(mask)
        id_list = np.delete(id_list, np.where((id_list == 0) | (id_list == 255)))

        masks = (mask[None, :] == id_list.reshape(-1, 1, 1)).astype(np.uint8)
    
        if self.mask_transform is not None:
            masks = self.image_transform(masks)

        return masks
    
    def get_annotation(self, index):# labels, boxes
        annotation = VOCDetection.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
        objects = annotation["annotation"]["object"]
        
        labels = np.array([VOC_LABEL_PAIR[o["name"]] for o in objects], dtype = np.int64)
        boxes = [o["bndbox"] for o in objects]
        boxes = np.array([[box["xmin"], box["ymin"], box["xmax"], box["ymax"]] for box in boxes], dtype = np.float32)

        if self.label_transform is not None:
            labels = self.label_transform(labels)
        
        if self.box_transform is not None:
            boxes = self.box_transform(boxes)
        
        return (labels, boxes)





In [21]:
class VOC2012_MaskRCNN_InstanceSegmentation:
    CLASS_NUM = 21
    SCORE_THRESHOLDS = [0.1*i for i in range(1, 11)]

    @staticmethod
    def loss_eval(model, device, data_loader):
        model.train()
        model.to(device)
        loss_list = []
        for images, targets in tqdm(data_loader):
            images, targets = __class__.batch_to(images, targets, device)
            loss_dict = model(images, targets)
            loss_vector = torch.stack([value for value in loss_dict.values()])
            loss = loss_vector @ loss_vector**(1/2)
            loss_list.append(loss)
        return np.array(loss_list).mean()
        
    def map_eval(model, device, data_loader, class_num, mask_threshold, score_threshold):
        model.eval()
        model.to(device)

        counts = []
        for images, targets in tqdm(data_loader):
            images, targets = __class__.batch_to(images, targets, device)
            results = model(images, targets)

            # 클래스 별로 GT, TP, Detection 개수 카운트하여 counts에 추가
            for result, target in zip(results, targets):

                result = __class__.tensor_to_numpy_for_result(result)
                target = __class__.tensor_to_numpy_for_target(target)

                p_scores, p_labels, p_boxes, p_masks = __class__.dict_to_tuple_for_result(result)
                gt_labels, gt_boxes, gt_masks = __class__.dict_to_tuple_for_target(target)

                p_masks = __class__.to_binary_by_threshold(p_masks, threshold = mask_threshold).squeeze(1)


                count_result = __class__.count_result(p_masks, p_labels, p_scores, gt_masks, gt_labels, 
                                            class_num = class_num, 
                                            score_thresholds = score_threshold,
                                            iou_threshold = 0.3)
                counts.append(count_result)
        
        # 카운트 결과 종합
        detected_num_list = [count["detected_num"] for count in counts]
        gt_num_list = [count["gt_num"] for count in counts]
        tp_num_list = [count["tp_num"] for count in counts]

        detected_num = np.stack(detected_num_list).sum(0)
        gt_num = np.stack(gt_num_list).sum(0)
        tp_num = np.stack(tp_num_list).sum(0)

        # MAP 계산 후 반환
        return __class__.ap_per_class(detected_num, gt_num, tp_num)

    
    @staticmethod
    def tensor_to_numpy(data):
        return data.to("cpu").detach().numpy()
    
    # Batch #######################################################
    @staticmethod
    def batch_to(images, targets, device):
        images = [image.to(device) for image in images]
        targets = [{k:v.to(device) for k, v in t.items()} for t in targets]
        return images, targets

    # Image #######################################################
    @staticmethod
    def tensor_to_numpy_for_image(tensor_image):
        return __class__.tensor_to_numpy(tensor_image).transpose((1, 2, 0))
    
    # Target    
    @staticmethod
    def dict_to_tuple_for_target(sample_target):
        return sample_target["labels"], sample_target["boxes"], sample_target["masks"]
    
    @staticmethod
    def tuple_to_dict_for_target(labels, boxes, masks):
        return {"labels":labels, "boxes":boxes, "masks":masks}
    
    @staticmethod 
    def tensor_to_numpy_for_target(sample_target):
        labels, boxes, masks = __class__.dict_to_tuple_for_target(sample_target)
        labels = __class__.tensor_to_numpy(labels)
        boxes = __class__.tensor_to_numpy(boxes)
        masks = __class__.tensor_to_numpy(masks)
        return __class__.tuple_to_dict_for_target(labels, boxes, masks)
    
    # Result #######################################################
    @staticmethod
    def dict_to_tuple_for_result(sample_result):
        return  sample_result["scores"], sample_result["labels"], sample_result["boxes"], sample_result["masks"]

    @staticmethod
    def tuple_to_dict_for_result(scores, labels, boxes, masks):
        return {"scores":scores, "labels":labels, "boxes":boxes, "masks":masks}

    @staticmethod
    def tensor_to_numpy_for_result(sample_result):
        scores, labels, boxes, masks = __class__.dict_to_tuple_for_result(sample_result)
        scores = __class__.tensor_to_numpy(scores)
        labels = __class__.tensor_to_numpy(labels)
        boxes = __class__.tensor_to_numpy(boxes)
        masks = __class__.tensor_to_numpy(masks)
        return __class__.tuple_to_dict_for_result(scores, labels, boxes, masks)
    
    @staticmethod
    def show_sample_result(sample_result):
        scores, labels, boxes, masks = __class__.dict_to_tuple_for_result(sample_result)
        
        titles = [f"{VOC_LABEL[l-1]}({round(s.item(), 2)})" for l, s in zip(labels, scores)]
        __class__.show_images(masks, titles)

    @staticmethod
    def filter_sample_by_score(sample_result, threshold):
        scores, labels, boxes, masks = __class__.dict_to_tuple_for_result(sample_result)
        indexes =  np.where(scores>threshold)[0]
        return __class__.tuple_to_dict_for_result(scores[indexes], labels[indexes], boxes[indexes], masks[indexes])

    @staticmethod
    def to_binary_by_threshold(data, threshold):
        """
        Numpy mask: binary 데이터로 바뀔 대상 
        threshold: binary 기준 (threshold보다 크면 1 아니면 0)
        """
        return (data > threshold).astype(bool)


    @staticmethod
    def show_images(images, titles):
        n = len(images) + 1
        fig = plt.figure(figsize = (4, 4*n))
        for i, (image, title) in enumerate(zip(images, titles)):
            sub = fig.add_subplot(n, 1, i+1)
            sub.imshow(image)
            sub.set_title(title)

        return fig
    
    # Evaluation
    @staticmethod
    def score_level(score, thresholds):
        for level, threshold in enumerate(thresholds):
            if score <= threshold:
                return level
    
    @staticmethod
    def iou(mask1, mask2):
        intersection = (mask1 * mask2).sum()
        union = (mask1 + mask2).sum() - intersection
        return intersection/union
    
    @staticmethod
    def count_result(p_masks, p_labels, p_scores, gt_masks, gt_labels, class_num, score_thresholds, iou_threshold):
        """
            detection 결과에 대해 다음 샘플의 개수를 계산한다.
            - ground truth object num per class
            - detected object num per class
            - positive num per class and score level

            score level은 detection 결과에 대한 분류 확률을 score threshold로 cut할 때 몇 번째 구간에 들어가는가를 의미 (높을수록 정확)
        """
        score_levels = [__class__.score_level(score, score_thresholds) for score in p_scores]
        
        #########################################################################
        # prediction과 GT의 idx를 class 별로 구분
        # 클래스 별 detected mask indexes
        detected_indexes_per_c = [[] for i in range(class_num)]
        for idx, label in enumerate(p_labels):
            detected_indexes_per_c[label].append(idx)

        gt_indexes_per_c = [[] for i in range(class_num)]
        for idx, label in enumerate(gt_labels):
            gt_indexes_per_c[label].append(idx)
            
        #########################################################################
        # prediction과 GT를 class별로 카운트
        
        # 클래스 별 탐지한 object 개수
        detected_num = np.array([len(x) for x in detected_indexes_per_c])

        # 클래스 별 object 개수
        gt_num = np.array([len(x) for x in gt_indexes_per_c])
        

        #########################################################################
        

        tp_num = np.zeros((class_num, len(score_thresholds)))

        for c in range(class_num):
            p_idxes = detected_indexes_per_c[c]
            gt_idxes = gt_indexes_per_c[c]

            if (not p_idxes) or (not gt_idxes):
                continue
            
            for p_idx in p_idxes:
                if not gt_idxes: # 탐색할게 없으면 끝
                    break

                iou_list = [__class__.iou(p_masks[p_idx], gt_masks[gt_idx]) for gt_idx in gt_idxes]

                max_iou = max(iou_list)
                max_idx = np.array(iou_list).argmax()

                if iou_threshold <= max_iou:
                    tp_num[c][score_levels[p_idx]] += 1
                    gt_idxes.pop(max_idx)
                    

        return {
            "detected_num":detected_num,
            "gt_num":gt_num,
            "tp_num":tp_num
        }
    
    @staticmethod
    def ap(detected_num, gt_num, tp_num):
        if gt_num == 0:
            return 0

        for i in range(len(tp_num)-1, 0, -1):
            tp_num[i-1] += tp_num[i]

        
        recall = tp_num/gt_num
        precision = tp_num/detected_num
        
        # precision을 계단식으로 맞추기
        for i in range(1, len(precision)):
            precision[i] = max(precision[i-1], precision[i])
        
        return (recall * precision).mean()

    @staticmethod
    def ap_per_class(detected_num, gt_num, tp_num):
        ap_list = []
        for c in range(len(detected_num)):
            ap_value = __class__.ap(detected_num[c], gt_num[c], tp_num[c])
            ap_list.append(ap_value)

        return np.array(ap_list).mean()

In [107]:
def collate_batch(batch):
    images = [sample[0] for sample in batch]
    targets = [sample[1] for sample in batch]
    
    return images, targets

class VOC2012_MaskRCNN_InstanceSegmentation_DataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset, batch_size):
        super().__init__(dataset, batch_size = batch_size, collate_fn=collate_batch)


In [89]:

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
    model.train()
    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
        )

    losses_list = []
    for images, targets in tqdm(data_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            loss_dict = model(images, targets)
            loss_list = torch.stack([value for value in loss_dict.values()])
            losses = loss_list @ loss_list**(1/2)
            losses_list.append(losses)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(losses).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            losses.backward()
            optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()
    return torch.mean(torch.stack(losses_list)).to("cpu")