In [1]:
from __future__ import print_function
import os
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
import argparse
import torch.utils.data as data
# from data import WiderFaceDetection, detection_collate, preproc, cfg_mnet, cfg_re50
# from layers.modules import MultiBoxLoss
# from layers.functions.prior_box import PriorBox
import time
import datetime
import math

import cv2
import numpy as np
import random
# from models.retinaface import RetinaFace

In [2]:
torch.cuda.empty_cache() 

In [3]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb=500, min_alloc_size_mb=20'

In [4]:
cfg_mnet = {
    'name': 'mobilenet0.25',
    'min_sizes': [[16, 32], [64, 128], [256, 512]],
    'steps': [8, 16, 32],
    'variance': [0.1, 0.2],
    'clip': False,
    'loc_weight': 2.0,
    'gpu_train': True,
    'batch_size': 32,
    'ngpu': 1,
    'epoch': 250,
    'decay1': 190,
    'decay2': 220,
    'image_size': 640,
    'pretrain': True,
    'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
    'in_channel': 32,
    'out_channel': 64
}

In [5]:
cfg_re50 = {
    'name': 'Resnet50',
    'min_sizes': [[16, 32], [64, 128], [256, 512]],
    'steps': [8, 16, 32],
    'variance': [0.1, 0.2],
    'clip': False,
    'loc_weight': 2.0,
    'gpu_train': True,
    'batch_size': 24,
    'ngpu': 4,
    'epoch': 100,
    'decay1': 70,
    'decay2': 90,
    'image_size': 840,
    'pretrain': True,
    'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
    'in_channel': 256,
    'out_channel': 256
}

In [6]:
cfg_convnext = {
    'name': 'convnext_large',
    'min_sizes': [[16, 32], [64, 128], [256, 512]],
    'steps': [8, 16, 32],
    'variance': [0.1, 0.2],
    'clip': False,
    'loc_weight': 2.0,
    'gpu_train': True,
    'batch_size': 3,
    'ngpu': 1,
    'epoch': 100,
    'decay1': 70,
    'decay2': 90,
    'image_size': 840,
    'pretrain': True,
    'return_layers': {'stem': 1, 'stages': 2, 'head' : 3},
    'in_channel': 256,
    'out_channel': 256
}

In [7]:
#  torch.cuda.empty_cache()

# data

In [8]:
def _crop(image, boxes, labels, landm, img_dim):
    height, width, _ = image.shape
    pad_image_flag = True

    for _ in range(250):
        """
        if random.uniform(0, 1) <= 0.2:
            scale = 1.0
        else:
            scale = random.uniform(0.3, 1.0)
        """
        PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
        scale = random.choice(PRE_SCALES)
        short_side = min(width, height)
        w = int(scale * short_side)
        h = w

        if width == w:
            l = 0
        else:
            l = random.randrange(width - w)
        if height == h:
            t = 0
        else:
            t = random.randrange(height - h)
        roi = np.array((l, t, l + w, t + h))

        value = matrix_iof(boxes, roi[np.newaxis])
        flag = (value >= 1)
        if not flag.any():
            continue

        centers = (boxes[:, :2] + boxes[:, 2:]) / 2
        mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
        boxes_t = boxes[mask_a].copy()
        labels_t = labels[mask_a].copy()
        landms_t = landm[mask_a].copy()
        landms_t = landms_t.reshape([-1, 5, 2])

        if boxes_t.shape[0] == 0:
            continue

        image_t = image[roi[1]:roi[3], roi[0]:roi[2]]

        boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
        boxes_t[:, :2] -= roi[:2]
        boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
        boxes_t[:, 2:] -= roi[:2]

        # landm
        landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2]
        landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0]))
        landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2])
        landms_t = landms_t.reshape([-1, 10])


	# make sure that the cropped image contains at least one face > 16 pixel at training image scale
        b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
        b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
        mask_b = np.minimum(b_w_t, b_h_t) > 0.0
        boxes_t = boxes_t[mask_b]
        labels_t = labels_t[mask_b]
        landms_t = landms_t[mask_b]

        if boxes_t.shape[0] == 0:
            continue

        pad_image_flag = False

        return image_t, boxes_t, labels_t, landms_t, pad_image_flag
    return image, boxes, labels, landm, pad_image_flag

In [9]:
def _distort(image):

    def _convert(image, alpha=1, beta=0):
        tmp = image.astype(float) * alpha + beta
        tmp[tmp < 0] = 0
        tmp[tmp > 255] = 255
        image[:] = tmp

    image = image.copy()

    if random.randrange(2):

        #brightness distortion
        if random.randrange(2):
            _convert(image, beta=random.uniform(-32, 32))

        #contrast distortion
        if random.randrange(2):
            _convert(image, alpha=random.uniform(0.5, 1.5))

        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

        #saturation distortion
        if random.randrange(2):
            _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))

        #hue distortion
        if random.randrange(2):
            tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
            tmp %= 180
            image[:, :, 0] = tmp

        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)

    else:

        #brightness distortion
        if random.randrange(2):
            _convert(image, beta=random.uniform(-32, 32))

        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

        #saturation distortion
        if random.randrange(2):
            _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))

        #hue distortion
        if random.randrange(2):
            tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
            tmp %= 180
            image[:, :, 0] = tmp

        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)

        #contrast distortion
        if random.randrange(2):
            _convert(image, alpha=random.uniform(0.5, 1.5))

    return image

In [10]:
def _mirror(image, boxes, landms):
    _, width, _ = image.shape
    if random.randrange(2):
        image = image[:, ::-1]
        boxes = boxes.copy()
        boxes[:, 0::2] = width - boxes[:, 2::-2]

        # landm
        landms = landms.copy()
        landms = landms.reshape([-1, 5, 2])
        landms[:, :, 0] = width - landms[:, :, 0]
        tmp = landms[:, 1, :].copy()
        landms[:, 1, :] = landms[:, 0, :]
        landms[:, 0, :] = tmp
        tmp1 = landms[:, 4, :].copy()
        landms[:, 4, :] = landms[:, 3, :]
        landms[:, 3, :] = tmp1
        landms = landms.reshape([-1, 10])

    return image, boxes, landms

In [11]:
def _pad_to_square(image, rgb_mean, pad_image_flag):
    if not pad_image_flag:
        return image
    height, width, _ = image.shape
    long_side = max(width, height)
    image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
    image_t[:, :] = rgb_mean
    image_t[0:0 + height, 0:0 + width] = image
    return image_t

