In [18]:
import numpy as np
import os 
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torch import nn
import torchsummary
from torch.utils.data import DataLoader
from collections import defaultdict
from torchvision.utils import make_grid

In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Utils

In [20]:
CLASS_NAME_TO_ID = {'Unformed': 0, 'Burr': 1}
CLASS_ID_TO_NAME = {0: 'Unformed', 1: 'Burr'}
BOX_COLOR = {'Unformed':(200, 0, 0), 'Burr':(0, 0, 200)}
TEXT_COLOR = (255, 255, 255)

def save_model(model_state, model_name, save_dir="./trained_model"):
    os.makedirs(save_dir, exist_ok=True)
    torch.save(model_state, os.path.join(save_dir, model_name))


def visualize_bbox(image, bbox, class_name, color=BOX_COLOR, thickness=2):
    x_center, y_center, w, h = bbox
    x_min = int(x_center - w/2)
    y_min = int(y_center - h/2)
    x_max = int(x_center + w/2)
    y_max = int(y_center + h/2)
    
    cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=color[class_name], thickness=thickness)
    
    ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)    
    cv2.rectangle(image, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), color[class_name], -1)
    cv2.putText(
        image,
        text=class_name,
        org=(x_min, y_min - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35, 
        color=TEXT_COLOR, 
        lineType=cv2.LINE_AA,
    )
    return image


def visualize(image, bboxes, category_ids):
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
#         print('category_id: ',category_id)
        class_name = CLASS_ID_TO_NAME[category_id.item()]
        img = visualize_bbox(img, bbox, class_name)
    return img

## Datasets

In [21]:
class PET_dataset():
    def __init__(self,part,neck_dir,body_dir,phase, transformer=None, aug=None, aug_factor=0):
        self.neck_dir=neck_dir
        self.body_dir=body_dir
        self.part=part
        self.phase=phase
        self.transformer=transformer
        self.aug=aug
        self.aug_factor=aug_factor
        if(self.part=="body"):
            self.image_files = sorted([fn for fn in os.listdir(self.body_dir+"/"+self.phase+"/image") if fn.endswith("jpg")])
            self.label_files= sorted([lab for lab in os.listdir(self.body_dir+"/"+self.phase+"/label") if lab.endswith("txt")])
        elif(self.part=="neck"):
            self.image_files = sorted([fn for fn in os.listdir(self.neck_dir+"/"+self.phase+"/image") if fn.endswith("jpg")])
            self.label_files= sorted([lab for lab in os.listdir(self.neck_dir+"/"+self.phase+"/label") if lab.endswith("txt")])
        
        self.auged_img_list, self.auged_label_list=self.make_aug_list(self.image_files, self.label_files)
        
    def __getitem__(self,index):
        if(self.aug==None):
            filename, image = self.get_image(self.part, index)
            bboxes, class_ids = self.get_label(self.part, index)

            if(self.transformer):
                transformed_data=self.transformer(image=image, bboxes=bboxes, class_ids=class_ids)
                image = transformed_data['image']
                bboxes = np.array(transformed_data['bboxes'])
                class_ids = np.array(transformed_data['class_ids'])


            target = {}
    #         print(f'bboxes:{bboxes}\nclass_ids:{class_ids}\nlen_bboxes:{len(bboxes)}\nlen_class_ids:{len(class_ids)}')
    #         print(f'filename: {filename}')
            target["boxes"] = torch.Tensor(bboxes).float()
            target["labels"] = torch.Tensor(class_ids).long()

            ###
            bboxes=torch.Tensor(bboxes).float()
            class_ids=torch.Tensor(class_ids).long()
            target = np.concatenate((bboxes, class_ids[:, np.newaxis]), axis=1)
            ###
        else:
            image=self.auged_img_list[index][1]
            target=self.auged_label_list[index]
            filename=self.auged_img_list[index][0]
        return image, target, filename
    
    def __len__(self, ):
        length=0
        if(self.aug==None):
            length=len(self.image_files)
        else:
            length=len(self.auged_img_list)
        return length
    
    def make_aug_list(self,ori_image_list,ori_label_files):
        aug_image_list=[]
        aug_label_list=[]
        
        print(f"start making augmented images-- augmented factor:{self.aug_factor}")
        for i in range(len(ori_image_list)):
            filename, ori_image = self.get_image(self.part, i)
            ori_bboxes, ori_class_ids = self.get_label(self.part, i)
            for j in range(self.aug_factor):
                auged_data=self.aug(image=ori_image, bboxes=ori_bboxes, class_ids=ori_class_ids)
                image = auged_data['image']
                bboxes = np.array(auged_data['bboxes'])
                class_ids = np.array(auged_data['class_ids'])
                
                bboxes=torch.Tensor(bboxes).float()
                class_ids=torch.Tensor(class_ids).long()
                
                aug_image_list.append((filename, image))
                aug_label_list.append(np.concatenate((bboxes, class_ids[:, np.newaxis]), axis=1))
        
        print(f"total length of augmented images: {len(aug_image_list)}")
        
        return aug_image_list, aug_label_list
        
    
    def get_image(self, part, index): # 이미지 불러오는 함수
        filename = self.image_files[index]
        if(part=="body"):
#             print(f"body called!-> {self.part}")
            image_path = self.body_dir+"/"+self.phase+"/image/"+filename
        elif(part=="neck"):
#             print(f"neck called!-> {self.part}")
            image_path = self.neck_dir+"/"+self.phase+"/image/"+filename
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return filename, image
    
    def get_label(self, part, index): # label (box좌표, class_id) 불러오는 함수
        label_filename=self.label_files[index]
        if(part=="body"):
#             print(f"body label called!-> {self.part}")
            label_path = self.body_dir+"/"+self.phase+"/label/"+label_filename
        elif(part=="neck"):
#             print(f"neck label called!-> {self.part}")
            label_path = self.neck_dir+"/"+self.phase+"/label/"+label_filename
        with open(label_path, 'r') as file:
            labels = file.readlines()
        
        class_ids=[]
        bboxes=[]
        for label in labels:
            label=label.replace("\n", "")
            obj=label.split(' ')[0]
            coor=label.split(' ')[1:]
            obj=int(obj)
            coor=list(map(float, coor))
            class_ids.append(obj)
            bboxes.append(coor)
            
        return bboxes, class_ids
    

In [22]:
IMAGE_SIZE = 448

transformer = A.Compose([ 
        # bounding box의 변환, augmentation에서 albumentations는 Detection 학습을 할 때 굉장히 유용하다. 
        A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
        # albumentations 라이브러리에서는 Normalization을 먼저 진행해 주고 tensor화를 진행해 주어야한다.
    ],
    # box 위치에 대한 transformation도 함께 진행된다. 
    bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
)

augmentator=A.Compose([
#     A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
    A.HorizontalFlip(p=0.7),
#     A.Sharpen(p=0.7),
    A.BBoxSafeRandomCrop(p=0.6),
    A.VerticalFlip (p=0.5),
    A.HueSaturationValue(p=0.5),
    A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
    A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
)

def collate_fn(batch):
    image_list = []
    target_list = []
    filename_list = []
    
    for a,b,c in batch:
        image_list.append(a)
        target_list.append(b)
        filename_list.append(c)

    return torch.stack(image_list, dim=0), target_list, filename_list


## Model

In [23]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import init
# from .cbam import *
# from .bam import *

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

        if use_cbam:
            self.cbam = CBAM( planes, 16 )
        else:
            self.cbam = None

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        if not self.cbam is None:
            out = self.cbam(out)

        out += residual
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

        if use_cbam:
            self.cbam = CBAM( planes * 4, 16 )
        else:
            self.cbam = None

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        if not self.cbam is None:
            out = self.cbam(out)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers,  network_type, num_classes, att_type=None):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.network_type = network_type
        # different model config between ImageNet and CIFAR 
        if network_type == "ImageNet":
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            self.avgpool = nn.AvgPool2d(7)
        else:
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        if att_type=='BAM':
            self.bam1 = BAM(64*block.expansion)
            self.bam2 = BAM(128*block.expansion)
            self.bam3 = BAM(256*block.expansion)
        else:
            self.bam1, self.bam2, self.bam3 = None, None, None

        self.layer1 = self._make_layer(block, 64,  layers[0], att_type=att_type)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type)

        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        self.final_conv = nn.Conv2d(512, 12, kernel_size=1, stride=1, padding=0, bias=False)

        init.kaiming_normal(self.fc.weight)
        for key in self.state_dict():
            if key.split('.')[-1]=="weight":
                if "conv" in key:
                    init.kaiming_normal(self.state_dict()[key], mode='fan_out')
                if "bn" in key:
                    if "SpatialGate" in key:
                        self.state_dict()[key][...] = 0
                    else:
                        self.state_dict()[key][...] = 1
            elif key.split(".")[-1]=='bias':
                self.state_dict()[key][...] = 0

    def _make_layer(self, block, planes, blocks, stride=1, att_type=None):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type=='CBAM'))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_cbam=att_type=='CBAM'))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if self.network_type == "ImageNet":
            x = self.maxpool(x)

        x = self.layer1(x)
        if not self.bam1 is None:
            x = self.bam1(x)

        x = self.layer2(x)
        if not self.bam2 is None:
            x = self.bam2(x)

        x = self.layer3(x)
        if not self.bam3 is None:
            x = self.bam3(x)

        x = self.layer4(x)

#         if self.network_type == "ImageNet":
#             x = self.avgpool(x)
#         else:
        x = F.avg_pool2d(x, 2)
        x = self.final_conv(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc(x)
        return x

def ResidualNet(network_type, depth, num_classes, att_type):

    assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100"
    assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101'

    if depth == 18:
        model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type)

    elif depth == 34:
        model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type)

    elif depth == 50:
        model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type)

    elif depth == 101:
        model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type)

    return model

In [41]:
ckpt_path="/workspace/Plastic_Bottle_defect_detection/trained_model/YOLO_RESNET_CBAM_neck_LR0.0001_IP50_nonPretrain/model_100.pth"
state_dict_my = torch.load(ckpt_path)
print(state_dict_my.keys())

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.cbam.ChannelGate.mlp.1.weight', 'layer1.0.cbam.ChannelGate.mlp.1.bias', 'layer1.0.cbam.ChannelGate.mlp.3.weight', 'layer1.0.cbam.ChannelGate.mlp.3.bias', 'layer1.0.cbam.SpatialGate.spatial.conv.weight', 'layer1.0.cbam.SpatialGate.spatial.bn.weight', 'layer1.0.cbam.SpatialGate.spatial.bn.bias', 'layer1.0.cbam.SpatialGate.spatial.bn.running_mean', 'layer1.0.cbam.SpatialGate.spatial.bn.running_var', 'layer1.0.cbam.SpatialGate.spatial.bn.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.run

In [40]:
ckpt_path="/workspace/Plastic_Bottle_defect_detection/experiments/trained_model/RESNET50_CBAM_new_name_wrap.pth"
state_dict = torch.load(ckpt_path)
print(list(state_dict["state_dict"].keys()))

