In [141]:
import os
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from collections import defaultdict
from ipywidgets import interact


import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms, models
from torchvision.ops import nms

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# 시작 파라미터 설정

In [112]:
data_dir = '/home/hts/A_project/hts_pytorch/data/DRIVING-DATASET/Detection/'
train_data_dir = os.path.join(data_dir, 'train/')
val_data_dir = os.path.join(data_dir, 'val/')

csv_data = pd.read_csv(data_dir + 'df.csv')
csv_data = csv_data.drop(columns=['Source', 'Confidence','IsOccluded','IsTruncated','IsGroupOf','IsDepiction','IsInside','XClick1X','XClick2X','XClick3X','XClick4X','XClick1Y','XClick2Y','XClick3Y','XClick4Y'])
csv_data = csv_data[['ImageID', 'LabelName', 'XMin', 'YMin', 'XMax', 'YMax']]
train_data_list = os.listdir(train_data_dir)
val_data_list = os.listdir(val_data_dir)

Class_Name_To_Int = {'Bus':0, 'Truck':1}
Int_To_Class_Name = {0:'Bus', 1:'Truck'}
NUM_CLASSES = 2
VERBOSE_FREQ = 200

# 커스텀데이터셋 설정

In [130]:
class car_data_set():
    def __init__(self, data_dir, phase, csv_data, transformer = None):
        self.csv_data = csv_data
        self.phase_data_dir = (data_dir + phase + '/')
        self.data_list = os.listdir(self.phase_data_dir)
        self.transformer = transformer

    def __len__(self):
        return len(self.data_list)
    
    def get_label_def(self, image_name, img_H, img_W):
        label = self.csv_data.loc[(self.csv_data['ImageID'] == image_name.split(".")[0])]
        target_name = [Class_Name_To_Int[i] for i in label['LabelName'].values]
        bounding_box = label.drop(columns = ['ImageID', 'LabelName']).values
        bounding_box[:, [0,2]] *= img_W
        bounding_box[:, [1,3]] *= img_H

        return target_name, bounding_box
    

    def __getitem__(self, index):
        image_name = self.data_list[index]
        
        
        image = cv2.imread(self.phase_data_dir + image_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        img_H, img_W,_ = image.shape

        if self.transformer:
            image = self.transformer(image)
            _, img_H, img_W = image.shape


        target_name, bounding_box = self.get_label_def(image_name,img_H, img_W)

        target = {}
        target['boxes'] = torch.Tensor(bounding_box).float()
        target['labels'] = torch.Tensor(target_name).long()
        
        
        return image, target

In [131]:
def collate_fn(batch):
    image_list = []
    target_list = []

    for a,b in batch:
        image_list.append(a)
        target_list.append(b)
    
    return image_list, target_list

In [132]:
def build_dataloader(data_dir,csv_data):
    
    dataloaders = {}
    train_data_set = car_data_set(data_dir=data_dir, phase='train', csv_data=csv_data, transformer=transformer)
    val_data_set = car_data_set(data_dir=data_dir, phase='val', csv_data=csv_data, transformer=transformer)

    dataloaders['train'] = DataLoader(train_data_set, shuffle=True, collate_fn=collate_fn)
    dataloaders['val'] = DataLoader(val_data_set, shuffle=False, collate_fn=collate_fn)
    return dataloaders

# 이미지 데이터 확인 (transformer사용 X)

In [117]:
temp_set = car_data_set(data_dir=data_dir, phase='train', csv_data=csv_data, transformer=None)

#@interact(index=(0,len(train_data_set)-1))
def show_sample(index=0):
    images, target = temp_set[index]
    boundingbox = target['boxes']
    targetname = target['labels']

    fig, ax = plt.subplots()
    ax.imshow(images)

    for bbox,target in zip(boundingbox, targetname):
        target = Int_To_Class_Name[target.item()]
        xmin = bbox[0] 
        ymin = bbox[1]
        ax.text(xmin, ymin, target,color='red' )
        ax.add_patch(patches.Rectangle(
            (xmin, ymin), bbox[2]-bbox[0], bbox[3]-bbox[1],
            fill=False,
            edgecolor = 'red',
            linewidth=1))
    plt.show()


# 이미지 데이터 확인 (transformer사용 O)

In [129]:
transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(size=(448, 448)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

temp_set = car_data_set(data_dir=data_dir, phase='train', csv_data=csv_data, transformer=transformer)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def unnormalize(image, mean, std):
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    image = image * std + mean
    return image

#@interact(index=(0,len(train_data_set)-1))
def show_sample(index=0):
    images, target = temp_set[index]
    boundingbox = target['boxes']
    targetname = target['labels']

    fig, ax = plt.subplots()

    images = unnormalize(image=images, mean=mean, std=std)
    images = images.permute(1,2,0).numpy()
    ax.imshow(images)

    for bbox,target in zip(boundingbox, targetname):
        target = Int_To_Class_Name[target.item()]
        xmin = bbox[0] 
        ymin = bbox[1]
        ax.text(xmin, ymin, target,color='red' )
        ax.add_patch(patches.Rectangle(
            (xmin, ymin), bbox[2]-bbox[0], bbox[3]-bbox[1],
            fill=False,
            edgecolor = 'red',
            linewidth=1))
    plt.show()


# 모델 생성

In [133]:
def build_model():
    model = models.detection.fasterrcnn_resnet50_fpn(pretrain=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
    return model

# one_epoch설정

In [134]:
def train_one_epoch(dataloaders, model, optimizer, device):
    train_loss = defaultdict(float)
    val_loss = defaultdict(float)

    model.train()

    for phase in ['train', 'val']:
        for index, batch in enumerate(dataloaders[phase]):
            images = batch[0]
            targets = batch[1] 

            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.item()} for t in targets]

            with torch.set_grad_enabled(phase == 'train'):
                loss = model(images, targets)
            total_loss = sum(each_loss for each_loss in loss.values())

            if phase == 'train':
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                if(index>0) and (index % 100) == 0:
                    text = f"{index}/{len(dataloaders[phase])} - "
                    for k,v in loss.items():
                        text += f"{k}: {v.item():.4f}"
                    print(text)
                
                for k, v in loss.items():
                    train_loss[k] += v.item()
                train_loss['total_loss'] += total_loss.item()
            
            else:
                for k, v in loss.items():
                    val_loss[k] += v.item()
                val_loss['total_loss'] +=total_loss.item()
    for k in train_loss.keys():
        train_loss[k] /= len(dataloaders['train'])
        val_loss[k] /= len(dataloaders['val'])
    return train_loss, val_loss

# 학습 파라미터와 학습

In [136]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dataloaders = build_dataloader(data_dir=data_dir, csv_data=csv_data)
model = build_model().to(device=device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/hts/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:09<00:00, 10.8MB/s]


In [137]:
num_epoch=30

train_losses = []
val_losses = []

for epoch in range(num_epoch):
    train_loss, val_loss = train_one_epoch(dataloaders=dataloaders, model=model, optimizer=optimizer, device=device)
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    print(f"epoch:{epoch+1}/{num_epoch} - Train Loss: {train_loss['total_loss']:.4f}, Val Loss: {val_loss['total_loss']:.4f}")
    
    if (epoch+1) % 10 == 0:
        os.makedirs('./trained_model/', exist_ok=True)
        torch.save(model, os.path.join('./trained_model/','bestmodel.pt'),_use_new_zipfile_serialization=False)




ValueError: not enough values to unpack (expected 3, got 2)

# 학습된 모델 불러오기

In [None]:
test_model = build_model()
test_model.load_state_dict(torch.load('./trained_model/bestmodel.pt'))
test_model.eval().to(device=device)

# 겹치는거 확인해주는 구간

In [147]:
def postprocess(prediction, conf_thres=0.2, iou_threshold = 0.1):
    pred_box = prediction['boxes'].cpu().detach().numpy()
    pred_label = prediction['labels'].cpu().detach().numpy()
    pred_conf = prediction['scores'].cpu().detach().numpy()

    valid_index = pred_conf>conf_thres
    pred_box = pred_box[valid_index]
    pred_label = pred_label[valid_index]
    pred_conf = pred_conf[valid_index]

    valid_index = nms(torch.tensor(pred_box.astype(np.float32)), torch.tensor(pred_conf), iou_threshold=iou_threshold)
    pred_box = pred_box[valid_index.numpy()]
    pred_conf = pred_conf[valid_index.numpy()]
    pred_label = pred_label[valid_index.numpy()]
    return np.concatenate((pred_box, pred_conf[:, np.newaxis], pred_label[:, np.newaxis]), axis=1)

In [138]:

for index, batch in enumerate(dataloaders['val']):
    images = batch[0]
    targets = batch[1]

    images = list(image.to(device) for image in images)
    targets = [{k: v.to(device) for k,v in t.items()} for t in targets]

    with torch.no_grad():
        prediction = model(images)

    image = images[0]
    prediction = postprocess(prediction[0])
    prediction[:, 2].clip(min=0, max = image.shape[1])
    prediction[:, 3].clip(min=0, max = image.shape[0])

    plt.imshow(image)
    print(prediction)


    


    if index == 1:
        break

In [None]:
def show_sample(index=0):
    images, target = temp_set[index]
    boundingbox = target['boxes']
    targetname = target['labels']

    fig, ax = plt.subplots()

    images = unnormalize(image=images, mean=mean, std=std)
    images = images.permute(1,2,0).numpy()
    ax.imshow(images)

    for bbox,target in zip(boundingbox, targetname):
        target = Int_To_Class_Name[target.item()]
        xmin = bbox[0] 
        ymin = bbox[1]
        ax.text(xmin, ymin, target,color='red' )
        ax.add_patch(patches.Rectangle(
            (xmin, ymin), bbox[2]-bbox[0], bbox[3]-bbox[1],
            fill=False,
            edgecolor = 'red',
            linewidth=1))
    plt.show()

# NMS적용