In [1]:
import numpy as np
import os 
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torch import nn
import torchsummary
from torch.utils.data import DataLoader
from collections import defaultdict
from torchvision.utils import make_grid

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Utils

In [3]:
CLASS_NAME_TO_ID = {'Unformed': 0, 'Burr': 1}
CLASS_ID_TO_NAME = {0: 'Unformed', 1: 'Burr'}
BOX_COLOR = {'Unformed':(200, 0, 0), 'Burr':(0, 0, 200)}
TEXT_COLOR = (255, 255, 255)

def save_model(model_state, model_name, save_dir="./trained_model"):
    os.makedirs(save_dir, exist_ok=True)
    torch.save(model_state, os.path.join(save_dir, model_name))


def visualize_bbox(image, bbox, class_name, color=BOX_COLOR, thickness=2):
    x_center, y_center, w, h = bbox
    x_min = int(x_center - w/2)
    y_min = int(y_center - h/2)
    x_max = int(x_center + w/2)
    y_max = int(y_center + h/2)
    
    cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=color[class_name], thickness=thickness)
    
    ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)    
    cv2.rectangle(image, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), color[class_name], -1)
    cv2.putText(
        image,
        text=class_name,
        org=(x_min, y_min - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35, 
        color=TEXT_COLOR, 
        lineType=cv2.LINE_AA,
    )
    return image


def visualize(image, bboxes, category_ids):
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
#         print('category_id: ',category_id)
        class_name = CLASS_ID_TO_NAME[category_id.item()]
        img = visualize_bbox(img, bbox, class_name)
    return img

## Datasets

In [4]:
class PET_dataset():
    def __init__(self,part,neck_dir,body_dir,phase, transformer=None, aug=None, aug_factor=0):
        self.neck_dir=neck_dir
        self.body_dir=body_dir
        self.part=part
        self.phase=phase
        self.transformer=transformer
        self.aug=aug
        self.aug_factor=aug_factor
        if(self.part=="body"):
            self.image_files = sorted([fn for fn in os.listdir(self.body_dir+"/"+self.phase+"/image") if fn.endswith("jpg")])
            self.label_files= sorted([lab for lab in os.listdir(self.body_dir+"/"+self.phase+"/label") if lab.endswith("txt")])
        elif(self.part=="neck"):
            self.image_files = sorted([fn for fn in os.listdir(self.neck_dir+"/"+self.phase+"/image") if fn.endswith("jpg")])
            self.label_files= sorted([lab for lab in os.listdir(self.neck_dir+"/"+self.phase+"/label") if lab.endswith("txt")])
        
        self.auged_img_list, self.auged_label_list=self.make_aug_list(self.image_files, self.label_files)
        
    def __getitem__(self,index):
        if(self.aug==None):
            filename, image = self.get_image(self.part, index)
            bboxes, class_ids = self.get_label(self.part, index)

            if(self.transformer):
                transformed_data=self.transformer(image=image, bboxes=bboxes, class_ids=class_ids)
                image = transformed_data['image']
                bboxes = np.array(transformed_data['bboxes'])
                class_ids = np.array(transformed_data['class_ids'])


            target = {}
    #         print(f'bboxes:{bboxes}\nclass_ids:{class_ids}\nlen_bboxes:{len(bboxes)}\nlen_class_ids:{len(class_ids)}')
    #         print(f'filename: {filename}')
            target["boxes"] = torch.Tensor(bboxes).float()
            target["labels"] = torch.Tensor(class_ids).long()

            ###
            bboxes=torch.Tensor(bboxes).float()
            class_ids=torch.Tensor(class_ids).long()
            target = np.concatenate((bboxes, class_ids[:, np.newaxis]), axis=1)
            ###
        else:
            image=self.auged_img_list[index][1]
            target=self.auged_label_list[index]
            filename=self.auged_img_list[index][0]
        return image, target, filename
    
    def __len__(self, ):
        length=0
        if(self.aug==None):
            length=len(self.image_files)
        else:
            length=len(self.auged_img_list)
        return length
    
    def make_aug_list(self,ori_image_list,ori_label_files):
        aug_image_list=[]
        aug_label_list=[]
        
        print(f"start making augmented images-- augmented factor:{self.aug_factor}")
        for i in range(len(ori_image_list)):
            filename, ori_image = self.get_image(self.part, i)
            ori_bboxes, ori_class_ids = self.get_label(self.part, i)
#             print(f"{filename}--------------\n{ori_bboxes}")
            for j in range(self.aug_factor):
                auged_data=self.aug(image=ori_image, bboxes=ori_bboxes, class_ids=ori_class_ids)
                image = auged_data['image']
                bboxes = np.array(auged_data['bboxes'])
                class_ids = np.array(auged_data['class_ids'])
                
                bboxes=torch.Tensor(bboxes).float()
                class_ids=torch.Tensor(class_ids).long()
                
                aug_image_list.append((filename, image))
                aug_label_list.append(np.concatenate((bboxes, class_ids[:, np.newaxis]), axis=1))
        
        print(f"total length of augmented images: {len(aug_image_list)}")
        
        return aug_image_list, aug_label_list
        
    
    def get_image(self, part, index): # 이미지 불러오는 함수
        filename = self.image_files[index]
        if(part=="body"):
#             print(f"body called!-> {self.part}")
            image_path = self.body_dir+"/"+self.phase+"/image/"+filename
        elif(part=="neck"):
#             print(f"neck called!-> {self.part}")
            image_path = self.neck_dir+"/"+self.phase+"/image/"+filename
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return filename, image
    
    def get_label(self, part, index): # label (box좌표, class_id) 불러오는 함수
        label_filename=self.label_files[index]
        if(part=="body"):
#             print(f"body label called!-> {self.part}")
            label_path = self.body_dir+"/"+self.phase+"/label/"+label_filename
        elif(part=="neck"):
#             print(f"neck label called!-> {self.part}")
            label_path = self.neck_dir+"/"+self.phase+"/label/"+label_filename
        with open(label_path, 'r') as file:
            labels = file.readlines()
        
        class_ids=[]
        bboxes=[]
        for label in labels:
            label=label.replace("\n", "")
            obj=label.split(' ')[0]
            coor=label.split(' ')[1:]
            obj=int(obj)
            coor=list(map(float, coor))
            class_ids.append(obj)
            bboxes.append(coor)
            
        return bboxes, class_ids
    

In [5]:
IMAGE_SIZE = 448

transformer = A.Compose([ 
        # bounding box의 변환, augmentation에서 albumentations는 Detection 학습을 할 때 굉장히 유용하다. 
        A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
        # albumentations 라이브러리에서는 Normalization을 먼저 진행해 주고 tensor화를 진행해 주어야한다.
    ],
    # box 위치에 대한 transformation도 함께 진행된다. 
    bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
)

augmentator=A.Compose([
#     A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
#     A.HorizontalFlip(p=0.7),
#     A.Sharpen(p=0.7),
    A.RandomSizedBBoxSafeCrop(height=IMAGE_SIZE, width=IMAGE_SIZE, p=1),
#     A.VerticalFlip (p=0.5),
#     A.HueSaturationValue(p=0.5),
#     A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
#     A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
#     ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
)

def collate_fn(batch):
    image_list = []
    target_list = []
    filename_list = []
    
    for a,b,c in batch:
        image_list.append(a)
        target_list.append(b)
        filename_list.append(c)

    return torch.stack(image_list, dim=0), target_list, filename_list


In [6]:
NECK_PATH = '/home/host_data/PET_data/Neck'
BODY_PATH = '/home/host_data/PET_data/Body'

trainset_crop=PET_dataset(part='neck',neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='train', transformer=transformer, aug=augmentator, aug_factor=50)



start making augmented images-- augmented factor:50
total length of augmented images: 10500


In [6]:
NECK_PATH = '/home/host_data/PET_data/Neck'
BODY_PATH = '/home/host_data/PET_data/Body'
PHASE='valid'
AUG_FAC=50
trainset_crop=PET_dataset(part='neck',neck_dir=NECK_PATH,body_dir=BODY_PATH,phase=PHASE, transformer=transformer, aug=augmentator, aug_factor=AUG_FAC)


_, _, temp_filename=trainset_crop[0]
file_num=0
for i in range(len(trainset_crop)):
    image, target, filename = trainset_crop[i]
    if(temp_filename==filename):
        file_num=file_num+1
    else:
        file_num=1
    temp_filename=filename
    file_title=filename.split(".")[0]
    save_title=f"{file_title}_aug{file_num}"
    
    #이미지 저장
    cv2.imwrite(f"/home/host_data/PET_data_image_patching/patched_Neck/{PHASE}/image/"+save_title+".jpg", image)
        
    #라벨파일 저장
    f = open(f"/home/host_data/PET_data_image_patching/patched_Neck/{PHASE}/label/"+save_title+".txt", 'w')
    print(f"ori-tar:{target}")
    for j in range(len(target)):
        write_str=f"{int(target[j][4])} {target[j][0]} {target[j][1]} {target[j][2]} {target[j][3]}\n"
        print(write_str)
        f.write(write_str)

    print(save_title)
#     print(str(target).replace("[", " ").replace("]", " "))
    print(target)

start making augmented images-- augmented factor:50
total length of augmented images: 1800
ori-tar:[[0.45514888 0.66199952 0.14617775 0.12399958 0.        ]
 [0.4534885  0.41199958 0.14285703 0.14400026 0.        ]]
0 0.4551488757133484 0.661999523639679 0.1461777538061142 0.12399958074092865

0 0.4534884989261627 0.41199958324432373 0.14285703003406525 0.14400026202201843

shape1_101_aug1
[[0.45514888 0.66199952 0.14617775 0.12399958 0.        ]
 [0.4534885  0.41199958 0.14285703 0.14400026 0.        ]]
ori-tar:[[0.29113805 0.82894677 0.27847788 0.16315734 0.        ]
 [0.28797495 0.49999943 0.27215168 0.18947403 0.        ]]
0 0.29113805294036865 0.8289467692375183 0.27847787737846375 0.16315734386444092

0 0.2879749536514282 0.49999943375587463 0.2721516788005829 0.18947403132915497

shape1_101_aug2
[[0.29113805 0.82894677 0.27847788 0.16315734 0.        ]
 [0.28797495 0.49999943 0.27215168 0.18947403 0.        ]]
ori-tar:[[0.06286819 0.75943285 0.04322152 0.29245183 0.        ]
 [0

ori-tar:[[0.89126366 0.44000098 0.05982884 0.40000042 0.        ]]
0 0.8912636637687683 0.440000981092453 0.05982884019613266 0.4000004231929779

shape1_12_aug21
[[0.89126366 0.44000098 0.05982884 0.40000042 0.        ]]
ori-tar:[[0.52880907 0.7519691  0.2592583  0.23622072 0.        ]]
0 0.5288090705871582 0.7519690990447998 0.25925830006599426 0.23622071743011475

shape1_12_aug22
[[0.52880907 0.7519691  0.2592583  0.23622072 0.        ]]
ori-tar:[[0.95506263 0.67272794 0.06508241 0.27272755 0.        ]]
0 0.9550626277923584 0.6727279424667358 0.06508240848779678 0.2727275490760803

shape1_12_aug23
[[0.95506263 0.67272794 0.06508241 0.27272755 0.        ]]
ori-tar:[[0.95531082 0.45454624 0.06624582 0.32085595 0.        ]]
0 0.9553108215332031 0.4545462429523468 0.06624581664800644 0.3208559453487396

shape1_12_aug24
[[0.95531082 0.45454624 0.06624582 0.32085595 0.        ]]
ori-tar:[[0.28481272 0.26388991 0.26582181 0.4166671  0.        ]]
0 0.28481271862983704 0.2638899087905884 0.26

ori-tar:[[0.58528072 0.33928588 0.07009377 0.25000072 0.        ]]
0 0.58528071641922 0.33928588032722473 0.07009377330541611 0.2500007152557373

shape1_151_aug43
[[0.58528072 0.33928588 0.07009377 0.25000072 0.        ]]
ori-tar:[[0.2416923  0.67099577 0.09063485 0.18181869 0.        ]]
0 0.24169230461120605 0.6709957718849182 0.09063485264778137 0.18181869387626648

shape1_151_aug44
[[0.2416923  0.67099577 0.09063485 0.18181869 0.        ]]
ori-tar:[[0.31638476 0.60919553 0.11299486 0.16091999 0.        ]]
0 0.31638476252555847 0.6091955304145813 0.11299485713243484 0.16091999411582947

shape1_151_aug45
[[0.31638476 0.60919553 0.11299486 0.16091999 0.        ]]
ori-tar:[[0.4013496  0.37662378 0.1011809  0.54545611 0.        ]]
0 0.40134960412979126 0.3766237795352936 0.10118089616298676 0.5454561114311218

shape1_151_aug46
[[0.4013496  0.37662378 0.1011809  0.54545611 0.        ]]
ori-tar:[[0.43007028 0.53284693 0.06993038 0.30657023 0.        ]]
0 0.43007028102874756 0.5328469276428

ori-tar:[[0.59482718 0.63802052 0.12413786 0.21354206 0.        ]]
0 0.5948271751403809 0.6380205154418945 0.12413785606622696 0.21354205906391144

shape1_63_aug15
[[0.59482718 0.63802052 0.12413786 0.21354206 0.        ]]
ori-tar:[[0.82857066 0.63907248 0.22857128 0.27152368 0.        ]]
0 0.8285706639289856 0.6390724778175354 0.22857128083705902 0.27152368426322937

shape1_63_aug16
[[0.82857066 0.63907248 0.22857128 0.27152368 0.        ]]
ori-tar:[[0.78225774 0.48611078 0.09677413 0.2277782  0.        ]]
0 0.7822577357292175 0.48611077666282654 0.0967741310596466 0.22777819633483887

shape1_63_aug17
[[0.78225774 0.48611078 0.09677413 0.2277782  0.        ]]
ori-tar:[[0.5181548  0.31470552 0.07065746 0.24117692 0.        ]]
0 0.5181547999382019 0.3147055208683014 0.07065746188163757 0.24117691814899445

shape1_63_aug18
[[0.5181548  0.31470552 0.07065746 0.24117692 0.        ]]
ori-tar:[[0.69701052 0.5132156  0.09782603 0.18061708 0.        ]]
0 0.6970105171203613 0.5132156014442444 0

ori-tar:[[0.68000174 0.33809501 0.53714257 0.46666661 0.        ]]
0 0.6800017356872559 0.3380950093269348 0.5371425747871399 0.4666666090488434

shape1_83_aug37
[[0.68000174 0.33809501 0.53714257 0.46666661 0.        ]]
ori-tar:[[0.8335911  0.4931505  0.28967628 0.33561641 0.        ]]
0 0.833591103553772 0.4931505024433136 0.2896762788295746 0.33561640977859497

shape1_83_aug38
[[0.8335911  0.4931505  0.28967628 0.33561641 0.        ]]
ori-tar:[[0.82332265 0.51388854 0.33215529 0.68055546 0.        ]]
0 0.8233226537704468 0.5138885378837585 0.3321552872657776 0.6805554628372192

shape1_83_aug39
[[0.82332265 0.51388854 0.33215529 0.68055546 0.        ]]
ori-tar:[[0.88673532 0.58636343 0.19183663 0.44545451 0.        ]]
0 0.8867353200912476 0.5863634347915649 0.1918366253376007 0.44545450806617737

shape1_83_aug40
[[0.88673532 0.58636343 0.19183663 0.44545451 0.        ]]
ori-tar:[[0.81768066 0.30660355 0.34622449 0.46226409 0.        ]]
0 0.8176806569099426 0.3066035509109497 0.346224

ori-tar:[[0.06003379 0.46323532 0.04459776 0.68137252 1.        ]]
1 0.06003379076719284 0.4632353186607361 0.04459775984287262 0.6813725233078003

shape2_38_aug4
[[0.06003379 0.46323532 0.04459776 0.68137252 1.        ]]
ori-tar:[[0.67816043 0.46462265 0.04269375 0.65566033 1.        ]]
1 0.6781604290008545 0.4646226465702057 0.04269374907016754 0.6556603312492371

shape2_38_aug5
[[0.67816043 0.46462265 0.04269375 0.65566033 1.        ]]
ori-tar:[[0.20642732 0.33413464 0.03213905 0.66826922 1.        ]]
1 0.20642732083797455 0.33413463830947876 0.03213905170559883 0.6682692170143127

shape2_38_aug6
[[0.20642732 0.33413464 0.03213905 0.66826922 1.        ]]
ori-tar:[[0.41006392 0.36718753 0.02783779 0.72395831 1.        ]]
1 0.4100639224052429 0.3671875298023224 0.027837788686156273 0.7239583134651184

shape2_38_aug7
[[0.41006392 0.36718753 0.02783779 0.72395831 1.        ]]
ori-tar:[[0.43058133 0.42640695 0.02439071 0.60173154 1.        ]]
1 0.43058133125305176 0.42640694975852966 0.0

ori-tar:[[0.07777809 0.39606801 0.07407362 0.34269717 0.        ]]
0 0.07777809351682663 0.3960680067539215 0.07407362014055252 0.3426971733570099

shape3_43_aug29
[[0.07777809 0.39606801 0.07407362 0.34269717 0.        ]]
ori-tar:[[0.15010643 0.46089444 0.16913214 0.34078267 0.        ]]
0 0.15010643005371094 0.4608944356441498 0.16913214325904846 0.34078267216682434

shape3_43_aug30
[[0.15010643 0.46089444 0.16913214 0.34078267 0.        ]]
ori-tar:[[0.04743117 0.31132141 0.0790509  0.38364843 0.        ]]
0 0.04743117094039917 0.3113214075565338 0.07905089855194092 0.38364842534065247

shape3_43_aug31
[[0.04743117 0.31132141 0.0790509  0.38364843 0.        ]]
ori-tar:[[0.04442285 0.53200084 0.07897286 0.48800078 0.        ]]
0 0.04442284628748894 0.5320008397102356 0.07897286117076874 0.4880007803440094

shape3_43_aug32
[[0.04442285 0.53200084 0.07897286 0.48800078 0.        ]]
ori-tar:[[0.15083008 0.59150392 0.12066291 0.39869347 0.        ]]
0 0.15083007514476776 0.591503918170929

ori-tar:[[0.91241997 0.50000012 0.1687887  0.28571308 0.        ]]
0 0.9124199748039246 0.5000001192092896 0.16878870129585266 0.2857130765914917

shape5_20_aug11
[[0.91241997 0.50000012 0.1687887  0.28571308 0.        ]]
ori-tar:[[0.91288269 0.43961361 0.05596584 0.21255948 0.        ]]
0 0.9128826856613159 0.43961361050605774 0.0559658408164978 0.21255947649478912

shape5_20_aug12
[[0.91288269 0.43961361 0.05596584 0.21255948 0.        ]]
ori-tar:[[0.63812089 0.20547952 0.29281574 0.15068428 0.        ]]
0 0.6381208896636963 0.20547951757907867 0.2928157448768616 0.15068428218364716

shape5_20_aug13
[[0.63812089 0.20547952 0.29281574 0.15068428 0.        ]]
ori-tar:[[0.80868983 0.63013726 0.06874144 0.60273713 0.        ]]
0 0.8086898326873779 0.6301372647285461 0.06874144077301025 0.6027371287345886

shape5_20_aug14
[[0.80868983 0.63013726 0.06874144 0.60273713 0.        ]]
ori-tar:[[0.84719306 0.25776404 0.11018638 0.13664538 0.        ]]
0 0.8471930623054504 0.2577640414237976 0.1

ori-tar:[[0.71618718 0.45040956 0.13531488 0.21218896 0.        ]]
0 0.7161871790885925 0.45040956139564514 0.13531488180160522 0.2121889591217041

shape6_9_aug34
[[0.71618718 0.45040956 0.13531488 0.21218896 0.        ]]
ori-tar:[[0.43569502 0.77305377 0.44677931 0.43080789 0.        ]]
0 0.43569502234458923 0.773053765296936 0.4467793107032776 0.430807888507843

shape6_9_aug35
[[0.43569502 0.77305377 0.44677931 0.43080789 0.        ]]
ori-tar:[[0.83865976 0.36783719 0.08380211 0.25088224 0.        ]]
0 0.8386597633361816 0.36783719062805176 0.08380211144685745 0.25088223814964294

shape6_9_aug36
[[0.83865976 0.36783719 0.08380211 0.25088224 0.        ]]
ori-tar:[[0.72313678 0.59777296 0.07150161 0.20309515 0.        ]]
0 0.7231367826461792 0.5977729558944702 0.07150161266326904 0.20309515297412872

shape6_9_aug37
[[0.72313678 0.59777296 0.07150161 0.20309515 0.        ]]
ori-tar:[[0.86702877 0.51368141 0.11761013 0.38080341 0.        ]]
0 0.8670287728309631 0.5136814117431641 0.11761

In [7]:
@interact(index=(0, len(trainset_crop)-1))

def show_sample(index=0):
    image, target, filename = trainset_crop[index]
#     image=image.permute(1,2,0).numpy()
    print(f"before cal target:\n{target}")
    img_H, img_W, _ = image.shape
    print(filename)
    print(image.shape)
#     print(image)

#     bboxes = target['boxes']
#     class_ids = target["labels"]
    
    ###
    bboxes = target[:, 0:4]
    class_ids = target[:, 4]
    ###
    bboxes[:, [0,2]] *= img_W
    bboxes[:, [1,3]] *= img_H
    
    print(bboxes)

    canvas = visualize(image, bboxes, class_ids)
    plt.figure(figsize=(6,6))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()

# show_sample()

interactive(children=(IntSlider(value=0, description='index', max=1049), Output()), _dom_classes=('widget-inte…