['module.conv1.weight', 'module.bn1.weight', 'module.bn1.bias', 'module.bn1.running_mean', 'module.bn1.running_var', 'module.layer1.0.conv1.weight', 'module.layer1.0.bn1.weight', 'module.layer1.0.bn1.bias', 'module.layer1.0.bn1.running_mean', 'module.layer1.0.bn1.running_var', 'module.layer1.0.conv2.weight', 'module.layer1.0.bn2.weight', 'module.layer1.0.bn2.bias', 'module.layer1.0.bn2.running_mean', 'module.layer1.0.bn2.running_var', 'module.layer1.0.conv3.weight', 'module.layer1.0.bn3.weight', 'module.layer1.0.bn3.bias', 'module.layer1.0.bn3.running_mean', 'module.layer1.0.bn3.running_var', 'module.layer1.0.downsample.0.weight', 'module.layer1.0.downsample.1.weight', 'module.layer1.0.downsample.1.bias', 'module.layer1.0.downsample.1.running_mean', 'module.layer1.0.downsample.1.running_var', 'module.layer1.0.cbam.ChannelGate.mlp.1.weight', 'module.layer1.0.cbam.ChannelGate.mlp.1.bias', 'module.layer1.0.cbam.ChannelGate.mlp.3.weight', 'module.layer1.0.cbam.ChannelGate.mlp.3.bias', 'mod

In [51]:
for layer in list(state_dict["state_dict"].keys()):
    name=layer.split('module.')[1]
#     print(name)
    if(name not in list(state_dict_my)):
        print(layer)
        

module.layer1.0.conv3.weight
module.layer1.0.bn3.weight
module.layer1.0.bn3.bias
module.layer1.0.bn3.running_mean
module.layer1.0.bn3.running_var
module.layer1.0.downsample.0.weight
module.layer1.0.downsample.1.weight
module.layer1.0.downsample.1.bias
module.layer1.0.downsample.1.running_mean
module.layer1.0.downsample.1.running_var
module.layer1.1.conv3.weight
module.layer1.1.bn3.weight
module.layer1.1.bn3.bias
module.layer1.1.bn3.running_mean
module.layer1.1.bn3.running_var
module.layer1.2.conv3.weight
module.layer1.2.bn3.weight
module.layer1.2.bn3.bias
module.layer1.2.bn3.running_mean
module.layer1.2.bn3.running_var
module.layer2.0.conv3.weight
module.layer2.0.bn3.weight
module.layer2.0.bn3.bias
module.layer2.0.bn3.running_mean
module.layer2.0.bn3.running_var
module.layer2.1.conv3.weight
module.layer2.1.bn3.weight
module.layer2.1.bn3.bias
module.layer2.1.bn3.running_mean
module.layer2.1.bn3.running_var
module.layer2.2.conv3.weight
module.layer2.2.bn3.weight
module.layer2.2.bn3.bias


In [26]:
def load_model(ckpt_path, num_classes, device):
    checkpoint = torch.load(ckpt_path, map_location=device)
#     model = YOLO_SWIN(num_classes=num_classes)
#     model = YOLO_SWIN(num_classes=num_classes)
    model = ResNet(BasicBlock, [3, 4, 6, 3], network_type="ImageNet", num_classes=NUM_CLASSES, att_type="CBAM")
    model.load_state_dict(checkpoint)
    model = model.to(device)
    return model

ckpt_path="/workspace/Plastic_Bottle_defect_detection/experiments/trained_model/RESNET50_CBAM_new_name_wrap.pth"
model = load_model(ckpt_path, NUM_CLASSES, device)

  init.kaiming_normal(self.fc.weight)
  init.kaiming_normal(self.state_dict()[key], mode='fan_out')


RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.cbam.ChannelGate.mlp.1.weight", "layer1.0.cbam.ChannelGate.mlp.1.bias", "layer1.0.cbam.ChannelGate.mlp.3.weight", "layer1.0.cbam.ChannelGate.mlp.3.bias", "layer1.0.cbam.SpatialGate.spatial.conv.weight", "layer1.0.cbam.SpatialGate.spatial.bn.weight", "layer1.0.cbam.SpatialGate.spatial.bn.bias", "layer1.0.cbam.SpatialGate.spatial.bn.running_mean", "layer1.0.cbam.SpatialGate.spatial.bn.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.cbam.ChannelGate.mlp.1.weight", "layer1.1.cbam.ChannelGate.mlp.1.bias", "layer1.1.cbam.ChannelGate.mlp.3.weight", "layer1.1.cbam.ChannelGate.mlp.3.bias", "layer1.1.cbam.SpatialGate.spatial.conv.weight", "layer1.1.cbam.SpatialGate.spatial.bn.weight", "layer1.1.cbam.SpatialGate.spatial.bn.bias", "layer1.1.cbam.SpatialGate.spatial.bn.running_mean", "layer1.1.cbam.SpatialGate.spatial.bn.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.cbam.ChannelGate.mlp.1.weight", "layer1.2.cbam.ChannelGate.mlp.1.bias", "layer1.2.cbam.ChannelGate.mlp.3.weight", "layer1.2.cbam.ChannelGate.mlp.3.bias", "layer1.2.cbam.SpatialGate.spatial.conv.weight", "layer1.2.cbam.SpatialGate.spatial.bn.weight", "layer1.2.cbam.SpatialGate.spatial.bn.bias", "layer1.2.cbam.SpatialGate.spatial.bn.running_mean", "layer1.2.cbam.SpatialGate.spatial.bn.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.0.cbam.ChannelGate.mlp.1.weight", "layer2.0.cbam.ChannelGate.mlp.1.bias", "layer2.0.cbam.ChannelGate.mlp.3.weight", "layer2.0.cbam.ChannelGate.mlp.3.bias", "layer2.0.cbam.SpatialGate.spatial.conv.weight", "layer2.0.cbam.SpatialGate.spatial.bn.weight", "layer2.0.cbam.SpatialGate.spatial.bn.bias", "layer2.0.cbam.SpatialGate.spatial.bn.running_mean", "layer2.0.cbam.SpatialGate.spatial.bn.running_var", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.cbam.ChannelGate.mlp.1.weight", "layer2.1.cbam.ChannelGate.mlp.1.bias", "layer2.1.cbam.ChannelGate.mlp.3.weight", "layer2.1.cbam.ChannelGate.mlp.3.bias", "layer2.1.cbam.SpatialGate.spatial.conv.weight", "layer2.1.cbam.SpatialGate.spatial.bn.weight", "layer2.1.cbam.SpatialGate.spatial.bn.bias", "layer2.1.cbam.SpatialGate.spatial.bn.running_mean", "layer2.1.cbam.SpatialGate.spatial.bn.running_var", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.cbam.ChannelGate.mlp.1.weight", "layer2.2.cbam.ChannelGate.mlp.1.bias", "layer2.2.cbam.ChannelGate.mlp.3.weight", "layer2.2.cbam.ChannelGate.mlp.3.bias", "layer2.2.cbam.SpatialGate.spatial.conv.weight", "layer2.2.cbam.SpatialGate.spatial.bn.weight", "layer2.2.cbam.SpatialGate.spatial.bn.bias", "layer2.2.cbam.SpatialGate.spatial.bn.running_mean", "layer2.2.cbam.SpatialGate.spatial.bn.running_var", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.cbam.ChannelGate.mlp.1.weight", "layer2.3.cbam.ChannelGate.mlp.1.bias", "layer2.3.cbam.ChannelGate.mlp.3.weight", "layer2.3.cbam.ChannelGate.mlp.3.bias", "layer2.3.cbam.SpatialGate.spatial.conv.weight", "layer2.3.cbam.SpatialGate.spatial.bn.weight", "layer2.3.cbam.SpatialGate.spatial.bn.bias", "layer2.3.cbam.SpatialGate.spatial.bn.running_mean", "layer2.3.cbam.SpatialGate.spatial.bn.running_var", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.0.cbam.ChannelGate.mlp.1.weight", "layer3.0.cbam.ChannelGate.mlp.1.bias", "layer3.0.cbam.ChannelGate.mlp.3.weight", "layer3.0.cbam.ChannelGate.mlp.3.bias", "layer3.0.cbam.SpatialGate.spatial.conv.weight", "layer3.0.cbam.SpatialGate.spatial.bn.weight", "layer3.0.cbam.SpatialGate.spatial.bn.bias", "layer3.0.cbam.SpatialGate.spatial.bn.running_mean", "layer3.0.cbam.SpatialGate.spatial.bn.running_var", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.cbam.ChannelGate.mlp.1.weight", "layer3.1.cbam.ChannelGate.mlp.1.bias", "layer3.1.cbam.ChannelGate.mlp.3.weight", "layer3.1.cbam.ChannelGate.mlp.3.bias", "layer3.1.cbam.SpatialGate.spatial.conv.weight", "layer3.1.cbam.SpatialGate.spatial.bn.weight", "layer3.1.cbam.SpatialGate.spatial.bn.bias", "layer3.1.cbam.SpatialGate.spatial.bn.running_mean", "layer3.1.cbam.SpatialGate.spatial.bn.running_var", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.cbam.ChannelGate.mlp.1.weight", "layer3.2.cbam.ChannelGate.mlp.1.bias", "layer3.2.cbam.ChannelGate.mlp.3.weight", "layer3.2.cbam.ChannelGate.mlp.3.bias", "layer3.2.cbam.SpatialGate.spatial.conv.weight", "layer3.2.cbam.SpatialGate.spatial.bn.weight", "layer3.2.cbam.SpatialGate.spatial.bn.bias", "layer3.2.cbam.SpatialGate.spatial.bn.running_mean", "layer3.2.cbam.SpatialGate.spatial.bn.running_var", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.cbam.ChannelGate.mlp.1.weight", "layer3.3.cbam.ChannelGate.mlp.1.bias", "layer3.3.cbam.ChannelGate.mlp.3.weight", "layer3.3.cbam.ChannelGate.mlp.3.bias", "layer3.3.cbam.SpatialGate.spatial.conv.weight", "layer3.3.cbam.SpatialGate.spatial.bn.weight", "layer3.3.cbam.SpatialGate.spatial.bn.bias", "layer3.3.cbam.SpatialGate.spatial.bn.running_mean", "layer3.3.cbam.SpatialGate.spatial.bn.running_var", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.cbam.ChannelGate.mlp.1.weight", "layer3.4.cbam.ChannelGate.mlp.1.bias", "layer3.4.cbam.ChannelGate.mlp.3.weight", "layer3.4.cbam.ChannelGate.mlp.3.bias", "layer3.4.cbam.SpatialGate.spatial.conv.weight", "layer3.4.cbam.SpatialGate.spatial.bn.weight", "layer3.4.cbam.SpatialGate.spatial.bn.bias", "layer3.4.cbam.SpatialGate.spatial.bn.running_mean", "layer3.4.cbam.SpatialGate.spatial.bn.running_var", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.cbam.ChannelGate.mlp.1.weight", "layer3.5.cbam.ChannelGate.mlp.1.bias", "layer3.5.cbam.ChannelGate.mlp.3.weight", "layer3.5.cbam.ChannelGate.mlp.3.bias", "layer3.5.cbam.SpatialGate.spatial.conv.weight", "layer3.5.cbam.SpatialGate.spatial.bn.weight", "layer3.5.cbam.SpatialGate.spatial.bn.bias", "layer3.5.cbam.SpatialGate.spatial.bn.running_mean", "layer3.5.cbam.SpatialGate.spatial.bn.running_var", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.cbam.ChannelGate.mlp.1.weight", "layer4.0.cbam.ChannelGate.mlp.1.bias", "layer4.0.cbam.ChannelGate.mlp.3.weight", "layer4.0.cbam.ChannelGate.mlp.3.bias", "layer4.0.cbam.SpatialGate.spatial.conv.weight", "layer4.0.cbam.SpatialGate.spatial.bn.weight", "layer4.0.cbam.SpatialGate.spatial.bn.bias", "layer4.0.cbam.SpatialGate.spatial.bn.running_mean", "layer4.0.cbam.SpatialGate.spatial.bn.running_var", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.cbam.ChannelGate.mlp.1.weight", "layer4.1.cbam.ChannelGate.mlp.1.bias", "layer4.1.cbam.ChannelGate.mlp.3.weight", "layer4.1.cbam.ChannelGate.mlp.3.bias", "layer4.1.cbam.SpatialGate.spatial.conv.weight", "layer4.1.cbam.SpatialGate.spatial.bn.weight", "layer4.1.cbam.SpatialGate.spatial.bn.bias", "layer4.1.cbam.SpatialGate.spatial.bn.running_mean", "layer4.1.cbam.SpatialGate.spatial.bn.running_var", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.cbam.ChannelGate.mlp.1.weight", "layer4.2.cbam.ChannelGate.mlp.1.bias", "layer4.2.cbam.ChannelGate.mlp.3.weight", "layer4.2.cbam.ChannelGate.mlp.3.bias", "layer4.2.cbam.SpatialGate.spatial.conv.weight", "layer4.2.cbam.SpatialGate.spatial.bn.weight", "layer4.2.cbam.SpatialGate.spatial.bn.bias", "layer4.2.cbam.SpatialGate.spatial.bn.running_mean", "layer4.2.cbam.SpatialGate.spatial.bn.running_var", "fc.weight", "fc.bias", "final_conv.weight". 
	Unexpected key(s) in state_dict: "epoch", "best_prec1", "state_dict". 

In [8]:
NUM_CLASSES = 2
# model = YOLO_SWIN(num_classes=NUM_CLASSES)
model = ResNet(BasicBlock, [3, 4, 6, 3], network_type="ImageNet", num_classes=NUM_CLASSES, att_type="CBAM")
model.to(device)

  init.kaiming_normal(self.fc.weight)
  init.kaiming_normal(self.state_dict()[key], mode='fan_out')


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (avgpool): AvgPool2d(kernel_size=7, stride=7, padding=0)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (cbam): CBAM(
        (ChannelGate): ChannelGate(
          (mlp): Sequential(
            (0): Flatten()
            (1): Linear(in_features=64, out_features=4, bias=True)
          

In [9]:
torchsummary.summary(model, (3,448,448))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]          36,864
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,864
       BatchNorm2d-9         [-1, 64, 112, 112]             128
          Flatten-10                   [-1, 64]               0
           Linear-11                    [-1, 4]             260
             ReLU-12                    [-1, 4]               0
           Linear-13                   [-1, 64]             320
          Flatten-14                   



In [10]:
x = torch.randn(1, 3, 448, 448).to(device)
with torch.no_grad():
    y = model(x)
print(y.shape)

torch.Size([1, 12, 7, 7])


## Loss func

In [11]:
class YOLO_LOSS():
    def __init__(self, num_classes, device, lambda_coord=5., lambda_noobj=0.5):
        self.num_classes = num_classes
        self.device = device
        self.grid_size = 7
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.mse_loss = nn.MSELoss(reduction="sum")

    def __call__(self, predictions, targets):
        self.batch_size, _, _, _ = predictions.shape
        groundtruths = self.build_batch_target_grid(targets)
        groundtruths = groundtruths.to(self.device)
        
        with torch.no_grad():
            iou1 = self.get_IoU(predictions[:, 1:5, ...], groundtruths[:, 1:5, ...])
            iou2 = self.get_IoU(predictions[:, 6:10, ...], groundtruths[:, 1:5, ...])

        ious = torch.stack([iou1, iou2], dim=1)
        max_iou, best_box = ious.max(dim=1, keepdim=True)
        max_iou = torch.cat([max_iou, max_iou], dim=1)
        best_box = torch.cat([best_box.eq(0), best_box.eq(1)], dim=1)

        predictions_ = predictions[:, :5*2, ...].reshape(self.batch_size, 2, 5, self.grid_size, self.grid_size)
        obj_pred = predictions_[:, :, 0, ...]
        xy_pred = predictions_[:, :, 1:3, ...]
        wh_pred = predictions_[:, :, 3:5, ...]
        cls_pred = predictions[:, 5*2:, ...]

        groundtruths_ = groundtruths[:, :5, ...].reshape(self.batch_size, 1, 5, self.grid_size, self.grid_size)
        obj_target = groundtruths_[:, :, 0, ...]
        xy_target = groundtruths_[:, :, 1:3, ...]
        wh_target= groundtruths_[:, :, 3:5, ...]
        cls_target = groundtruths[:, 5:, ...]
        
        positive = obj_target * best_box

        obj_loss = self.mse_loss(positive * obj_pred, positive * ious)
        noobj_loss = self.mse_loss((1 - positive) * obj_pred, ious*0)
        xy_loss = self.mse_loss(positive.unsqueeze(dim=2) * xy_pred, positive.unsqueeze(dim=2) * xy_target)
        wh_loss = self.mse_loss(positive.unsqueeze(dim=2) * (wh_pred.sign() * (wh_pred.abs() + 1e-8).sqrt()),
                           positive.unsqueeze(dim=2) * (wh_target + 1e-8).sqrt())
        cls_loss = self.mse_loss(obj_target * cls_pred, cls_target)
        
        obj_loss /= self.batch_size
        noobj_loss /= self.batch_size
        bbox_loss = (xy_loss+wh_loss) / self.batch_size
        cls_loss /= self.batch_size
        
        total_loss = obj_loss + self.lambda_noobj*noobj_loss + self.lambda_coord*bbox_loss + cls_loss
        return total_loss, (obj_loss.item(), noobj_loss.item(), bbox_loss.item(), cls_loss.item())
    
    def build_target_grid(self, target):
        target_grid = torch.zeros((1+4+self.num_classes, self.grid_size, self.grid_size), device=self.device)

        for gt in target:
            xc, yc, w, h, cls_id = gt
            xn = (xc % (1/self.grid_size))
            yn = (yc % (1/self.grid_size))
            cls_id = int(cls_id)

            i_grid = int(xc * self.grid_size)
            j_grid = int(yc * self.grid_size)
            target_grid[0, j_grid, i_grid] = 1
            target_grid[1:5, j_grid, i_grid] = torch.Tensor([xn,yn,w,h])
#             print(5+cls_id, j_grid, i_grid)
            target_grid[5+cls_id, j_grid, i_grid] = 1

        return target_grid
    
    def build_batch_target_grid(self, targets):
        target_grid_batch = torch.stack([self.build_target_grid(target) for target in targets], dim=0)
        return target_grid_batch
    
    def get_IoU(self, cbox1, cbox2):
        box1 = self.xywh_to_xyxy(cbox1)
        box2 = self.xywh_to_xyxy(cbox2)

        x1 = torch.max(box1[:, 0, ...], box2[:, 0, ...])
        y1 = torch.max(box1[:, 1, ...], box2[:, 1, ...])
        x2 = torch.min(box1[:, 2, ...], box2[:, 2, ...])
        y2 = torch.min(box1[:, 3, ...], box2[:, 3, ...])

        intersection = (x2-x1).clamp(min=0) * (y2-y1).clamp(min=0)
        union = abs(cbox1[:, 2, ...]*cbox1[:, 3, ...]) + \
                abs(cbox2[:, 2, ...]*cbox2[:, 3, ...]) - intersection

        intersection[intersection.gt(0)] = intersection[intersection.gt(0)] / union[intersection.gt(0)]
        return intersection
    
    def generate_xy_normed_grid(self):
        y_offset, x_offset = torch.meshgrid(torch.arange(self.grid_size), torch.arange(self.grid_size))
        xy_grid = torch.stack([x_offset, y_offset], dim=0)
        xy_normed_grid = xy_grid / self.grid_size
        return xy_normed_grid.to(self.device)

    def xywh_to_xyxy(self, bboxes):
        xy_normed_grid = self.generate_xy_normed_grid()
        xcyc = bboxes[:,0:2,...] + xy_normed_grid.tile(self.batch_size, 1,1,1)
        wh = bboxes[:,2:4,...]
        x1y1 = xcyc - (wh/2)
        x2y2 = xcyc + (wh/2)
        return torch.cat([x1y1, x2y2], dim=1)

## Train

In [12]:
def train_one_epoch(dataloaders, model, criterion, optimizer, device):
    train_loss = defaultdict(float)
    val_loss = defaultdict(float)
    
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()
        else:
            model.eval()
        
        running_loss = defaultdict(float)
        for index, batch in enumerate(dataloaders[phase]):
            images = batch[0].to(device)
            targets = batch[1]
            filenames = batch[2]
            
            with torch.set_grad_enabled(phase == "train"): # phase가 train 일때만 gradient 추적기능을 킨다.
                predictions = model(images) #prediction shape=> B,12,7,7
#             print(f"predictions:{predictions}, \ntargets: {targets}\n")
            loss, (obj_loss, noobj_loss, bbox_loss, cls_loss) = criterion(predictions, targets)
#             print(f"loss:{loss}, obj_loss:{obj_loss}, noobj_loss:{noobj_loss}\nbbox_loss:{bbox_loss}, cls_loss:{cls_loss}\n--------------\n")
            if phase == "train":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # 현재 epoch단계에서 loss가 얼마인지 running loss 가출력
                running_loss["total_loss"] += loss.item()
                running_loss["obj_loss"] += obj_loss
                running_loss["noobj_loss"] += noobj_loss
                running_loss["bbox_loss"] += bbox_loss
                running_loss["cls_loss"] += cls_loss
                
                train_loss["total_loss"] += loss.item()
                train_loss["obj_loss"] += obj_loss
                train_loss["noobj_loss"] += noobj_loss
                train_loss["bbox_loss"] += bbox_loss
                train_loss["cls_loss"] += cls_loss
                
                if (index > 0) and (index % VERBOSE_FREQ) == 0:
                    text = f"<<<iteration:[{index}/{len(dataloaders[phase])}] - "
                    for k, v in running_loss.items():
                        text += f"{k}: {v/VERBOSE_FREQ:.4f}  "
                        running_loss[k] = 0.
                    print(text)
            else:
                val_loss["total_loss"] += loss.item()
                val_loss["obj_loss"] += obj_loss
                val_loss["noobj_loss"] += noobj_loss
                val_loss["bbox_loss"] += bbox_loss
                val_loss["cls_loss"] += cls_loss

    for k in train_loss.keys():
        train_loss[k] /= len(dataloaders["train"])
        val_loss[k] /= len(dataloaders["val"])
    return train_loss, val_loss

In [13]:
def build_dataloader(part, NECK_PATH, BODY_PATH, batch_size=2, aug_factor=0):
    IMAGE_SIZE = 448
    transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ],
        bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
    )
    augmentator=A.Compose([
    #     A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        A.HorizontalFlip(p=0.7),
    #     A.Sharpen(p=0.7),
        A.BBoxSafeRandomCrop(p=0.6),
        A.VerticalFlip (p=0.6),
        A.HueSaturationValue(p=0.6),
        A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
        ],
        bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
    )
    
    dataloaders = {}
#     train_dataset = Detection_dataset(data_dir=data_dir, phase="train", transformer=transformer)
#     train_dataset=PET_dataset(part ,neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='train', transformer=transformer, aug=augmentator, aug_factor=aug_factor)
    train_dataset=PET_dataset(part ,neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='train', transformer=transformer, aug=None)
    dataloaders["train"] = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

#     val_dataset = Detection_dataset(data_dir=data_dir, phase="val", transformer=transformer)
    val_dataset=PET_dataset(part ,neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='valid', transformer=transformer, aug=None)
    dataloaders["val"] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    print(f"trainset:{len(train_dataset)} validset:{len(val_dataset)}")
    return dataloaders

In [15]:
# data_dir = "/content/drive/MyDrive/fastCamMedicalProj/DATASET/DATASET/Detection/"
# NECK_PATH = '/home/host_data/PET_data/Neck'
# BODY_PATH = '/home/host_data/PET_data/Body'
NECK_PATH = '/home/host_data/PET_data_image_patching/patched_Neck'
BODY_PATH = '/home/host_data/PET_data_image_patching/Body'
is_cuda = True

NUM_CLASSES = 2
IMAGE_SIZE = 448
BATCH_SIZE = 16
VERBOSE_FREQ = 20
LR=0.0001
AUG_FACTOR=0
PATCH_FACTOR=50
BACKBONE="YOLO_RESNET_CBAM"
PART="neck"
num_epochs = 100
# DEVICE = torch.device('cuda' if torch.cuda.is_available and is_cuda else 'cpu')

dataloaders = build_dataloader(part=PART,NECK_PATH=NECK_PATH,BODY_PATH=BODY_PATH,batch_size=BATCH_SIZE, aug_factor=AUG_FACTOR)
model = ResNet(BasicBlock, [3, 4, 6, 3], network_type="ImageNet", num_classes=NUM_CLASSES, att_type="CBAM")
model = model.to(device)
criterion = YOLO_LOSS(num_classes=NUM_CLASSES, device=device)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

start making augmented images-- augmented factor:0
total length of augmented images: 0
start making augmented images-- augmented factor:0
total length of augmented images: 0
trainset:10500 validset:1800


  init.kaiming_normal(self.fc.weight)
  init.kaiming_normal(self.state_dict()[key], mode='fan_out')


In [16]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="yolo_cbam_neck_IMAGE_PATCH",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": LR,
    "batch_size": BATCH_SIZE,
    "architecture": BACKBONE,
    "dataset": PART,
    "epochs": num_epochs,
    "patch factor":PATCH_FACTOR,
    "aug factor":AUG_FACTOR,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mgomduribo[0m ([33murp[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [17]:
best_epoch = 0
best_score = float('inf')
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    train_loss, val_loss = train_one_epoch(dataloaders, model, criterion, optimizer, device)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
#     train_loss["obj_loss"] += obj_loss
#     train_loss["noobj_loss"] += noobj_loss
#     train_loss["bbox_loss"] += bbox_loss
#     train_loss["cls_loss"] += cls_loss
    wandb.log({"Train Loss": train_loss['total_loss'],
               "Train obj Loss":train_loss["obj_loss"],
               "Train bbox Loss":train_loss["bbox_loss"],
               "Train class Loss":train_loss["cls_loss"],
               "Val Loss": val_loss['total_loss'],
               "Val obj Loss":val_loss["obj_loss"],
               "Val bbox Loss":val_loss["bbox_loss"],
               "Val class Loss":val_loss["cls_loss"],})
    print(f"\nepoch:{epoch+1}/{num_epochs} - Train Loss: {train_loss['total_loss']:.4f}, Val Loss: {val_loss['total_loss']:.4f}\n")
    
    if (epoch+1) % 10 == 0:
        save_model(model.state_dict(), f'model_{epoch+1}.pth', save_dir=f"./trained_model/{BACKBONE}_{PART}_LR{LR}_IP{PATCH_FACTOR}")
wandb.finish()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


<<<iteration:[20/657] - total_loss: 7403.3331  obj_loss: 406.4925  noobj_loss: 12789.6373  bbox_loss: 116.8576  cls_loss: 17.7339  
<<<iteration:[40/657] - total_loss: 69.6353  obj_loss: 1.2321  noobj_loss: 60.8392  bbox_loss: 6.4654  cls_loss: 5.6564  
<<<iteration:[60/657] - total_loss: 38.2591  obj_loss: 0.5847  noobj_loss: 36.3274  bbox_loss: 3.2739  cls_loss: 3.1411  
<<<iteration:[80/657] - total_loss: 30.2336  obj_loss: 0.5018  noobj_loss: 26.2221  bbox_loss: 2.8558  cls_loss: 2.3420  
<<<iteration:[100/657] - total_loss: 24.2520  obj_loss: 0.3655  noobj_loss: 20.9246  bbox_loss: 2.2766  cls_loss: 2.0411  
<<<iteration:[120/657] - total_loss: 18.9928  obj_loss: 0.3464  noobj_loss: 16.5720  bbox_loss: 1.7445  cls_loss: 1.6381  
<<<iteration:[140/657] - total_loss: 16.1883  obj_loss: 0.2349  noobj_loss: 13.9078  bbox_loss: 1.5359  cls_loss: 1.3201  
<<<iteration:[160/657] - total_loss: 15.8058  obj_loss: 0.1900  noobj_loss: 12.5080  bbox_loss: 1.5885  cls_loss: 1.4192  
<<<iterati

<<<iteration:[80/657] - total_loss: 2.5772  obj_loss: 0.0466  noobj_loss: 1.9511  bbox_loss: 0.2480  cls_loss: 0.3150  
<<<iteration:[100/657] - total_loss: 2.4204  obj_loss: 0.0513  noobj_loss: 1.9643  bbox_loss: 0.2186  cls_loss: 0.2939  
<<<iteration:[120/657] - total_loss: 2.5796  obj_loss: 0.0463  noobj_loss: 2.0003  bbox_loss: 0.2491  cls_loss: 0.2874  
<<<iteration:[140/657] - total_loss: 2.6342  obj_loss: 0.0444  noobj_loss: 1.9614  bbox_loss: 0.2626  cls_loss: 0.2960  
<<<iteration:[160/657] - total_loss: 2.4682  obj_loss: 0.0438  noobj_loss: 1.9363  bbox_loss: 0.2287  cls_loss: 0.3127  
<<<iteration:[180/657] - total_loss: 2.3412  obj_loss: 0.0476  noobj_loss: 1.8461  bbox_loss: 0.2218  cls_loss: 0.2614  
<<<iteration:[200/657] - total_loss: 2.3915  obj_loss: 0.0561  noobj_loss: 1.8890  bbox_loss: 0.2225  cls_loss: 0.2784  
<<<iteration:[220/657] - total_loss: 2.4210  obj_loss: 0.0441  noobj_loss: 1.9142  bbox_loss: 0.2250  cls_loss: 0.2947  
<<<iteration:[240/657] - total_lo

<<<iteration:[140/657] - total_loss: 1.5015  obj_loss: 0.0531  noobj_loss: 1.0478  bbox_loss: 0.1437  cls_loss: 0.2062  
<<<iteration:[160/657] - total_loss: 1.6813  obj_loss: 0.0442  noobj_loss: 1.1575  bbox_loss: 0.1641  cls_loss: 0.2380  
<<<iteration:[180/657] - total_loss: 1.5797  obj_loss: 0.0596  noobj_loss: 1.1354  bbox_loss: 0.1477  cls_loss: 0.2139  
<<<iteration:[200/657] - total_loss: 1.5209  obj_loss: 0.0503  noobj_loss: 1.0568  bbox_loss: 0.1439  cls_loss: 0.2224  
<<<iteration:[220/657] - total_loss: 1.6278  obj_loss: 0.0460  noobj_loss: 1.1452  bbox_loss: 0.1600  cls_loss: 0.2090  
<<<iteration:[240/657] - total_loss: 1.4528  obj_loss: 0.0534  noobj_loss: 1.0610  bbox_loss: 0.1334  cls_loss: 0.2021  
<<<iteration:[260/657] - total_loss: 1.4116  obj_loss: 0.0446  noobj_loss: 1.0193  bbox_loss: 0.1287  cls_loss: 0.2138  
<<<iteration:[280/657] - total_loss: 1.6034  obj_loss: 0.0408  noobj_loss: 1.0250  bbox_loss: 0.1688  cls_loss: 0.2064  
<<<iteration:[300/657] - total_l

<<<iteration:[200/657] - total_loss: 1.2063  obj_loss: 0.0529  noobj_loss: 0.7885  bbox_loss: 0.1035  cls_loss: 0.2418  
<<<iteration:[220/657] - total_loss: 1.2103  obj_loss: 0.0459  noobj_loss: 0.8160  bbox_loss: 0.1165  cls_loss: 0.1738  
<<<iteration:[240/657] - total_loss: 1.6244  obj_loss: 0.0452  noobj_loss: 0.7923  bbox_loss: 0.1990  cls_loss: 0.1878  
<<<iteration:[260/657] - total_loss: 1.1682  obj_loss: 0.0498  noobj_loss: 0.7447  bbox_loss: 0.1138  cls_loss: 0.1769  
<<<iteration:[280/657] - total_loss: 1.1578  obj_loss: 0.0476  noobj_loss: 0.8211  bbox_loss: 0.1043  cls_loss: 0.1783  
<<<iteration:[300/657] - total_loss: 1.1547  obj_loss: 0.0501  noobj_loss: 0.7910  bbox_loss: 0.1057  cls_loss: 0.1805  
<<<iteration:[320/657] - total_loss: 1.2357  obj_loss: 0.0560  noobj_loss: 0.8210  bbox_loss: 0.1141  cls_loss: 0.1987  
<<<iteration:[340/657] - total_loss: 1.1343  obj_loss: 0.0578  noobj_loss: 0.7124  bbox_loss: 0.1052  cls_loss: 0.1941  
<<<iteration:[360/657] - total_l

<<<iteration:[260/657] - total_loss: 0.9582  obj_loss: 0.0567  noobj_loss: 0.6212  bbox_loss: 0.0857  cls_loss: 0.1624  
<<<iteration:[280/657] - total_loss: 0.9985  obj_loss: 0.0559  noobj_loss: 0.6137  bbox_loss: 0.0998  cls_loss: 0.1365  
<<<iteration:[300/657] - total_loss: 0.9225  obj_loss: 0.0576  noobj_loss: 0.5808  bbox_loss: 0.0820  cls_loss: 0.1648  
<<<iteration:[320/657] - total_loss: 0.9972  obj_loss: 0.0474  noobj_loss: 0.6147  bbox_loss: 0.0914  cls_loss: 0.1854  
<<<iteration:[340/657] - total_loss: 0.9476  obj_loss: 0.0560  noobj_loss: 0.6102  bbox_loss: 0.0807  cls_loss: 0.1830  
<<<iteration:[360/657] - total_loss: 0.8603  obj_loss: 0.0584  noobj_loss: 0.5996  bbox_loss: 0.0691  cls_loss: 0.1565  
<<<iteration:[380/657] - total_loss: 0.9170  obj_loss: 0.0682  noobj_loss: 0.6140  bbox_loss: 0.0738  cls_loss: 0.1727  
<<<iteration:[400/657] - total_loss: 1.0022  obj_loss: 0.0576  noobj_loss: 0.6446  bbox_loss: 0.0901  cls_loss: 0.1720  
<<<iteration:[420/657] - total_l

<<<iteration:[320/657] - total_loss: 4.3482  obj_loss: 0.0479  noobj_loss: 0.8173  bbox_loss: 0.7425  cls_loss: 0.1792  
<<<iteration:[340/657] - total_loss: 2.9188  obj_loss: 0.0523  noobj_loss: 0.7377  bbox_loss: 0.4595  cls_loss: 0.2004  
<<<iteration:[360/657] - total_loss: 3.8083  obj_loss: 0.0542  noobj_loss: 0.7396  bbox_loss: 0.6355  cls_loss: 0.2070  
<<<iteration:[380/657] - total_loss: 3.6352  obj_loss: 0.0452  noobj_loss: 0.7277  bbox_loss: 0.6068  cls_loss: 0.1920  
<<<iteration:[400/657] - total_loss: 5.1156  obj_loss: 0.0483  noobj_loss: 0.7389  bbox_loss: 0.9025  cls_loss: 0.1854  
<<<iteration:[420/657] - total_loss: 5.7986  obj_loss: 0.0522  noobj_loss: 0.7156  bbox_loss: 1.0348  cls_loss: 0.2148  
<<<iteration:[440/657] - total_loss: 4.2478  obj_loss: 0.0497  noobj_loss: 0.7435  bbox_loss: 0.7296  cls_loss: 0.1783  
<<<iteration:[460/657] - total_loss: 4.3194  obj_loss: 0.0521  noobj_loss: 0.7280  bbox_loss: 0.7449  cls_loss: 0.1789  
<<<iteration:[480/657] - total_l

<<<iteration:[380/657] - total_loss: 2.9322  obj_loss: 0.0467  noobj_loss: 0.5738  bbox_loss: 0.4878  cls_loss: 0.1596  
<<<iteration:[400/657] - total_loss: 2.6774  obj_loss: 0.0664  noobj_loss: 0.5476  bbox_loss: 0.4329  cls_loss: 0.1729  
<<<iteration:[420/657] - total_loss: 3.2264  obj_loss: 0.0529  noobj_loss: 0.5761  bbox_loss: 0.5430  cls_loss: 0.1706  
<<<iteration:[440/657] - total_loss: 1.5767  obj_loss: 0.0510  noobj_loss: 0.5663  bbox_loss: 0.2170  cls_loss: 0.1575  
<<<iteration:[460/657] - total_loss: 2.7730  obj_loss: 0.0638  noobj_loss: 0.5006  bbox_loss: 0.4580  cls_loss: 0.1688  
<<<iteration:[480/657] - total_loss: 1.8936  obj_loss: 0.0672  noobj_loss: 0.5825  bbox_loss: 0.2789  cls_loss: 0.1409  
<<<iteration:[500/657] - total_loss: 1.6070  obj_loss: 0.0628  noobj_loss: 0.5228  bbox_loss: 0.2280  cls_loss: 0.1429  
<<<iteration:[520/657] - total_loss: 3.4337  obj_loss: 0.0615  noobj_loss: 0.5131  bbox_loss: 0.5922  cls_loss: 0.1546  
<<<iteration:[540/657] - total_l

<<<iteration:[440/657] - total_loss: 0.9796  obj_loss: 0.0608  noobj_loss: 0.4736  bbox_loss: 0.1049  cls_loss: 0.1574  
<<<iteration:[460/657] - total_loss: 1.4140  obj_loss: 0.0628  noobj_loss: 0.4624  bbox_loss: 0.1944  cls_loss: 0.1480  
<<<iteration:[480/657] - total_loss: 2.0024  obj_loss: 0.0717  noobj_loss: 0.5279  bbox_loss: 0.3070  cls_loss: 0.1317  
<<<iteration:[500/657] - total_loss: 1.7762  obj_loss: 0.0594  noobj_loss: 0.4957  bbox_loss: 0.2623  cls_loss: 0.1571  
<<<iteration:[520/657] - total_loss: 1.0988  obj_loss: 0.0576  noobj_loss: 0.4539  bbox_loss: 0.1308  cls_loss: 0.1604  
<<<iteration:[540/657] - total_loss: 1.6667  obj_loss: 0.0676  noobj_loss: 0.4709  bbox_loss: 0.2457  cls_loss: 0.1353  
<<<iteration:[560/657] - total_loss: 1.7977  obj_loss: 0.0554  noobj_loss: 0.4565  bbox_loss: 0.2706  cls_loss: 0.1609  
<<<iteration:[580/657] - total_loss: 1.1135  obj_loss: 0.0751  noobj_loss: 0.4732  bbox_loss: 0.1297  cls_loss: 0.1534  
<<<iteration:[600/657] - total_l

<<<iteration:[500/657] - total_loss: 1.0557  obj_loss: 0.0661  noobj_loss: 0.4854  bbox_loss: 0.1213  cls_loss: 0.1405  
<<<iteration:[520/657] - total_loss: 1.1454  obj_loss: 0.0693  noobj_loss: 0.4386  bbox_loss: 0.1426  cls_loss: 0.1436  
<<<iteration:[540/657] - total_loss: 1.0531  obj_loss: 0.0688  noobj_loss: 0.4343  bbox_loss: 0.1262  cls_loss: 0.1361  
<<<iteration:[560/657] - total_loss: 1.2126  obj_loss: 0.0608  noobj_loss: 0.4593  bbox_loss: 0.1526  cls_loss: 0.1593  
<<<iteration:[580/657] - total_loss: 0.8932  obj_loss: 0.0766  noobj_loss: 0.4030  bbox_loss: 0.0909  cls_loss: 0.1605  
<<<iteration:[600/657] - total_loss: 0.9621  obj_loss: 0.0715  noobj_loss: 0.4226  bbox_loss: 0.1048  cls_loss: 0.1552  
<<<iteration:[620/657] - total_loss: 0.9482  obj_loss: 0.0712  noobj_loss: 0.3985  bbox_loss: 0.1082  cls_loss: 0.1367  
<<<iteration:[640/657] - total_loss: 0.8863  obj_loss: 0.0620  noobj_loss: 0.4721  bbox_loss: 0.0848  cls_loss: 0.1642  

epoch:17/100 - Train Loss: 1.04

<<<iteration:[560/657] - total_loss: 1.0200  obj_loss: 0.0711  noobj_loss: 0.4017  bbox_loss: 0.1170  cls_loss: 0.1630  
<<<iteration:[580/657] - total_loss: 0.7015  obj_loss: 0.0721  noobj_loss: 0.3681  bbox_loss: 0.0619  cls_loss: 0.1359  
<<<iteration:[600/657] - total_loss: 0.7742  obj_loss: 0.0690  noobj_loss: 0.3945  bbox_loss: 0.0753  cls_loss: 0.1313  
<<<iteration:[620/657] - total_loss: 0.6702  obj_loss: 0.0637  noobj_loss: 0.3580  bbox_loss: 0.0603  cls_loss: 0.1259  
<<<iteration:[640/657] - total_loss: 0.6811  obj_loss: 0.0762  noobj_loss: 0.3520  bbox_loss: 0.0540  cls_loss: 0.1587  

epoch:19/100 - Train Loss: 0.8083, Val Loss: 0.6829

<<<iteration:[20/657] - total_loss: 0.7259  obj_loss: 0.0806  noobj_loss: 0.3775  bbox_loss: 0.0627  cls_loss: 0.1431  
<<<iteration:[40/657] - total_loss: 0.6859  obj_loss: 0.0757  noobj_loss: 0.3830  bbox_loss: 0.0605  cls_loss: 0.1162  
<<<iteration:[60/657] - total_loss: 0.6665  obj_loss: 0.0724  noobj_loss: 0.3779  bbox_loss: 0.0528  

<<<iteration:[620/657] - total_loss: 0.6221  obj_loss: 0.0749  noobj_loss: 0.3103  bbox_loss: 0.0501  cls_loss: 0.1415  
<<<iteration:[640/657] - total_loss: 0.6078  obj_loss: 0.0783  noobj_loss: 0.3182  bbox_loss: 0.0477  cls_loss: 0.1319  

epoch:21/100 - Train Loss: 0.7142, Val Loss: 0.6298

<<<iteration:[20/657] - total_loss: 0.6942  obj_loss: 0.0815  noobj_loss: 0.3455  bbox_loss: 0.0569  cls_loss: 0.1553  
<<<iteration:[40/657] - total_loss: 0.6138  obj_loss: 0.0764  noobj_loss: 0.3238  bbox_loss: 0.0493  cls_loss: 0.1290  
<<<iteration:[60/657] - total_loss: 0.9370  obj_loss: 0.0766  noobj_loss: 0.3200  bbox_loss: 0.1147  cls_loss: 0.1272  
<<<iteration:[80/657] - total_loss: 0.7216  obj_loss: 0.0649  noobj_loss: 0.3647  bbox_loss: 0.0705  cls_loss: 0.1221  
<<<iteration:[100/657] - total_loss: 0.6081  obj_loss: 0.0768  noobj_loss: 0.2956  bbox_loss: 0.0492  cls_loss: 0.1375  
<<<iteration:[120/657] - total_loss: 0.6100  obj_loss: 0.0721  noobj_loss: 0.3085  bbox_loss: 0.0478  c

<<<iteration:[40/657] - total_loss: 0.6024  obj_loss: 0.0692  noobj_loss: 0.3186  bbox_loss: 0.0468  cls_loss: 0.1398  
<<<iteration:[60/657] - total_loss: 0.5630  obj_loss: 0.0803  noobj_loss: 0.2826  bbox_loss: 0.0420  cls_loss: 0.1313  
<<<iteration:[80/657] - total_loss: 0.6580  obj_loss: 0.0817  noobj_loss: 0.3268  bbox_loss: 0.0569  cls_loss: 0.1283  
<<<iteration:[100/657] - total_loss: 0.6468  obj_loss: 0.0832  noobj_loss: 0.3027  bbox_loss: 0.0582  cls_loss: 0.1211  
<<<iteration:[120/657] - total_loss: 0.7837  obj_loss: 0.0744  noobj_loss: 0.3804  bbox_loss: 0.0794  cls_loss: 0.1223  
<<<iteration:[140/657] - total_loss: 0.5926  obj_loss: 0.0862  noobj_loss: 0.2996  bbox_loss: 0.0455  cls_loss: 0.1293  
<<<iteration:[160/657] - total_loss: 0.5876  obj_loss: 0.0771  noobj_loss: 0.2881  bbox_loss: 0.0479  cls_loss: 0.1270  
<<<iteration:[180/657] - total_loss: 0.5656  obj_loss: 0.0863  noobj_loss: 0.2719  bbox_loss: 0.0444  cls_loss: 0.1215  
<<<iteration:[200/657] - total_loss

<<<iteration:[100/657] - total_loss: 0.6120  obj_loss: 0.0852  noobj_loss: 0.2909  bbox_loss: 0.0495  cls_loss: 0.1338  
<<<iteration:[120/657] - total_loss: 0.6228  obj_loss: 0.0750  noobj_loss: 0.2724  bbox_loss: 0.0612  cls_loss: 0.1057  
<<<iteration:[140/657] - total_loss: 0.5103  obj_loss: 0.0741  noobj_loss: 0.2514  bbox_loss: 0.0396  cls_loss: 0.1124  
<<<iteration:[160/657] - total_loss: 0.5906  obj_loss: 0.0795  noobj_loss: 0.2618  bbox_loss: 0.0515  cls_loss: 0.1227  
<<<iteration:[180/657] - total_loss: 0.5498  obj_loss: 0.0935  noobj_loss: 0.2637  bbox_loss: 0.0420  cls_loss: 0.1143  
<<<iteration:[200/657] - total_loss: 0.5831  obj_loss: 0.0834  noobj_loss: 0.2724  bbox_loss: 0.0495  cls_loss: 0.1161  
<<<iteration:[220/657] - total_loss: 0.6256  obj_loss: 0.0764  noobj_loss: 0.2791  bbox_loss: 0.0590  cls_loss: 0.1146  
<<<iteration:[240/657] - total_loss: 0.5732  obj_loss: 0.0887  noobj_loss: 0.2985  bbox_loss: 0.0462  cls_loss: 0.1045  
<<<iteration:[260/657] - total_l

<<<iteration:[160/657] - total_loss: 0.5568  obj_loss: 0.0783  noobj_loss: 0.2346  bbox_loss: 0.0486  cls_loss: 0.1183  
<<<iteration:[180/657] - total_loss: 0.5888  obj_loss: 0.0844  noobj_loss: 0.2688  bbox_loss: 0.0533  cls_loss: 0.1038  
<<<iteration:[200/657] - total_loss: 0.5456  obj_loss: 0.0935  noobj_loss: 0.2557  bbox_loss: 0.0409  cls_loss: 0.1199  
<<<iteration:[220/657] - total_loss: 0.5246  obj_loss: 0.0834  noobj_loss: 0.2728  bbox_loss: 0.0406  cls_loss: 0.1017  
<<<iteration:[240/657] - total_loss: 0.5211  obj_loss: 0.0917  noobj_loss: 0.2459  bbox_loss: 0.0414  cls_loss: 0.0994  
<<<iteration:[260/657] - total_loss: 0.6024  obj_loss: 0.0741  noobj_loss: 0.2653  bbox_loss: 0.0555  cls_loss: 0.1181  
<<<iteration:[280/657] - total_loss: 0.5543  obj_loss: 0.0782  noobj_loss: 0.2865  bbox_loss: 0.0403  cls_loss: 0.1312  
<<<iteration:[300/657] - total_loss: 0.6464  obj_loss: 0.0790  noobj_loss: 0.3035  bbox_loss: 0.0614  cls_loss: 0.1086  
<<<iteration:[320/657] - total_l

<<<iteration:[220/657] - total_loss: 0.5151  obj_loss: 0.0858  noobj_loss: 0.2375  bbox_loss: 0.0400  cls_loss: 0.1106  
<<<iteration:[240/657] - total_loss: 0.7508  obj_loss: 0.0900  noobj_loss: 0.2429  bbox_loss: 0.0840  cls_loss: 0.1193  
<<<iteration:[260/657] - total_loss: 0.5095  obj_loss: 0.0806  noobj_loss: 0.2542  bbox_loss: 0.0376  cls_loss: 0.1136  
<<<iteration:[280/657] - total_loss: 0.5266  obj_loss: 0.0983  noobj_loss: 0.2302  bbox_loss: 0.0420  cls_loss: 0.1031  
<<<iteration:[300/657] - total_loss: 0.5007  obj_loss: 0.0892  noobj_loss: 0.2389  bbox_loss: 0.0390  cls_loss: 0.0973  
<<<iteration:[320/657] - total_loss: 0.5152  obj_loss: 0.0970  noobj_loss: 0.2245  bbox_loss: 0.0384  cls_loss: 0.1141  
<<<iteration:[340/657] - total_loss: 0.5502  obj_loss: 0.0873  noobj_loss: 0.2422  bbox_loss: 0.0451  cls_loss: 0.1162  
<<<iteration:[360/657] - total_loss: 0.5131  obj_loss: 0.0884  noobj_loss: 0.2309  bbox_loss: 0.0411  cls_loss: 0.1038  
<<<iteration:[380/657] - total_l

<<<iteration:[280/657] - total_loss: 0.4844  obj_loss: 0.0898  noobj_loss: 0.2233  bbox_loss: 0.0340  cls_loss: 0.1129  
<<<iteration:[300/657] - total_loss: 0.5423  obj_loss: 0.0868  noobj_loss: 0.2104  bbox_loss: 0.0456  cls_loss: 0.1221  
<<<iteration:[320/657] - total_loss: 0.4761  obj_loss: 0.0872  noobj_loss: 0.2296  bbox_loss: 0.0353  cls_loss: 0.0975  
<<<iteration:[340/657] - total_loss: 0.4880  obj_loss: 0.0908  noobj_loss: 0.2190  bbox_loss: 0.0348  cls_loss: 0.1137  
<<<iteration:[360/657] - total_loss: 0.5420  obj_loss: 0.0831  noobj_loss: 0.2446  bbox_loss: 0.0472  cls_loss: 0.1005  
<<<iteration:[380/657] - total_loss: 0.4768  obj_loss: 0.0873  noobj_loss: 0.2172  bbox_loss: 0.0336  cls_loss: 0.1128  
<<<iteration:[400/657] - total_loss: 0.6106  obj_loss: 0.0914  noobj_loss: 0.2423  bbox_loss: 0.0557  cls_loss: 0.1193  
<<<iteration:[420/657] - total_loss: 0.5041  obj_loss: 0.0886  noobj_loss: 0.2353  bbox_loss: 0.0398  cls_loss: 0.0987  
<<<iteration:[440/657] - total_l

<<<iteration:[340/657] - total_loss: 0.5445  obj_loss: 0.0893  noobj_loss: 0.2086  bbox_loss: 0.0475  cls_loss: 0.1131  
<<<iteration:[360/657] - total_loss: 0.5344  obj_loss: 0.0959  noobj_loss: 0.2126  bbox_loss: 0.0447  cls_loss: 0.1087  
<<<iteration:[380/657] - total_loss: 0.5508  obj_loss: 0.0948  noobj_loss: 0.2177  bbox_loss: 0.0438  cls_loss: 0.1282  
<<<iteration:[400/657] - total_loss: 0.4897  obj_loss: 0.0923  noobj_loss: 0.2200  bbox_loss: 0.0352  cls_loss: 0.1112  
<<<iteration:[420/657] - total_loss: 0.5245  obj_loss: 0.0906  noobj_loss: 0.2208  bbox_loss: 0.0474  cls_loss: 0.0866  
<<<iteration:[440/657] - total_loss: 0.5072  obj_loss: 0.1016  noobj_loss: 0.2138  bbox_loss: 0.0365  cls_loss: 0.1162  
<<<iteration:[460/657] - total_loss: 0.4882  obj_loss: 0.0956  noobj_loss: 0.2049  bbox_loss: 0.0375  cls_loss: 0.1027  
<<<iteration:[480/657] - total_loss: 0.4948  obj_loss: 0.0918  noobj_loss: 0.2268  bbox_loss: 0.0387  cls_loss: 0.0962  
<<<iteration:[500/657] - total_l

<<<iteration:[400/657] - total_loss: 0.4752  obj_loss: 0.0912  noobj_loss: 0.1998  bbox_loss: 0.0340  cls_loss: 0.1142  
<<<iteration:[420/657] - total_loss: 0.4891  obj_loss: 0.0863  noobj_loss: 0.2143  bbox_loss: 0.0373  cls_loss: 0.1091  
<<<iteration:[440/657] - total_loss: 0.5262  obj_loss: 0.0897  noobj_loss: 0.2098  bbox_loss: 0.0446  cls_loss: 0.1083  
<<<iteration:[460/657] - total_loss: 0.5195  obj_loss: 0.0928  noobj_loss: 0.2675  bbox_loss: 0.0374  cls_loss: 0.1057  
<<<iteration:[480/657] - total_loss: 0.4901  obj_loss: 0.0851  noobj_loss: 0.1972  bbox_loss: 0.0417  cls_loss: 0.0978  
<<<iteration:[500/657] - total_loss: 0.4638  obj_loss: 0.0878  noobj_loss: 0.2109  bbox_loss: 0.0332  cls_loss: 0.1047  
<<<iteration:[520/657] - total_loss: 0.4709  obj_loss: 0.0841  noobj_loss: 0.2042  bbox_loss: 0.0373  cls_loss: 0.0980  
<<<iteration:[540/657] - total_loss: 0.5103  obj_loss: 0.0854  noobj_loss: 0.2061  bbox_loss: 0.0425  cls_loss: 0.1092  
<<<iteration:[560/657] - total_l

<<<iteration:[460/657] - total_loss: 0.4508  obj_loss: 0.0882  noobj_loss: 0.2135  bbox_loss: 0.0318  cls_loss: 0.0968  
<<<iteration:[480/657] - total_loss: 0.4458  obj_loss: 0.1058  noobj_loss: 0.1853  bbox_loss: 0.0307  cls_loss: 0.0936  
<<<iteration:[500/657] - total_loss: 0.4775  obj_loss: 0.0916  noobj_loss: 0.1895  bbox_loss: 0.0359  cls_loss: 0.1117  
<<<iteration:[520/657] - total_loss: 0.4531  obj_loss: 0.0932  noobj_loss: 0.1911  bbox_loss: 0.0341  cls_loss: 0.0936  
<<<iteration:[540/657] - total_loss: 0.4667  obj_loss: 0.1046  noobj_loss: 0.1883  bbox_loss: 0.0314  cls_loss: 0.1111  
<<<iteration:[560/657] - total_loss: 0.4379  obj_loss: 0.0850  noobj_loss: 0.1987  bbox_loss: 0.0314  cls_loss: 0.0967  
<<<iteration:[580/657] - total_loss: 0.4724  obj_loss: 0.1016  noobj_loss: 0.1895  bbox_loss: 0.0375  cls_loss: 0.0883  
<<<iteration:[600/657] - total_loss: 0.4502  obj_loss: 0.1043  noobj_loss: 0.1957  bbox_loss: 0.0294  cls_loss: 0.1012  
<<<iteration:[620/657] - total_l

<<<iteration:[520/657] - total_loss: 0.4536  obj_loss: 0.1021  noobj_loss: 0.1805  bbox_loss: 0.0276  cls_loss: 0.1234  
<<<iteration:[540/657] - total_loss: 0.4500  obj_loss: 0.0948  noobj_loss: 0.1910  bbox_loss: 0.0325  cls_loss: 0.0970  
<<<iteration:[560/657] - total_loss: 0.4855  obj_loss: 0.0875  noobj_loss: 0.1998  bbox_loss: 0.0380  cls_loss: 0.1081  
<<<iteration:[580/657] - total_loss: 0.4787  obj_loss: 0.0942  noobj_loss: 0.1883  bbox_loss: 0.0372  cls_loss: 0.1042  
<<<iteration:[600/657] - total_loss: 0.4405  obj_loss: 0.0949  noobj_loss: 0.1847  bbox_loss: 0.0282  cls_loss: 0.1121  
<<<iteration:[620/657] - total_loss: 0.4324  obj_loss: 0.1010  noobj_loss: 0.1984  bbox_loss: 0.0287  cls_loss: 0.0886  
<<<iteration:[640/657] - total_loss: 0.4112  obj_loss: 0.0861  noobj_loss: 0.1788  bbox_loss: 0.0297  cls_loss: 0.0872  

epoch:40/100 - Train Loss: 0.4547, Val Loss: 0.4777

<<<iteration:[20/657] - total_loss: 0.4570  obj_loss: 0.1036  noobj_loss: 0.1933  bbox_loss: 0.0316

<<<iteration:[580/657] - total_loss: 0.4221  obj_loss: 0.0961  noobj_loss: 0.1680  bbox_loss: 0.0301  cls_loss: 0.0913  
<<<iteration:[600/657] - total_loss: 0.4674  obj_loss: 0.0769  noobj_loss: 0.1751  bbox_loss: 0.0384  cls_loss: 0.1110  
<<<iteration:[620/657] - total_loss: 0.4289  obj_loss: 0.0995  noobj_loss: 0.1813  bbox_loss: 0.0319  cls_loss: 0.0794  
<<<iteration:[640/657] - total_loss: 0.4492  obj_loss: 0.1009  noobj_loss: 0.1867  bbox_loss: 0.0321  cls_loss: 0.0946  

epoch:42/100 - Train Loss: 0.4489, Val Loss: 0.5004

<<<iteration:[20/657] - total_loss: 0.4889  obj_loss: 0.0990  noobj_loss: 0.1889  bbox_loss: 0.0376  cls_loss: 0.1078  
<<<iteration:[40/657] - total_loss: 0.4604  obj_loss: 0.1081  noobj_loss: 0.1745  bbox_loss: 0.0330  cls_loss: 0.1003  
<<<iteration:[60/657] - total_loss: 0.6674  obj_loss: 0.0862  noobj_loss: 0.1930  bbox_loss: 0.0749  cls_loss: 0.1103  
<<<iteration:[80/657] - total_loss: 0.5192  obj_loss: 0.0749  noobj_loss: 0.1770  bbox_loss: 0.0519  c

<<<iteration:[640/657] - total_loss: 0.4200  obj_loss: 0.0999  noobj_loss: 0.1873  bbox_loss: 0.0287  cls_loss: 0.0832  

epoch:44/100 - Train Loss: 0.4552, Val Loss: 0.4474

<<<iteration:[20/657] - total_loss: 0.5438  obj_loss: 0.0957  noobj_loss: 0.1837  bbox_loss: 0.0499  cls_loss: 0.1066  
<<<iteration:[40/657] - total_loss: 0.4286  obj_loss: 0.0970  noobj_loss: 0.1778  bbox_loss: 0.0299  cls_loss: 0.0935  
<<<iteration:[60/657] - total_loss: 0.4518  obj_loss: 0.1127  noobj_loss: 0.1757  bbox_loss: 0.0315  cls_loss: 0.0936  
<<<iteration:[80/657] - total_loss: 0.4725  obj_loss: 0.0984  noobj_loss: 0.1755  bbox_loss: 0.0405  cls_loss: 0.0837  
<<<iteration:[100/657] - total_loss: 0.4685  obj_loss: 0.1040  noobj_loss: 0.1784  bbox_loss: 0.0314  cls_loss: 0.1185  
<<<iteration:[120/657] - total_loss: 0.4247  obj_loss: 0.0931  noobj_loss: 0.1695  bbox_loss: 0.0301  cls_loss: 0.0963  
<<<iteration:[140/657] - total_loss: 0.3895  obj_loss: 0.0910  noobj_loss: 0.1782  bbox_loss: 0.0275  c

<<<iteration:[60/657] - total_loss: 0.4753  obj_loss: 0.0943  noobj_loss: 0.1710  bbox_loss: 0.0357  cls_loss: 0.1170  
<<<iteration:[80/657] - total_loss: 0.4410  obj_loss: 0.1005  noobj_loss: 0.1700  bbox_loss: 0.0329  cls_loss: 0.0908  
<<<iteration:[100/657] - total_loss: 0.4132  obj_loss: 0.1098  noobj_loss: 0.1690  bbox_loss: 0.0253  cls_loss: 0.0926  
<<<iteration:[120/657] - total_loss: 0.4038  obj_loss: 0.1049  noobj_loss: 0.1673  bbox_loss: 0.0276  cls_loss: 0.0775  
<<<iteration:[140/657] - total_loss: 0.4211  obj_loss: 0.0969  noobj_loss: 0.1709  bbox_loss: 0.0276  cls_loss: 0.1005  
<<<iteration:[160/657] - total_loss: 0.4224  obj_loss: 0.0952  noobj_loss: 0.1743  bbox_loss: 0.0303  cls_loss: 0.0888  
<<<iteration:[180/657] - total_loss: 0.4234  obj_loss: 0.1149  noobj_loss: 0.1836  bbox_loss: 0.0256  cls_loss: 0.0889  
<<<iteration:[200/657] - total_loss: 0.4307  obj_loss: 0.0988  noobj_loss: 0.1641  bbox_loss: 0.0285  cls_loss: 0.1075  
<<<iteration:[220/657] - total_los

<<<iteration:[120/657] - total_loss: 0.4285  obj_loss: 0.0982  noobj_loss: 0.1636  bbox_loss: 0.0307  cls_loss: 0.0952  
<<<iteration:[140/657] - total_loss: 0.4372  obj_loss: 0.1100  noobj_loss: 0.1800  bbox_loss: 0.0281  cls_loss: 0.0967  
<<<iteration:[160/657] - total_loss: 0.4010  obj_loss: 0.0815  noobj_loss: 0.1642  bbox_loss: 0.0291  cls_loss: 0.0920  
<<<iteration:[180/657] - total_loss: 0.4163  obj_loss: 0.1015  noobj_loss: 0.1655  bbox_loss: 0.0261  cls_loss: 0.1018  
<<<iteration:[200/657] - total_loss: 0.4044  obj_loss: 0.1101  noobj_loss: 0.1659  bbox_loss: 0.0246  cls_loss: 0.0884  
<<<iteration:[220/657] - total_loss: 0.4129  obj_loss: 0.1050  noobj_loss: 0.1678  bbox_loss: 0.0263  cls_loss: 0.0923  
<<<iteration:[240/657] - total_loss: 0.4291  obj_loss: 0.1057  noobj_loss: 0.1696  bbox_loss: 0.0278  cls_loss: 0.0997  
<<<iteration:[260/657] - total_loss: 0.4355  obj_loss: 0.0969  noobj_loss: 0.1665  bbox_loss: 0.0303  cls_loss: 0.1038  
<<<iteration:[280/657] - total_l

<<<iteration:[180/657] - total_loss: 0.3982  obj_loss: 0.1041  noobj_loss: 0.1646  bbox_loss: 0.0255  cls_loss: 0.0844  
<<<iteration:[200/657] - total_loss: 0.4257  obj_loss: 0.1032  noobj_loss: 0.1690  bbox_loss: 0.0298  cls_loss: 0.0891  
<<<iteration:[220/657] - total_loss: 0.3809  obj_loss: 0.0956  noobj_loss: 0.1626  bbox_loss: 0.0238  cls_loss: 0.0852  
<<<iteration:[240/657] - total_loss: 0.3799  obj_loss: 0.0985  noobj_loss: 0.1644  bbox_loss: 0.0261  cls_loss: 0.0686  
<<<iteration:[260/657] - total_loss: 0.4285  obj_loss: 0.1034  noobj_loss: 0.1727  bbox_loss: 0.0283  cls_loss: 0.0973  
<<<iteration:[280/657] - total_loss: 0.3904  obj_loss: 0.0961  noobj_loss: 0.1595  bbox_loss: 0.0274  cls_loss: 0.0777  
<<<iteration:[300/657] - total_loss: 0.4353  obj_loss: 0.0963  noobj_loss: 0.1615  bbox_loss: 0.0277  cls_loss: 0.1196  
<<<iteration:[320/657] - total_loss: 0.3833  obj_loss: 0.0947  noobj_loss: 0.1594  bbox_loss: 0.0270  cls_loss: 0.0741  
<<<iteration:[340/657] - total_l

<<<iteration:[240/657] - total_loss: 0.4031  obj_loss: 0.1058  noobj_loss: 0.1592  bbox_loss: 0.0275  cls_loss: 0.0801  
<<<iteration:[260/657] - total_loss: 0.3917  obj_loss: 0.1048  noobj_loss: 0.1537  bbox_loss: 0.0252  cls_loss: 0.0839  
<<<iteration:[280/657] - total_loss: 0.4130  obj_loss: 0.1020  noobj_loss: 0.1649  bbox_loss: 0.0278  cls_loss: 0.0895  
<<<iteration:[300/657] - total_loss: 0.3942  obj_loss: 0.0974  noobj_loss: 0.1577  bbox_loss: 0.0273  cls_loss: 0.0815  
<<<iteration:[320/657] - total_loss: 0.3760  obj_loss: 0.0981  noobj_loss: 0.1572  bbox_loss: 0.0251  cls_loss: 0.0737  
<<<iteration:[340/657] - total_loss: 0.4005  obj_loss: 0.1030  noobj_loss: 0.1538  bbox_loss: 0.0260  cls_loss: 0.0905  
<<<iteration:[360/657] - total_loss: 0.3841  obj_loss: 0.1003  noobj_loss: 0.1514  bbox_loss: 0.0249  cls_loss: 0.0838  
<<<iteration:[380/657] - total_loss: 0.4114  obj_loss: 0.1099  noobj_loss: 0.1648  bbox_loss: 0.0275  cls_loss: 0.0816  
<<<iteration:[400/657] - total_l

<<<iteration:[300/657] - total_loss: 0.3917  obj_loss: 0.1146  noobj_loss: 0.1539  bbox_loss: 0.0232  cls_loss: 0.0840  
<<<iteration:[320/657] - total_loss: 0.4301  obj_loss: 0.1105  noobj_loss: 0.1631  bbox_loss: 0.0301  cls_loss: 0.0876  
<<<iteration:[340/657] - total_loss: 0.3929  obj_loss: 0.0957  noobj_loss: 0.1561  bbox_loss: 0.0263  cls_loss: 0.0875  
<<<iteration:[360/657] - total_loss: 0.3947  obj_loss: 0.1050  noobj_loss: 0.1543  bbox_loss: 0.0251  cls_loss: 0.0869  
<<<iteration:[380/657] - total_loss: 0.3971  obj_loss: 0.1037  noobj_loss: 0.1501  bbox_loss: 0.0269  cls_loss: 0.0839  
<<<iteration:[400/657] - total_loss: 0.4298  obj_loss: 0.1187  noobj_loss: 0.1538  bbox_loss: 0.0280  cls_loss: 0.0939  
<<<iteration:[420/657] - total_loss: 0.3992  obj_loss: 0.0999  noobj_loss: 0.1520  bbox_loss: 0.0265  cls_loss: 0.0907  
<<<iteration:[440/657] - total_loss: 0.3761  obj_loss: 0.1043  noobj_loss: 0.1509  bbox_loss: 0.0251  cls_loss: 0.0706  
<<<iteration:[460/657] - total_l

<<<iteration:[360/657] - total_loss: 0.3777  obj_loss: 0.1077  noobj_loss: 0.1402  bbox_loss: 0.0240  cls_loss: 0.0801  
<<<iteration:[380/657] - total_loss: 0.3916  obj_loss: 0.1124  noobj_loss: 0.1533  bbox_loss: 0.0233  cls_loss: 0.0862  
<<<iteration:[400/657] - total_loss: 0.3810  obj_loss: 0.1038  noobj_loss: 0.1567  bbox_loss: 0.0240  cls_loss: 0.0787  
<<<iteration:[420/657] - total_loss: 0.3802  obj_loss: 0.1081  noobj_loss: 0.1537  bbox_loss: 0.0234  cls_loss: 0.0781  
<<<iteration:[440/657] - total_loss: 0.4027  obj_loss: 0.1100  noobj_loss: 0.1561  bbox_loss: 0.0263  cls_loss: 0.0832  
<<<iteration:[460/657] - total_loss: 0.4078  obj_loss: 0.1176  noobj_loss: 0.1538  bbox_loss: 0.0234  cls_loss: 0.0962  
<<<iteration:[480/657] - total_loss: 0.4026  obj_loss: 0.1151  noobj_loss: 0.1521  bbox_loss: 0.0242  cls_loss: 0.0903  
<<<iteration:[500/657] - total_loss: 0.3930  obj_loss: 0.1045  noobj_loss: 0.1521  bbox_loss: 0.0264  cls_loss: 0.0806  
<<<iteration:[520/657] - total_l

<<<iteration:[420/657] - total_loss: 0.4029  obj_loss: 0.1170  noobj_loss: 0.1470  bbox_loss: 0.0241  cls_loss: 0.0919  
<<<iteration:[440/657] - total_loss: 0.4031  obj_loss: 0.1118  noobj_loss: 0.1531  bbox_loss: 0.0264  cls_loss: 0.0826  
<<<iteration:[460/657] - total_loss: 0.3815  obj_loss: 0.1092  noobj_loss: 0.1468  bbox_loss: 0.0250  cls_loss: 0.0740  
<<<iteration:[480/657] - total_loss: 0.3967  obj_loss: 0.1137  noobj_loss: 0.1537  bbox_loss: 0.0239  cls_loss: 0.0868  
<<<iteration:[500/657] - total_loss: 0.3862  obj_loss: 0.0981  noobj_loss: 0.1419  bbox_loss: 0.0243  cls_loss: 0.0955  
<<<iteration:[520/657] - total_loss: 0.3979  obj_loss: 0.1114  noobj_loss: 0.1529  bbox_loss: 0.0253  cls_loss: 0.0837  
<<<iteration:[540/657] - total_loss: 0.3710  obj_loss: 0.1035  noobj_loss: 0.1464  bbox_loss: 0.0234  cls_loss: 0.0772  
<<<iteration:[560/657] - total_loss: 0.3956  obj_loss: 0.1140  noobj_loss: 0.1583  bbox_loss: 0.0235  cls_loss: 0.0850  
<<<iteration:[580/657] - total_l

<<<iteration:[480/657] - total_loss: 0.3751  obj_loss: 0.1113  noobj_loss: 0.1452  bbox_loss: 0.0213  cls_loss: 0.0845  
<<<iteration:[500/657] - total_loss: 0.3721  obj_loss: 0.0932  noobj_loss: 0.1474  bbox_loss: 0.0247  cls_loss: 0.0815  
<<<iteration:[520/657] - total_loss: 0.3729  obj_loss: 0.1141  noobj_loss: 0.1470  bbox_loss: 0.0227  cls_loss: 0.0716  
<<<iteration:[540/657] - total_loss: 0.3908  obj_loss: 0.1016  noobj_loss: 0.1488  bbox_loss: 0.0256  cls_loss: 0.0870  
<<<iteration:[560/657] - total_loss: 0.3869  obj_loss: 0.1047  noobj_loss: 0.1428  bbox_loss: 0.0248  cls_loss: 0.0867  
<<<iteration:[580/657] - total_loss: 0.3814  obj_loss: 0.1109  noobj_loss: 0.1433  bbox_loss: 0.0244  cls_loss: 0.0771  
<<<iteration:[600/657] - total_loss: 0.3951  obj_loss: 0.1117  noobj_loss: 0.1514  bbox_loss: 0.0247  cls_loss: 0.0843  
<<<iteration:[620/657] - total_loss: 0.3935  obj_loss: 0.1032  noobj_loss: 0.1475  bbox_loss: 0.0236  cls_loss: 0.0986  
<<<iteration:[640/657] - total_l

<<<iteration:[540/657] - total_loss: 0.7786  obj_loss: 0.0728  noobj_loss: 0.1379  bbox_loss: 0.1103  cls_loss: 0.0852  
<<<iteration:[560/657] - total_loss: 0.3984  obj_loss: 0.1093  noobj_loss: 0.1454  bbox_loss: 0.0269  cls_loss: 0.0817  
<<<iteration:[580/657] - total_loss: 0.3897  obj_loss: 0.1005  noobj_loss: 0.1411  bbox_loss: 0.0298  cls_loss: 0.0696  
<<<iteration:[600/657] - total_loss: 0.3758  obj_loss: 0.1016  noobj_loss: 0.1404  bbox_loss: 0.0231  cls_loss: 0.0885  
<<<iteration:[620/657] - total_loss: 0.4065  obj_loss: 0.0989  noobj_loss: 0.1390  bbox_loss: 0.0309  cls_loss: 0.0835  
<<<iteration:[640/657] - total_loss: 0.3771  obj_loss: 0.0996  noobj_loss: 0.1426  bbox_loss: 0.0254  cls_loss: 0.0792  

epoch:63/100 - Train Loss: 0.3969, Val Loss: 0.4467

<<<iteration:[20/657] - total_loss: 0.4034  obj_loss: 0.1131  noobj_loss: 0.1525  bbox_loss: 0.0266  cls_loss: 0.0812  
<<<iteration:[40/657] - total_loss: 0.3571  obj_loss: 0.1020  noobj_loss: 0.1463  bbox_loss: 0.0226 

<<<iteration:[600/657] - total_loss: 0.3535  obj_loss: 0.1143  noobj_loss: 0.1391  bbox_loss: 0.0215  cls_loss: 0.0623  
<<<iteration:[620/657] - total_loss: 0.3892  obj_loss: 0.1145  noobj_loss: 0.1380  bbox_loss: 0.0250  cls_loss: 0.0805  
<<<iteration:[640/657] - total_loss: 0.3814  obj_loss: 0.1152  noobj_loss: 0.1458  bbox_loss: 0.0224  cls_loss: 0.0813  

epoch:65/100 - Train Loss: 0.3802, Val Loss: 0.4113

<<<iteration:[20/657] - total_loss: 0.4214  obj_loss: 0.1253  noobj_loss: 0.1484  bbox_loss: 0.0240  cls_loss: 0.1019  
<<<iteration:[40/657] - total_loss: 0.3533  obj_loss: 0.1004  noobj_loss: 0.1453  bbox_loss: 0.0199  cls_loss: 0.0807  
<<<iteration:[60/657] - total_loss: 0.3767  obj_loss: 0.1062  noobj_loss: 0.1461  bbox_loss: 0.0232  cls_loss: 0.0813  
<<<iteration:[80/657] - total_loss: 0.3795  obj_loss: 0.1076  noobj_loss: 0.1449  bbox_loss: 0.0245  cls_loss: 0.0772  
<<<iteration:[100/657] - total_loss: 0.3861  obj_loss: 0.1102  noobj_loss: 0.1400  bbox_loss: 0.0253  c


epoch:67/100 - Train Loss: 0.3752, Val Loss: 0.5118

<<<iteration:[20/657] - total_loss: 0.4334  obj_loss: 0.1119  noobj_loss: 0.1437  bbox_loss: 0.0337  cls_loss: 0.0811  
<<<iteration:[40/657] - total_loss: 0.4055  obj_loss: 0.1048  noobj_loss: 0.1420  bbox_loss: 0.0308  cls_loss: 0.0755  
<<<iteration:[60/657] - total_loss: 0.3929  obj_loss: 0.1130  noobj_loss: 0.1407  bbox_loss: 0.0241  cls_loss: 0.0891  
<<<iteration:[80/657] - total_loss: 0.3842  obj_loss: 0.1063  noobj_loss: 0.1499  bbox_loss: 0.0266  cls_loss: 0.0701  
<<<iteration:[100/657] - total_loss: 0.3643  obj_loss: 0.1188  noobj_loss: 0.1379  bbox_loss: 0.0201  cls_loss: 0.0759  
<<<iteration:[120/657] - total_loss: 0.3819  obj_loss: 0.1106  noobj_loss: 0.1426  bbox_loss: 0.0235  cls_loss: 0.0826  
<<<iteration:[140/657] - total_loss: 0.3791  obj_loss: 0.1128  noobj_loss: 0.1415  bbox_loss: 0.0217  cls_loss: 0.0870  
<<<iteration:[160/657] - total_loss: 0.3535  obj_loss: 0.1029  noobj_loss: 0.1453  bbox_loss: 0.0217  c

<<<iteration:[80/657] - total_loss: 0.3990  obj_loss: 0.0992  noobj_loss: 0.1476  bbox_loss: 0.0273  cls_loss: 0.0895  
<<<iteration:[100/657] - total_loss: 0.3730  obj_loss: 0.1192  noobj_loss: 0.1366  bbox_loss: 0.0221  cls_loss: 0.0752  
<<<iteration:[120/657] - total_loss: 0.3506  obj_loss: 0.0993  noobj_loss: 0.1329  bbox_loss: 0.0220  cls_loss: 0.0748  
<<<iteration:[140/657] - total_loss: 0.3914  obj_loss: 0.1202  noobj_loss: 0.1318  bbox_loss: 0.0231  cls_loss: 0.0900  
<<<iteration:[160/657] - total_loss: 0.3655  obj_loss: 0.1158  noobj_loss: 0.1367  bbox_loss: 0.0213  cls_loss: 0.0748  
<<<iteration:[180/657] - total_loss: 0.3524  obj_loss: 0.1069  noobj_loss: 0.1418  bbox_loss: 0.0207  cls_loss: 0.0713  
<<<iteration:[200/657] - total_loss: 0.3660  obj_loss: 0.1066  noobj_loss: 0.1374  bbox_loss: 0.0219  cls_loss: 0.0813  
<<<iteration:[220/657] - total_loss: 0.3685  obj_loss: 0.1093  noobj_loss: 0.1378  bbox_loss: 0.0201  cls_loss: 0.0897  
<<<iteration:[240/657] - total_lo

<<<iteration:[140/657] - total_loss: 0.3682  obj_loss: 0.1120  noobj_loss: 0.1347  bbox_loss: 0.0232  cls_loss: 0.0730  
<<<iteration:[160/657] - total_loss: 0.3941  obj_loss: 0.1098  noobj_loss: 0.1393  bbox_loss: 0.0253  cls_loss: 0.0884  
<<<iteration:[180/657] - total_loss: 0.3796  obj_loss: 0.1101  noobj_loss: 0.1381  bbox_loss: 0.0229  cls_loss: 0.0859  
<<<iteration:[200/657] - total_loss: 0.3907  obj_loss: 0.1096  noobj_loss: 0.1361  bbox_loss: 0.0242  cls_loss: 0.0922  
<<<iteration:[220/657] - total_loss: 0.3441  obj_loss: 0.1055  noobj_loss: 0.1424  bbox_loss: 0.0207  cls_loss: 0.0641  
<<<iteration:[240/657] - total_loss: 0.3811  obj_loss: 0.1130  noobj_loss: 0.1314  bbox_loss: 0.0241  cls_loss: 0.0821  
<<<iteration:[260/657] - total_loss: 0.3901  obj_loss: 0.1162  noobj_loss: 0.1350  bbox_loss: 0.0240  cls_loss: 0.0864  
<<<iteration:[280/657] - total_loss: 0.3653  obj_loss: 0.1058  noobj_loss: 0.1374  bbox_loss: 0.0249  cls_loss: 0.0661  
<<<iteration:[300/657] - total_l

<<<iteration:[200/657] - total_loss: 0.3665  obj_loss: 0.1034  noobj_loss: 0.1353  bbox_loss: 0.0240  cls_loss: 0.0755  
<<<iteration:[220/657] - total_loss: 0.4051  obj_loss: 0.1195  noobj_loss: 0.1502  bbox_loss: 0.0282  cls_loss: 0.0694  
<<<iteration:[240/657] - total_loss: 0.4195  obj_loss: 0.0978  noobj_loss: 0.1356  bbox_loss: 0.0356  cls_loss: 0.0760  
<<<iteration:[260/657] - total_loss: 0.3726  obj_loss: 0.1022  noobj_loss: 0.1323  bbox_loss: 0.0284  cls_loss: 0.0622  
<<<iteration:[280/657] - total_loss: 0.3811  obj_loss: 0.1096  noobj_loss: 0.1326  bbox_loss: 0.0254  cls_loss: 0.0780  
<<<iteration:[300/657] - total_loss: 0.3674  obj_loss: 0.1016  noobj_loss: 0.1389  bbox_loss: 0.0271  cls_loss: 0.0607  
<<<iteration:[320/657] - total_loss: 0.3894  obj_loss: 0.1079  noobj_loss: 0.1347  bbox_loss: 0.0291  cls_loss: 0.0688  
<<<iteration:[340/657] - total_loss: 0.3767  obj_loss: 0.1126  noobj_loss: 0.1318  bbox_loss: 0.0263  cls_loss: 0.0666  
<<<iteration:[360/657] - total_l

<<<iteration:[260/657] - total_loss: 0.3610  obj_loss: 0.1136  noobj_loss: 0.1302  bbox_loss: 0.0209  cls_loss: 0.0777  
<<<iteration:[280/657] - total_loss: 0.3762  obj_loss: 0.1076  noobj_loss: 0.1404  bbox_loss: 0.0223  cls_loss: 0.0868  
<<<iteration:[300/657] - total_loss: 0.3482  obj_loss: 0.1104  noobj_loss: 0.1318  bbox_loss: 0.0208  cls_loss: 0.0678  
<<<iteration:[320/657] - total_loss: 0.3378  obj_loss: 0.1084  noobj_loss: 0.1341  bbox_loss: 0.0200  cls_loss: 0.0625  
<<<iteration:[340/657] - total_loss: 0.3852  obj_loss: 0.1100  noobj_loss: 0.1341  bbox_loss: 0.0253  cls_loss: 0.0816  
<<<iteration:[360/657] - total_loss: 0.3621  obj_loss: 0.1182  noobj_loss: 0.1367  bbox_loss: 0.0225  cls_loss: 0.0632  
<<<iteration:[380/657] - total_loss: 0.3778  obj_loss: 0.1301  noobj_loss: 0.1392  bbox_loss: 0.0223  cls_loss: 0.0664  
<<<iteration:[400/657] - total_loss: 0.3407  obj_loss: 0.1198  noobj_loss: 0.1343  bbox_loss: 0.0183  cls_loss: 0.0624  
<<<iteration:[420/657] - total_l

<<<iteration:[320/657] - total_loss: 0.3555  obj_loss: 0.1128  noobj_loss: 0.1324  bbox_loss: 0.0213  cls_loss: 0.0701  
<<<iteration:[340/657] - total_loss: 0.3502  obj_loss: 0.1140  noobj_loss: 0.1372  bbox_loss: 0.0212  cls_loss: 0.0617  
<<<iteration:[360/657] - total_loss: 0.3495  obj_loss: 0.1021  noobj_loss: 0.1371  bbox_loss: 0.0214  cls_loss: 0.0719  
<<<iteration:[380/657] - total_loss: 0.3999  obj_loss: 0.1049  noobj_loss: 0.1260  bbox_loss: 0.0310  cls_loss: 0.0773  
<<<iteration:[400/657] - total_loss: 0.3729  obj_loss: 0.1069  noobj_loss: 0.1360  bbox_loss: 0.0263  cls_loss: 0.0666  
<<<iteration:[420/657] - total_loss: 0.3748  obj_loss: 0.1008  noobj_loss: 0.1294  bbox_loss: 0.0281  cls_loss: 0.0689  
<<<iteration:[440/657] - total_loss: 0.3809  obj_loss: 0.1169  noobj_loss: 0.1318  bbox_loss: 0.0224  cls_loss: 0.0861  
<<<iteration:[460/657] - total_loss: 0.3642  obj_loss: 0.1025  noobj_loss: 0.1337  bbox_loss: 0.0235  cls_loss: 0.0774  
<<<iteration:[480/657] - total_l

<<<iteration:[380/657] - total_loss: 0.3586  obj_loss: 0.1127  noobj_loss: 0.1314  bbox_loss: 0.0232  cls_loss: 0.0641  
<<<iteration:[400/657] - total_loss: 0.3645  obj_loss: 0.1253  noobj_loss: 0.1371  bbox_loss: 0.0212  cls_loss: 0.0648  
<<<iteration:[420/657] - total_loss: 0.3652  obj_loss: 0.1164  noobj_loss: 0.1297  bbox_loss: 0.0197  cls_loss: 0.0854  
<<<iteration:[440/657] - total_loss: 0.3438  obj_loss: 0.1091  noobj_loss: 0.1346  bbox_loss: 0.0219  cls_loss: 0.0579  
<<<iteration:[460/657] - total_loss: 0.3725  obj_loss: 0.1167  noobj_loss: 0.1336  bbox_loss: 0.0215  cls_loss: 0.0814  
<<<iteration:[480/657] - total_loss: 0.3771  obj_loss: 0.1162  noobj_loss: 0.1359  bbox_loss: 0.0236  cls_loss: 0.0751  
<<<iteration:[500/657] - total_loss: 0.3840  obj_loss: 0.1010  noobj_loss: 0.1276  bbox_loss: 0.0316  cls_loss: 0.0614  
<<<iteration:[520/657] - total_loss: 0.3690  obj_loss: 0.1185  noobj_loss: 0.1292  bbox_loss: 0.0242  cls_loss: 0.0648  
<<<iteration:[540/657] - total_l

<<<iteration:[440/657] - total_loss: 0.3545  obj_loss: 0.1072  noobj_loss: 0.1348  bbox_loss: 0.0242  cls_loss: 0.0589  
<<<iteration:[460/657] - total_loss: 0.3510  obj_loss: 0.1129  noobj_loss: 0.1340  bbox_loss: 0.0208  cls_loss: 0.0668  
<<<iteration:[480/657] - total_loss: 0.3635  obj_loss: 0.1125  noobj_loss: 0.1287  bbox_loss: 0.0218  cls_loss: 0.0776  
<<<iteration:[500/657] - total_loss: 0.3363  obj_loss: 0.1003  noobj_loss: 0.1280  bbox_loss: 0.0208  cls_loss: 0.0681  
<<<iteration:[520/657] - total_loss: 0.3379  obj_loss: 0.1273  noobj_loss: 0.1304  bbox_loss: 0.0168  cls_loss: 0.0616  
<<<iteration:[540/657] - total_loss: 0.3553  obj_loss: 0.1227  noobj_loss: 0.1314  bbox_loss: 0.0205  cls_loss: 0.0646  
<<<iteration:[560/657] - total_loss: 0.3520  obj_loss: 0.1085  noobj_loss: 0.1294  bbox_loss: 0.0227  cls_loss: 0.0653  
<<<iteration:[580/657] - total_loss: 0.3620  obj_loss: 0.1042  noobj_loss: 0.1295  bbox_loss: 0.0248  cls_loss: 0.0691  
<<<iteration:[600/657] - total_l

<<<iteration:[500/657] - total_loss: 0.3414  obj_loss: 0.1111  noobj_loss: 0.1285  bbox_loss: 0.0201  cls_loss: 0.0654  
<<<iteration:[520/657] - total_loss: 0.3384  obj_loss: 0.1219  noobj_loss: 0.1230  bbox_loss: 0.0186  cls_loss: 0.0617  
<<<iteration:[540/657] - total_loss: 0.3551  obj_loss: 0.1232  noobj_loss: 0.1324  bbox_loss: 0.0206  cls_loss: 0.0630  
<<<iteration:[560/657] - total_loss: 0.3442  obj_loss: 0.1099  noobj_loss: 0.1323  bbox_loss: 0.0219  cls_loss: 0.0586  
<<<iteration:[580/657] - total_loss: 0.3633  obj_loss: 0.1188  noobj_loss: 0.1346  bbox_loss: 0.0204  cls_loss: 0.0751  
<<<iteration:[600/657] - total_loss: 0.3784  obj_loss: 0.1101  noobj_loss: 0.1282  bbox_loss: 0.0280  cls_loss: 0.0641  
<<<iteration:[620/657] - total_loss: 0.3364  obj_loss: 0.1045  noobj_loss: 0.1252  bbox_loss: 0.0206  cls_loss: 0.0665  
<<<iteration:[640/657] - total_loss: 0.3482  obj_loss: 0.1106  noobj_loss: 0.1295  bbox_loss: 0.0204  cls_loss: 0.0711  

epoch:84/100 - Train Loss: 0.35

<<<iteration:[560/657] - total_loss: 0.3567  obj_loss: 0.1165  noobj_loss: 0.1297  bbox_loss: 0.0209  cls_loss: 0.0710  
<<<iteration:[580/657] - total_loss: 0.3582  obj_loss: 0.1125  noobj_loss: 0.1274  bbox_loss: 0.0202  cls_loss: 0.0808  
<<<iteration:[600/657] - total_loss: 0.3270  obj_loss: 0.1111  noobj_loss: 0.1264  bbox_loss: 0.0194  cls_loss: 0.0558  
<<<iteration:[620/657] - total_loss: 0.3355  obj_loss: 0.1082  noobj_loss: 0.1304  bbox_loss: 0.0201  cls_loss: 0.0616  
<<<iteration:[640/657] - total_loss: 0.3235  obj_loss: 0.1060  noobj_loss: 0.1262  bbox_loss: 0.0197  cls_loss: 0.0557  

epoch:86/100 - Train Loss: 0.3507, Val Loss: 0.3982

<<<iteration:[20/657] - total_loss: 0.3742  obj_loss: 0.1331  noobj_loss: 0.1328  bbox_loss: 0.0198  cls_loss: 0.0759  
<<<iteration:[40/657] - total_loss: 0.3740  obj_loss: 0.1249  noobj_loss: 0.1257  bbox_loss: 0.0209  cls_loss: 0.0819  
<<<iteration:[60/657] - total_loss: 0.3437  obj_loss: 0.1119  noobj_loss: 0.1319  bbox_loss: 0.0198  

<<<iteration:[620/657] - total_loss: 0.3263  obj_loss: 0.1160  noobj_loss: 0.1290  bbox_loss: 0.0187  cls_loss: 0.0521  
<<<iteration:[640/657] - total_loss: 0.3627  obj_loss: 0.1067  noobj_loss: 0.1308  bbox_loss: 0.0236  cls_loss: 0.0724  

epoch:88/100 - Train Loss: 0.3495, Val Loss: 0.3840

<<<iteration:[20/657] - total_loss: 0.4032  obj_loss: 0.1171  noobj_loss: 0.1428  bbox_loss: 0.0270  cls_loss: 0.0796  
<<<iteration:[40/657] - total_loss: 0.3568  obj_loss: 0.1177  noobj_loss: 0.1237  bbox_loss: 0.0210  cls_loss: 0.0724  
<<<iteration:[60/657] - total_loss: 0.3590  obj_loss: 0.1317  noobj_loss: 0.1274  bbox_loss: 0.0187  cls_loss: 0.0703  
<<<iteration:[80/657] - total_loss: 0.3473  obj_loss: 0.1203  noobj_loss: 0.1222  bbox_loss: 0.0194  cls_loss: 0.0686  
<<<iteration:[100/657] - total_loss: 0.3334  obj_loss: 0.1189  noobj_loss: 0.1222  bbox_loss: 0.0181  cls_loss: 0.0631  
<<<iteration:[120/657] - total_loss: 0.3342  obj_loss: 0.1117  noobj_loss: 0.1233  bbox_loss: 0.0181  c

<<<iteration:[40/657] - total_loss: 0.3429  obj_loss: 0.1180  noobj_loss: 0.1293  bbox_loss: 0.0193  cls_loss: 0.0635  
<<<iteration:[60/657] - total_loss: 0.3493  obj_loss: 0.1208  noobj_loss: 0.1283  bbox_loss: 0.0196  cls_loss: 0.0666  
<<<iteration:[80/657] - total_loss: 0.3326  obj_loss: 0.1262  noobj_loss: 0.1253  bbox_loss: 0.0186  cls_loss: 0.0506  
<<<iteration:[100/657] - total_loss: 0.3368  obj_loss: 0.1164  noobj_loss: 0.1200  bbox_loss: 0.0181  cls_loss: 0.0701  
<<<iteration:[120/657] - total_loss: 0.3389  obj_loss: 0.1027  noobj_loss: 0.1238  bbox_loss: 0.0216  cls_loss: 0.0662  
<<<iteration:[140/657] - total_loss: 0.3546  obj_loss: 0.1215  noobj_loss: 0.1248  bbox_loss: 0.0192  cls_loss: 0.0745  
<<<iteration:[160/657] - total_loss: 0.3439  obj_loss: 0.1157  noobj_loss: 0.1225  bbox_loss: 0.0184  cls_loss: 0.0750  
<<<iteration:[180/657] - total_loss: 0.3250  obj_loss: 0.1274  noobj_loss: 0.1221  bbox_loss: 0.0163  cls_loss: 0.0551  
<<<iteration:[200/657] - total_loss

<<<iteration:[100/657] - total_loss: 0.3555  obj_loss: 0.1217  noobj_loss: 0.1292  bbox_loss: 0.0215  cls_loss: 0.0619  
<<<iteration:[120/657] - total_loss: 0.3488  obj_loss: 0.1186  noobj_loss: 0.1216  bbox_loss: 0.0184  cls_loss: 0.0774  
<<<iteration:[140/657] - total_loss: 0.3521  obj_loss: 0.1267  noobj_loss: 0.1247  bbox_loss: 0.0200  cls_loss: 0.0628  
<<<iteration:[160/657] - total_loss: 0.3303  obj_loss: 0.1070  noobj_loss: 0.1198  bbox_loss: 0.0193  cls_loss: 0.0669  
<<<iteration:[180/657] - total_loss: 0.3570  obj_loss: 0.1282  noobj_loss: 0.1301  bbox_loss: 0.0193  cls_loss: 0.0674  
<<<iteration:[200/657] - total_loss: 0.3340  obj_loss: 0.1105  noobj_loss: 0.1260  bbox_loss: 0.0199  cls_loss: 0.0609  
<<<iteration:[220/657] - total_loss: 0.3744  obj_loss: 0.1207  noobj_loss: 0.1237  bbox_loss: 0.0244  cls_loss: 0.0698  
<<<iteration:[240/657] - total_loss: 0.3469  obj_loss: 0.1180  noobj_loss: 0.1275  bbox_loss: 0.0193  cls_loss: 0.0688  
<<<iteration:[260/657] - total_l

<<<iteration:[160/657] - total_loss: 0.3390  obj_loss: 0.1118  noobj_loss: 0.1311  bbox_loss: 0.0200  cls_loss: 0.0615  
<<<iteration:[180/657] - total_loss: 0.3329  obj_loss: 0.1186  noobj_loss: 0.1279  bbox_loss: 0.0186  cls_loss: 0.0574  
<<<iteration:[200/657] - total_loss: 0.3244  obj_loss: 0.1157  noobj_loss: 0.1220  bbox_loss: 0.0184  cls_loss: 0.0555  
<<<iteration:[220/657] - total_loss: 0.3396  obj_loss: 0.1174  noobj_loss: 0.1225  bbox_loss: 0.0182  cls_loss: 0.0699  
<<<iteration:[240/657] - total_loss: 0.3305  obj_loss: 0.1173  noobj_loss: 0.1216  bbox_loss: 0.0181  cls_loss: 0.0616  
<<<iteration:[260/657] - total_loss: 0.3297  obj_loss: 0.1071  noobj_loss: 0.1266  bbox_loss: 0.0193  cls_loss: 0.0630  
<<<iteration:[280/657] - total_loss: 0.3436  obj_loss: 0.1357  noobj_loss: 0.1205  bbox_loss: 0.0190  cls_loss: 0.0523  
<<<iteration:[300/657] - total_loss: 0.3552  obj_loss: 0.1165  noobj_loss: 0.1225  bbox_loss: 0.0239  cls_loss: 0.0581  
<<<iteration:[320/657] - total_l

<<<iteration:[220/657] - total_loss: 0.3466  obj_loss: 0.1227  noobj_loss: 0.1244  bbox_loss: 0.0196  cls_loss: 0.0637  
<<<iteration:[240/657] - total_loss: 0.3298  obj_loss: 0.1178  noobj_loss: 0.1216  bbox_loss: 0.0174  cls_loss: 0.0641  
<<<iteration:[260/657] - total_loss: 0.3361  obj_loss: 0.1242  noobj_loss: 0.1248  bbox_loss: 0.0187  cls_loss: 0.0561  
<<<iteration:[280/657] - total_loss: 0.3274  obj_loss: 0.1135  noobj_loss: 0.1269  bbox_loss: 0.0188  cls_loss: 0.0563  
<<<iteration:[300/657] - total_loss: 0.3381  obj_loss: 0.1217  noobj_loss: 0.1235  bbox_loss: 0.0187  cls_loss: 0.0608  
<<<iteration:[320/657] - total_loss: 0.3585  obj_loss: 0.1111  noobj_loss: 0.1213  bbox_loss: 0.0222  cls_loss: 0.0756  
<<<iteration:[340/657] - total_loss: 0.3134  obj_loss: 0.1185  noobj_loss: 0.1239  bbox_loss: 0.0167  cls_loss: 0.0492  
<<<iteration:[360/657] - total_loss: 0.3312  obj_loss: 0.1068  noobj_loss: 0.1254  bbox_loss: 0.0195  cls_loss: 0.0643  
<<<iteration:[380/657] - total_l

<<<iteration:[280/657] - total_loss: 0.3318  obj_loss: 0.1097  noobj_loss: 0.1164  bbox_loss: 0.0200  cls_loss: 0.0641  
<<<iteration:[300/657] - total_loss: 0.3366  obj_loss: 0.1276  noobj_loss: 0.1262  bbox_loss: 0.0191  cls_loss: 0.0506  
<<<iteration:[320/657] - total_loss: 0.3459  obj_loss: 0.1207  noobj_loss: 0.1224  bbox_loss: 0.0198  cls_loss: 0.0651  
<<<iteration:[340/657] - total_loss: 0.3276  obj_loss: 0.1140  noobj_loss: 0.1228  bbox_loss: 0.0191  cls_loss: 0.0569  
<<<iteration:[360/657] - total_loss: 0.3323  obj_loss: 0.1298  noobj_loss: 0.1205  bbox_loss: 0.0174  cls_loss: 0.0555  
<<<iteration:[380/657] - total_loss: 0.3302  obj_loss: 0.1096  noobj_loss: 0.1255  bbox_loss: 0.0198  cls_loss: 0.0590  
<<<iteration:[400/657] - total_loss: 0.3345  obj_loss: 0.1144  noobj_loss: 0.1187  bbox_loss: 0.0213  cls_loss: 0.0544  
<<<iteration:[420/657] - total_loss: 0.3473  obj_loss: 0.1200  noobj_loss: 0.1255  bbox_loss: 0.0190  cls_loss: 0.0694  
<<<iteration:[440/657] - total_l

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Train Loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train bbox Loss,█▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train class Loss,█▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train obj Loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,█▄▂▂▅▃▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val bbox Loss,█▄▂▂█▅▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val class Loss,█▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val obj Loss,▂▁▂▃▂▄▃▃▄▅▃▄▆▇▆▅▆▆▅▇▆▆▆▇▇▇▆▆█▆▇█▇▇▇██▇██

0,1
Train Loss,0.3376
Train bbox Loss,0.01932
Train class Loss,0.06206
Train obj Loss,0.1175
Val Loss,0.38656
Val bbox Loss,0.02824
Val class Loss,0.06158
Val obj Loss,0.12768


## Test Dataset Inference

In [None]:
import numpy as np
import os 
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torch import nn
import torchsummary
from torch.utils.data import DataLoader
from collections import defaultdict
from torchvision.utils import make_grid

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def load_model(ckpt_path, num_classes, device):
    checkpoint = torch.load(ckpt_path, map_location=device)
    model = YOLO_SWIN(num_classes=num_classes)
    model.load_state_dict(checkpoint)
    model = model.to(device)
    model.eval()
    return model

In [None]:
IMAGE_SIZE=448
transformer = A.Compose([
            A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ],
        bbox_params=A.BboxParams(format='yolo', label_fields=['class_ids']),
)

In [None]:
# ckpt_path="./trained_model/YOLO_SWIN_T_body_LR0.0001_AUG30/model_90.pth"
ckpt_path="/workspace/Plastic_Bottle_defect_detection/trained_model/YOLO_SWIN_T_neck_LR0.0001_Image_Patch50/model_100.pth"
model = load_model(ckpt_path, NUM_CLASSES, device)

In [None]:
NECK_PATH = '/home/host_data/PET_data/Neck'
BODY_PATH = '/home/host_data/PET_data/Body'
test_dataset=PET_dataset("neck" ,neck_dir=NECK_PATH,body_dir=BODY_PATH,phase='test', transformer=transformer, aug=None)
test_dataloaders = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

In [None]:
len(test_dataset)

In [None]:
@torch.no_grad()
def model_predict(image, model, conf_thres=0.2, iou_threshold=0.1):
    predictions = model(image)
    prediction = predictions.detach().cpu().squeeze(dim=0)
    f_map=prediction

#     print(prediction.shape)
    
    grid_size = prediction.shape[-1]
    y_grid, x_grid = torch.meshgrid(torch.arange(grid_size), torch.arange(grid_size))
    stride_size = IMAGE_SIZE/grid_size

    conf = prediction[[0,5], ...].reshape(1, -1)
    xc = (prediction[[1,6], ...] * IMAGE_SIZE + x_grid*stride_size).reshape(1,-1)
    yc = (prediction[[2,7], ...] * IMAGE_SIZE + y_grid*stride_size).reshape(1,-1)
    w = (prediction[[3,8], ...] * IMAGE_SIZE).reshape(1,-1)
    h = (prediction[[4,9], ...] * IMAGE_SIZE).reshape(1,-1)
    cls = torch.max(prediction[10:, ...].reshape(NUM_CLASSES, -1), dim=0).indices.tile(1,2)
    
    x_min = xc - w/2
    y_min = yc - h/2
    x_max = xc + w/2
    y_max = yc + h/2

    prediction_res = torch.cat([x_min, y_min, x_max, y_max, conf, cls], dim=0)
    prediction_res = prediction_res.transpose(0,1)

    # x_min과 y_min이 음수가 되지않고, x_max와 y_max가 이미지 크기를 넘지 않게 제한
    prediction_res[:, 2].clip(min=0, max=image.shape[1]) 
    prediction_res[:, 3].clip(min=0, max=image.shape[0])
        
    pred_res = prediction_res[prediction_res[:, 4] > conf_thres]
    nms_index = torchvision.ops.nms(boxes=pred_res[:, 0:4], scores=pred_res[:, 4], iou_threshold=iou_threshold)
    pred_res_ = pred_res[nms_index].numpy()
    
    n_obj = pred_res_.shape[0]
    bboxes = np.zeros(shape=(n_obj, 4), dtype=np.float32)
    bboxes[:, 0:2] = (pred_res_[:, 0:2] + pred_res_[:, 2:4]) / 2
    bboxes[:, 2:4] = pred_res_[:, 2:4] - pred_res_[:, 0:2]
    scores = pred_res_[:, 4]
    class_ids = pred_res_[:, 5]
    
    # 이미지 값이 들어가면 모델을 통해서, 후처리까지 포함된 yolo 포멧의 box좌표, 그 좌표에 대한 confidence score
    # 그리고 class id를 반환
    return bboxes, scores, class_ids,f_map

In [None]:
pred_images = []
pred_labels =[]
feature_maps=[]

for index, batch in enumerate(test_dataloaders):
    images = batch[0].to(device)
    bboxes, scores, class_ids, fmap = model_predict(images, model, conf_thres=0.1, iou_threshold=0.1)
    
    if len(bboxes) > 0:
        prediction_yolo = np.concatenate([bboxes, scores[:, np.newaxis], class_ids[:, np.newaxis]], axis=1)
    else:
        prediction_yolo = np.array([])
    
    # 텐서형의 이미지를 다시 unnormalize를 시키고, 다시 chw를 hwc로 바꾸고 넘파이로 바꾼다.
    np_image = make_grid(images[0], normalize=True).cpu().permute(1,2,0).numpy()
    pred_images.append(np_image)
    pred_labels.append(prediction_yolo)
    feature_maps.append(fmap)

    

In [None]:
from ipywidgets import interact

@interact(index=(0,len(pred_images)-1))
def show_result(index=0):
    print(pred_labels[index])
    if len(pred_labels[index]) > 0:
        result = visualize(pred_images[index], pred_labels[index][:, 0:4], pred_labels[index][:, 5])
    else:
        result = pred_images[index]
        
    plt.figure(figsize=(6,6))
    plt.imshow(result)
    plt.show()

In [None]:
#feature map에서 0,5번쨰에 해당하는 objectness 투사
from ipywidgets import interact

@interact(index=(0,len(pred_images)-1))
def show_result(index=0):
    print(pred_labels[index])
    if len(pred_labels[index]) > 0:
        result = visualize(pred_images[index], pred_labels[index][:, 0:4], pred_labels[index][:, 5])
    else:
        result = pred_images[index]
    
    f_map=feature_maps[index]
    zero_canvas=np.zeros((448,448))

    cv_re1=cv2.resize(f_map[0,:,:].numpy(),(448,448))
    cv_re2=cv2.resize(f_map[5,:,:].numpy(),(448,448))
    zero_canvas=zero_canvas+cv_re1+cv_re2

    
    fig = plt.figure()
    rows = 1
    cols = 2
    ax1 = fig.add_subplot(rows, cols, 1)
    ax1.imshow(result)
    ax1.set_title('Detection')
    ax1.axis("off")
    
    ax2 = fig.add_subplot(rows, cols, 2)
    ax2.imshow(zero_canvas)
    ax2.set_title('feature map-objectness')
    ax2.axis("off")

    plt.show()
