In [1]:
import import_ipynb

import os
import math
import torch
import pickle
import torchvision
import mh_utils as MH

import numpy as np
import bbox_visualizer as bbv
import matplotlib.pyplot as plt
from typing import Dict, List, Any

from PIL import Image
from tqdm import tqdm
from pathlib import Path
from torchvision.datasets.voc import *
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

importing Jupyter notebook from mh_utils.ipynb


In [2]:
class Device:
    # Input data
    @staticmethod
    def to_for_sample_data(sample_data, device):
        image, target = sample_data
        image = image.detach().to(device)
        target = {k:v.detach().to(device) for k, v in target.items()}
        return (image, target)
    
    @staticmethod
    def to_for_batch_data(batch_data, device):
        images, targets = batch_data
        images = [image.detach().to(device) for image in images]
        targets = [{k:v.detach().to(device) for k, v in t.items()} for t in targets]
        return (images, targets)

    # Result
    @staticmethod
    def to_for_sample_result(sample_result, device):
        return {k:v.detach().to(device) for k, v in sample_result.items()}

    @staticmethod
    def to_for_batch_result(batch_result, device):
        return [__class__.to_for_sample_result(x, device) for x in batch_result]

In [3]:
from torchvision.transforms import Compose
VOC_LABEL = [   "background", 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus','car','cat','chair','cow',
                'diningtable', 'dog','horse','motorbike', 'person','pottedplant', 'sheep','sofa','train','tvmonitor']
VOC_LABEL_PAIR = {VOC_LABEL[i]:i for i in range(len(VOC_LABEL))}


WRONG_FILE_NAMES_IN_TRAIN = ['2009_005069']
WRONG_FILES_NAMES_IN_VAL = ['2008_005245', '2009_000455', '2009_004969', '2011_002644', '2011_002863']
WRONG_FILES_NAMES = WRONG_FILE_NAMES_IN_TRAIN+WRONG_FILES_NAMES_IN_VAL
def remove_wrong_annotated_file(file_names):
    for wrong_file_name in WRONG_FILES_NAMES:
        if wrong_file_name in file_names:
            file_names.remove(wrong_file_name)
    return file_names