In [12]:
class WiderFaceDetection(data.Dataset):
    def __init__(self, txt_path, preproc=None):
        self.preproc = preproc
        self.imgs_path = []
        self.words = []
        f = open(txt_path,'r')
        lines = f.readlines()
        isFirst = True
        labels = []
        for line in lines:
            line = line.rstrip()
            if line.startswith('#'):
                if isFirst is True:
                    isFirst = False
                else:
                    labels_copy = labels.copy()
                    self.words.append(labels_copy)
                    labels.clear()
                path = line[2:]
                path = txt_path.replace('label.txt','images/') + path
                self.imgs_path.append(path)
            else:
                line = line.split(' ')
                label = [float(x) for x in line]
                labels.append(label)

        self.words.append(labels)

    def __len__(self):
        return len(self.imgs_path)

    def __getitem__(self, index):
        img = cv2.imread(self.imgs_path[index])
        height, width, _ = img.shape

        labels = self.words[index]
        annotations = np.zeros((0, 15))
        if len(labels) == 0:
            return annotations
        for idx, label in enumerate(labels):
            annotation = np.zeros((1, 15))
            # bbox
            annotation[0, 0] = label[0]  # x1
            annotation[0, 1] = label[1]  # y1
            annotation[0, 2] = label[0] + label[2]  # x2
            annotation[0, 3] = label[1] + label[3]  # y2

            # landmarks
            annotation[0, 4] = label[4]    # l0_x
            annotation[0, 5] = label[5]    # l0_y
            annotation[0, 6] = label[7]    # l1_x
            annotation[0, 7] = label[8]    # l1_y
            annotation[0, 8] = label[10]   # l2_x
            annotation[0, 9] = label[11]   # l2_y
            annotation[0, 10] = label[13]  # l3_x
            annotation[0, 11] = label[14]  # l3_y
            annotation[0, 12] = label[16]  # l4_x
            annotation[0, 13] = label[17]  # l4_y
            if (annotation[0, 4]<0):
                annotation[0, 14] = -1
            else:
                annotation[0, 14] = 1

            annotations = np.append(annotations, annotation, axis=0)
        target = np.array(annotations)
        if self.preproc is not None:
            img, target = self.preproc(img, target)

        return torch.from_numpy(img), target

In [13]:
def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on 0 dim
    """
    targets = []
    imgs = []
    for _, sample in enumerate(batch):
        for _, tup in enumerate(sample):
            if torch.is_tensor(tup):
                imgs.append(tup)
            elif isinstance(tup, type(np.empty(0))):
                annos = torch.from_numpy(tup).float()
                targets.append(annos)

    return (torch.stack(imgs, 0), targets)

In [14]:
def _resize_subtract_mean(image, insize, rgb_mean):
    interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
    interp_method = interp_methods[random.randrange(5)]
    image = cv2.resize(image, (insize, insize), interpolation=interp_method)
    image = image.astype(np.float32)
    image -= rgb_mean
    return image.transpose(2, 0, 1)

In [15]:
class preproc(object):

    def __init__(self, img_dim, rgb_means):
        self.img_dim = img_dim
        self.rgb_means = rgb_means

    def __call__(self, image, targets):
        assert targets.shape[0] > 0, "this image does not have gt"

        boxes = targets[:, :4].copy()
        labels = targets[:, -1].copy()
        landm = targets[:, 4:-1].copy()

        image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim)
        image_t = _distort(image_t)
        image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag)
        image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)
        height, width, _ = image_t.shape
        image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
        boxes_t[:, 0::2] /= width
        boxes_t[:, 1::2] /= height

        landm_t[:, 0::2] /= width
        landm_t[:, 1::2] /= height

        labels_t = np.expand_dims(labels_t, 1)
        targets_t = np.hstack((boxes_t, landm_t, labels_t))

        return image_t, targets_t

# utils

In [16]:
def matrix_iou(a, b):
    """
    return iou of a and b, numpy version for data augenmentation
    """
    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])

    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
    area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
    return area_i / (area_a[:, np.newaxis] + area_b - area_i)

In [17]:
def matrix_iof(a, b):
    """
    return iof of a and b, numpy version for data augenmentation
    """
    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])

    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
    return area_i / np.maximum(area_a[:, np.newaxis], 1)

In [18]:
def point_form(boxes):
    """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
    representation for comparison to point form ground truth data.
    Args:
        boxes: (tensor) center-size default boxes from priorbox layers.
    Return:
        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
    """
    return torch.cat((boxes[:, :2] - boxes[:, 2:]/2,     # xmin, ymin
                     boxes[:, :2] + boxes[:, 2:]/2), 1)  # xmax, ymax

In [19]:
def encode(matched, priors, variances):
    """Encode the variances from the priorbox layers into the ground truth boxes
    we have matched (based on jaccard overlap) with the prior boxes.
    Args:
        matched: (tensor) Coords of ground truth for each prior in point-form
            Shape: [num_priors, 4].
        priors: (tensor) Prior boxes in center-offset form
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        encoded boxes (tensor), Shape: [num_priors, 4]
    """

    # dist b/t match center and prior's center
    g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
    # encode variance
    g_cxcy /= (variances[0] * priors[:, 2:])
    # match wh / prior wh
    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
    g_wh = torch.log(g_wh) / variances[1]
    # return target for smooth_l1_loss
    return torch.cat([g_cxcy, g_wh], 1)  # [num_priors,4]

In [20]:
def encode_landm(matched, priors, variances):
    """Encode the variances from the priorbox layers into the ground truth boxes
    we have matched (based on jaccard overlap) with the prior boxes.
    Args:
        matched: (tensor) Coords of ground truth for each prior in point-form
            Shape: [num_priors, 10].
        priors: (tensor) Prior boxes in center-offset form
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        encoded landm (tensor), Shape: [num_priors, 10]
    """

    # dist b/t match center and prior's center
    matched = torch.reshape(matched, (matched.size(0), 5, 2))
    priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
    priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
    priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
    priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
    priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
    g_cxcy = matched[:, :, :2] - priors[:, :, :2]
    # encode variance
    g_cxcy /= (variances[0] * priors[:, :, 2:])
    # g_cxcy /= priors[:, :, 2:]
    g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
    # return target for smooth_l1_loss
    return g_cxcy

In [21]:
def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """

    boxes = torch.cat((
        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]
    return boxes

In [22]:
def decode_landm(pre, priors, variances):
    """Decode landm from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        pre (tensor): landm predictions for loc layers,
            Shape: [num_priors,10]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded landm predictions
    """
    landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
                        ), dim=1)
    return landms

