<span style="font-size:150%">라이브러리</span>

In [None]:
import os
import sys
# path setting
lecture_root = os.path.dirname(os.getcwd())
print(lecture_root)
sys.path.append(lecture_root)

import torch
import torchvision
import torchvision.models.detection as detection
from torchvision.models import ResNet50_Weights
from torch.utils.data import DataLoader
import utils.coco.transforms as T
from utils.coco.engine import train_one_epoch
from utils.coco.cal_utils import * 

import matplotlib.pyplot as plt
import cv2
%matplotlib inline

<span style="font-size:150%">경로 및 파라미터 설정</span>

In [None]:
lecture_root = os.path.dirname(os.getcwd())
data_path = os.path.join(lecture_root, 'data', 'PennFudanPed')
ckp_path = os.path.join(lecture_root, 'checkpoints', 'detection')
os.makedirs(ckp_path, exist_ok=True)
num_epoch = 10

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

<span style="font-size:150%">어그멘테이션 설정</span>

In [None]:
def get_transform(is_train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    if is_train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

<span style="font-size:150%">데이터셋</span>

In [None]:
import os
import numpy as np
import torch
from PIL import Image


class PennFudanDataset(torch.utils.data.Dataset):
    """Penn-Fudan Database for Pedestrian Detection and Segmentation
    Download the dataset from https://www.cis.upenn.edu/~jshi/ped_html/
    """
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

    def __getitem__(self, idx):
        # load images and masks
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path)
        # convert the PIL Image into a numpy array
        mask = np.array(mask)
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.nonzero(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [None]:
trainset = PennFudanDataset(data_path, get_transform(is_train=True))
testset = PennFudanDataset(data_path, get_transform(is_train=False))

indices = [i for i in range(len(trainset))]
dataset = torch.utils.data.Subset(trainset, indices[:-1])
dataset_test = torch.utils.data.Subset(testset, indices[-1:])

trainLoader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
testLoader = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)

<span style="font-size:150%">모델 선언</span>

In [None]:
backbone = detection.backbone_utils.resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.IMAGENET1K_V2)

anchor_generator = detection.rpn.AnchorGenerator(sizes=((32,), (64,), (128,), (256,), (512,),),
                                                 aspect_ratios=((0.5, 1.0, 2.0),)*5)

roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'],
                                                output_size=7,
                                                sampling_ratio=2)

model = detection.FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0005)

<span style="font-size:150%">모델 학습</span>

In [None]:
for epoch in range(num_epoch):
    train_one_epoch(model, optimizer, trainLoader, device, epoch, print_freq=10)
    torch.save(model.state_dict(), os.path.join(ckp_path, f'detector_{epoch}.pth'))

<span style="font-size:150%">모델 출력 시각화</span>

In [None]:
imgs, targets= next(iter(testLoader))
img = imgs[0]
sample = img.permute(1,2,0).cpu().numpy()
target = targets[0]
boxes = target['boxes'].cpu().numpy().astype(int)
print(boxes)

In [None]:
model.eval()
device = torch.device('cpu')
model = model.to(device)
outputs = model(img.unsqueeze(0))
outputs = [{k: v.to(device) for k, v in t.items()} for t in outputs]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(16, 8))

mean_score = torch.mean(outputs[0]['scores'])

for box, score in zip(outputs[0]['boxes'].int(), outputs[0]['scores']):
    print(box, score)
    if score > 0.5:
        cv2.rectangle(sample,(box[0].item(), box[1].item()),(box[2].item(), box[3].item()),(225, 0, 0), 3)
        
for box in zip(targets[0]['boxes'].int()):
    box = box[0]
    cv2.rectangle(sample,(box[0].item(), box[1].item()),(box[2].item(), box[3].item()),(0, 0, 255), 3)
    
ax.set_axis_off()
ax.imshow(sample)