class VOC2012_MaskRCNN_InstanceSegmentation_Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root: str,
        image_set: str = "train",
        image_transform: Optional[Callable] = Compose([torchvision.transforms.ToTensor()]),
        label_transform: Optional[Callable] = None,
        box_transform: Optional[Callable] = None,
        mask_transform: Optional[Callable] = None,
        cropping = False,
        mask_expending = False,
        flip = False,
        jitter = False


    ):
        root = Path(root)
        self.image_transform = image_transform
        self.label_transform = label_transform
        self.box_transform = box_transform
        self.mask_transform = mask_transform
        self.cropping = cropping
        self.mask_expending = mask_expending
        self.flip = flip
        self.jitter = jitter


        self.image_set = verify_str_arg(image_set, "image_set", ["train", "trainval", "val"])

        split_f = root/"ImageSets"/"Segmentation"/f"{image_set}.txt"
        with open(os.path.join(split_f)) as f:
            file_names = [x.strip() for x in f.readlines()]
            file_names = remove_wrong_annotated_file(file_names)
            
        # dir setting
        image_dir = root/"JPEGImages"
        annotation_dir = root/"Annotations"
        mask_dir = root/"SegmentationObject"

        # file name list setting
        self.images = [image_dir/f"{x}.jpg" for x in file_names]
        self.annotations = [annotation_dir/f"{x}.xml" for x in file_names]
        self.masks = [mask_dir/f"{x}.png" for x in file_names]
        
        assert len(self.images) == len(self.annotations) == len(self.masks)
        
    def __len__(self) -> int:
        return len(self.images)
    
    def __getitem__(self, index:int):
        image = self.get_image(index)
        labels, boxes = self.get_annotation(index)
        masks = self.get_masks(index)
        
        labels = torch.tensor(labels)
        boxes = torch.tensor(boxes)
        masks = torch.tensor(masks)

        # random_cropping ###########################################
        if self.cropping:
            c, h, w = image.shape
            image, masks = MH.random_cropping(image, masks, size=(min(200, h), min(200, w)))
            # mask에 아직 객체의 mask가 남아있는가?
            is_instance = masks.reshape(masks.shape[0], -1).any(1) == True
            # indexes에 해당하는 것만 뽑아서 mask list로 형태 복원
            boxes = torch.stack([MH.extract_box_from_binary_mask(mask) for mask in masks])

            # boxes = torch.zeros(boxes.shape)
            # boxes = torch.tensor([[0, 0, 1, 1] for i in range(boxes.shape[0])])
            labels = labels * is_instance.type(torch.uint8)
        ###########################################

        # expend mask ###########################################
        # 각 마스크를 상하좌우 1픽셀식 확장
        if self.mask_expending:
            masks = [torch.tensor(MH.expend_mask(mask.numpy(), 2)) for mask in masks]
            masks = torch.stack(masks)
        ###########################################

        # Flip #########################################

        if self.flip:
            if np.random.rand() < 0.5:
                image = torchvision.transforms.RandomHorizontalFlip(p=1)(image)
                masks = torchvision.transforms.RandomHorizontalFlip(p=1)(masks)
            
            if np.random.rand() < 0.5:
                image = torchvision.transforms.RandomVerticalFlip(p=1)(image)
                masks = torchvision.transforms.RandomVerticalFlip(p=1)(masks)
            
        if self.jitter:
            image = torchvision.transforms.ColorJitter(brightness=(0.3, 3), contrast=(0.45, 0.55), saturation=(0.45, 0.55), hue=(0.45, 0.45))(image)
        target = {
            "labels": labels,
            "boxes" : boxes,
            "masks" : masks
        }
        
        return (image, target)

    
    def get_image(self, index):
        image = Image.open(self.images[index]).convert("RGB")
        if self.image_transform is not None:
            image = self.image_transform(image)
        return image
    
    def get_masks(self, index):
        mask = Image.open(self.masks[index])
        mask = np.array(mask)

        id_list = np.unique(mask)
        id_list = np.delete(id_list, np.where((id_list == 0) | (id_list == 255)))

        masks = (mask[None, :] == id_list.reshape(-1, 1, 1)).astype(np.uint8)
    
        if self.mask_transform is not None:
            masks = self.image_transform(masks)

        return masks
    
    def get_annotation(self, index):# labels, boxes
        annotation = VOCDetection.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
        objects = annotation["annotation"]["object"]
        
        labels = np.array([VOC_LABEL_PAIR[o["name"]] for o in objects], dtype = np.int64)
        boxes = [o["bndbox"] for o in objects]
        boxes = np.array([[box["xmin"], box["ymin"], box["xmax"], box["ymax"]] for box in boxes], dtype = np.float32)

        if self.label_transform is not None:
            labels = self.label_transform(labels)
        
        if self.box_transform is not None:
            boxes = self.box_transform(boxes)
        
        return (labels, boxes)

In [4]:
class VOC2012_MaskRCNN_InstanceSegmentation_DataLoader(torch.utils.data.DataLoader):
    @staticmethod
    def collate_batch(batch):
        images = [sample[0] for sample in batch]
        targets = [sample[1] for sample in batch]
        
        return images, targets

    def __init__(self, dataset, batch_size):
        super().__init__(dataset, batch_size = batch_size, collate_fn=__class__.collate_batch)