In [23]:
def jaccard(box_a, box_b):
    """Compute the jaccard overlap of two sets of boxes.  The jaccard overlap
    is simply the intersection over union of two boxes.  Here we operate on
    ground truth boxes and default boxes.
    E.g.:
        A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
    Args:
        box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
        box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
    Return:
        jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
    """
    inter = intersect(box_a, box_b)
    area_a = ((box_a[:, 2]-box_a[:, 0]) *
              (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
    area_b = ((box_b[:, 2]-box_b[:, 0]) *
              (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]
    union = area_a + area_b - inter
    return inter / union  # [A,B]

In [24]:
def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
    """Match each prior box with the ground truth box of the highest jaccard
    overlap, encode the bounding boxes, then return the matched indices
    corresponding to both confidence and location preds.
    Args:
        threshold: (float) The overlap threshold used when mathing boxes.
        truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
        priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
        variances: (tensor) Variances corresponding to each prior coord,
            Shape: [num_priors, 4].
        labels: (tensor) All the class labels for the image, Shape: [num_obj].
        landms: (tensor) Ground truth landms, Shape [num_obj, 10].
        loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
        conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
        landm_t: (tensor) Tensor to be filled w/ endcoded landm targets.
        idx: (int) current batch index
    Return:
        The matched indices corresponding to 1)location 2)confidence 3)landm preds.
    """
    # jaccard index
    overlaps = jaccard(
        truths,
        point_form(priors)
    )
    # (Bipartite Matching)
    # [1,num_objects] best prior for each ground truth
    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)

    # ignore hard gt
    valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
    best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
    if best_prior_idx_filter.shape[0] <= 0:
        loc_t[idx] = 0
        conf_t[idx] = 0
        return

    # [1,num_priors] best ground truth for each prior
    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
    best_truth_idx.squeeze_(0)
    best_truth_overlap.squeeze_(0)
    best_prior_idx.squeeze_(1)
    best_prior_idx_filter.squeeze_(1)
    best_prior_overlap.squeeze_(1)
    best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2)  # ensure best prior
    # TODO refactor: index  best_prior_idx with long tensor
    # ensure every gt matches with its prior of max overlap
    for j in range(best_prior_idx.size(0)):     # 判别此anchor是预测哪一个boxes
        best_truth_idx[best_prior_idx[j]] = j
    matches = truths[best_truth_idx]            # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
    conf = labels[best_truth_idx]               # Shape: [num_priors]      此处为每一个anchor对应的label取出来
    conf[best_truth_overlap < threshold] = 0    # label as background   overlap<0.35的全部作为负样本
    loc = encode(matches, priors, variances)

    matches_landm = landms[best_truth_idx]
    landm = encode_landm(matches_landm, priors, variances)
    loc_t[idx] = loc    # [num_priors,4] encoded offsets to learn
    conf_t[idx] = conf  # [num_priors] top class label for each prior
    landm_t[idx] = landm


In [25]:
def intersect(box_a, box_b):
    """ We resize both tensors to [A,B,2] without new malloc:
    [A,2] -> [A,1,2] -> [A,B,2]
    [B,2] -> [1,B,2] -> [A,B,2]
    Then we compute the area of intersect between box_a and box_b.
    Args:
      box_a: (tensor) bounding boxes, Shape: [A,4].
      box_b: (tensor) bounding boxes, Shape: [B,4].
    Return:
      (tensor) intersection area, Shape: [A,B].
    """
    A = box_a.size(0)
    B = box_b.size(0)
    max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
                       box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
    min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
                       box_b[:, :2].unsqueeze(0).expand(A, B, 2))
    inter = torch.clamp((max_xy - min_xy), min=0)
    return inter[:, :, 0] * inter[:, :, 1]

In [26]:
def log_sum_exp(x):
    """Utility function for computing log_sum_exp while determining
    This will be used to determine unaveraged confidence loss across
    all examples in a batch.
    Args:
        x (Variable(tensor)): conf_preds from conf layers
    """
    x_max = x.data.max()
    return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max

# layers

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# from utils.box_utils import match, log_sum_exp
# from data import cfg_mnet
GPU = cfg_mnet['gpu_train']

