In [8]:
import numpy as np
import os 
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torch import nn
import torchsummary
from torch.utils.data import DataLoader
from collections import defaultdict
from torchvision.utils import make_grid

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

# Utils

In [10]:
CLASS_NAME_TO_ID = {'Unformed': 0, 'Burr': 1}
CLASS_ID_TO_NAME = {0: 'Unformed', 1: 'Burr'}
BOX_COLOR = {'Unformed':(200, 0, 0), 'Burr':(0, 0, 200)}
TEXT_COLOR = (255, 255, 255)

def save_model(model_state, model_name, save_dir="./trained_model"):
    os.makedirs(save_dir, exist_ok=True)
    torch.save(model_state, os.path.join(save_dir, model_name))


def visualize_bbox(image, bbox, class_name, color=BOX_COLOR, thickness=2):
    x_center, y_center, w, h = bbox
    x_min = int(x_center - w/2)
    y_min = int(y_center - h/2)
    x_max = int(x_center + w/2)
    y_max = int(y_center + h/2)
    
    cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=color[class_name], thickness=thickness)
    
    ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)    
    cv2.rectangle(image, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), color[class_name], -1)
    cv2.putText(
        image,
        text=class_name,
        org=(x_min, y_min - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35, 
        color=TEXT_COLOR, 
        lineType=cv2.LINE_AA,
    )
    return image


def visualize(image, bboxes, category_ids):
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
#         print('category_id: ',category_id)
        class_name = CLASS_ID_TO_NAME[category_id.item()]
        img = visualize_bbox(img, bbox, class_name)
    return img

# Datasets

In [11]:
class PET_dataset():
    def __init__(self,part,neck_dir,body_dir,phase, transformer=None, aug=None, aug_factor=0):
        self.neck_dir=neck_dir
        self.body_dir=body_dir
        self.part=part
        self.phase=phase
        self.transformer=transformer
        self.aug=aug
        self.aug_factor=aug_factor
        if(self.part=="body"):
            self.image_files = sorted([fn for fn in os.listdir(self.body_dir+"/"+self.phase+"/image") if fn.endswith("jpg")])
            self.label_files= sorted([lab for lab in os.listdir(self.body_dir+"/"+self.phase+"/label") if lab.endswith("txt")])
        elif(self.part=="neck"):
            self.image_files = sorted([fn for fn in os.listdir(self.neck_dir+"/"+self.phase+"/image") if fn.endswith("jpg")])
            self.label_files= sorted([lab for lab in os.listdir(self.neck_dir+"/"+self.phase+"/label") if lab.endswith("txt")])
        
        self.auged_img_list, self.auged_label_list=self.make_aug_list(self.image_files, self.label_files)
        
    def __getitem__(self,index):
        if(self.aug==None):
            filename, image = self.get_image(self.part, index)
            bboxes, class_ids = self.get_label(self.part, index)

            if(self.transformer):
                transformed_data=self.transformer(image=image, bboxes=bboxes, class_ids=class_ids)
                image = transformed_data['image']
                bboxes = np.array(transformed_data['bboxes'])
                class_ids = np.array(transformed_data['class_ids'])


            target = {}
    #         print(f'bboxes:{bboxes}\nclass_ids:{class_ids}\nlen_bboxes:{len(bboxes)}\nlen_class_ids:{len(class_ids)}')
    #         print(f'filename: {filename}')
            target["boxes"] = torch.Tensor(bboxes).float()
            target["labels"] = torch.Tensor(class_ids).long()

            ###
            bboxes=torch.Tensor(bboxes).float()
            class_ids=torch.Tensor(class_ids).long()
            target = np.concatenate((bboxes, class_ids[:, np.newaxis]), axis=1)
            ###
        else:
            image=self.auged_img_list[index][1]
            target=self.auged_label_list[index]
            filename=self.auged_img_list[index][0]
        return image, target, filename
    
    def __len__(self, ):
        length=0
        if(self.aug==None):
            length=len(self.image_files)
        else:
            length=len(self.auged_img_list)
        return length
    
    def make_aug_list(self,ori_image_list,ori_label_files):
        aug_image_list=[]
        aug_label_list=[]
        
        print(f"start making augmented images-- augmented factor:{self.aug_factor}")
        for i in range(len(ori_image_list)):
            filename, ori_image = self.get_image(self.part, i)
            ori_bboxes, ori_class_ids = self.get_label(self.part, i)
            for j in range(self.aug_factor):
                auged_data=self.aug(image=ori_image, bboxes=ori_bboxes, class_ids=ori_class_ids)
                image = auged_data['image']
                bboxes = np.array(auged_data['bboxes'])
                class_ids = np.array(auged_data['class_ids'])
                
                bboxes=torch.Tensor(bboxes).float()
                class_ids=torch.Tensor(class_ids).long()
                
                aug_image_list.append((filename, image))
                aug_label_list.append(np.concatenate((bboxes, class_ids[:, np.newaxis]), axis=1))
        
        print(f"total length of augmented images: {len(aug_image_list)}")
        
        return aug_image_list, aug_label_list
        
    
    def get_image(self, part, index): # 이미지 불러오는 함수
        filename = self.image_files[index]
        if(part=="body"):
#             print(f"body called!-> {self.part}")
            image_path = self.body_dir+"/"+self.phase+"/image/"+filename
        elif(part=="neck"):
#             print(f"neck called!-> {self.part}")
            image_path = self.neck_dir+"/"+self.phase+"/image/"+filename
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return filename, image
    
    def get_label(self, part, index): # label (box좌표, class_id) 불러오는 함수
        label_filename=self.label_files[index]
        if(part=="body"):
#             print(f"body label called!-> {self.part}")
            label_path = self.body_dir+"/"+self.phase+"/label/"+label_filename
        elif(part=="neck"):
#             print(f"neck label called!-> {self.part}")
            label_path = self.neck_dir+"/"+self.phase+"/label/"+label_filename
        with open(label_path, 'r') as file:
            labels = file.readlines()
        
        class_ids=[]
        bboxes=[]
        for label in labels:
            label=label.replace("\n", "")
            obj=label.split(' ')[0]
            coor=label.split(' ')[1:]
            obj=int(obj)
            coor=list(map(float, coor))
            class_ids.append(obj)
            bboxes.append(coor)
            
        return bboxes, class_ids
    

In [12]:
IMAGE_SIZE = 448

transformer = A.Compose([ 
        # bounding box의 변환, augmentation에서 albumentations는 Detection 학습을 할 때 굉장히 유용하다. 
        A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
        # albumentations 라이브러리에서는 Normalization을 먼저 진행해 주고 tensor화를 진행해 주어야한다.
    ],
    # box 위치에 대한 transformation도 함께 진행된다. 
    bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
)

augmentator=A.Compose([
#     A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
    A.HorizontalFlip(p=0.7),
#     A.Sharpen(p=0.7),
    A.BBoxSafeRandomCrop(p=0.6),
    A.VerticalFlip (p=0.5),
    A.HueSaturationValue(p=0.5),
    A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
    A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
)

def collate_fn(batch):
    image_list = []
    target_list = []
    filename_list = []
    
    for a,b,c in batch:
        image_list.append(a)
        target_list.append(b)
        filename_list.append(c)

    return torch.stack(image_list, dim=0), target_list, filename_list


In [83]:
NECK_PATH = '/home/host_data/PET_data/Neck'
BODY_PATH = '/home/host_data/PET_data/Body'
trainset_yes_aug=PET_dataset(part='neck',neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='train', transformer=transformer, aug=augmentator, aug_factor=5)
trainset_no_aug=PET_dataset(part='neck',neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='train', transformer=transformer, aug=None)


start making augmented images-- augmented factor:5
total length of augmented images: 1050
start making augmented images-- augmented factor:0
total length of augmented images: 0


In [84]:
len(trainset_yes_aug)

1050

In [85]:
@interact(index=(0, len(trainset_no_aug)-1))

def show_sample(index=0):
    image, target, filename = trainset_no_aug[index]
    image=image.permute(1,2,0).numpy()
    img_H, img_W, _ = image.shape
    print(filename)
    print(image.shape)
#     print(image)

#     bboxes = target['boxes']
#     class_ids = target["labels"]
    
    ###
    bboxes = target[:, 0:4]
    class_ids = target[:, 4]
    ###
    bboxes[:, [0,2]] *= img_W
    bboxes[:, [1,3]] *= img_H

    canvas = visualize(image, bboxes, class_ids)
    plt.figure(figsize=(6,6))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()

# show_sample()

interactive(children=(IntSlider(value=0, description='index', max=209), Output()), _dom_classes=('widget-inter…

In [9]:
@interact(index=(0, len(trainset_yes_aug)-1))

def show_sample(index=0):
    image, target, filename = trainset_yes_aug[index]
    image=image.permute(1,2,0).numpy()
    img_H, img_W, _ = image.shape
    print(filename)
    print(image.shape)
#     print(image)

#     bboxes = target['boxes']
#     class_ids = target["labels"]
    ###
    bboxes = target[:, 0:4]
    class_ids = target[:, 4]
    ###
    bboxes[:, [0,2]] *= img_W
    bboxes[:, [1,3]] *= img_H
    print(bboxes)

    canvas = visualize(image, bboxes, class_ids)
    plt.figure(figsize=(6,6))
    plt.imshow(canvas)
    plt.axis('off')
    plt.show()

# show_sample()

interactive(children=(IntSlider(value=0, description='index', max=1049), Output()), _dom_classes=('widget-inte…

## Model

In [13]:
class YOLO_SWIN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.num_classes = num_classes
        self.num_bboxes = 2
        self.grid_size = 7

#         resnet18 = torchvision.models.resnet18(pretrained = True)
        swin=torchvision.models.swin_v2_t(weights='IMAGENET1K_V1')
        layers = [m for m in swin.children()] #Resnet에서 Yolo에서 가져올수 있을만한 layer만 선별적으로 가져오기 위해서

        # 기존 Resnet18의 layer들중에서 맨 뒤에 두개만 제외하고 다 가져와서 Backbone으로 사용
        self.backbone = nn.Sequential(*layers[:-3]) 
        self.head = nn.Sequential(
                nn.Conv2d(in_channels=768, out_channels=1024, kernel_size=1, padding=0,bias=False),
                nn.BatchNorm2d(1024),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(1024),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(1024),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1,bias=False),
                nn.BatchNorm2d(1024),
                nn.ReLU(inplace=True),

                nn.Conv2d(in_channels=1024, out_channels=(4+1)*self.num_bboxes+num_classes, kernel_size=1, padding=0, bias=False),
                nn.AdaptiveAvgPool2d(output_size=(self.grid_size, self.grid_size))
            )

    def forward(self, x):
        out = self.backbone(x)
        # out = self.neck(out)
        out = self.head(out) # input (batch, 3, 448, 448) -> output feature (batch, 12, 7, 7)
        return out


In [14]:
NUM_CLASSES = 2
model = YOLO_SWIN(num_classes=NUM_CLASSES)
model.to(device)

YOLO_SWIN(
  (backbone): Sequential(
    (0): Sequential(
      (0): Sequential(
        (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        (1): Permute()
        (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      )
      (1): Sequential(
        (0): SwinTransformerBlockV2(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): ShiftedWindowAttentionV2(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (cpb_mlp): Sequential(
              (0): Linear(in_features=2, out_features=512, bias=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=512, out_features=3, bias=False)
            )
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (0): Linear(in_features=96, out_f

In [15]:
torchsummary.summary(model, (3,448,448))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 96, 112, 112]           4,704
           Permute-2         [-1, 112, 112, 96]               0
         LayerNorm-3         [-1, 112, 112, 96]             192
            Linear-4          [-1, 15, 15, 512]           1,536
              ReLU-5          [-1, 15, 15, 512]               0
            Linear-6            [-1, 15, 15, 3]           1,536
ShiftedWindowAttentionV2-7         [-1, 112, 112, 96]               0
         LayerNorm-8         [-1, 112, 112, 96]             192
   StochasticDepth-9         [-1, 112, 112, 96]               0
           Linear-10        [-1, 112, 112, 384]          37,248
             GELU-11        [-1, 112, 112, 384]               0
          Dropout-12        [-1, 112, 112, 384]               0
           Linear-13         [-1, 112, 112, 96]          36,960
          Dropout-14         [-1,

In [16]:
x = torch.randn(1, 3, 448, 448).to(device)
with torch.no_grad():
    y = model(x)
print(y.shape)

torch.Size([1, 12, 7, 7])


# Loss func

In [17]:
class YOLO_LOSS():
    def __init__(self, num_classes, device, lambda_coord=5., lambda_noobj=0.5):
        self.num_classes = num_classes
        self.device = device
        self.grid_size = 7
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.mse_loss = nn.MSELoss(reduction="sum")

    def __call__(self, predictions, targets):
        self.batch_size, _, _, _ = predictions.shape
        groundtruths = self.build_batch_target_grid(targets)
        groundtruths = groundtruths.to(self.device)
        
        with torch.no_grad():
            iou1 = self.get_IoU(predictions[:, 1:5, ...], groundtruths[:, 1:5, ...])
            iou2 = self.get_IoU(predictions[:, 6:10, ...], groundtruths[:, 1:5, ...])

        ious = torch.stack([iou1, iou2], dim=1)
        max_iou, best_box = ious.max(dim=1, keepdim=True)
        max_iou = torch.cat([max_iou, max_iou], dim=1)
        best_box = torch.cat([best_box.eq(0), best_box.eq(1)], dim=1)

        predictions_ = predictions[:, :5*2, ...].reshape(self.batch_size, 2, 5, self.grid_size, self.grid_size)
        obj_pred = predictions_[:, :, 0, ...]
        xy_pred = predictions_[:, :, 1:3, ...]
        wh_pred = predictions_[:, :, 3:5, ...]
        cls_pred = predictions[:, 5*2:, ...]

        groundtruths_ = groundtruths[:, :5, ...].reshape(self.batch_size, 1, 5, self.grid_size, self.grid_size)
        obj_target = groundtruths_[:, :, 0, ...]
        xy_target = groundtruths_[:, :, 1:3, ...]
        wh_target= groundtruths_[:, :, 3:5, ...]
        cls_target = groundtruths[:, 5:, ...]
        
        positive = obj_target * best_box

        obj_loss = self.mse_loss(positive * obj_pred, positive * ious)
        noobj_loss = self.mse_loss((1 - positive) * obj_pred, ious*0)
        xy_loss = self.mse_loss(positive.unsqueeze(dim=2) * xy_pred, positive.unsqueeze(dim=2) * xy_target)
        wh_loss = self.mse_loss(positive.unsqueeze(dim=2) * (wh_pred.sign() * (wh_pred.abs() + 1e-8).sqrt()),
                           positive.unsqueeze(dim=2) * (wh_target + 1e-8).sqrt())
        cls_loss = self.mse_loss(obj_target * cls_pred, cls_target)
        
        obj_loss /= self.batch_size
        noobj_loss /= self.batch_size
        bbox_loss = (xy_loss+wh_loss) / self.batch_size
        cls_loss /= self.batch_size
        
        total_loss = obj_loss + self.lambda_noobj*noobj_loss + self.lambda_coord*bbox_loss + cls_loss
        return total_loss, (obj_loss.item(), noobj_loss.item(), bbox_loss.item(), cls_loss.item())
    
    def build_target_grid(self, target):
        target_grid = torch.zeros((1+4+self.num_classes, self.grid_size, self.grid_size), device=self.device)

        for gt in target:
            xc, yc, w, h, cls_id = gt
            xn = (xc % (1/self.grid_size))
            yn = (yc % (1/self.grid_size))
            cls_id = int(cls_id)

            i_grid = int(xc * self.grid_size)
            j_grid = int(yc * self.grid_size)
            target_grid[0, j_grid, i_grid] = 1
            target_grid[1:5, j_grid, i_grid] = torch.Tensor([xn,yn,w,h])
#             print(5+cls_id, j_grid, i_grid)
            target_grid[5+cls_id, j_grid, i_grid] = 1

        return target_grid
    
    def build_batch_target_grid(self, targets):
        target_grid_batch = torch.stack([self.build_target_grid(target) for target in targets], dim=0)
        return target_grid_batch
    
    def get_IoU(self, cbox1, cbox2):
        box1 = self.xywh_to_xyxy(cbox1)
        box2 = self.xywh_to_xyxy(cbox2)

        x1 = torch.max(box1[:, 0, ...], box2[:, 0, ...])
        y1 = torch.max(box1[:, 1, ...], box2[:, 1, ...])
        x2 = torch.min(box1[:, 2, ...], box2[:, 2, ...])
        y2 = torch.min(box1[:, 3, ...], box2[:, 3, ...])

        intersection = (x2-x1).clamp(min=0) * (y2-y1).clamp(min=0)
        union = abs(cbox1[:, 2, ...]*cbox1[:, 3, ...]) + \
                abs(cbox2[:, 2, ...]*cbox2[:, 3, ...]) - intersection

        intersection[intersection.gt(0)] = intersection[intersection.gt(0)] / union[intersection.gt(0)]
        return intersection
    
    def generate_xy_normed_grid(self):
        y_offset, x_offset = torch.meshgrid(torch.arange(self.grid_size), torch.arange(self.grid_size))
        xy_grid = torch.stack([x_offset, y_offset], dim=0)
        xy_normed_grid = xy_grid / self.grid_size
        return xy_normed_grid.to(self.device)

    def xywh_to_xyxy(self, bboxes):
        xy_normed_grid = self.generate_xy_normed_grid()
        xcyc = bboxes[:,0:2,...] + xy_normed_grid.tile(self.batch_size, 1,1,1)
        wh = bboxes[:,2:4,...]
        x1y1 = xcyc - (wh/2)
        x2y2 = xcyc + (wh/2)
        return torch.cat([x1y1, x2y2], dim=1)

# Train

In [18]:
def train_one_epoch(dataloaders, model, criterion, optimizer, device):
    train_loss = defaultdict(float)
    val_loss = defaultdict(float)
    
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()
        else:
            model.eval()
        
        running_loss = defaultdict(float)
        for index, batch in enumerate(dataloaders[phase]):
            images = batch[0].to(device)
            targets = batch[1]
            filenames = batch[2]
            
            with torch.set_grad_enabled(phase == "train"): # phase가 train 일때만 gradient 추적기능을 킨다.
                predictions = model(images) #prediction shape=> B,12,7,7
#             print(f"predictions:{predictions}, \ntargets: {targets}\n")
            loss, (obj_loss, noobj_loss, bbox_loss, cls_loss) = criterion(predictions, targets)
#             print(f"loss:{loss}, obj_loss:{obj_loss}, noobj_loss:{noobj_loss}\nbbox_loss:{bbox_loss}, cls_loss:{cls_loss}\n--------------\n")
            if phase == "train":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # 현재 epoch단계에서 loss가 얼마인지 running loss 가출력
                running_loss["total_loss"] += loss.item()
                running_loss["obj_loss"] += obj_loss
                running_loss["noobj_loss"] += noobj_loss
                running_loss["bbox_loss"] += bbox_loss
                running_loss["cls_loss"] += cls_loss
                
                train_loss["total_loss"] += loss.item()
                train_loss["obj_loss"] += obj_loss
                train_loss["noobj_loss"] += noobj_loss
                train_loss["bbox_loss"] += bbox_loss
                train_loss["cls_loss"] += cls_loss
                
                if (index > 0) and (index % VERBOSE_FREQ) == 0:
                    text = f"<<<iteration:[{index}/{len(dataloaders[phase])}] - "
                    for k, v in running_loss.items():
                        text += f"{k}: {v/VERBOSE_FREQ:.4f}  "
                        running_loss[k] = 0.
                    print(text)
            else:
                val_loss["total_loss"] += loss.item()
                val_loss["obj_loss"] += obj_loss
                val_loss["noobj_loss"] += noobj_loss
                val_loss["bbox_loss"] += bbox_loss
                val_loss["cls_loss"] += cls_loss

    for k in train_loss.keys():
        train_loss[k] /= len(dataloaders["train"])
        val_loss[k] /= len(dataloaders["val"])
    return train_loss, val_loss

In [19]:
def build_dataloader(part, NECK_PATH, BODY_PATH, batch_size=2, aug_factor=0):
    IMAGE_SIZE = 448
    transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ],
        bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
    )
    augmentator=A.Compose([
    #     A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        A.HorizontalFlip(p=0.7),
    #     A.Sharpen(p=0.7),
        A.BBoxSafeRandomCrop(p=0.6),
        A.VerticalFlip (p=0.6),
        A.HueSaturationValue(p=0.6),
        A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
        ],
        bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
    )
    
    dataloaders = {}
#     train_dataset = Detection_dataset(data_dir=data_dir, phase="train", transformer=transformer)
    train_dataset=PET_dataset(part ,neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='train', transformer=transformer, aug=augmentator, aug_factor=aug_factor)
    dataloaders["train"] = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

#     val_dataset = Detection_dataset(data_dir=data_dir, phase="val", transformer=transformer)
    val_dataset=PET_dataset(part ,neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='valid', transformer=transformer, aug=augmentator, aug_factor=aug_factor)
    dataloaders["val"] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    return dataloaders

In [13]:
# data_dir = "/content/drive/MyDrive/fastCamMedicalProj/DATASET/DATASET/Detection/"
NECK_PATH = '/home/host_data/PET_data/Neck'
BODY_PATH = '/home/host_data/PET_data/Body'
is_cuda = True

NUM_CLASSES = 2
IMAGE_SIZE = 448
BATCH_SIZE = 8
VERBOSE_FREQ = 20
LR=0.0001
AUG_FACTOR=20
BACKBONE="YOLO_SWIN_T"
PART="neck"
num_epochs = 100
# DEVICE = torch.device('cuda' if torch.cuda.is_available and is_cuda else 'cpu')

dataloaders = build_dataloader(part=PART,NECK_PATH=NECK_PATH,BODY_PATH=BODY_PATH,batch_size=BATCH_SIZE, aug_factor=AUG_FACTOR)
model = YOLO_SWIN(num_classes=NUM_CLASSES)
model = model.to(device)
criterion = YOLO_LOSS(num_classes=NUM_CLASSES, device=device)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

start making augmented images-- augmented factor:20
total length of augmented images: 4200
start making augmented images-- augmented factor:20
total length of augmented images: 720


In [14]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="yolo_swin_neck",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": LR,
    "batch_size": BATCH_SIZE,
    "architecture": BACKBONE,
    "dataset": PART,
    "epochs": num_epochs,
    "aug factor":AUG_FACTOR,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mgomduribo[0m ([33murp[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
best_epoch = 0
best_score = float('inf')
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    train_loss, val_loss = train_one_epoch(dataloaders, model, criterion, optimizer, device)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
#     train_loss["obj_loss"] += obj_loss
#     train_loss["noobj_loss"] += noobj_loss
#     train_loss["bbox_loss"] += bbox_loss
#     train_loss["cls_loss"] += cls_loss
    wandb.log({"Train Loss": train_loss['total_loss'],
               "Train obj Loss":train_loss["obj_loss"],
               "Train bbox Loss":train_loss["bbox_loss"],
               "Train class Loss":train_loss["cls_loss"],
               "Val Loss": val_loss['total_loss'],
               "Val obj Loss":val_loss["obj_loss"],
               "Val bbox Loss":val_loss["bbox_loss"],
               "Val class Loss":val_loss["cls_loss"],})
    print(f"\nepoch:{epoch+1}/{num_epochs} - Train Loss: {train_loss['total_loss']:.4f}, Val Loss: {val_loss['total_loss']:.4f}\n")
    
    if (epoch+1) % 10 == 0:
        save_model(model.state_dict(), f'model_{epoch+1}.pth', save_dir=f"./trained_model/{BACKBONE}_{PART}_LR{LR}_AUG{AUG_FACTOR}")
wandb.finish()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


<<<iteration:[20/525] - total_loss: 6.9731  obj_loss: 0.0772  noobj_loss: 3.9329  bbox_loss: 0.7826  cls_loss: 1.0162  
<<<iteration:[40/525] - total_loss: 3.8196  obj_loss: 0.0522  noobj_loss: 2.6349  bbox_loss: 0.3829  cls_loss: 0.5354  
<<<iteration:[60/525] - total_loss: 3.2380  obj_loss: 0.0586  noobj_loss: 2.3305  bbox_loss: 0.3300  cls_loss: 0.3641  
<<<iteration:[80/525] - total_loss: 3.2955  obj_loss: 0.0496  noobj_loss: 2.2320  bbox_loss: 0.3516  cls_loss: 0.3718  
<<<iteration:[100/525] - total_loss: 4.5594  obj_loss: 0.0439  noobj_loss: 2.1509  bbox_loss: 0.6174  cls_loss: 0.3530  
<<<iteration:[120/525] - total_loss: 3.6070  obj_loss: 0.0408  noobj_loss: 2.0752  bbox_loss: 0.4348  cls_loss: 0.3545  
<<<iteration:[140/525] - total_loss: 3.2490  obj_loss: 0.0394  noobj_loss: 1.9889  bbox_loss: 0.3683  cls_loss: 0.3735  
<<<iteration:[160/525] - total_loss: 3.2121  obj_loss: 0.0278  noobj_loss: 1.8492  bbox_loss: 0.3840  cls_loss: 0.3398  
<<<iteration:[180/525] - total_loss:

<<<iteration:[320/525] - total_loss: 0.7090  obj_loss: 0.0828  noobj_loss: 0.2428  bbox_loss: 0.0719  cls_loss: 0.1453  
<<<iteration:[340/525] - total_loss: 0.6744  obj_loss: 0.0680  noobj_loss: 0.2265  bbox_loss: 0.0690  cls_loss: 0.1484  
<<<iteration:[360/525] - total_loss: 0.6969  obj_loss: 0.0610  noobj_loss: 0.2176  bbox_loss: 0.0713  cls_loss: 0.1708  
<<<iteration:[380/525] - total_loss: 0.7726  obj_loss: 0.0610  noobj_loss: 0.2268  bbox_loss: 0.0773  cls_loss: 0.2115  
<<<iteration:[400/525] - total_loss: 0.7059  obj_loss: 0.0666  noobj_loss: 0.2216  bbox_loss: 0.0706  cls_loss: 0.1755  
<<<iteration:[420/525] - total_loss: 0.7525  obj_loss: 0.0632  noobj_loss: 0.2255  bbox_loss: 0.0798  cls_loss: 0.1772  
<<<iteration:[440/525] - total_loss: 0.5746  obj_loss: 0.0744  noobj_loss: 0.1839  bbox_loss: 0.0507  cls_loss: 0.1547  
<<<iteration:[460/525] - total_loss: 0.5276  obj_loss: 0.0775  noobj_loss: 0.1851  bbox_loss: 0.0488  cls_loss: 0.1137  
<<<iteration:[480/525] - total_l

<<<iteration:[100/525] - total_loss: 0.4147  obj_loss: 0.1261  noobj_loss: 0.0907  bbox_loss: 0.0295  cls_loss: 0.0959  
<<<iteration:[120/525] - total_loss: 0.4701  obj_loss: 0.0964  noobj_loss: 0.0940  bbox_loss: 0.0350  cls_loss: 0.1520  
<<<iteration:[140/525] - total_loss: 0.3942  obj_loss: 0.1149  noobj_loss: 0.0893  bbox_loss: 0.0298  cls_loss: 0.0855  
<<<iteration:[160/525] - total_loss: 0.4412  obj_loss: 0.1245  noobj_loss: 0.0901  bbox_loss: 0.0330  cls_loss: 0.1066  
<<<iteration:[180/525] - total_loss: 0.4566  obj_loss: 0.1184  noobj_loss: 0.0998  bbox_loss: 0.0361  cls_loss: 0.1080  
<<<iteration:[200/525] - total_loss: 0.4597  obj_loss: 0.0978  noobj_loss: 0.0868  bbox_loss: 0.0344  cls_loss: 0.1466  
<<<iteration:[220/525] - total_loss: 0.3816  obj_loss: 0.1066  noobj_loss: 0.0904  bbox_loss: 0.0289  cls_loss: 0.0854  
<<<iteration:[240/525] - total_loss: 0.4383  obj_loss: 0.1277  noobj_loss: 0.0865  bbox_loss: 0.0347  cls_loss: 0.0941  
<<<iteration:[260/525] - total_l

<<<iteration:[400/525] - total_loss: 0.4021  obj_loss: 0.1262  noobj_loss: 0.0760  bbox_loss: 0.0292  cls_loss: 0.0921  
<<<iteration:[420/525] - total_loss: 0.3524  obj_loss: 0.1208  noobj_loss: 0.0718  bbox_loss: 0.0235  cls_loss: 0.0784  
<<<iteration:[440/525] - total_loss: 0.3677  obj_loss: 0.1236  noobj_loss: 0.0750  bbox_loss: 0.0247  cls_loss: 0.0831  
<<<iteration:[460/525] - total_loss: 0.3856  obj_loss: 0.1198  noobj_loss: 0.0741  bbox_loss: 0.0277  cls_loss: 0.0900  
<<<iteration:[480/525] - total_loss: 0.3620  obj_loss: 0.1306  noobj_loss: 0.0754  bbox_loss: 0.0253  cls_loss: 0.0673  
<<<iteration:[500/525] - total_loss: 0.3801  obj_loss: 0.1394  noobj_loss: 0.0736  bbox_loss: 0.0258  cls_loss: 0.0747  
<<<iteration:[520/525] - total_loss: 0.3992  obj_loss: 0.1302  noobj_loss: 0.0728  bbox_loss: 0.0260  cls_loss: 0.1028  

epoch:8/100 - Train Loss: 0.4361, Val Loss: 0.3699

<<<iteration:[20/525] - total_loss: 0.4198  obj_loss: 0.1145  noobj_loss: 0.0877  bbox_loss: 0.0331 

<<<iteration:[180/525] - total_loss: 0.3638  obj_loss: 0.1405  noobj_loss: 0.0648  bbox_loss: 0.0184  cls_loss: 0.0991  
<<<iteration:[200/525] - total_loss: 0.3595  obj_loss: 0.1377  noobj_loss: 0.0668  bbox_loss: 0.0208  cls_loss: 0.0846  
<<<iteration:[220/525] - total_loss: 0.3314  obj_loss: 0.1422  noobj_loss: 0.0723  bbox_loss: 0.0201  cls_loss: 0.0524  
<<<iteration:[240/525] - total_loss: 0.3743  obj_loss: 0.1403  noobj_loss: 0.0691  bbox_loss: 0.0224  cls_loss: 0.0876  
<<<iteration:[260/525] - total_loss: 0.3443  obj_loss: 0.1120  noobj_loss: 0.0675  bbox_loss: 0.0209  cls_loss: 0.0940  
<<<iteration:[280/525] - total_loss: 0.3800  obj_loss: 0.1356  noobj_loss: 0.0653  bbox_loss: 0.0237  cls_loss: 0.0932  
<<<iteration:[300/525] - total_loss: 0.3552  obj_loss: 0.1217  noobj_loss: 0.0662  bbox_loss: 0.0276  cls_loss: 0.0626  
<<<iteration:[320/525] - total_loss: 0.3246  obj_loss: 0.1337  noobj_loss: 0.0664  bbox_loss: 0.0176  cls_loss: 0.0697  
<<<iteration:[340/525] - total_l

<<<iteration:[480/525] - total_loss: 0.3027  obj_loss: 0.1389  noobj_loss: 0.0644  bbox_loss: 0.0150  cls_loss: 0.0564  
<<<iteration:[500/525] - total_loss: 0.3608  obj_loss: 0.1575  noobj_loss: 0.0639  bbox_loss: 0.0160  cls_loss: 0.0914  
<<<iteration:[520/525] - total_loss: 0.3389  obj_loss: 0.1351  noobj_loss: 0.0626  bbox_loss: 0.0191  cls_loss: 0.0771  

epoch:13/100 - Train Loss: 0.3398, Val Loss: 0.3266

<<<iteration:[20/525] - total_loss: 0.3324  obj_loss: 0.1347  noobj_loss: 0.0689  bbox_loss: 0.0175  cls_loss: 0.0758  
<<<iteration:[40/525] - total_loss: 0.3260  obj_loss: 0.1461  noobj_loss: 0.0666  bbox_loss: 0.0178  cls_loss: 0.0578  
<<<iteration:[60/525] - total_loss: 0.3331  obj_loss: 0.1543  noobj_loss: 0.0648  bbox_loss: 0.0180  cls_loss: 0.0565  
<<<iteration:[80/525] - total_loss: 0.3357  obj_loss: 0.1428  noobj_loss: 0.0666  bbox_loss: 0.0193  cls_loss: 0.0629  
<<<iteration:[100/525] - total_loss: 0.3358  obj_loss: 0.1585  noobj_loss: 0.0634  bbox_loss: 0.0211  c

<<<iteration:[260/525] - total_loss: 0.3632  obj_loss: 0.1698  noobj_loss: 0.0648  bbox_loss: 0.0184  cls_loss: 0.0689  
<<<iteration:[280/525] - total_loss: 0.3302  obj_loss: 0.1365  noobj_loss: 0.0600  bbox_loss: 0.0198  cls_loss: 0.0648  
<<<iteration:[300/525] - total_loss: 0.2942  obj_loss: 0.1419  noobj_loss: 0.0622  bbox_loss: 0.0152  cls_loss: 0.0453  
<<<iteration:[320/525] - total_loss: 0.3276  obj_loss: 0.1460  noobj_loss: 0.0613  bbox_loss: 0.0167  cls_loss: 0.0676  
<<<iteration:[340/525] - total_loss: 0.3289  obj_loss: 0.1446  noobj_loss: 0.0673  bbox_loss: 0.0177  cls_loss: 0.0623  
<<<iteration:[360/525] - total_loss: 0.3502  obj_loss: 0.1619  noobj_loss: 0.0690  bbox_loss: 0.0170  cls_loss: 0.0690  
<<<iteration:[380/525] - total_loss: 0.3595  obj_loss: 0.1466  noobj_loss: 0.0664  bbox_loss: 0.0216  cls_loss: 0.0718  
<<<iteration:[400/525] - total_loss: 0.3324  obj_loss: 0.1812  noobj_loss: 0.0618  bbox_loss: 0.0158  cls_loss: 0.0415  
<<<iteration:[420/525] - total_l

<<<iteration:[40/525] - total_loss: 0.3094  obj_loss: 0.1638  noobj_loss: 0.0702  bbox_loss: 0.0121  cls_loss: 0.0500  
<<<iteration:[60/525] - total_loss: 0.3063  obj_loss: 0.1632  noobj_loss: 0.0691  bbox_loss: 0.0138  cls_loss: 0.0394  
<<<iteration:[80/525] - total_loss: 0.3393  obj_loss: 0.1512  noobj_loss: 0.0684  bbox_loss: 0.0167  cls_loss: 0.0704  
<<<iteration:[100/525] - total_loss: 0.3370  obj_loss: 0.1642  noobj_loss: 0.0707  bbox_loss: 0.0163  cls_loss: 0.0560  
<<<iteration:[120/525] - total_loss: 0.3282  obj_loss: 0.1553  noobj_loss: 0.0763  bbox_loss: 0.0130  cls_loss: 0.0699  
<<<iteration:[140/525] - total_loss: 0.3077  obj_loss: 0.1598  noobj_loss: 0.0675  bbox_loss: 0.0130  cls_loss: 0.0490  
<<<iteration:[160/525] - total_loss: 0.3513  obj_loss: 0.1481  noobj_loss: 0.0687  bbox_loss: 0.0171  cls_loss: 0.0835  
<<<iteration:[180/525] - total_loss: 0.3065  obj_loss: 0.1579  noobj_loss: 0.0687  bbox_loss: 0.0143  cls_loss: 0.0428  
<<<iteration:[200/525] - total_loss

<<<iteration:[340/525] - total_loss: 0.3102  obj_loss: 0.1574  noobj_loss: 0.0669  bbox_loss: 0.0141  cls_loss: 0.0488  
<<<iteration:[360/525] - total_loss: 0.3146  obj_loss: 0.1725  noobj_loss: 0.0695  bbox_loss: 0.0140  cls_loss: 0.0371  
<<<iteration:[380/525] - total_loss: 0.3319  obj_loss: 0.1695  noobj_loss: 0.0718  bbox_loss: 0.0138  cls_loss: 0.0574  
<<<iteration:[400/525] - total_loss: 0.2979  obj_loss: 0.1612  noobj_loss: 0.0692  bbox_loss: 0.0114  cls_loss: 0.0451  
<<<iteration:[420/525] - total_loss: 0.3253  obj_loss: 0.1690  noobj_loss: 0.0709  bbox_loss: 0.0122  cls_loss: 0.0601  
<<<iteration:[440/525] - total_loss: 0.3306  obj_loss: 0.1545  noobj_loss: 0.0736  bbox_loss: 0.0151  cls_loss: 0.0638  
<<<iteration:[460/525] - total_loss: 0.3194  obj_loss: 0.1641  noobj_loss: 0.0696  bbox_loss: 0.0123  cls_loss: 0.0589  
<<<iteration:[480/525] - total_loss: 0.3062  obj_loss: 0.1551  noobj_loss: 0.0662  bbox_loss: 0.0131  cls_loss: 0.0525  
<<<iteration:[500/525] - total_l

<<<iteration:[120/525] - total_loss: 0.2980  obj_loss: 0.1414  noobj_loss: 0.0777  bbox_loss: 0.0166  cls_loss: 0.0345  
<<<iteration:[140/525] - total_loss: 0.2970  obj_loss: 0.1598  noobj_loss: 0.0737  bbox_loss: 0.0128  cls_loss: 0.0362  
<<<iteration:[160/525] - total_loss: 0.2928  obj_loss: 0.1624  noobj_loss: 0.0746  bbox_loss: 0.0121  cls_loss: 0.0324  
<<<iteration:[180/525] - total_loss: 0.3049  obj_loss: 0.1558  noobj_loss: 0.0720  bbox_loss: 0.0127  cls_loss: 0.0495  
<<<iteration:[200/525] - total_loss: 0.3435  obj_loss: 0.1564  noobj_loss: 0.0764  bbox_loss: 0.0152  cls_loss: 0.0728  
<<<iteration:[220/525] - total_loss: 0.2994  obj_loss: 0.1591  noobj_loss: 0.0709  bbox_loss: 0.0121  cls_loss: 0.0442  
<<<iteration:[240/525] - total_loss: 0.3408  obj_loss: 0.1541  noobj_loss: 0.0856  bbox_loss: 0.0180  cls_loss: 0.0538  
<<<iteration:[260/525] - total_loss: 0.3131  obj_loss: 0.1598  noobj_loss: 0.0715  bbox_loss: 0.0138  cls_loss: 0.0488  
<<<iteration:[280/525] - total_l

<<<iteration:[420/525] - total_loss: 0.2920  obj_loss: 0.1524  noobj_loss: 0.0743  bbox_loss: 0.0115  cls_loss: 0.0451  
<<<iteration:[440/525] - total_loss: 0.3035  obj_loss: 0.1558  noobj_loss: 0.0768  bbox_loss: 0.0135  cls_loss: 0.0420  
<<<iteration:[460/525] - total_loss: 0.2946  obj_loss: 0.1590  noobj_loss: 0.0799  bbox_loss: 0.0106  cls_loss: 0.0428  
<<<iteration:[480/525] - total_loss: 0.2928  obj_loss: 0.1591  noobj_loss: 0.0771  bbox_loss: 0.0126  cls_loss: 0.0323  
<<<iteration:[500/525] - total_loss: 0.3166  obj_loss: 0.1754  noobj_loss: 0.0808  bbox_loss: 0.0128  cls_loss: 0.0369  
<<<iteration:[520/525] - total_loss: 0.3006  obj_loss: 0.1552  noobj_loss: 0.0846  bbox_loss: 0.0124  cls_loss: 0.0413  

epoch:26/100 - Train Loss: 0.2999, Val Loss: 0.3090

<<<iteration:[20/525] - total_loss: 0.3232  obj_loss: 0.1623  noobj_loss: 0.0788  bbox_loss: 0.0143  cls_loss: 0.0499  
<<<iteration:[40/525] - total_loss: 0.2974  obj_loss: 0.1760  noobj_loss: 0.0776  bbox_loss: 0.0108 

<<<iteration:[200/525] - total_loss: 0.3074  obj_loss: 0.1684  noobj_loss: 0.0852  bbox_loss: 0.0119  cls_loss: 0.0369  
<<<iteration:[220/525] - total_loss: 0.3132  obj_loss: 0.1783  noobj_loss: 0.0842  bbox_loss: 0.0102  cls_loss: 0.0421  
<<<iteration:[240/525] - total_loss: 0.3043  obj_loss: 0.1480  noobj_loss: 0.0854  bbox_loss: 0.0146  cls_loss: 0.0408  
<<<iteration:[260/525] - total_loss: 0.2970  obj_loss: 0.1655  noobj_loss: 0.0801  bbox_loss: 0.0108  cls_loss: 0.0374  
<<<iteration:[280/525] - total_loss: 0.2903  obj_loss: 0.1775  noobj_loss: 0.0802  bbox_loss: 0.0097  cls_loss: 0.0244  
<<<iteration:[300/525] - total_loss: 0.3052  obj_loss: 0.1703  noobj_loss: 0.0759  bbox_loss: 0.0134  cls_loss: 0.0301  
<<<iteration:[320/525] - total_loss: 0.3266  obj_loss: 0.1665  noobj_loss: 0.0803  bbox_loss: 0.0145  cls_loss: 0.0474  
<<<iteration:[340/525] - total_loss: 0.3303  obj_loss: 0.1678  noobj_loss: 0.0811  bbox_loss: 0.0136  cls_loss: 0.0538  
<<<iteration:[360/525] - total_l

<<<iteration:[500/525] - total_loss: 0.2997  obj_loss: 0.1642  noobj_loss: 0.0797  bbox_loss: 0.0116  cls_loss: 0.0376  
<<<iteration:[520/525] - total_loss: 0.3001  obj_loss: 0.1738  noobj_loss: 0.0871  bbox_loss: 0.0108  cls_loss: 0.0287  

epoch:31/100 - Train Loss: 0.2924, Val Loss: 0.2950

<<<iteration:[20/525] - total_loss: 0.3008  obj_loss: 0.1584  noobj_loss: 0.0919  bbox_loss: 0.0121  cls_loss: 0.0357  
<<<iteration:[40/525] - total_loss: 0.3006  obj_loss: 0.1702  noobj_loss: 0.0828  bbox_loss: 0.0098  cls_loss: 0.0399  
<<<iteration:[60/525] - total_loss: 0.2994  obj_loss: 0.1533  noobj_loss: 0.0881  bbox_loss: 0.0107  cls_loss: 0.0484  
<<<iteration:[80/525] - total_loss: 0.2788  obj_loss: 0.1686  noobj_loss: 0.0930  bbox_loss: 0.0090  cls_loss: 0.0190  
<<<iteration:[100/525] - total_loss: 0.3014  obj_loss: 0.1801  noobj_loss: 0.0786  bbox_loss: 0.0094  cls_loss: 0.0352  
<<<iteration:[120/525] - total_loss: 0.2787  obj_loss: 0.1406  noobj_loss: 0.0847  bbox_loss: 0.0120  c

<<<iteration:[280/525] - total_loss: 0.3068  obj_loss: 0.1842  noobj_loss: 0.0850  bbox_loss: 0.0108  cls_loss: 0.0262  
<<<iteration:[300/525] - total_loss: 0.2874  obj_loss: 0.1698  noobj_loss: 0.0931  bbox_loss: 0.0110  cls_loss: 0.0162  
<<<iteration:[320/525] - total_loss: 0.2983  obj_loss: 0.1709  noobj_loss: 0.0885  bbox_loss: 0.0103  cls_loss: 0.0317  
<<<iteration:[340/525] - total_loss: 0.2932  obj_loss: 0.1673  noobj_loss: 0.0829  bbox_loss: 0.0111  cls_loss: 0.0289  
<<<iteration:[360/525] - total_loss: 0.2974  obj_loss: 0.1853  noobj_loss: 0.0918  bbox_loss: 0.0085  cls_loss: 0.0237  
<<<iteration:[380/525] - total_loss: 0.2732  obj_loss: 0.1560  noobj_loss: 0.0931  bbox_loss: 0.0090  cls_loss: 0.0257  
<<<iteration:[400/525] - total_loss: 0.2914  obj_loss: 0.1513  noobj_loss: 0.0909  bbox_loss: 0.0108  cls_loss: 0.0404  
<<<iteration:[420/525] - total_loss: 0.2652  obj_loss: 0.1419  noobj_loss: 0.0937  bbox_loss: 0.0104  cls_loss: 0.0243  
<<<iteration:[440/525] - total_l

<<<iteration:[60/525] - total_loss: 0.2855  obj_loss: 0.1678  noobj_loss: 0.0889  bbox_loss: 0.0089  cls_loss: 0.0286  
<<<iteration:[80/525] - total_loss: 0.2913  obj_loss: 0.1459  noobj_loss: 0.0950  bbox_loss: 0.0111  cls_loss: 0.0424  
<<<iteration:[100/525] - total_loss: 0.2924  obj_loss: 0.1563  noobj_loss: 0.0841  bbox_loss: 0.0103  cls_loss: 0.0424  
<<<iteration:[120/525] - total_loss: 0.2838  obj_loss: 0.1642  noobj_loss: 0.0927  bbox_loss: 0.0102  cls_loss: 0.0223  
<<<iteration:[140/525] - total_loss: 0.2704  obj_loss: 0.1563  noobj_loss: 0.0884  bbox_loss: 0.0093  cls_loss: 0.0232  
<<<iteration:[160/525] - total_loss: 0.2761  obj_loss: 0.1450  noobj_loss: 0.0855  bbox_loss: 0.0108  cls_loss: 0.0341  
<<<iteration:[180/525] - total_loss: 0.2981  obj_loss: 0.1677  noobj_loss: 0.0846  bbox_loss: 0.0106  cls_loss: 0.0353  
<<<iteration:[200/525] - total_loss: 0.2802  obj_loss: 0.1496  noobj_loss: 0.0931  bbox_loss: 0.0111  cls_loss: 0.0285  
<<<iteration:[220/525] - total_los

<<<iteration:[360/525] - total_loss: 0.3133  obj_loss: 0.1931  noobj_loss: 0.0930  bbox_loss: 0.0094  cls_loss: 0.0265  
<<<iteration:[380/525] - total_loss: 0.3019  obj_loss: 0.1767  noobj_loss: 0.0963  bbox_loss: 0.0096  cls_loss: 0.0288  
<<<iteration:[400/525] - total_loss: 0.2833  obj_loss: 0.1543  noobj_loss: 0.0949  bbox_loss: 0.0110  cls_loss: 0.0264  
<<<iteration:[420/525] - total_loss: 0.2925  obj_loss: 0.1518  noobj_loss: 0.1031  bbox_loss: 0.0116  cls_loss: 0.0310  
<<<iteration:[440/525] - total_loss: 0.2693  obj_loss: 0.1509  noobj_loss: 0.0916  bbox_loss: 0.0100  cls_loss: 0.0224  
<<<iteration:[460/525] - total_loss: 0.2760  obj_loss: 0.1595  noobj_loss: 0.0977  bbox_loss: 0.0095  cls_loss: 0.0199  
<<<iteration:[480/525] - total_loss: 0.2880  obj_loss: 0.1599  noobj_loss: 0.0905  bbox_loss: 0.0097  cls_loss: 0.0341  
<<<iteration:[500/525] - total_loss: 0.2828  obj_loss: 0.1465  noobj_loss: 0.0944  bbox_loss: 0.0105  cls_loss: 0.0366  
<<<iteration:[520/525] - total_l

<<<iteration:[140/525] - total_loss: 0.2857  obj_loss: 0.1718  noobj_loss: 0.0926  bbox_loss: 0.0086  cls_loss: 0.0246  
<<<iteration:[160/525] - total_loss: 0.2717  obj_loss: 0.1580  noobj_loss: 0.0968  bbox_loss: 0.0102  cls_loss: 0.0146  
<<<iteration:[180/525] - total_loss: 0.2882  obj_loss: 0.1770  noobj_loss: 0.0890  bbox_loss: 0.0088  cls_loss: 0.0227  
<<<iteration:[200/525] - total_loss: 0.2932  obj_loss: 0.1720  noobj_loss: 0.0931  bbox_loss: 0.0098  cls_loss: 0.0257  
<<<iteration:[220/525] - total_loss: 0.2970  obj_loss: 0.1509  noobj_loss: 0.1015  bbox_loss: 0.0107  cls_loss: 0.0417  
<<<iteration:[240/525] - total_loss: 0.2884  obj_loss: 0.1593  noobj_loss: 0.1003  bbox_loss: 0.0101  cls_loss: 0.0288  
<<<iteration:[260/525] - total_loss: 0.2928  obj_loss: 0.1640  noobj_loss: 0.1004  bbox_loss: 0.0082  cls_loss: 0.0376  
<<<iteration:[280/525] - total_loss: 0.3013  obj_loss: 0.1804  noobj_loss: 0.1002  bbox_loss: 0.0095  cls_loss: 0.0235  
<<<iteration:[300/525] - total_l

<<<iteration:[440/525] - total_loss: 0.2637  obj_loss: 0.1532  noobj_loss: 0.1004  bbox_loss: 0.0080  cls_loss: 0.0202  
<<<iteration:[460/525] - total_loss: 0.3002  obj_loss: 0.1748  noobj_loss: 0.0999  bbox_loss: 0.0096  cls_loss: 0.0274  
<<<iteration:[480/525] - total_loss: 0.2889  obj_loss: 0.1778  noobj_loss: 0.1007  bbox_loss: 0.0078  cls_loss: 0.0219  
<<<iteration:[500/525] - total_loss: 0.3061  obj_loss: 0.1835  noobj_loss: 0.1039  bbox_loss: 0.0107  cls_loss: 0.0170  
<<<iteration:[520/525] - total_loss: 0.2953  obj_loss: 0.1509  noobj_loss: 0.1007  bbox_loss: 0.0123  cls_loss: 0.0324  

epoch:44/100 - Train Loss: 0.2804, Val Loss: 0.2825

<<<iteration:[20/525] - total_loss: 0.2966  obj_loss: 0.1855  noobj_loss: 0.0997  bbox_loss: 0.0076  cls_loss: 0.0233  
<<<iteration:[40/525] - total_loss: 0.2720  obj_loss: 0.1633  noobj_loss: 0.0988  bbox_loss: 0.0084  cls_loss: 0.0172  
<<<iteration:[60/525] - total_loss: 0.2848  obj_loss: 0.1658  noobj_loss: 0.1032  bbox_loss: 0.0094  

<<<iteration:[220/525] - total_loss: 0.2834  obj_loss: 0.1735  noobj_loss: 0.1037  bbox_loss: 0.0080  cls_loss: 0.0182  
<<<iteration:[240/525] - total_loss: 0.2667  obj_loss: 0.1541  noobj_loss: 0.0961  bbox_loss: 0.0097  cls_loss: 0.0159  
<<<iteration:[260/525] - total_loss: 0.2791  obj_loss: 0.1641  noobj_loss: 0.1034  bbox_loss: 0.0082  cls_loss: 0.0224  
<<<iteration:[280/525] - total_loss: 0.2664  obj_loss: 0.1504  noobj_loss: 0.0977  bbox_loss: 0.0085  cls_loss: 0.0247  
<<<iteration:[300/525] - total_loss: 0.2959  obj_loss: 0.1701  noobj_loss: 0.1010  bbox_loss: 0.0110  cls_loss: 0.0202  
<<<iteration:[320/525] - total_loss: 0.2825  obj_loss: 0.1716  noobj_loss: 0.0964  bbox_loss: 0.0086  cls_loss: 0.0199  
<<<iteration:[340/525] - total_loss: 0.2629  obj_loss: 0.1548  noobj_loss: 0.0929  bbox_loss: 0.0087  cls_loss: 0.0184  
<<<iteration:[360/525] - total_loss: 0.2954  obj_loss: 0.1751  noobj_loss: 0.1082  bbox_loss: 0.0093  cls_loss: 0.0199  
<<<iteration:[380/525] - total_l

<<<iteration:[520/525] - total_loss: 0.3063  obj_loss: 0.1803  noobj_loss: 0.1096  bbox_loss: 0.0077  cls_loss: 0.0324  

epoch:49/100 - Train Loss: 0.2765, Val Loss: 0.2882

<<<iteration:[20/525] - total_loss: 0.2634  obj_loss: 0.1581  noobj_loss: 0.0998  bbox_loss: 0.0082  cls_loss: 0.0145  
<<<iteration:[40/525] - total_loss: 0.2775  obj_loss: 0.1638  noobj_loss: 0.1096  bbox_loss: 0.0092  cls_loss: 0.0127  
<<<iteration:[60/525] - total_loss: 0.2747  obj_loss: 0.1685  noobj_loss: 0.0977  bbox_loss: 0.0073  cls_loss: 0.0207  
<<<iteration:[80/525] - total_loss: 0.2769  obj_loss: 0.1727  noobj_loss: 0.1016  bbox_loss: 0.0079  cls_loss: 0.0140  
<<<iteration:[100/525] - total_loss: 0.2774  obj_loss: 0.1691  noobj_loss: 0.1084  bbox_loss: 0.0075  cls_loss: 0.0166  
<<<iteration:[120/525] - total_loss: 0.2714  obj_loss: 0.1637  noobj_loss: 0.1033  bbox_loss: 0.0076  cls_loss: 0.0183  
<<<iteration:[140/525] - total_loss: 0.2905  obj_loss: 0.1646  noobj_loss: 0.1006  bbox_loss: 0.0096  c

<<<iteration:[300/525] - total_loss: 0.2683  obj_loss: 0.1657  noobj_loss: 0.1059  bbox_loss: 0.0074  cls_loss: 0.0129  
<<<iteration:[320/525] - total_loss: 0.2829  obj_loss: 0.1696  noobj_loss: 0.1050  bbox_loss: 0.0073  cls_loss: 0.0240  
<<<iteration:[340/525] - total_loss: 0.2723  obj_loss: 0.1575  noobj_loss: 0.1042  bbox_loss: 0.0090  cls_loss: 0.0175  
<<<iteration:[360/525] - total_loss: 0.2673  obj_loss: 0.1582  noobj_loss: 0.1164  bbox_loss: 0.0067  cls_loss: 0.0172  
<<<iteration:[380/525] - total_loss: 0.2412  obj_loss: 0.1457  noobj_loss: 0.1011  bbox_loss: 0.0061  cls_loss: 0.0147  
<<<iteration:[400/525] - total_loss: 0.2891  obj_loss: 0.1715  noobj_loss: 0.1070  bbox_loss: 0.0077  cls_loss: 0.0255  
<<<iteration:[420/525] - total_loss: 0.3026  obj_loss: 0.1681  noobj_loss: 0.1139  bbox_loss: 0.0085  cls_loss: 0.0352  
<<<iteration:[440/525] - total_loss: 0.2747  obj_loss: 0.1654  noobj_loss: 0.1093  bbox_loss: 0.0079  cls_loss: 0.0152  
<<<iteration:[460/525] - total_l

<<<iteration:[80/525] - total_loss: 0.2793  obj_loss: 0.1812  noobj_loss: 0.1086  bbox_loss: 0.0066  cls_loss: 0.0110  
<<<iteration:[100/525] - total_loss: 0.2716  obj_loss: 0.1611  noobj_loss: 0.1060  bbox_loss: 0.0083  cls_loss: 0.0158  
<<<iteration:[120/525] - total_loss: 0.2682  obj_loss: 0.1607  noobj_loss: 0.1046  bbox_loss: 0.0084  cls_loss: 0.0134  
<<<iteration:[140/525] - total_loss: 0.2563  obj_loss: 0.1492  noobj_loss: 0.1030  bbox_loss: 0.0069  cls_loss: 0.0209  
<<<iteration:[160/525] - total_loss: 0.2816  obj_loss: 0.1727  noobj_loss: 0.1052  bbox_loss: 0.0076  cls_loss: 0.0182  
<<<iteration:[180/525] - total_loss: 0.2719  obj_loss: 0.1717  noobj_loss: 0.1138  bbox_loss: 0.0060  cls_loss: 0.0131  
<<<iteration:[200/525] - total_loss: 0.2870  obj_loss: 0.1818  noobj_loss: 0.1091  bbox_loss: 0.0066  cls_loss: 0.0174  
<<<iteration:[220/525] - total_loss: 0.2872  obj_loss: 0.1675  noobj_loss: 0.1157  bbox_loss: 0.0089  cls_loss: 0.0175  
<<<iteration:[240/525] - total_lo

<<<iteration:[380/525] - total_loss: 0.2746  obj_loss: 0.1684  noobj_loss: 0.1029  bbox_loss: 0.0075  cls_loss: 0.0171  
<<<iteration:[400/525] - total_loss: 0.2742  obj_loss: 0.1645  noobj_loss: 0.1169  bbox_loss: 0.0067  cls_loss: 0.0179  
<<<iteration:[420/525] - total_loss: 0.2701  obj_loss: 0.1558  noobj_loss: 0.1143  bbox_loss: 0.0073  cls_loss: 0.0206  
<<<iteration:[440/525] - total_loss: 0.2924  obj_loss: 0.1627  noobj_loss: 0.1132  bbox_loss: 0.0087  cls_loss: 0.0297  
<<<iteration:[460/525] - total_loss: 0.2750  obj_loss: 0.1672  noobj_loss: 0.1168  bbox_loss: 0.0074  cls_loss: 0.0124  
<<<iteration:[480/525] - total_loss: 0.2747  obj_loss: 0.1682  noobj_loss: 0.1104  bbox_loss: 0.0080  cls_loss: 0.0113  
<<<iteration:[500/525] - total_loss: 0.2685  obj_loss: 0.1627  noobj_loss: 0.1162  bbox_loss: 0.0065  cls_loss: 0.0154  
<<<iteration:[520/525] - total_loss: 0.2498  obj_loss: 0.1430  noobj_loss: 0.1043  bbox_loss: 0.0072  cls_loss: 0.0186  

epoch:57/100 - Train Loss: 0.27

<<<iteration:[160/525] - total_loss: 0.2464  obj_loss: 0.1443  noobj_loss: 0.0998  bbox_loss: 0.0076  cls_loss: 0.0144  
<<<iteration:[180/525] - total_loss: 0.2586  obj_loss: 0.1510  noobj_loss: 0.1072  bbox_loss: 0.0073  cls_loss: 0.0175  
<<<iteration:[200/525] - total_loss: 0.2485  obj_loss: 0.1498  noobj_loss: 0.1098  bbox_loss: 0.0058  cls_loss: 0.0146  
<<<iteration:[220/525] - total_loss: 0.2641  obj_loss: 0.1373  noobj_loss: 0.1067  bbox_loss: 0.0096  cls_loss: 0.0257  
<<<iteration:[240/525] - total_loss: 0.2696  obj_loss: 0.1608  noobj_loss: 0.1113  bbox_loss: 0.0080  cls_loss: 0.0133  
<<<iteration:[260/525] - total_loss: 0.2461  obj_loss: 0.1429  noobj_loss: 0.1048  bbox_loss: 0.0076  cls_loss: 0.0130  
<<<iteration:[280/525] - total_loss: 0.2599  obj_loss: 0.1407  noobj_loss: 0.1114  bbox_loss: 0.0094  cls_loss: 0.0163  
<<<iteration:[300/525] - total_loss: 0.2710  obj_loss: 0.1635  noobj_loss: 0.1030  bbox_loss: 0.0074  cls_loss: 0.0192  
<<<iteration:[320/525] - total_l

<<<iteration:[460/525] - total_loss: 0.2728  obj_loss: 0.1612  noobj_loss: 0.1182  bbox_loss: 0.0070  cls_loss: 0.0175  
<<<iteration:[480/525] - total_loss: 0.2651  obj_loss: 0.1568  noobj_loss: 0.1046  bbox_loss: 0.0067  cls_loss: 0.0225  
<<<iteration:[500/525] - total_loss: 0.2611  obj_loss: 0.1606  noobj_loss: 0.1015  bbox_loss: 0.0076  cls_loss: 0.0120  
<<<iteration:[520/525] - total_loss: 0.2671  obj_loss: 0.1489  noobj_loss: 0.1129  bbox_loss: 0.0084  cls_loss: 0.0197  

epoch:62/100 - Train Loss: 0.2679, Val Loss: 0.2743

<<<iteration:[20/525] - total_loss: 0.2855  obj_loss: 0.1732  noobj_loss: 0.1164  bbox_loss: 0.0075  cls_loss: 0.0163  
<<<iteration:[40/525] - total_loss: 0.2875  obj_loss: 0.1685  noobj_loss: 0.1144  bbox_loss: 0.0080  cls_loss: 0.0219  
<<<iteration:[60/525] - total_loss: 0.2841  obj_loss: 0.1499  noobj_loss: 0.1125  bbox_loss: 0.0085  cls_loss: 0.0353  
<<<iteration:[80/525] - total_loss: 0.2535  obj_loss: 0.1547  noobj_loss: 0.1108  bbox_loss: 0.0069  c

<<<iteration:[240/525] - total_loss: 0.2912  obj_loss: 0.1712  noobj_loss: 0.1246  bbox_loss: 0.0086  cls_loss: 0.0149  
<<<iteration:[260/525] - total_loss: 0.2505  obj_loss: 0.1468  noobj_loss: 0.1123  bbox_loss: 0.0070  cls_loss: 0.0125  
<<<iteration:[280/525] - total_loss: 0.2764  obj_loss: 0.1703  noobj_loss: 0.1089  bbox_loss: 0.0080  cls_loss: 0.0117  
<<<iteration:[300/525] - total_loss: 0.2792  obj_loss: 0.1735  noobj_loss: 0.1165  bbox_loss: 0.0066  cls_loss: 0.0145  
<<<iteration:[320/525] - total_loss: 0.2633  obj_loss: 0.1612  noobj_loss: 0.1141  bbox_loss: 0.0063  cls_loss: 0.0137  
<<<iteration:[340/525] - total_loss: 0.2635  obj_loss: 0.1642  noobj_loss: 0.1186  bbox_loss: 0.0058  cls_loss: 0.0109  
<<<iteration:[360/525] - total_loss: 0.2725  obj_loss: 0.1671  noobj_loss: 0.1154  bbox_loss: 0.0068  cls_loss: 0.0137  
<<<iteration:[380/525] - total_loss: 0.2778  obj_loss: 0.1480  noobj_loss: 0.1270  bbox_loss: 0.0084  cls_loss: 0.0245  
<<<iteration:[400/525] - total_l


epoch:67/100 - Train Loss: 0.2690, Val Loss: 0.2805

<<<iteration:[20/525] - total_loss: 0.2659  obj_loss: 0.1478  noobj_loss: 0.1218  bbox_loss: 0.0079  cls_loss: 0.0176  
<<<iteration:[40/525] - total_loss: 0.2667  obj_loss: 0.1530  noobj_loss: 0.1123  bbox_loss: 0.0085  cls_loss: 0.0148  
<<<iteration:[60/525] - total_loss: 0.2700  obj_loss: 0.1672  noobj_loss: 0.1159  bbox_loss: 0.0065  cls_loss: 0.0122  
<<<iteration:[80/525] - total_loss: 0.2670  obj_loss: 0.1624  noobj_loss: 0.1199  bbox_loss: 0.0068  cls_loss: 0.0108  
<<<iteration:[100/525] - total_loss: 0.2577  obj_loss: 0.1617  noobj_loss: 0.1182  bbox_loss: 0.0061  cls_loss: 0.0067  
<<<iteration:[120/525] - total_loss: 0.2729  obj_loss: 0.1769  noobj_loss: 0.1152  bbox_loss: 0.0060  cls_loss: 0.0085  
<<<iteration:[140/525] - total_loss: 0.2554  obj_loss: 0.1527  noobj_loss: 0.1231  bbox_loss: 0.0062  cls_loss: 0.0100  
<<<iteration:[160/525] - total_loss: 0.2597  obj_loss: 0.1540  noobj_loss: 0.1117  bbox_loss: 0.0072  c

<<<iteration:[320/525] - total_loss: 0.2688  obj_loss: 0.1683  noobj_loss: 0.1284  bbox_loss: 0.0052  cls_loss: 0.0103  
<<<iteration:[340/525] - total_loss: 0.2809  obj_loss: 0.1705  noobj_loss: 0.1216  bbox_loss: 0.0081  cls_loss: 0.0091  
<<<iteration:[360/525] - total_loss: 0.2698  obj_loss: 0.1581  noobj_loss: 0.1245  bbox_loss: 0.0076  cls_loss: 0.0114  
<<<iteration:[380/525] - total_loss: 0.3023  obj_loss: 0.1743  noobj_loss: 0.1255  bbox_loss: 0.0098  cls_loss: 0.0164  
<<<iteration:[400/525] - total_loss: 0.2714  obj_loss: 0.1600  noobj_loss: 0.1114  bbox_loss: 0.0082  cls_loss: 0.0146  
<<<iteration:[420/525] - total_loss: 0.2621  obj_loss: 0.1606  noobj_loss: 0.1186  bbox_loss: 0.0064  cls_loss: 0.0104  
<<<iteration:[440/525] - total_loss: 0.3010  obj_loss: 0.1670  noobj_loss: 0.1299  bbox_loss: 0.0085  cls_loss: 0.0263  
<<<iteration:[460/525] - total_loss: 0.2628  obj_loss: 0.1561  noobj_loss: 0.1250  bbox_loss: 0.0072  cls_loss: 0.0083  
<<<iteration:[480/525] - total_l

<<<iteration:[100/525] - total_loss: 0.2675  obj_loss: 0.1672  noobj_loss: 0.1270  bbox_loss: 0.0052  cls_loss: 0.0110  
<<<iteration:[120/525] - total_loss: 0.2807  obj_loss: 0.1725  noobj_loss: 0.1226  bbox_loss: 0.0071  cls_loss: 0.0113  
<<<iteration:[140/525] - total_loss: 0.2826  obj_loss: 0.1756  noobj_loss: 0.1257  bbox_loss: 0.0067  cls_loss: 0.0106  
<<<iteration:[160/525] - total_loss: 0.2618  obj_loss: 0.1506  noobj_loss: 0.1241  bbox_loss: 0.0077  cls_loss: 0.0107  
<<<iteration:[180/525] - total_loss: 0.2579  obj_loss: 0.1448  noobj_loss: 0.1271  bbox_loss: 0.0084  cls_loss: 0.0074  
<<<iteration:[200/525] - total_loss: 0.2711  obj_loss: 0.1726  noobj_loss: 0.1158  bbox_loss: 0.0052  cls_loss: 0.0146  
<<<iteration:[220/525] - total_loss: 0.2588  obj_loss: 0.1567  noobj_loss: 0.1224  bbox_loss: 0.0057  cls_loss: 0.0123  
<<<iteration:[240/525] - total_loss: 0.2445  obj_loss: 0.1433  noobj_loss: 0.1210  bbox_loss: 0.0061  cls_loss: 0.0101  
<<<iteration:[260/525] - total_l

<<<iteration:[400/525] - total_loss: 0.2642  obj_loss: 0.1655  noobj_loss: 0.1284  bbox_loss: 0.0052  cls_loss: 0.0087  
<<<iteration:[420/525] - total_loss: 0.2670  obj_loss: 0.1584  noobj_loss: 0.1268  bbox_loss: 0.0068  cls_loss: 0.0111  
<<<iteration:[440/525] - total_loss: 0.2690  obj_loss: 0.1640  noobj_loss: 0.1248  bbox_loss: 0.0067  cls_loss: 0.0092  
<<<iteration:[460/525] - total_loss: 0.2550  obj_loss: 0.1476  noobj_loss: 0.1206  bbox_loss: 0.0075  cls_loss: 0.0098  
<<<iteration:[480/525] - total_loss: 0.2514  obj_loss: 0.1530  noobj_loss: 0.1185  bbox_loss: 0.0062  cls_loss: 0.0082  
<<<iteration:[500/525] - total_loss: 0.2760  obj_loss: 0.1727  noobj_loss: 0.1262  bbox_loss: 0.0060  cls_loss: 0.0102  
<<<iteration:[520/525] - total_loss: 0.2720  obj_loss: 0.1724  noobj_loss: 0.1179  bbox_loss: 0.0061  cls_loss: 0.0099  

epoch:75/100 - Train Loss: 0.2681, Val Loss: 0.2705

<<<iteration:[20/525] - total_loss: 0.2783  obj_loss: 0.1669  noobj_loss: 0.1330  bbox_loss: 0.0071

<<<iteration:[180/525] - total_loss: 0.2718  obj_loss: 0.1660  noobj_loss: 0.1292  bbox_loss: 0.0058  cls_loss: 0.0123  
<<<iteration:[200/525] - total_loss: 0.2506  obj_loss: 0.1550  noobj_loss: 0.1110  bbox_loss: 0.0059  cls_loss: 0.0105  
<<<iteration:[220/525] - total_loss: 0.2528  obj_loss: 0.1466  noobj_loss: 0.1179  bbox_loss: 0.0065  cls_loss: 0.0146  
<<<iteration:[240/525] - total_loss: 0.2698  obj_loss: 0.1683  noobj_loss: 0.1260  bbox_loss: 0.0053  cls_loss: 0.0122  
<<<iteration:[260/525] - total_loss: 0.2705  obj_loss: 0.1622  noobj_loss: 0.1283  bbox_loss: 0.0065  cls_loss: 0.0116  
<<<iteration:[280/525] - total_loss: 0.4080  obj_loss: 0.1087  noobj_loss: 0.1097  bbox_loss: 0.0452  cls_loss: 0.0187  
<<<iteration:[300/525] - total_loss: 0.5235  obj_loss: 0.0941  noobj_loss: 0.1031  bbox_loss: 0.0714  cls_loss: 0.0207  
<<<iteration:[320/525] - total_loss: 0.4834  obj_loss: 0.0860  noobj_loss: 0.1172  bbox_loss: 0.0610  cls_loss: 0.0338  
<<<iteration:[340/525] - total_l

<<<iteration:[480/525] - total_loss: 0.2590  obj_loss: 0.1534  noobj_loss: 0.1034  bbox_loss: 0.0081  cls_loss: 0.0135  
<<<iteration:[500/525] - total_loss: 0.2545  obj_loss: 0.1405  noobj_loss: 0.1094  bbox_loss: 0.0076  cls_loss: 0.0215  
<<<iteration:[520/525] - total_loss: 0.2557  obj_loss: 0.1427  noobj_loss: 0.0988  bbox_loss: 0.0095  cls_loss: 0.0163  

epoch:80/100 - Train Loss: 0.2493, Val Loss: 0.2578

<<<iteration:[20/525] - total_loss: 0.2733  obj_loss: 0.1552  noobj_loss: 0.1159  bbox_loss: 0.0074  cls_loss: 0.0232  
<<<iteration:[40/525] - total_loss: 0.2289  obj_loss: 0.1144  noobj_loss: 0.1034  bbox_loss: 0.0105  cls_loss: 0.0103  
<<<iteration:[60/525] - total_loss: 0.2365  obj_loss: 0.1344  noobj_loss: 0.1139  bbox_loss: 0.0068  cls_loss: 0.0111  
<<<iteration:[80/525] - total_loss: 0.2740  obj_loss: 0.1688  noobj_loss: 0.1113  bbox_loss: 0.0079  cls_loss: 0.0099  
<<<iteration:[100/525] - total_loss: 0.2369  obj_loss: 0.1338  noobj_loss: 0.1048  bbox_loss: 0.0082  c

<<<iteration:[260/525] - total_loss: 0.2519  obj_loss: 0.1529  noobj_loss: 0.1172  bbox_loss: 0.0066  cls_loss: 0.0077  
<<<iteration:[280/525] - total_loss: 0.2742  obj_loss: 0.1739  noobj_loss: 0.1128  bbox_loss: 0.0063  cls_loss: 0.0124  
<<<iteration:[300/525] - total_loss: 0.2620  obj_loss: 0.1612  noobj_loss: 0.1160  bbox_loss: 0.0065  cls_loss: 0.0105  
<<<iteration:[320/525] - total_loss: 0.2298  obj_loss: 0.1234  noobj_loss: 0.1141  bbox_loss: 0.0072  cls_loss: 0.0135  
<<<iteration:[340/525] - total_loss: 0.2453  obj_loss: 0.1489  noobj_loss: 0.1071  bbox_loss: 0.0058  cls_loss: 0.0139  
<<<iteration:[360/525] - total_loss: 0.2587  obj_loss: 0.1561  noobj_loss: 0.1219  bbox_loss: 0.0068  cls_loss: 0.0078  
<<<iteration:[380/525] - total_loss: 0.2437  obj_loss: 0.1381  noobj_loss: 0.1193  bbox_loss: 0.0072  cls_loss: 0.0101  
<<<iteration:[400/525] - total_loss: 0.2612  obj_loss: 0.1638  noobj_loss: 0.1109  bbox_loss: 0.0067  cls_loss: 0.0083  
<<<iteration:[420/525] - total_l

<<<iteration:[40/525] - total_loss: 0.2405  obj_loss: 0.1418  noobj_loss: 0.1230  bbox_loss: 0.0057  cls_loss: 0.0087  
<<<iteration:[60/525] - total_loss: 0.2439  obj_loss: 0.1443  noobj_loss: 0.1110  bbox_loss: 0.0059  cls_loss: 0.0144  
<<<iteration:[80/525] - total_loss: 0.2517  obj_loss: 0.1645  noobj_loss: 0.1093  bbox_loss: 0.0053  cls_loss: 0.0060  
<<<iteration:[100/525] - total_loss: 0.2418  obj_loss: 0.1413  noobj_loss: 0.1092  bbox_loss: 0.0061  cls_loss: 0.0154  
<<<iteration:[120/525] - total_loss: 0.2344  obj_loss: 0.1392  noobj_loss: 0.1192  bbox_loss: 0.0056  cls_loss: 0.0077  
<<<iteration:[140/525] - total_loss: 0.2394  obj_loss: 0.1400  noobj_loss: 0.1218  bbox_loss: 0.0060  cls_loss: 0.0087  
<<<iteration:[160/525] - total_loss: 0.2630  obj_loss: 0.1492  noobj_loss: 0.1203  bbox_loss: 0.0087  cls_loss: 0.0100  
<<<iteration:[180/525] - total_loss: 0.2570  obj_loss: 0.1577  noobj_loss: 0.1218  bbox_loss: 0.0058  cls_loss: 0.0093  
<<<iteration:[200/525] - total_loss

<<<iteration:[340/525] - total_loss: 0.2774  obj_loss: 0.1670  noobj_loss: 0.1152  bbox_loss: 0.0068  cls_loss: 0.0186  
<<<iteration:[360/525] - total_loss: 0.2536  obj_loss: 0.1512  noobj_loss: 0.1184  bbox_loss: 0.0067  cls_loss: 0.0097  
<<<iteration:[380/525] - total_loss: 0.2670  obj_loss: 0.1726  noobj_loss: 0.1235  bbox_loss: 0.0052  cls_loss: 0.0063  
<<<iteration:[400/525] - total_loss: 0.2567  obj_loss: 0.1502  noobj_loss: 0.1268  bbox_loss: 0.0064  cls_loss: 0.0111  
<<<iteration:[420/525] - total_loss: 0.2677  obj_loss: 0.1660  noobj_loss: 0.1212  bbox_loss: 0.0059  cls_loss: 0.0114  
<<<iteration:[440/525] - total_loss: 0.2531  obj_loss: 0.1471  noobj_loss: 0.1302  bbox_loss: 0.0057  cls_loss: 0.0122  
<<<iteration:[460/525] - total_loss: 0.2508  obj_loss: 0.1490  noobj_loss: 0.1144  bbox_loss: 0.0068  cls_loss: 0.0108  
<<<iteration:[480/525] - total_loss: 0.2332  obj_loss: 0.1446  noobj_loss: 0.1142  bbox_loss: 0.0051  cls_loss: 0.0058  
<<<iteration:[500/525] - total_l

<<<iteration:[120/525] - total_loss: 0.2435  obj_loss: 0.1424  noobj_loss: 0.1193  bbox_loss: 0.0065  cls_loss: 0.0092  
<<<iteration:[140/525] - total_loss: 0.2570  obj_loss: 0.1518  noobj_loss: 0.1257  bbox_loss: 0.0060  cls_loss: 0.0123  
<<<iteration:[160/525] - total_loss: 0.2574  obj_loss: 0.1536  noobj_loss: 0.1337  bbox_loss: 0.0063  cls_loss: 0.0056  
<<<iteration:[180/525] - total_loss: 0.2300  obj_loss: 0.1324  noobj_loss: 0.1197  bbox_loss: 0.0058  cls_loss: 0.0087  
<<<iteration:[200/525] - total_loss: 0.2921  obj_loss: 0.1719  noobj_loss: 0.1232  bbox_loss: 0.0069  cls_loss: 0.0240  
<<<iteration:[220/525] - total_loss: 0.2398  obj_loss: 0.1445  noobj_loss: 0.1191  bbox_loss: 0.0058  cls_loss: 0.0069  
<<<iteration:[240/525] - total_loss: 0.2716  obj_loss: 0.1664  noobj_loss: 0.1312  bbox_loss: 0.0060  cls_loss: 0.0097  
<<<iteration:[260/525] - total_loss: 0.2537  obj_loss: 0.1479  noobj_loss: 0.1295  bbox_loss: 0.0057  cls_loss: 0.0123  
<<<iteration:[280/525] - total_l

<<<iteration:[420/525] - total_loss: 0.2703  obj_loss: 0.1669  noobj_loss: 0.1347  bbox_loss: 0.0049  cls_loss: 0.0114  
<<<iteration:[440/525] - total_loss: 0.2216  obj_loss: 0.1253  noobj_loss: 0.1197  bbox_loss: 0.0060  cls_loss: 0.0065  
<<<iteration:[460/525] - total_loss: 0.2370  obj_loss: 0.1372  noobj_loss: 0.1207  bbox_loss: 0.0056  cls_loss: 0.0113  
<<<iteration:[480/525] - total_loss: 0.2547  obj_loss: 0.1618  noobj_loss: 0.1193  bbox_loss: 0.0053  cls_loss: 0.0070  
<<<iteration:[500/525] - total_loss: 0.2508  obj_loss: 0.1507  noobj_loss: 0.1322  bbox_loss: 0.0056  cls_loss: 0.0060  
<<<iteration:[520/525] - total_loss: 0.2647  obj_loss: 0.1733  noobj_loss: 0.1212  bbox_loss: 0.0045  cls_loss: 0.0083  

epoch:93/100 - Train Loss: 0.2518, Val Loss: 0.2607

<<<iteration:[20/525] - total_loss: 0.2728  obj_loss: 0.1731  noobj_loss: 0.1273  bbox_loss: 0.0056  cls_loss: 0.0078  
<<<iteration:[40/525] - total_loss: 0.2560  obj_loss: 0.1612  noobj_loss: 0.1205  bbox_loss: 0.0050 

<<<iteration:[200/525] - total_loss: 0.2678  obj_loss: 0.1618  noobj_loss: 0.1340  bbox_loss: 0.0062  cls_loss: 0.0077  
<<<iteration:[220/525] - total_loss: 0.2494  obj_loss: 0.1465  noobj_loss: 0.1297  bbox_loss: 0.0053  cls_loss: 0.0114  
<<<iteration:[240/525] - total_loss: 0.2594  obj_loss: 0.1593  noobj_loss: 0.1379  bbox_loss: 0.0048  cls_loss: 0.0070  
<<<iteration:[260/525] - total_loss: 0.2394  obj_loss: 0.1403  noobj_loss: 0.1284  bbox_loss: 0.0056  cls_loss: 0.0068  
<<<iteration:[280/525] - total_loss: 0.2604  obj_loss: 0.1590  noobj_loss: 0.1271  bbox_loss: 0.0057  cls_loss: 0.0091  
<<<iteration:[300/525] - total_loss: 0.2595  obj_loss: 0.1579  noobj_loss: 0.1222  bbox_loss: 0.0066  cls_loss: 0.0077  
<<<iteration:[320/525] - total_loss: 0.2654  obj_loss: 0.1647  noobj_loss: 0.1223  bbox_loss: 0.0060  cls_loss: 0.0094  
<<<iteration:[340/525] - total_loss: 0.2466  obj_loss: 0.1553  noobj_loss: 0.1287  bbox_loss: 0.0044  cls_loss: 0.0049  
<<<iteration:[360/525] - total_l

<<<iteration:[500/525] - total_loss: 0.2487  obj_loss: 0.1503  noobj_loss: 0.1393  bbox_loss: 0.0049  cls_loss: 0.0044  
<<<iteration:[520/525] - total_loss: 0.2343  obj_loss: 0.1407  noobj_loss: 0.1282  bbox_loss: 0.0047  cls_loss: 0.0061  

epoch:98/100 - Train Loss: 0.2549, Val Loss: 0.2608

<<<iteration:[20/525] - total_loss: 0.2649  obj_loss: 0.1676  noobj_loss: 0.1381  bbox_loss: 0.0047  cls_loss: 0.0044  
<<<iteration:[40/525] - total_loss: 0.2672  obj_loss: 0.1632  noobj_loss: 0.1297  bbox_loss: 0.0059  cls_loss: 0.0097  
<<<iteration:[60/525] - total_loss: 0.2302  obj_loss: 0.1341  noobj_loss: 0.1287  bbox_loss: 0.0051  cls_loss: 0.0064  
<<<iteration:[80/525] - total_loss: 0.2520  obj_loss: 0.1544  noobj_loss: 0.1389  bbox_loss: 0.0045  cls_loss: 0.0056  
<<<iteration:[100/525] - total_loss: 0.2640  obj_loss: 0.1590  noobj_loss: 0.1331  bbox_loss: 0.0048  cls_loss: 0.0146  
<<<iteration:[120/525] - total_loss: 0.2503  obj_loss: 0.1431  noobj_loss: 0.1344  bbox_loss: 0.0064  c

0,1
Train Loss,█▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train bbox Loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train class Loss,█▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train obj Loss,▁▂▅▅▇▇▇████████████████▇▇██████▆▇▇▇▇▇▇▇▇
Val Loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val bbox Loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val class Loss,█▄▃▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val obj Loss,▁▄▆▆▇▇█▇█████▇█▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▇

0,1
Train Loss,0.25024
Train bbox Loss,0.00558
Train class Loss,0.00841
Train obj Loss,0.1501
Val Loss,0.26846
Val bbox Loss,0.00763
Val class Loss,0.00489
Val obj Loss,0.15236


# Test Dataset Inference

In [20]:
import numpy as np
import os 
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torch import nn
import torchsummary
from torch.utils.data import DataLoader
from collections import defaultdict
from torchvision.utils import make_grid

In [21]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [22]:
def load_model(ckpt_path, num_classes, device):
    checkpoint = torch.load(ckpt_path, map_location=device)
    model = YOLO_SWIN(num_classes=num_classes)
    model.load_state_dict(checkpoint)
    model = model.to(device)
    model.eval()
    return model

In [23]:
IMAGE_SIZE=448
transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ],
        bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
)

In [24]:
# ckpt_path="./trained_model/YOLO_SWIN_T_body_LR0.0001_AUG30/model_90.pth"
ckpt_path="/workspace/Plastic_Bottle_defect_detection/trained_model/YOLO_SWIN_T_neck_LR0.0001_AUG20/model_100.pth"
model = load_model(ckpt_path, NUM_CLASSES, device)

In [25]:
NECK_PATH = '/home/host_data/PET_data/Neck'
BODY_PATH = '/home/host_data/PET_data/Body'
test_dataset=PET_dataset("neck" ,neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='test', transformer=transformer, aug=None)
test_dataloaders = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

start making augmented images-- augmented factor:0
total length of augmented images: 0


In [26]:
len(test_dataset)

25

In [27]:
@torch.no_grad()
def model_predict(image, model, conf_thres=0.2, iou_threshold=0.1):
    predictions = model(image)
    prediction = predictions.detach().cpu().squeeze(dim=0)
#     print(prediction.shape)
    
    grid_size = prediction.shape[-1]
    y_grid, x_grid = torch.meshgrid(torch.arange(grid_size), torch.arange(grid_size))
    stride_size = IMAGE_SIZE/grid_size

    conf = prediction[[0,5], ...].reshape(1, -1)
    xc = (prediction[[1,6], ...] * IMAGE_SIZE + x_grid*stride_size).reshape(1,-1)
    yc = (prediction[[2,7], ...] * IMAGE_SIZE + y_grid*stride_size).reshape(1,-1)
    w = (prediction[[3,8], ...] * IMAGE_SIZE).reshape(1,-1)
    h = (prediction[[4,9], ...] * IMAGE_SIZE).reshape(1,-1)
    cls = torch.max(prediction[10:, ...].reshape(NUM_CLASSES, -1), dim=0).indices.tile(1,2)
    
    x_min = xc - w/2
    y_min = yc - h/2
    x_max = xc + w/2
    y_max = yc + h/2

    prediction_res = torch.cat([x_min, y_min, x_max, y_max, conf, cls], dim=0)
    prediction_res = prediction_res.transpose(0,1)

    # x_min과 y_min이 음수가 되지않고, x_max와 y_max가 이미지 크기를 넘지 않게 제한
    prediction_res[:, 2].clip(min=0, max=image.shape[1]) 
    prediction_res[:, 3].clip(min=0, max=image.shape[0])
        
    pred_res = prediction_res[prediction_res[:, 4] > conf_thres]
    nms_index = torchvision.ops.nms(boxes=pred_res[:, 0:4], scores=pred_res[:, 4], iou_threshold=iou_threshold)
    pred_res_ = pred_res[nms_index].numpy()
    
    n_obj = pred_res_.shape[0]
    bboxes = np.zeros(shape=(n_obj, 4), dtype=np.float32)
    bboxes[:, 0:2] = (pred_res_[:, 0:2] + pred_res_[:, 2:4]) / 2
    bboxes[:, 2:4] = pred_res_[:, 2:4] - pred_res_[:, 0:2]
    scores = pred_res_[:, 4]
    class_ids = pred_res_[:, 5]
    
    # 이미지 값이 들어가면 모델을 통해서, 후처리까지 포함된 yolo 포멧의 box좌표, 그 좌표에 대한 confidence score
    # 그리고 class id를 반환
    return bboxes, scores, class_ids

In [28]:
pred_images = []
pred_labels =[]

for index, batch in enumerate(test_dataloaders):
    images = batch[0].to(device)
    bboxes, scores, class_ids = model_predict(images, model, conf_thres=0.1, iou_threshold=0.1)
    
    if len(bboxes) > 0:
        prediction_yolo = np.concatenate([bboxes, scores[:, np.newaxis], class_ids[:, np.newaxis]], axis=1)
    else:
        prediction_yolo = np.array([])
    
    # 텐서형의 이미지를 다시 unnormalize를 시키고, 다시 chw를 hwc로 바꾸고 넘파이로 바꾼다.
    np_image = make_grid(images[0], normalize=True).cpu().permute(1,2,0).numpy()
    pred_images.append(np_image)
    pred_labels.append(prediction_yolo)

    

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [29]:
from ipywidgets import interact

@interact(index=(0,len(pred_images)-1))
def show_result(index=0):
    print(pred_labels[index])
    if len(pred_labels[index]) > 0:
        result = visualize(pred_images[index], pred_labels[index][:, 0:4], pred_labels[index][:, 5])
    else:
        result = pred_images[index]
        
    plt.figure(figsize=(6,6))
    plt.imshow(result)
    plt.show()

interactive(children=(IntSlider(value=0, description='index', max=24), Output()), _dom_classes=('widget-intera…