In [25]:
class VOC2012_MaskRCNN_InstanceSegmentation:
    CLASS_NUM = 21
    SCORE_THRESHOLDS = [0.1*i for i in range(10, 0, -1)]

    @staticmethod
    def get_model(pre_trained_model_path = ""):
        model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(pretrained=True)
        model.roi_heads.box_predictor.cls_score = torch.nn.Linear(in_features=1024, out_features=21, bias=True)
        model.roi_heads.box_predictor.bbox_pred = torch.nn.Linear(in_features=1024, out_features=84, bias=True)
        model.roi_heads.mask_predictor.mask_fcn_logits = torch.nn.Conv2d(256, 21, kernel_size=(1, 1), stride=(1, 1))
        if pre_trained_model_path:
            model.load_state_dict(torch.load(pre_trained_model_path))
        return model

    @staticmethod
    def show_comparison_for_one_sample(model, sample, device, iou_threshold = 0.5, mask_binary_threshold=0.5):
        """
        [Operator]
            단일 데이터(sample)에 대해 model을 통과 시키고 검출 결과와 정답 결과를 대응시킨뒤 이를 이미지 리스트 형식으로 출력한다.
            이때 검출 여부로 iou_threshold를 고려하며
            마스크를 바이너리로 변환 시  mask_binary_threshol를 고려한다.
        """
        sample_result = __class__.eval_sample(model, sample, device)
        sample_result = Device.to_for_sample_result(sample_result, "cpu")
        sample_result = Converter.tensor_to_numpy_for_result(sample_result)
        sample_result["masks"] = MH.to_binary_by_threshold(sample_result["masks"], threshold = mask_binary_threshold)
        sample = Converter.tensor_to_numpy_for_sample_data(sample)
        candidates, targets, iou_list = Evaluator.compare_result_and_gt(sample_result, sample, 21, iou_threshold)

        Visualizer.show_comparison_result(candidates, targets, iou_list)

    @staticmethod
    def eval_sample(model, sample, device = "cpu"):
        model.eval()
        model.to(device)
        image, target = Device.to_for_sample_data(sample, device)
        result = model([image], [target])[0]
        return  Converter.squeeze_dimention_for_sample_result(result)

    @staticmethod
    def loss_eval(model, device, data_loader):
        model.train()
        model.to(device)
        loss_list = []
        for sample in tqdm(data_loader):
            images, targets = Device.to_for_batch_data(sample, device)
            loss_dict = model(images, targets)
            loss_vector = torch.stack([value for value in loss_dict.values()])
            loss = loss_vector @ loss_vector**(1/2)
            loss_list.append(loss.detach().to("cpu"))
            
        return np.array(loss_list).mean()
        
    def map_eval(model, device, data_loader, class_num, mask_threshold, score_threshold):
        model.eval()
        model.to(device)

        counts = []
        for images, targets in tqdm(data_loader):
            images, targets = Device.to_for_batch_data((images, targets), device)
            results = model(images, targets)

            # 클래스 별로 GT, TP, Detection 개수 카운트하여 counts에 추가
            for result, target in zip(results, targets):

                result = Converter.tensor_to_numpy_for_result(result)
                target = Converter.tensor_to_numpy_for_target(target)

                p_scores, p_labels, p_boxes, p_masks = Converter.dict_to_tuple_for_result(result)
                gt_labels, gt_boxes, gt_masks = Converter.dict_to_tuple_for_target(target)

                p_masks = MH.to_binary_by_threshold(p_masks, threshold = mask_threshold).squeeze(1)


                count_result = Evaluator.count_result(p_masks, p_labels, p_scores, gt_masks, gt_labels, 
                                            class_num = class_num, 
                                            score_thresholds = score_threshold,
                                            iou_threshold = 0.3)
                counts.append(count_result)
        
        # 카운트 결과 종합
        detected_num_list = [count["detected_num"] for count in counts]
        gt_num_list = [count["gt_num"] for count in counts]
        tp_num_list = [count["tp_num"] for count in counts]

        detected_num = np.stack(detected_num_list).sum(0)
        gt_num = np.stack(gt_num_list).sum(0)
        tp_num = np.stack(tp_num_list).sum(0)

        # MAP 계산 후 반환
        return Evaluator.ap_per_class(detected_num, gt_num, tp_num)

    @staticmethod
    def filter_sample_by_score(sample_result, threshold):
        scores, labels, boxes, masks = __class__.dict_to_tuple_for_result(sample_result)
        indexes =  np.where(scores>threshold)[0]
        return __class__.tuple_to_dict_for_result(scores[indexes], labels[indexes], boxes[indexes], masks[indexes])


    