class MultiBoxLoss(nn.Module):
    """SSD Weighted Loss Function
    Compute Targets:
        1) Produce Confidence Target Indices by matching  ground truth boxes
           with (default) 'priorboxes' that have jaccard index > threshold parameter
           (default threshold: 0.5).
        2) Produce localization target by 'encoding' variance into offsets of ground
           truth boxes and their matched  'priorboxes'.
        3) Hard negative mining to filter the excessive number of negative examples
           that comes with using a large number of default bounding boxes.
           (default negative:positive ratio 3:1)
    Objective Loss:
        L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
        Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
        weighted by α which is set to 1 by cross val.
        Args:
            c: class confidences =  #face classification loss
            l: predicted boxes = #face box regression loss
            # Facial landmark regression loss
            g: ground truth boxes
            N: number of matched default boxes
        See: https://arxiv.org/pdf/1512.02325.pdf for more details.
    """

    def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
        super(MultiBoxLoss, self).__init__()
        self.num_classes = num_classes
        self.threshold = overlap_thresh
        self.background_label = bkg_label
        self.encode_target = encode_target
        self.use_prior_for_matching = prior_for_matching
        self.do_neg_mining = neg_mining
        self.negpos_ratio = neg_pos
        self.neg_overlap = neg_overlap
        self.variance = [0.1, 0.2]

    def forward(self, predictions, priors, targets):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net.
                conf shape: torch.size(batch_size,num_priors,num_classes)
                loc shape: torch.size(batch_size,num_priors,4)
                priors shape: torch.size(num_priors,4)

            ground_truth (tensor): Ground truth boxes and labels for a batch,
                shape: [batch_size,num_objs,5] (last idx is the label).
        """
        
#         print("predictions \n", predictions)
#         print("priors \n", priors)
#         print("targets \n", targets)
        
        loc_data, conf_data, landm_data = predictions
        priors = priors
        num = loc_data.size(0)
        num_priors = (priors.size(0))
        
#         print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
#         print(priors.shape)
#         print(num)
#         print(num_priors)
        
        
        
#         print("############################")
#         print(loc_data.shape)
#         print(conf_data.shape)
#         print(landm_data.shape)
        
#         print(loc_data)
#         print(conf_data)
#         print(num)
#         print(num_priors)
        

        # match priors (default boxes) and ground truth boxes
        loc_t = torch.Tensor(num, num_priors, 4)
        landm_t = torch.Tensor(num, num_priors, 10)
        conf_t = torch.LongTensor(num, num_priors)
        
#         print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
#         print(loc_t.shape)
#         print(landm_t.shape)
#         print(conf_t.shape)
        
        
        
        
        for idx in range(num):
            truths = targets[idx][:, :4].data
            labels = targets[idx][:, -1].data
            landms = targets[idx][:, 4:14].data
            defaults = priors.data
            match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
        if GPU:
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()
            landm_t = landm_t.cuda()

        zeros = torch.tensor(0).cuda()
        # landm Loss (Smooth L1)
        # Shape: [batch,num_priors,10]
        pos1 = conf_t > zeros
        num_pos_landm = pos1.long().sum(1, keepdim=True)
        N1 = max(num_pos_landm.data.sum().float(), 1)
        pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
        landm_p = landm_data[pos_idx1].view(-1, 10)
        landm_t = landm_t[pos_idx1].view(-1, 10)
        loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')


        pos = conf_t != zeros
        conf_t[pos] = 1

        # Localization Loss (Smooth L1)
        # Shape: [batch,num_priors,4]
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
#         print("pos_idx")
#         print(pos_idx)
        
        loc_p = loc_data[pos_idx].view(-1, 4)
        loc_t = loc_t[pos_idx].view(-1, 4)
        loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')

        # Compute max conf across batch for hard negative mining
        batch_conf = conf_data.view(-1, self.num_classes)
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))

        # Hard Negative Mining
        loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
        loss_c = loss_c.view(num, -1)
        _, loss_idx = loss_c.sort(1, descending=True)
        _, idx_rank = loss_idx.sort(1)
        num_pos = pos.long().sum(1, keepdim=True)
        num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
        neg = idx_rank < num_neg.expand_as(idx_rank)

        # Confidence Loss Including Positive and Negative Examples
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
        targets_weighted = conf_t[(pos+neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')

        # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
        N = max(num_pos.data.sum().float(), 1)
        loss_l /= N
        loss_c /= N
        loss_landm /= N1

        return loss_l, loss_c, loss_landm

In [28]:
from math import ceil
from itertools import product as product

class PriorBox(object):
    def __init__(self, cfg, image_size=None, phase='train'):
        super(PriorBox, self).__init__()
        self.min_sizes = cfg['min_sizes']
        self.steps = cfg['steps']
        self.clip = cfg['clip']
        self.image_size = image_size
        self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
        self.name = "s"
#         print("prior")
        
#         print("min_sizes")
#         print(self.min_sizes)
#         print("image_size")
#         print(self.image_size)
#         print("steps")
#         print(self.steps)
#         print("feature_maps")
#         print(self.feature_maps)
        
        
    def forward(self):
        anchors = []
        for k, f in enumerate(self.feature_maps):
#             print(f)
            min_sizes = self.min_sizes[k]
            for i, j in product(range(f[0]), range(f[1])):
                for min_size in min_sizes:
                    s_kx = min_size / self.image_size[1]
                    s_ky = min_size / self.image_size[0]
                    dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
                    dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
#                     print("dddddddddddddddddd")
#                     print(s_kx)
#                     print(s_ky)
#                     print(dense_cx)
#                     print(dense_cy)
                    for cy, cx in product(dense_cy, dense_cx):
#                         print("xxxxxxxxxxxxxxxxxxxxxx")
#                         print(cy)
#                         print(cx)
                        anchors += [cx, cy, s_kx, s_ky]
#                         print(anchors)
        
#         print("len anchors")
#         print(len(anchors))
        
        # back to torch land
        output = torch.Tensor(anchors).view(-1, 4)
#         print('pppppppppppppp')
#         print(output.shape)
        if self.clip:
            output.clamp_(max=1, min=0)
        return output

In [29]:
# pb = priorbox = PriorBox(cfg, image_size=(img_dim, img_dim)) #바운딩박스 만드는 코드
# pb.forward()

# models

In [30]:
import timm

In [31]:
# model = timm.models.convnext_large(pretrained=True)

In [32]:
# print(model)

In [33]:
import time
import torch
import torch.nn as nn
import torchvision.models._utils as _utils
import torchvision.models as models
import torch.nn.functional as F
from torch.autograd import Variable

def conv_bn(inp, oup, stride = 1, leaky = 0):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.LeakyReLU(negative_slope=leaky, inplace=True)
    )

def conv_bn_no_relu(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
    )

def conv_bn1X1(inp, oup, stride, leaky=0):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
        nn.BatchNorm2d(oup),
        nn.LeakyReLU(negative_slope=leaky, inplace=True)
    )

def conv_dw(inp, oup, stride, leaky=0.1):
    return nn.Sequential(
        nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
        nn.BatchNorm2d(inp),
        nn.LeakyReLU(negative_slope= leaky,inplace=True),

        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.LeakyReLU(negative_slope= leaky,inplace=True),
    )

class SSH(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(SSH, self).__init__()
        assert out_channel % 4 == 0
        leaky = 0
        if (out_channel <= 64):
            leaky = 0.1
        self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)

        self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
        self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)

        self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
        self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)

    def forward(self, input):
        conv3X3 = self.conv3X3(input)

        conv5X5_1 = self.conv5X5_1(input)
        conv5X5 = self.conv5X5_2(conv5X5_1)

        conv7X7_2 = self.conv7X7_2(conv5X5_1)
        conv7X7 = self.conv7x7_3(conv7X7_2)

        out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
        out = F.relu(out)
        return out

class FPN(nn.Module):
    def __init__(self,in_channels_list,out_channels):
        super(FPN,self).__init__()
        leaky = 0
        if (out_channels <= 64):
            leaky = 0.1
        self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
        self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
        self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)

        self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
        self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)

    def forward(self, input1, input2, input3):
        # names = list(input.keys())
#         print("input_size")
#         print(input.shape)
#         input = list(input.values())

        output1 = self.output1(input1)
        output2 = self.output2(input2)
        output3 = self.output3(input3)

        up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
        output2 = output2 + up3
        output2 = self.merge2(output2)

        up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
        output1 = output1 + up2
        output1 = self.merge1(output1)

        out = [output1, output2, output3]
        return out



class MobileNetV1(nn.Module):
    def __init__(self):
        super(MobileNetV1, self).__init__()
        self.stage1 = nn.Sequential(
            conv_bn(3, 8, 2, leaky = 0.1),    # 3
            conv_dw(8, 16, 1),   # 7
            conv_dw(16, 32, 2),  # 11
            conv_dw(32, 32, 1),  # 19
            conv_dw(32, 64, 2),  # 27
            conv_dw(64, 64, 1),  # 43
        )
        self.stage2 = nn.Sequential(
            conv_dw(64, 128, 2),  # 43 + 16 = 59
            conv_dw(128, 128, 1), # 59 + 32 = 91
            conv_dw(128, 128, 1), # 91 + 32 = 123
            conv_dw(128, 128, 1), # 123 + 32 = 155
            conv_dw(128, 128, 1), # 155 + 32 = 187
            conv_dw(128, 128, 1), # 187 + 32 = 219
        )
        self.stage3 = nn.Sequential(
            conv_dw(128, 256, 2), # 219 +3 2 = 241
            conv_dw(256, 256, 1), # 241 + 64 = 301
        )
        self.avg = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256, 1000)

    def forward(self, x):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.avg(x)
        # x = self.model(x)
        x = x.view(-1, 256)
        x = self.fc(x)
        return x

In [34]:
# mnet = MobileNetV1()
# print(mnet)

In [35]:
import torch
import torch.nn as nn
import torchvision.models.detection.backbone_utils as backbone_utils
import torchvision.models._utils as _utils
import torch.nn.functional as F
from collections import OrderedDict

# from models.net import MobileNetV1 as MobileNetV1
# from models.net import FPN as FPN
# from models.net import SSH as SSH



class ClassHead(nn.Module):
    def __init__(self,inchannels=512,num_anchors=3):
        super(ClassHead,self).__init__()
        self.num_anchors = num_anchors
        self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)

    def forward(self,x):
#         print('ClassHead')
#         print(x.shape)
        out = self.conv1x1(x)
#         print(out.shape)
        out = out.permute(0,2,3,1).contiguous()
#         print(out.shape)
        
        return out.view(out.shape[0], -1, 2)

class BboxHead(nn.Module):
    def __init__(self,inchannels=512,num_anchors=3):
        super(BboxHead,self).__init__()
        self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)

    def forward(self,x):
#         print('BboxHead')
#         print(x.shape)
        out = self.conv1x1(x)
#         print(out.shape)
        out = out.permute(0,2,3,1).contiguous()
#         print(out.shape)
        
        return out.view(out.shape[0], -1, 4)

class LandmarkHead(nn.Module):
    def __init__(self,inchannels=512,num_anchors=3):
        super(LandmarkHead,self).__init__()
        self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)

    def forward(self,x):
#         print('Landmarkhead')
#         print(x.shape)
        out = self.conv1x1(x)
#         print(out.shape)
        out = out.permute(0,2,3,1).contiguous()
#         print(out.shape)

        return out.view(out.shape[0], -1, 10)

class RetinaFace(nn.Module):
    def __init__(self, cfg = None, phase = 'train'):
        """
        :param cfg:  Network related settings.
        :param phase: train or test.
        """
        super(RetinaFace,self).__init__()
        self.phase = phase
        backbone = None
        if cfg['name'] == 'mobilenet0.25':
            backbone = MobileNetV1()
            if cfg['pretrain']:
                checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cuda'))
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in checkpoint['state_dict'].items():
                    name = k[7:]  # remove module.
                    new_state_dict[name] = v
                # load params
                backbone.load_state_dict(new_state_dict)
        elif cfg['name'] == 'Resnet50':
            import torchvision.models as models
            backbone = models.resnet50(pretrained=cfg['pretrain'])
        elif cfg['name'] == 'convnext_large':
            import timm
            backbone = timm.models.convnext_tiny(pretrained=True)
        
        self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
        # 가정: body['head']가 수정하고자 하는 Sequential 객체
        head_layers = list(self.body['head'].children())

        # 유지하고 싶은 레이어 이름 목록
        keep_layers = ['global_pool', 'norm']

        # 해당 레이어만 유지
        head_layers = [layer for layer, name in zip(head_layers, keep_layers) if name in keep_layers]
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.01, inplace=True)
        head_layers.append(self.leaky_relu)
        # 다시 Sequential로 변환
        self.body['head'] = nn.Sequential(*head_layers)
        
        
        self.max_pool_1 = nn.AdaptiveMaxPool2d((1, 1))
        self.transposed_conv_1 = nn.ConvTranspose2d(in_channels=96, out_channels=96, kernel_size=105, stride=1, padding=0)
        self.transposed_conv_2 = nn.ConvTranspose2d(in_channels=768, out_channels=768, kernel_size=53, stride=1, padding=0)
        self.transposed_conv_3 = nn.ConvTranspose2d(in_channels=768, out_channels=768, kernel_size=27, stride=1, padding=0)
        
#         print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1")
#         print(self.body)
        
#         self.cnext_last_layer = torch.nn.Conv2d(768, 256, kernel_size=1)
#         self.body = del self.body.ConvNeXtStage[2]
        
        #         print("backbone")
#         print(backbone)
#         print("body")
#         print(self.body)
        
    
    
        in_channels_stage2 = cfg['in_channel']
        in_channels_list = [
            96,
            768,
            768
        ]
        out_channels = cfg['out_channel']
        self.fpn = FPN(in_channels_list,out_channels) #Unet구조로 upsampling
        self.ssh1 = SSH(out_channels, out_channels) # layer을 2개쌓고 activation function을 줘서 세부정보를 포착
        self.ssh2 = SSH(out_channels, out_channels)
        self.ssh3 = SSH(out_channels, out_channels)

        self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
        self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
        self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
        
        
        
    def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
        classhead = nn.ModuleList()
        for i in range(fpn_num):
            classhead.append(ClassHead(inchannels,anchor_num))
        return classhead
    
    def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
        bboxhead = nn.ModuleList()
        for i in range(fpn_num):
            bboxhead.append(BboxHead(inchannels,anchor_num))
        return bboxhead

    def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
        landmarkhead = nn.ModuleList()
        for i in range(fpn_num):
            landmarkhead.append(LandmarkHead(inchannels,anchor_num))
        return landmarkhead

    def forward(self,inputs):
        out = self.body(inputs)
#         print(out.)
        
#         print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
#         for key, value in out.items():
#             print("Key:", key)
#             print("Value:", value.shape)
        
#         for k, v in out.items():  # IntermediateLayerGetter returns a dict
#             if v.shape[1] == 768:  # Apply the conv layer to the output with 768 channels
#                 out[k] = self.cnext_last_layer(v)
        
        out[1] = self.max_pool_1(out[1])
        out[1] = self.transposed_conv_1(out[1])
    
        out[2] = self.max_pool_1(out[2])
        out[2] = self.transposed_conv_2(out[2])
        
        out[3] = self.transposed_conv_3(out[3])
        
        
        # FPN
        fpn = self.fpn(out[1], out[2], out[3])
        
        # SSH
        feature1 = self.ssh1(fpn[0])
        feature2 = self.ssh2(fpn[1])
        feature3 = self.ssh3(fpn[2])
        
#         print("fpn")
#         print(fpn[0].shape)
#         print(fpn[1].shape)
#         print(fpn[2].shape)
        
#         print("ssh")
#         print(feature1.shape)
#         print(feature2.shape)
#         print(feature3.shape)
        
        features = [feature1, feature2, feature3]
        
#         for i, feature in enumerate(features):
#             print(f"feature{i+1} size: {feature.size()}")
            
        bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
        classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
        ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
#         print("~~~~~~~~~~~~~~~~~~~~~~~~~~~")
#         print(bbox_regressions.shape)
#         print(classifications.shape)
#         print(ldm_regressions.shape)
        
        
        if self.phase == 'train':
            output = (bbox_regressions, classifications, ldm_regressions)
        else:
            output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
        return output

# train

In [36]:
import easydict
from tqdm import tqdm

args = easydict.EasyDict({
    'training_dataset' : '/data/dhk/face/face_detecte/widerface/train/label.txt',
    'network' : 'convnext_large',
    'num_workers' :  0,
    'lr' : 1e-3,
    'momentum' : 0.9,
    'resume_net' : None,
    'resume_epoch' : 0,
    'weight_decay' : 5e-4,
    'gamma' : 0.1,
    'save_folder' : './weights/',
    
#     #여기서부터 학습
#     'resume_net' : "./weights/convnext_large_epoch_30.pth",
#     'resume_epoch' : 30,
    
    
    
    
    
   })

In [37]:
if not os.path.exists(args.save_folder):
    os.mkdir(args.save_folder)
cfg = None
if args.network == "mobile0.25":
    cfg = cfg_mnet
elif args.network == "resnet50":
    cfg = cfg_re50
elif args.network == "convnext_large":
    cfg = cfg_convnext
    
rgb_mean = (104, 117, 123) # bgr order
num_classes = 2
img_dim = cfg['image_size']
num_gpu = cfg['ngpu']
batch_size = cfg['batch_size']
max_epoch = cfg['epoch']
gpu_train = cfg['gpu_train']

num_workers = args.num_workers
momentum = args.momentum
weight_decay = args.weight_decay
initial_lr = args.lr
gamma = args.gamma
training_dataset = args.training_dataset
save_folder = args.save_folder

net = RetinaFace(cfg=cfg)
print("Printing net...")
print(net)

if args.resume_net is not None:
    print('Loading resume network...')
    state_dict = torch.load(args.resume_net, map_location=torch.device('cpu'))
    # create new OrderedDict that does not contain `module.`
#     from collections import OrderedDict
#     new_state_dict = OrderedDict()
#     for k, v in state_dict.items():
#         head = k[:7]
#         if head == 'module.':
#             name = k[7:] # remove `module.`
#         else:
#             name = k
#         new_state_dict[name] = v
    net.load_state_dict(state_dict)

if num_gpu > 1 and gpu_train:
    net = torch.nn.DataParallel(net).cuda()
else:
    net = net.cuda()
    torch.cuda.empty_cache()

    
# # 현재 할당된 메모리 양(바이트 단위)
# mem_bytes = torch.cuda.memory_allocated()

# # 기가바이트 단위로 변환
# mem_gb = mem_bytes / (1024 ** 3)
# print(f"Currently allocated memory: {mem_gb} GB")

cudnn.benchmark = True


optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False) #객체탐지 loss

priorbox = PriorBox(cfg, image_size=(img_dim, img_dim)) #바운딩박스 만드는 코드
# print("priorbox")
# print(priorbox)

with torch.no_grad():
    priors = priorbox.forward()
    priors = priors.cuda()

def train():
    net.train()
    epoch = 0 + args.resume_epoch
    print('Loading Dataset...')

    dataset = WiderFaceDetection(training_dataset,preproc(img_dim, rgb_mean))
#     print(dataset[0])
    
    epoch_size = math.ceil(len(dataset) / batch_size)
    max_iter = max_epoch * epoch_size

    stepvalues = (cfg['decay1'] * epoch_size, cfg['decay2'] * epoch_size)
    step_index = 0

    if args.resume_epoch > 0:
        start_iter = args.resume_epoch * epoch_size
    else:
        start_iter = 0

    for iteration in tqdm(range(start_iter, max_iter)):
#         torch.cuda.empty_cache() 
        if iteration % epoch_size == 0:
            # create batch iterator
            batch_iterator = iter(data.DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers, collate_fn=detection_collate))
            if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0 and epoch > cfg['decay1']):
                torch.save(net.state_dict(), save_folder + cfg['name']+ '_epoch_' + str(epoch) + '.pth')
            epoch += 1

        load_t0 = time.time()
        if iteration in stepvalues:
            step_index += 1
        lr = adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size)

        # load train data
        images, targets = next(batch_iterator)
#         print(images.shape)
#         print(target[0])
        
        images = images.cuda()
        targets = [anno.cuda() for anno in targets]
        
#         # 현재 할당된 메모리 양(바이트 단위)
#         mem_bytes = torch.cuda.memory_allocated()

#         # 기가바이트 단위로 변환
#         mem_gb = mem_bytes / (1024 ** 3)
#         print(f"Currently allocated memory: {mem_gb} GB")

        # forward
        out = net(images)
#         for i, output in enumerate(out):
#             print(f"Shape of output {i}: {output.shape}")
            
        # backprop
        optimizer.zero_grad()
        loss_l, loss_c, loss_landm = criterion(out, priors, targets)
        loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
        loss.backward()
        optimizer.step()
        load_t1 = time.time()
        torch.cuda.empty_cache()
        batch_time = load_t1 - load_t0
        eta = int(batch_time * (max_iter - iteration))
        print('Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Loc: {:.4f} Cla: {:.4f} Landm: {:.4f} || LR: {:.8f} || Batchtime: {:.4f} s || ETA: {}'
              .format(epoch, max_epoch, (iteration % epoch_size) + 1,
              epoch_size, iteration + 1, max_iter, loss_l.item(), loss_c.item(), loss_landm.item(), lr, batch_time, str(datetime.timedelta(seconds=eta))))

    torch.save(net.state_dict(), save_folder + cfg['name'] + '_Final.pth')
    # torch.save(net.state_dict(), save_folder + 'Final_Retinaface.pth')


def adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size):
    """Sets the learning rate
    # Adapted from PyTorch Imagenet example:
    # https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """
    warmup_epoch = -1
    if epoch <= warmup_epoch:
        lr = 1e-6 + (initial_lr-1e-6) * iteration / (epoch_size * warmup_epoch)
    else:
        lr = initial_lr * (gamma ** (step_index))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

if __name__ == '__main__':
    train()

Printing net...
RetinaFace(
  (body): IntermediateLayerGetter(
    (stem): Sequential(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): ConvNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvNeXtBlock(
            (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=96, out_features=384, bias=True)
              (act): GELU(approximate=none)
              (drop1): Dropout(p=0.0, inplace=False)
              (fc2): Linear(in_features=384, out_features=96, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (drop_path): Identity()
          )
          (1): ConvNeXtBlock(
            (conv_dw): Conv2d(96, 96, kernel_size=(

  0%|                                         | 1/300580 [00:54<4573:41:41, 54.78s/it]

Epoch:31/100 || Epochiter: 1/4294 || Iter: 128821/429400 || Loc: 2.3042 Cla: 2.6143 Landm: 6.4057 || LR: 0.00100000 || Batchtime: 41.2181 s || ETA: 143 days, 9:29:01


  0%|                                         | 2/300580 [00:55<1930:41:00, 23.12s/it]

Epoch:31/100 || Epochiter: 2/4294 || Iter: 128822/429400 || Loc: 1.1234 Cla: 1.6305 Landm: 2.8694 || LR: 0.00100000 || Batchtime: 0.8007 s || ETA: 2 days, 18:51:06


  0%|                                         | 3/300580 [00:56<1091:07:33, 13.07s/it]

Epoch:31/100 || Epochiter: 3/4294 || Iter: 128823/429400 || Loc: 1.2139 Cla: 1.2735 Landm: 4.2410 || LR: 0.00100000 || Batchtime: 0.9391 s || ETA: 3 days, 6:24:43


  0%|                                          | 4/300580 [00:57<696:28:18,  8.34s/it]

Epoch:31/100 || Epochiter: 4/4294 || Iter: 128824/429400 || Loc: 2.1467 Cla: 2.4438 Landm: 7.7295 || LR: 0.00100000 || Batchtime: 0.9332 s || ETA: 3 days, 5:55:02


  0%|                                          | 5/300580 [00:59<478:21:03,  5.73s/it]

Epoch:31/100 || Epochiter: 5/4294 || Iter: 128825/429400 || Loc: 0.9368 Cla: 1.1698 Landm: 2.0097 || LR: 0.00100000 || Batchtime: 0.9358 s || ETA: 3 days, 6:08:06


  0%|                                          | 6/300580 [01:00<347:00:29,  4.16s/it]

Epoch:31/100 || Epochiter: 6/4294 || Iter: 128826/429400 || Loc: 1.7693 Cla: 1.4351 Landm: 2.8791 || LR: 0.00100000 || Batchtime: 0.9398 s || ETA: 3 days, 6:28:12


  0%|                                          | 7/300580 [01:01<263:31:22,  3.16s/it]

Epoch:31/100 || Epochiter: 7/4294 || Iter: 128827/429400 || Loc: 0.3714 Cla: 0.8116 Landm: 1.1563 || LR: 0.00100000 || Batchtime: 0.9348 s || ETA: 3 days, 6:02:50


  0%|                                          | 8/300580 [01:02<208:49:31,  2.50s/it]

Epoch:31/100 || Epochiter: 8/4294 || Iter: 128828/429400 || Loc: 2.4281 Cla: 2.8353 Landm: 14.2419 || LR: 0.00100000 || Batchtime: 0.9368 s || ETA: 3 days, 6:13:08


  0%|                                          | 9/300580 [01:03<172:15:12,  2.06s/it]

Epoch:31/100 || Epochiter: 9/4294 || Iter: 128829/429400 || Loc: 1.8371 Cla: 1.7529 Landm: 5.2811 || LR: 0.00100000 || Batchtime: 0.9374 s || ETA: 3 days, 6:16:08


  0%|                                         | 10/300580 [01:04<147:33:34,  1.77s/it]

Epoch:31/100 || Epochiter: 10/4294 || Iter: 128830/429400 || Loc: 3.2227 Cla: 3.1823 Landm: 9.5236 || LR: 0.00100000 || Batchtime: 0.9423 s || ETA: 3 days, 6:40:38


  0%|                                         | 11/300580 [01:05<130:06:02,  1.56s/it]

Epoch:31/100 || Epochiter: 11/4294 || Iter: 128831/429400 || Loc: 0.6674 Cla: 0.6648 Landm: 3.5038 || LR: 0.00100000 || Batchtime: 0.9216 s || ETA: 3 days, 4:56:58


  0%|                                         | 12/300580 [01:06<117:52:14,  1.41s/it]

Epoch:31/100 || Epochiter: 12/4294 || Iter: 128832/429400 || Loc: 1.4351 Cla: 1.7042 Landm: 3.9756 || LR: 0.00100000 || Batchtime: 0.9139 s || ETA: 3 days, 4:18:04


  0%|                                         | 13/300580 [01:07<109:53:11,  1.32s/it]

Epoch:31/100 || Epochiter: 13/4294 || Iter: 128833/429400 || Loc: 1.7910 Cla: 1.6166 Landm: 6.4922 || LR: 0.00100000 || Batchtime: 0.9350 s || ETA: 3 days, 6:04:00


  0%|                                         | 14/300580 [01:08<104:28:26,  1.25s/it]

Epoch:31/100 || Epochiter: 14/4294 || Iter: 128834/429400 || Loc: 0.9824 Cla: 0.9689 Landm: 0.8449 || LR: 0.00100000 || Batchtime: 0.9411 s || ETA: 3 days, 6:34:38


  0%|                                         | 15/300580 [01:10<100:48:39,  1.21s/it]

Epoch:31/100 || Epochiter: 15/4294 || Iter: 128835/429400 || Loc: 1.0259 Cla: 1.4422 Landm: 3.9362 || LR: 0.00100000 || Batchtime: 0.9421 s || ETA: 3 days, 6:39:15


  0%|                                          | 16/300580 [01:11<98:00:22,  1.17s/it]

Epoch:31/100 || Epochiter: 16/4294 || Iter: 128836/429400 || Loc: 1.7758 Cla: 2.4448 Landm: 5.3454 || LR: 0.00100000 || Batchtime: 0.9357 s || ETA: 3 days, 6:07:21


  0%|                                          | 17/300580 [01:12<96:16:24,  1.15s/it]

Epoch:31/100 || Epochiter: 17/4294 || Iter: 128837/429400 || Loc: 1.8962 Cla: 2.1159 Landm: 11.4943 || LR: 0.00100000 || Batchtime: 0.9402 s || ETA: 3 days, 6:29:41


  0%|                                          | 18/300580 [01:13<94:55:59,  1.14s/it]

Epoch:31/100 || Epochiter: 18/4294 || Iter: 128838/429400 || Loc: 0.7696 Cla: 1.3196 Landm: 6.1080 || LR: 0.00100000 || Batchtime: 0.9362 s || ETA: 3 days, 6:09:54


  0%|                                          | 19/300580 [01:14<93:56:05,  1.13s/it]

Epoch:31/100 || Epochiter: 19/4294 || Iter: 128839/429400 || Loc: 2.4152 Cla: 2.3934 Landm: 7.2540 || LR: 0.00100000 || Batchtime: 0.9348 s || ETA: 3 days, 6:02:59


  0%|                                          | 20/300580 [01:15<93:16:20,  1.12s/it]

Epoch:31/100 || Epochiter: 20/4294 || Iter: 128840/429400 || Loc: 1.6214 Cla: 2.1546 Landm: 6.0218 || LR: 0.00100000 || Batchtime: 0.9352 s || ETA: 3 days, 6:04:52


  0%|                                          | 21/300580 [01:16<92:57:58,  1.11s/it]

Epoch:31/100 || Epochiter: 21/4294 || Iter: 128841/429400 || Loc: 2.1235 Cla: 2.3811 Landm: 1.0993 || LR: 0.00100000 || Batchtime: 0.9439 s || ETA: 3 days, 6:48:12


  0%|                                          | 22/300580 [01:17<92:38:46,  1.11s/it]

Epoch:31/100 || Epochiter: 22/4294 || Iter: 128842/429400 || Loc: 0.4241 Cla: 0.7360 Landm: 0.8467 || LR: 0.00100000 || Batchtime: 0.9385 s || ETA: 3 days, 6:21:26


  0%|                                          | 23/300580 [01:18<92:18:02,  1.11s/it]

Epoch:31/100 || Epochiter: 23/4294 || Iter: 128843/429400 || Loc: 0.4441 Cla: 0.6940 Landm: 3.7414 || LR: 0.00100000 || Batchtime: 0.9352 s || ETA: 3 days, 6:04:29


  0%|                                          | 24/300580 [01:19<92:06:00,  1.10s/it]

Epoch:31/100 || Epochiter: 24/4294 || Iter: 128844/429400 || Loc: 0.9251 Cla: 1.4658 Landm: 3.8233 || LR: 0.00100000 || Batchtime: 0.9357 s || ETA: 3 days, 6:07:26


  0%|                                          | 25/300580 [01:21<92:01:11,  1.10s/it]

Epoch:31/100 || Epochiter: 25/4294 || Iter: 128845/429400 || Loc: 1.4391 Cla: 1.6673 Landm: 7.0280 || LR: 0.00100000 || Batchtime: 0.9384 s || ETA: 3 days, 6:20:41


  0%|                                          | 26/300580 [01:22<91:31:33,  1.10s/it]

Epoch:31/100 || Epochiter: 26/4294 || Iter: 128846/429400 || Loc: 0.6804 Cla: 0.8311 Landm: 2.1367 || LR: 0.00100000 || Batchtime: 0.9193 s || ETA: 3 days, 4:45:03


  0%|                                          | 27/300580 [01:23<91:58:05,  1.10s/it]

Epoch:31/100 || Epochiter: 27/4294 || Iter: 128847/429400 || Loc: 1.4481 Cla: 1.9116 Landm: 2.7483 || LR: 0.00100000 || Batchtime: 0.9503 s || ETA: 3 days, 7:20:11


  0%|                                          | 28/300580 [01:24<91:28:55,  1.10s/it]

Epoch:31/100 || Epochiter: 28/4294 || Iter: 128848/429400 || Loc: 0.7083 Cla: 0.9838 Landm: 2.2742 || LR: 0.00100000 || Batchtime: 0.9205 s || ETA: 3 days, 4:51:13


  0%|                                          | 29/300580 [01:25<91:36:04,  1.10s/it]

Epoch:31/100 || Epochiter: 29/4294 || Iter: 128849/429400 || Loc: 1.2144 Cla: 1.1069 Landm: 2.2815 || LR: 0.00100000 || Batchtime: 0.9380 s || ETA: 3 days, 6:18:22


  0%|                                          | 30/300580 [01:26<91:17:22,  1.09s/it]

Epoch:31/100 || Epochiter: 30/4294 || Iter: 128850/429400 || Loc: 3.7664 Cla: 2.9580 Landm: 13.0194 || LR: 0.00100000 || Batchtime: 0.9235 s || ETA: 3 days, 5:05:46


  0%|                                          | 31/300580 [01:27<91:34:28,  1.10s/it]

Epoch:31/100 || Epochiter: 31/4294 || Iter: 128851/429400 || Loc: 2.7841 Cla: 2.5545 Landm: 8.2937 || LR: 0.00100000 || Batchtime: 0.9411 s || ETA: 3 days, 6:34:03


  0%|                                          | 32/300580 [01:28<91:32:29,  1.10s/it]

Epoch:31/100 || Epochiter: 32/4294 || Iter: 128852/429400 || Loc: 0.5250 Cla: 0.9921 Landm: 2.9204 || LR: 0.00100000 || Batchtime: 0.9344 s || ETA: 3 days, 6:00:23


  0%|                                          | 33/300580 [01:29<91:46:19,  1.10s/it]

Epoch:31/100 || Epochiter: 33/4294 || Iter: 128853/429400 || Loc: 1.5507 Cla: 1.8020 Landm: 5.0400 || LR: 0.00100000 || Batchtime: 0.9451 s || ETA: 3 days, 6:54:06


  0%|                                          | 34/300580 [01:30<91:50:05,  1.10s/it]

Epoch:31/100 || Epochiter: 34/4294 || Iter: 128854/429400 || Loc: 1.5599 Cla: 1.9446 Landm: 5.0650 || LR: 0.00100000 || Batchtime: 0.9377 s || ETA: 3 days, 6:17:15


  0%|                                          | 35/300580 [01:31<91:50:15,  1.10s/it]

Epoch:31/100 || Epochiter: 35/4294 || Iter: 128855/429400 || Loc: 1.5798 Cla: 2.1716 Landm: 8.4387 || LR: 0.00100000 || Batchtime: 0.9384 s || ETA: 3 days, 6:20:43


  0%|                                          | 36/300580 [01:33<91:53:38,  1.10s/it]

Epoch:31/100 || Epochiter: 36/4294 || Iter: 128856/429400 || Loc: 3.3307 Cla: 2.9028 Landm: 11.0731 || LR: 0.00100000 || Batchtime: 0.9389 s || ETA: 3 days, 6:22:51


  0%|                                          | 37/300580 [01:34<91:59:07,  1.10s/it]

Epoch:31/100 || Epochiter: 37/4294 || Iter: 128857/429400 || Loc: 3.3168 Cla: 3.1650 Landm: 15.5418 || LR: 0.00100000 || Batchtime: 0.9410 s || ETA: 3 days, 6:33:26


  0%|                                          | 38/300580 [01:35<92:03:20,  1.10s/it]

Epoch:31/100 || Epochiter: 38/4294 || Iter: 128858/429400 || Loc: 0.7185 Cla: 1.3126 Landm: 2.5378 || LR: 0.00100000 || Batchtime: 0.9420 s || ETA: 3 days, 6:38:33


  0%|                                          | 39/300580 [01:36<92:10:10,  1.10s/it]

Epoch:31/100 || Epochiter: 39/4294 || Iter: 128859/429400 || Loc: 0.5231 Cla: 0.7099 Landm: 4.7102 || LR: 0.00100000 || Batchtime: 0.9437 s || ETA: 3 days, 6:46:51


  0%|                                          | 40/300580 [01:37<92:02:07,  1.10s/it]

Epoch:31/100 || Epochiter: 40/4294 || Iter: 128860/429400 || Loc: 1.3019 Cla: 2.2624 Landm: 6.4614 || LR: 0.00100000 || Batchtime: 0.9372 s || ETA: 3 days, 6:14:17


  0%|                                          | 41/300580 [01:38<92:02:58,  1.10s/it]

Epoch:31/100 || Epochiter: 41/4294 || Iter: 128861/429400 || Loc: 2.0326 Cla: 2.1800 Landm: 10.6178 || LR: 0.00100000 || Batchtime: 0.9427 s || ETA: 3 days, 6:41:46


  0%|                                          | 42/300580 [01:39<91:57:42,  1.10s/it]

Epoch:31/100 || Epochiter: 42/4294 || Iter: 128862/429400 || Loc: 2.5018 Cla: 2.2590 Landm: 7.4482 || LR: 0.00100000 || Batchtime: 0.9387 s || ETA: 3 days, 6:22:06


  0%|                                          | 43/300580 [01:40<92:02:11,  1.10s/it]

Epoch:31/100 || Epochiter: 43/4294 || Iter: 128863/429400 || Loc: 2.8020 Cla: 2.4779 Landm: 12.2708 || LR: 0.00100000 || Batchtime: 0.9442 s || ETA: 3 days, 6:49:24


  0%|                                          | 44/300580 [01:41<92:04:40,  1.10s/it]

Epoch:31/100 || Epochiter: 44/4294 || Iter: 128864/429400 || Loc: 2.7526 Cla: 2.6406 Landm: 12.5519 || LR: 0.00100000 || Batchtime: 0.9429 s || ETA: 3 days, 6:42:43


  0%|                                         | 44/300580 [01:42<194:40:17,  2.33s/it]


KeyboardInterrupt: 

In [None]:
print(net)

In [None]:
fffffffffff

In [None]:
# def get_intermediate_layers(model, layer_names):
#     outputs = {}
#     for name, module in model.named_modules():
#         if name in layer_names:
#             outputs[name] = module
#     return outputs

backbone = timm.models.convnext_tiny(pretrained=True)
# print(backbone)
# print("~~~~~~~~~~~~~~~~~~~~~~~~")
for name, module in backbone.named_children():
    print(name)
 


# # for name, module in backbone.named_modules():
# #     print(name, module)

cfg_convnext = {
    'name': 'convnext_large',
    'min_sizes': [[16, 32], [64, 128], [256, 512]],
    'steps': [8, 16, 32],
    'variance': [0.1, 0.2],
    'clip': False,
    'loc_weight': 2.0,
    'gpu_train': True,
    'batch_size': 6,
    'ngpu': 1,
    'epoch': 100,
    'decay1': 70,
    'decay2': 90,
    'image_size': 840,
    'pretrain': True,
    'return_layers': {'stem': 1, 'stages':2, 'head' : 3},
    'in_channel': 256,
    'out_channel': 256
}

body = _utils.IntermediateLayerGetter(backbone, cfg_convnext['return_layers'])

# 가정: body['head']가 수정하고자 하는 Sequential 객체
head_layers = list(body['head'].children())

# 유지하고 싶은 레이어 이름 목록
keep_layers = ['global_pool', 'norm']

# 해당 레이어만 유지
head_layers = [layer for layer, name in zip(head_layers, keep_layers) if name in keep_layers]

# 다시 Sequential로 변환
body['head'] = nn.Sequential(*head_layers)

print(body)

In [None]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.backends.cudnn.version())

In [None]:
cfg_convnext = {
    'name': 'convnext_large',
    'min_sizes': [[16, 32], [64, 128], [256, 512]],
    'steps': [8, 16, 32],
    'variance': [0.1, 0.2],
    'clip': False,
    'loc_weight': 2.0,
    'gpu_train': True,
    'batch_size': 6,
    'ngpu': 1,
    'epoch': 100,
    'decay1': 70,
    'decay2': 90,
    'image_size': 840,
    'pretrain': True,
    'return_layers': {'stem': 1, 'stages': 2, 'head' : 3},
    'in_channel': 256,
    'out_channel': 256
}

In [None]:
from thop import clever_format, profile

if __name__ == '__main__':
    net = RetinaFace(cfg=cfg_convnext)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # from torchvision.models import efficientnet
    model = net.to(device)

    dummy_input = torch.randn(1, 3, 840, 840).to(device)
    flops, params = profile(model.to(device), (dummy_input,), verbose=False)
    flops = flops * 2
    flops, params = clever_format([flops, params], "%.3f") #초당 실행
    print('Total GFLOPS: %s' % (flops)) #얼마나 빠르게 동작하는
    print('Total params: %s' % (params))

In [None]:
# 재 할당된 메모리 양(바이트 단위)
import torch
mem_bytes = torch.cuda.memory_allocated()

# 기가바이트 단위로 변환
mem_gb = mem_bytes / (1024 ** 3)
print(f"Currently allocated memory: {mem_gb} GB")