In [6]:
class Converter:

    # dimention
    @staticmethod
    def squeeze_dimention_for_sample_result(sample_result):
        sample_result["masks"] = torch.squeeze(sample_result["masks"], 1)
        return sample_result
    
    def squeeze_dimention_for_batch_result(batch_result):
        return [__class__.squeeze_dimention_for_batch_result(x) for x in batch_result]

    # tensor numpy
    @staticmethod
    def tensor_to_numpy(data):
        return data.to("cpu").detach().numpy()
    
    @staticmethod
    def tensor_to_numpy_for_image(tensor_image):
        return __class__.tensor_to_numpy(tensor_image).transpose((1, 2, 0))

    @staticmethod 
    def tensor_to_numpy_for_target(sample_target):
        labels, boxes, masks = __class__.dict_to_tuple_for_target(sample_target)
        labels = __class__.tensor_to_numpy(labels)
        boxes = __class__.tensor_to_numpy(boxes)
        masks = __class__.tensor_to_numpy(masks)
        return __class__.tuple_to_dict_for_target(labels, boxes, masks)
    
    @staticmethod
    def tensor_to_numpy_for_sample_data(sample_data):
        image, target = sample_data
        return (__class__.tensor_to_numpy_for_image(image), __class__.tensor_to_numpy_for_target(target))
    
    @staticmethod
    def tensor_to_numpy_for_result(sample_result):
        scores, labels, boxes, masks = __class__.dict_to_tuple_for_result(sample_result)
        scores = __class__.tensor_to_numpy(scores)
        labels = __class__.tensor_to_numpy(labels)
        boxes = __class__.tensor_to_numpy(boxes)
        masks = __class__.tensor_to_numpy(masks)
        return __class__.tuple_to_dict_for_result(scores, labels, boxes, masks)
    
    # dict - tuple
    @staticmethod
    def dict_to_tuple_for_target(sample_target):
        return sample_target["labels"], sample_target["boxes"], sample_target["masks"]
    
    @staticmethod
    def dict_to_tuple_for_result(sample_result):
        return  sample_result["scores"], sample_result["labels"], sample_result["boxes"], sample_result["masks"]

    @staticmethod
    def tuple_to_dict_for_result(scores, labels, boxes, masks):
        return {"scores":scores, "labels":labels, "boxes":boxes, "masks":masks}
    
    @staticmethod
    def tuple_to_dict_for_target(labels, boxes, masks):
        return {"labels":labels, "boxes":boxes, "masks":masks}


In [7]:
class Evaluator:
    @staticmethod
    def score_level(score, thresholds):
        f"""
        [Args]
            * score: (__float__): 구간을 확인할 score
            * thresholds: (__Numpy(N, dtype = float32)__): 객체 검출기의 신뢰도에 대한 임계치 리스트 (내림차순)
        [Return]
            * result: (__int__): 몇 번째 임계치 보다 큰지 반환
        """
        for level, threshold in enumerate(thresholds):
            if score >= threshold:
                return level
    
    @staticmethod
    def count_result(p_masks, p_labels, p_scores, gt_masks, gt_labels, class_num, score_thresholds, iou_threshold):
        """
        [Operation]
            객체 검출 결과에 대해 클래스 별 detection num, ground truth num, true positive num 개수를 계산한다.

        [Return]
            * result: (__Dict[str, Any]__):
                {
                    "detected_num": (__Numpy(C, dtype = int)__): 클래스 별로 모델이 검출한 객체 개수
                    "gt_num": (__Numpy(C, dtype = int)__): 클래스 별 ground truth 객체 개수
                    "tp_num": (__Numpy(C, S, dtype = int)__): 클래스 및 스코어 구간 별 true positive 개수 (스코어 = 내림차순)
                }

        """
        score_levels = [__class__.score_level(score, score_thresholds) for score in p_scores]
        
        #########################################################################
        # prediction과 GT의 idx를 class 별로 구분
        # 클래스 별 detected mask indexes
        detected_indexes_per_c = [[] for i in range(class_num)]
        for idx, label in enumerate(p_labels):
            detected_indexes_per_c[label].append(idx)

        gt_indexes_per_c = [[] for i in range(class_num)]
        for idx, label in enumerate(gt_labels):
            gt_indexes_per_c[label].append(idx)
            
        #########################################################################
        # prediction과 GT를 class별로 카운트
        
        # 클래스 별 탐지한 object 개수
        detected_num = np.array([len(x) for x in detected_indexes_per_c])

        # 클래스 별 object 개수
        gt_num = np.array([len(x) for x in gt_indexes_per_c])
        

        #########################################################################
        

        tp_num = np.zeros((class_num, len(score_thresholds)))

        for c in range(class_num):
            p_idxes = detected_indexes_per_c[c]
            gt_idxes = gt_indexes_per_c[c]

            if (not p_idxes) or (not gt_idxes):
                continue
            
            for p_idx in p_idxes:
                if not gt_idxes: # 탐색할게 없으면 끝
                    break

                iou_list = [MH.iou(p_masks[p_idx], gt_masks[gt_idx]) for gt_idx in gt_idxes]

                max_iou = max(iou_list)
                max_idx = np.array(iou_list).argmax()

                if iou_threshold <= max_iou:
                    tp_num[c][score_levels[p_idx]] += 1
                    gt_idxes.pop(max_idx)
                    
        return {"detected_num":detected_num, "gt_num":gt_num, "tp_num":tp_num}
    
    @staticmethod
    def ap(detected_num, gt_num, tp_num):
        """
        [Operation]
            * 단일 클래스에 대해 모델이 검출한 객체 개수, 전체 객체 개수, 신뢰도 별 positive 개수가 주어졌을 때
            * average precision을 계산하여 반환한다.
        [Args]
            * detected_num: (__int__): 전체 데이터 셋에서 클래스에 대해 모델이 제안한 객체 개수
            * gt_num: (__int__): 전체 데이터 셋에서 단일 클래스에 대한 ground truth 객체 개수
            * tp_num: (__Numpy(S)__): score 구간 1, 0.9, ...., 0 따른 True Positive 개수
        [Return]
            * result: (__Numpy(1)__) average precision
        """

        # score 구간 별 positive 개수를 score 이상의 positive 개수로 수정 -> 누적
        for i in range(0, len(tp_num)-1):
            tp_num[i+1] += tp_num[i]

        
        # score 구간 별 recall 계산 (오른쪽으로 갈 수록 커진다.)
        recall = np.ones(tp_num.shape) if gt_num == 0 else (tp_num/gt_num)

        # recall 구간 길이 계산
        for i in range(1, len(recall)):
            recall[i] -= recall[i-1]

        precision = np.ones(tp_num.shape) if detected_num == tp_num else np.zeros(tp_num.shape) if detected_num == 0 else (tp_num/detected_num) 
        
        # precision을 계단식으로 맞추기
        # recall이 더 높은 영역에서 최대 precision 사용
        for i in range(0, len(precision)):
            precision[i] = max(precision[i:])
        
        # 너비 게산
        return (recall * precision).mean()

    @staticmethod
    def ap_per_class(detected_num, gt_num, tp_num):
        """
        [Operation]
        * 클래스별 
        [Args]
            * detected_num: (__Numpy(C, dtype = uint8)__):전체 데이터에 대한 class 별 모델의 검출 객체 개수              # C = class num
            * gt_num: (__Numpy(C, dtype = uint8)__):전체 데이터에 대한 class 별 ground-truth 객체 개수 
            * tp_num: (__Numpy(C, S, dtype = uint8)__):전체 데이터에 대한 class 및 score level 별 true positive 개수    # S = Score level num
                - score level 내림차순 (첫 인덱스가 가장 높은 정확도)
        [Result]
            * result: (__Numpy(C, dtype = float32__): 클래스 별 ap
        """

        ap_list = []
        for c in range(len(detected_num)):
            ap_value = __class__.ap(detected_num[c], gt_num[c], tp_num[c])
            ap_list.append(ap_value)

        return np.array(ap_list).mean()
    
    @staticmethod
    def compare_result_and_gt(result, sample, class_num, iou_threshold):
        """
            [Operation]
                iou_threshold를 기준으로 각각의 모델의 객체 검출 결과와 정답을 대응 정보를 반환한다.

            [Args]
                * result: (__Dict[str, Any]__): MaskRCNN의 단일 출력 데이터
                    {
                            "scores": (__Numpy(M, dtype = float32)__),          # M = detected instance num (candidate num)
                            "labels": (__Numpy(M, dtype = uint8)__),
                            "boxes": (__Numpy(M, 4, dtype = float32)__),
                            "masks": (__Numpy(M, H, W, dtype = float32)__),
                    }
                * sample: (__Tuple[Image, Target]__): MaskRCNN의 단일 입력 데이터
                    - Image: (__Numpy(N, 3, H, W)__)                            # M = target instance num
                    - Target: (__Dict[str, Any]__)
                        {
                            "labels": (__Numpy(N, dtype = uint8)__),
                            "boxes": (__Numpy(N, 4, dtype = float32)__),
                            "masks": (__Numpy(N, H, W, dtype = float32)__),
                        } 


            [Result]
                * result: (__Tuple[Candidates, Targetsm, Iou_List]__): score 순으로 정렬된 candidates
                    - Candidates: (__Dict[str, Any]__)
                        {
                            "scores": (__Numpy(M, dtype = float32)__),          # M = detected instance num (candidate num)
                            "labels": (__Numpy(M, dtype = uint8)__),
                            "boxes": (__Numpy(M, 4, dtype = float32)__),
                            "masks": (__Numpy(M, H, W, dtype = float32)__),
                        }
                    - Targets:  (__Dict[str, Any]__) : candidate에 대응되는 target
                        {
                            "labels": (__Numpy(M, dtype = uint8)__),
                            "boxes": (__Numpy(M, 4, dtype = float32)__),
                            "masks": (__Numpy(M, H, W, dtype = float32)__),
                        }
                    - Iou_List:  (Numpy(N, dtype = float32)__) : candidate와 대응되는 target 간의 iou 리스트
        """
        image, target = sample
        
        # 탐색 결과와 정답 결과를 쌍 짓기 (누가 무엇을 탐지하는지)
        pairs = MH.pair_up_instances(result, target, class_num, iou_threshold)
        c_indexes = [pair[0] for pair in pairs] # detection 결과 idx 리스트
        t_indexes = [pair[1] for pair in pairs] # 대응되는 target(gt)의 idx 리스트 (없으면 -1)

        c_scores, c_labels, c_boxes, c_masks = Converter.dict_to_tuple_for_result(result)
        t_labels, t_boxes, t_masks = Converter.dict_to_tuple_for_target(target)

        # target = -1 인 경우에 대한 값 추가
        t_labels = np.append(t_labels, [0], axis = 0) # 배경 추가
        t_boxes = np.append(t_boxes, [[-1, -1, -1, -1]], axis = 0)
        t_masks = np.append(t_masks, [np.zeros(t_masks[0].shape)], axis = 0)

        # candidate 정보를 candidate indexes에 맞게 순서 정리
        c_scores, c_labels, c_boxes, c_masks = c_scores[c_indexes], c_labels[c_indexes], c_boxes[c_indexes], c_masks[c_indexes]

        # gt 정보를 candidate의 대응되는 idx에 따라 정리
        t_labels, t_boxes, t_masks = t_labels[t_indexes], t_boxes[t_indexes], t_masks[t_indexes]

        candidates = {
            "scores" : c_scores,
            "labels" : c_labels,
            "boxes" : c_boxes,
            "masks" : c_masks,
        }

        targets = {
            "labels" : t_labels,
            "boxes" : t_boxes,
            "masks" : t_masks
        }

        iou_list = np.array([MH.iou(c, t) for c, t in zip(c_masks, t_masks)])

        return (candidates, targets, iou_list)

False

In [8]:
class Visualizer:
    def show_comparison_result(candidates, targets, iou_list):
        """
            검출 결과 (candidate)와 그에 대응되는 정답 객체(target)를 나란히 출력한다.
        """
        c_scores, c_labels, c_boxes, c_masks = candidates["scores"], candidates["labels"], candidates["boxes"], candidates["masks"]
        t_labels, t_boxes, t_masks = targets["labels"], targets["boxes"], targets["masks"]
        
        n = len(c_scores)
        fig = plt.figure(figsize = (6, 3*n))


        for i, (c_score, c_label, c_box, c_mask, t_label, t_box, t_mask, iou) in enumerate(
            zip(c_scores, c_labels, c_boxes, c_masks, t_labels, t_boxes, t_masks, iou_list)
        ):
            # prediction instance
            sub = fig.add_subplot(n, 2, 2*i+1)
            sub.imshow(c_mask)
            sub.set_title("Prediction(%s, %.3f, %.3f)"%(VOC_LABEL[c_label], c_score, iou), fontdict={"fontsize":8})

            # Corresponding instance
            sub = fig.add_subplot(n, 2, 2*i+2)
            sub.imshow(t_mask)
            sub.set_title(f"GT({VOC_LABEL[t_label]})", fontdict={"fontsize":8})

In [9]:
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
    model.train()
    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
        )

    losses_list = []
    for images, targets in tqdm(data_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            loss_dict = model(images, targets)
            loss_list = torch.stack([value for value in loss_dict.values()])
            losses = loss_list @ loss_list**(1/2)
            losses_list.append(losses)
        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(losses).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            losses.backward()
            optimizer.step()
        if lr_scheduler is not None:
            lr_scheduler.step()
    return torch.mean(torch.stack(losses_list)).to("cpu")

In [None]:
def train(model, data_loader, optimizer, scheduler, device):
    model.train()
    model.to(device)
    losses_list = []
    for images, targets in tqdm(data_loader):
        # to device
        images, targets = Device.to_for_batch_data((images, targets), device)

        # model
        loss_dict = model(images, targets)

        # loss
        loss_list = torch.stack([value for value in loss_dict.values()])
        losses = loss_list @ loss_list**(1/2)
        losses_list.append(losses)

        # back propergation
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    
    scheduler.step()

    return torch.mean(torch.stack(losses_list)).to("cpu")