In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import sys
sys.path.append('/content/drive/Shareddrives/EECS 545 Project/5DE/lane-detection-2019-howard')

In [None]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
class UNetFactory(nn.Module):
    """
    本质上就是一个U型的网络，先encode，后decode，中间可能有架bridge。
    其中encoder需要输出skip到decode那边做concatenate，使得decode阶段能补充信息。
    bridge不能存在下采样和上采样的操作。
    """
    def __init__(self, encoder_blocks, decoder_blocks, bridge=None):
        super(UNetFactory, self).__init__()
        self.encoder = UNetEncoder(encoder_blocks)
        self.bridge = bridge
        self.decoder = UNetDecoder(decoder_blocks)

    def forward(self, x):
        res = self.encoder(x)
        out, skips = res[0], res[1:]
        if self.bridge is not None:
            out = self.bridge(out)
        out = self.decoder(out, skips)
        return out

class UNetEncoder(nn.Module):
    """
    encoder会有多次下采样，下采样前的feature map要作为skip缓存起来将来送到decoder用。
    这里约定，以下采样为界线，将encoder分成多个block，其中第一个block无下采样操作，后面的每个block内都
    含有一次下采样操作。
    """
    def __init__(self, blocks):
        super(UNetEncoder, self).__init__()
        assert len(blocks) > 0
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x):
        skips = []
        for i in range(len(self.blocks) - 1):
            x = self.blocks[i](x)
            skips.append(x)
        res = [self.blocks[i+1](x)]
        res += skips
        return res # 只能以这种方式返回多个tensor

class UNetDecoder(nn.Module):
    """
    decoder会有多次上采样，每次上采样后，要跟相应的skip做concatenate。
    这里约定，以上采样为界线，将decoder分成多个block，其中最后一个block无上采样操作，其他block内
    都含有一次上采样。如此一来，除第一个block以外，其他block都先做concatenate。
    """
    def __init__(self, blocks):
        super(UNetDecoder, self).__init__()
        assert len(blocks) > 1
        self.blocks = nn.ModuleList(blocks)
    
    def _center_crop(self, skip, x):
        """
        skip和x，谁比较大，就裁剪谁
        """
        _, _, h1, w1 = skip.shape
        _, _, h2, w2 = x.shape
        ht, wt = min(h1, h2), min(w1, w2)
        dh1 = (h1 - ht) // 2 if h1 > ht else 0
        dw1 = (w1 - wt) // 2 if w1 > wt else 0
        dh2 = (h2 - ht) // 2 if h2 > ht else 0
        dw2 = (w2 - wt) // 2 if w2 > wt else 0
        return skip[:, :, dh1: (dh1 + ht), dw1: (dw1 + wt)], \
                x[:, :, dh2: (dh2 + ht), dw2: (dw2 + wt)]

    def forward(self, x, skips, reverse_skips=True):
        assert len(skips) == len(self.blocks) - 1
        if reverse_skips:
            skips = skips[::-1]
        x = self.blocks[0](x)
        for i in range(1, len(self.blocks)):
            skip, x = self._center_crop(skips[i-1], x)
            x = torch.cat([skip, x], dim=1)
            x = self.blocks[i](x)
        return x

def unet_convs(in_channels, out_channels, padding=0):
    """
    unet论文里出现次数最多的2个conv3x3(non-padding)的结构
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


def unet(in_channels, out_channels):
    """
    构造跟论文一致的unet网络
    https://arxiv.org/abs/1505.04597
    """
    # encoder
    encoder_blocks = [
        # two conv3x3
        unet_convs(in_channels, 64),
        # max pool 2x2, two conv3x3
        nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
            unet_convs(64, 128)
        ),
        # max pool 2x2, two conv3x3
        nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
            unet_convs(128, 256)
        ),
        # max pool 2x2, two conv3x3
        nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
            unet_convs(256, 512)
        ),
        # max pool 2x2
        nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
    ]
    # bridge
    bridge = nn.Sequential(
        # two conv3x3
        unet_convs(512, 1024)
    )
    # decoder
    decoder_blocks = [
        # up-conv2x2
        nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
        # two conv3x3, up-conv2x2
        nn.Sequential(
            unet_convs(1024, 512),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
        ),
        # two conv3x3, up-conv2x2
        nn.Sequential(
            unet_convs(512, 256),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
        ),
        # two conv3x3, up-conv2x2
        nn.Sequential(
            unet_convs(256, 128),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
        ),
        # two conv3x3, conv1x1
        nn.Sequential(
            unet_convs(128, 64),
            nn.Conv2d(64, out_channels, kernel_size=1)
        )
    ]
    return UNetFactory(encoder_blocks, decoder_blocks, bridge)

def unet_resnet(resnet_type, in_channels, out_channels, pretrained=True):
    """
    利用resnet作为encoder，相应地，decoder也做一些改动，使得输出的尺寸跟原始的一致
    """
    if resnet_type == 'resnet18':
        resnet = torchvision.models.resnet.resnet18(pretrained)
        encoder_out_channels = [in_channels, 64, 64, 128, 256, 512]  # encoder各个block的输出channel
    elif resnet_type == 'resnet34':
        resnet = torchvision.models.resnet.resnet34(pretrained)
        encoder_out_channels = [in_channels, 64, 64, 128, 256, 512]
    elif resnet_type == 'resnet50':
        resnet = torchvision.models.resnet.resnet50(pretrained)
        encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
    elif resnet_type == 'resnet101':
        resnet = torchvision.models.resnet.resnet101(pretrained)
        encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
    elif resnet_type == 'resnet152':
        resnet = torchvision.models.resnet.resnet152(pretrained)
        encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
    elif resnet_type == 'resnext50_32x4d':
        resnet = torchvision.models.resnet.resnext50_32x4d(pretrained)
        encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
    else:
        raise ValueError("unexpected resnet_type")

    # encoder
    encoder_blocks = [
        # org input
        nn.Sequential(),
        # conv1
        nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu
        ),
        # conv2_x
        nn.Sequential(
            resnet.maxpool,
            resnet.layer1
        ),
        # conv3_x
        resnet.layer2,
        # conv4_x
        resnet.layer3,
        # conv5_x
        resnet.layer4
    ]
    # bridge
    bridge = None  # 感觉并无必要
    # decoder
    decoder_blocks = []
    in_ch = encoder_out_channels[-1]
    out_ch = in_ch // 2
    decoder_blocks.append(nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)) # up-conv2x2
    for i in range(1, len(encoder_blocks)-1):
        in_ch = encoder_out_channels[-i-1] + out_ch  # cat
        decoder_blocks.append(nn.Sequential(  # two conv3x3, up-conv2x2
            unet_convs(in_ch, out_ch, padding=1),
            nn.ConvTranspose2d(out_ch, out_ch//2, kernel_size=2, stride=2),
        ))
        out_ch = out_ch // 2
    in_ch = encoder_out_channels[0] + out_ch  # cat
    decoder_blocks.append(nn.Sequential(  # two conv3x3, conv1x1
        unet_convs(in_ch, out_ch, padding=1),
        nn.Conv2d(out_ch, out_channels, kernel_size=1)
    ))

    return UNetFactory(encoder_blocks, decoder_blocks, bridge)

In [None]:
def encode_gray_label(labels):
    """
    将标签图的灰度值转换成类别id
    注意：ignoreInEval为True的都当分类0处理
    @param labels: 标签灰度图
    """
    encoded_labels = np.zeros_like(labels)
    # 除了下面特意转换的，其余都属于类别0
    # 1
    encoded_labels[labels == 200] = 1
    encoded_labels[labels == 204] = 1
    encoded_labels[labels == 209] = 1
    # 2
    encoded_labels[labels == 201] = 0
    encoded_labels[labels == 203] = 0
    # 3
    encoded_labels[labels == 217] = 0
    # 4
    encoded_labels[labels == 210] = 0
    # 5
    encoded_labels[labels == 214] = 0
    # 6
    encoded_labels[labels == 220] = 0
    encoded_labels[labels == 221] = 0
    encoded_labels[labels == 222] = 0
    encoded_labels[labels == 224] = 0
    encoded_labels[labels == 225] = 0
    encoded_labels[labels == 226] = 0
    # 7
    encoded_labels[labels == 205] = 0
    encoded_labels[labels == 227] = 0
    encoded_labels[labels == 250] = 0
    return encoded_labels
def train_data_generator(image_list, label_list, batch_size, out_size, height_crop_offset):
    """
    训练数据生成器
    :@param image_list: 图片文件的绝对地址
    :@param label_list: 标签文件的绝对地址
    :@param batch_size: 每批取多少张图片
    :@param image_size: 输出的图片尺寸
    :@param crop_offset: 在高度的方向上，将原始图片截掉多少
    """
    indices = np.arange(0, len(image_list))  # 索引
    out_images = []
    out_labels = []
    out_images_filename = []
    while True:  # 可以无限生成
        np.random.shuffle(indices)
        for i in indices:
            try:
                image = cv2.imread(image_list[i])
                labels = cv2.imread(label_list[i], cv2.IMREAD_GRAYSCALE)
            except:
                continue
            # crop & resize
            image, labels = crop_resize_data(image, labels, out_size, height_crop_offset)
            # encode
            labels = encode_gray_label(labels)

            ## data argumentation here 
            
            out_images.append(image)
            out_labels.append(labels)
            out_images_filename.append(image_list[i])
            if len(out_images) == batch_size:
                out_images = np.array(out_images, dtype=np.float32)
                out_labels = np.array(out_labels, dtype=np.int64)
                # 转换成RGB
                out_images = out_images[:, :, :, ::-1]
                # 维度改成 (n, c, h, w)
                out_images = out_images.transpose(0, 3, 1, 2)
                # 归一化 -1 ~ 1
                out_images = out_images*2/255 - 1
                yield torch.from_numpy(out_images), torch.from_numpy(out_labels).long(), out_images_filename
                out_images = []
                out_labels = []
                out_images_filename = []
def crop_resize_data(image, labels, out_size, height_crop_offset):
    """
    @param out_size: (w, h)
    """
    roi_image = image[height_crop_offset:] # crop
    roi_image = cv2.resize(roi_image, out_size, interpolation=cv2.INTER_LINEAR)  # resize
    if labels is not None:
        roi_label = labels[height_crop_offset:]
        roi_label = cv2.resize(roi_label, out_size, interpolation=cv2.INTER_NEAREST)  # label必须用最近邻来，因为每个像素值是一个分类id
    else:
        roi_label = None
    return roi_image, roi_label


In [None]:
"""
@description: Configure Class 
"""

from os.path import join as pjoin
from os.path import dirname, abspath
import torch

class ConfigTrain(object):
    # 目录
    PROJECT_ROOT = "/content/drive/Shareddrives/EECS 545 Project/5DE/lane-detection-2019-howard"
    DATA_LIST_ROOT = pjoin(PROJECT_ROOT, 'data_list')
    TRAIN_ROOT = "/content/drive/Shareddrives/EECS 545 Project/data"
    IMAGE_ROOT = pjoin(TRAIN_ROOT, 'Image_Data')
    LABEL_ROOT = pjoin(TRAIN_ROOT, 'Gray_Label')
    WEIGHTS_ROOT = pjoin(PROJECT_ROOT, 'binary_classification','weights')
    WEIGHTS_SAVE_ROOT = WEIGHTS_ROOT
    LOG_ROOT = pjoin(PROJECT_ROOT, 'logs')

    # log文件
    # LOG_SUSPICIOUS_FILES = pjoin(LOG_ROOT, 'suspicious_files.log')
    # LOG_SUSPICIOUS_FILES = pjoin(LOG_ROOT, 'suspicious_files_b1.log')
    # LOG_SUSPICIOUS_FILES = pjoin(LOG_ROOT, 'suspicious_files_b4.log')
    LOG_SUSPICIOUS_FILES = pjoin(LOG_ROOT, 'suspicious_files_b2.log')

    # 设备
    DEVICE = 'cuda:0' 
    
    if torch.cuda.is_available():
      print("Using the GPU. You are good to go!")
      DEVICE = 'cuda'
    else:
      print("Using the CPU. Overall speed may be slowed down")
      DEVICE = 'cpu'
    
    # 网络类型
    NET_NAME = 'unet_resnet101'
    #NET_NAME = 'resnext50_32x4d'
    # 网络参数
    NUM_CLASSES = 2  # 8个类别
    IMAGE_SIZE = (768, 256)  # 训练的图片的尺寸(h,w)
    # IMAGE_SIZE = (1024, 384)  # 训练的图片的尺寸(h,w)
    #IMAGE_SIZE = (1536, 512)  # 训练的图片的尺寸(h,w)
    HEIGHT_CROP_OFFSET = 690  # 在height方向上将原图裁掉的offset
    # BATCH_SIZE = 8  # 数据批次大小
    # BATCH_SIZE = 1  # 数据批次大小
    # BATCH_SIZE = 4  # 数据批次大小
    BATCH_SIZE = 8  # 数据批次大小
    EPOCH_NUM = 8  # 总轮次
    PRETRAIN = False # 是否加载预训练的权重
    EPOCH_BEGIN = 0  # 接着前面的epoch训练，默认0，表示从头训练
    PRETRAINED_WEIGHTS = pjoin(WEIGHTS_ROOT, '1024x384_b4_unet_resnext50_32x4d', 'result_6.pt')
    BASE_LR = 0.001  # 学习率
    LR_STRATEGY = [
        [0.001], # epoch 0
        [0.001], # epoch 1
        [0.001], # epoch 2
        [0.001, 0.0006, 0.0003, 0.0001, 0.0004, 0.0008, 0.001], # epoch 3
        [0.001, 0.0006, 0.0003, 0.0001, 0.0004, 0.0008, 0.001], # epoch 4
        [0.001, 0.0006, 0.0003, 0.0001, 0.0004, 0.0008, 0.001], # epoch 5
        [0.0004, 0.0003, 0.0002, 0.0001, 0.0002, 0.0003, 0.0004], # epoch 6
        [0.0004, 0.0003, 0.0002, 0.0001, 0.0002, 0.0003, 0.0004], # epoch 7
    ]
    SUSPICIOUS_RATE = 0.8  # 可疑比例：当某个iteration的miou比当前epoch_miou的可疑比例还要小的时候，记录此次iteration的训练数据索引，人工排查是否数据有问题

## TO DO: Define PROJECT_ROOT##
'''    
class ConfigInference(object):
    # 目录
    PROJECT_ROOT = dirname(abspath(__file__)) 
    DATA_ROOT = pjoin(PROJECT_ROOT, 'data')
    IMAGE_ROOT = pjoin(DATA_ROOT, 'TestImage')
    LABEL_ROOT = pjoin(DATA_ROOT, 'TestLabel')
    OVERLAY_ROOT = pjoin(DATA_ROOT, 'TestOverlay')
    WEIGHTS_ROOT = pjoin(PROJECT_ROOT, 'weights')
    PRETRAINED_WEIGHTS = pjoin(WEIGHTS_ROOT, '1024x384_b4_unet_resnext50_32x4d', 'resnext50_32x4d-7cdf4587.pth')
    LOG_ROOT = pjoin(PROJECT_ROOT, 'logs')

    # 设备
    DEVICE = 'cuda:0'

    # 网络类型
    NET_NAME = 'resnext50_32x4d'

    # 网络参数
    NUM_CLASSES = 8  # 8个类别
    # IMAGE_SIZE = (768, 256)  # 训练的图片的尺寸(h,w)
    IMAGE_SIZE = (1024, 384)  # 训练的图片的尺寸(h,w)
    # IMAGE_SIZE = (1536, 512)  # 训练的图片的尺寸(h,w)
    HEIGHT_CROP_OFFSET = 690  # 在height方向上将原图裁掉的offset
    BATCH_SIZE = 1  # 数据批次大小

    # 原图的大小
    IMAGE_SIZE_ORG = (3384, 1710)
'''    

Using the GPU. You are good to go!


"    \nclass ConfigInference(object):\n    # 目录\n    PROJECT_ROOT = dirname(abspath(__file__)) \n    DATA_ROOT = pjoin(PROJECT_ROOT, 'data')\n    IMAGE_ROOT = pjoin(DATA_ROOT, 'TestImage')\n    LABEL_ROOT = pjoin(DATA_ROOT, 'TestLabel')\n    OVERLAY_ROOT = pjoin(DATA_ROOT, 'TestOverlay')\n    WEIGHTS_ROOT = pjoin(PROJECT_ROOT, 'weights')\n    PRETRAINED_WEIGHTS = pjoin(WEIGHTS_ROOT, '1024x384_b4_unet_resnext50_32x4d', 'resnext50_32x4d-7cdf4587.pth')\n    LOG_ROOT = pjoin(PROJECT_ROOT, 'logs')\n\n    # 设备\n    DEVICE = 'cuda:0'\n\n    # 网络类型\n    NET_NAME = 'resnext50_32x4d'\n\n    # 网络参数\n    NUM_CLASSES = 8  # 8个类别\n    # IMAGE_SIZE = (768, 256)  # 训练的图片的尺寸(h,w)\n    IMAGE_SIZE = (1024, 384)  # 训练的图片的尺寸(h,w)\n    # IMAGE_SIZE = (1536, 512)  # 训练的图片的尺寸(h,w)\n    HEIGHT_CROP_OFFSET = 690  # 在height方向上将原图裁掉的offset\n    BATCH_SIZE = 1  # 数据批次大小\n\n    # 原图的大小\n    IMAGE_SIZE_ORG = (3384, 1710)\n"

In [None]:
'''
Define loss
'''
class MySoftmaxCrossEntropyLoss(nn.Module):

    def __init__(self, nbclasses):
        super(MySoftmaxCrossEntropyLoss, self).__init__()
        self.nbclasses = nbclasses

    def forward(self, inputs, target):
        if inputs.dim() > 2:
            inputs = inputs.view(inputs.size(0), inputs.size(1), -1)  # N,C,H,W => N,C,H*W
            inputs = inputs.transpose(1, 2)  # N,C,H*W => N,H*W,C
            inputs = inputs.contiguous().view(-1, self.nbclasses)  # N,H*W,C => N*H*W,C
        target = target.view(-1)
        return nn.CrossEntropyLoss(reduction="mean")(inputs, target)


class FocalLoss(nn.Module):

    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.alpha = torch.tensor([alpha, 1 - alpha])
        self.size_average = size_average

    def forward(self, inputs, target):
        if inputs.dim() > 2:
            inputs = inputs
            inputs = inputs.view(inputs.size(0), inputs.size(1), -1)  # N,C,H,W => N,C,H*W
            inputs = inputs.transpose(1, 2)  # N,C,H*W => N,H*W,C
            inputs = inputs.contiguous().view(-1, inputs.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(inputs,dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.exp()

        if self.alpha is not None:
            if self.alpha.type() != inputs.data.type():
                self.alpha = self.alpha.type_as(inputs.data)
            at = self.alpha.gather(0, target.view(-1))
            logpt = logpt * at
        # mask = mask.view(-1)
        loss = -1 * (1 - pt) ** self.gamma * logpt #* mask
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()


def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, input.cpu(), 1)

    return result


class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict log文件
    # LOG_SUSPICIOUS_FILES = pjoin(LOG_ROOT, 'suspicious_files.log')
    # LOG_SUSPICIOUS_FILES = pjoin(LOG_ROOT, 'suspicious_files_b1.log')
    # LOG_SUSPICIOUS_FILES = pjoin(LOG_ROOT, 'suspicious_files_b4.log')
    LOG_SUSPICIOUS_FILES = pjoin(LOG_ROOT, 'suspicious_files_b2.log')
        reduction: Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """
    def __init__(self, smooth=1, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1)
        num = 2*torch.sum(torch.mul(predict, target), dim=1) + self.smooth
        den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth

        loss = 1 - num / den

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))


class DiceLoss(nn.Module):
    """Dice loss, need one hot encode input
    Args:
        weight: An array of shape [num_classes,]
        ignore_index: class index to ignore
        predict: A tensor of shape [N, C, *]
        target: A tensor of same shape with predict
        other args pass to BinaryDiceLoss
    Return:
        same as BinaryDiceLoss
    """
    def __init__(self, weight=None, ignore_index=None, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        assert predict.shape == target.shape, 'predict & target shape do not match'
        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        predict = F.softmax(predict, dim=1)

        for i in range(target.shape[1]):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[:, i])
                if self.weight is not None:
                    assert self.weight.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += dice_loss

        return total_loss/target.shape[1]



In [None]:

"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""

from __future__ import print_function, division

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse as ifilterfalse


def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard


def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
    """
    IoU for foreground class
    binary: 1 foreground, 0 background
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        intersection = ((label == 1) & (pred == 1)).sum()
        union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
        if not union:
            iou = EMPTY
        else:
            iou = float(intersection) / float(union)
        ious.append(iou)
    iou = mean(ious)    # mean accross images if per_image
    return 100 * iou


def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
    """
    Array of IoU for each (non ignored) class
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        iou = []    
        for i in range(C):
            if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
                intersection = ((label == i) & (pred == i)).sum()
                union = ((label == i) | ((pred == i) & (label != ignore))).sum()
                if not union:
                    iou.append(EMPTY)
                else:
                    iou.append(float(intersection) / float(union))
        ious.append(iou)
    ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
    return 100 * np.array(ious)


# --------------------------- BINARY LOSSES ---------------------------


def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
                          for log, lab in zip(logits, labels))
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = scores.view(-1)
    labels = labels.view(-1)
    if ignore is None:
        return scores, labels
    valid = (labels != ignore)
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels


class StableBCELoss(torch.nn.modules.Module):
    def __init__(self):
         super(StableBCELoss, self).__init__()
    def forward(self, input, target):
         neg_abs = - input.abs()
         loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
         return loss.mean()


def binary_xloss(logits, labels, ignore=None):
    """
    Binary Cross entropy loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      ignore: void class id
    """
    logits, labels = flatten_binary_scores(logits, labels, ignore)
    loss = StableBCELoss()(logits, Variable(labels.float()))
    return loss


# --------------------------- MULTICLASS LOSSES ---------------------------


def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
                          for prob, lab in zip(probas, labels))
    else:
        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
    return loss


def lovasz_softmax_flat(probas, labels, classes='present'):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    """
    if probas.numel() == 0:
        # only void pixels, the gradients should be 0
        return probas * 0.
    C = probas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
    for c in class_to_sum:
        fg = (labels == c).float() # foreground for class c
        if (classes is 'present' and fg.sum() == 0):
            continue
        if C == 1:
            if len(classes) > 1:
                raise ValueError('Sigmoid output possible only with 1 class')
            class_pred = probas[:, 0]
        else:
            class_pred = probas[:, c]
        errors = (Variable(fg) - class_pred).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    if probas.dim() == 3:
        # assumes output of a sigmoid layer
        B, H, W = probas.size()
        probas = probas.view(B, 1, H, W)
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = (labels != ignore)
    vprobas = probas[valid.nonzero().squeeze()]
    vlabels = labels[valid]
    return vprobas, vlabels

def xloss(logits, labels, ignore=None):
    """
    Cross entropy loss
    """
    return F.cross_entropy(logits, Variable(labels), ignore_index=255)


# --------------------------- HELPER FUNCTIONS ---------------------------
def isnan(x):
    return x != x
    
    
def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

In [None]:
'''
util files
'''
def create_net(in_channels, out_channels, net_name='unet'):
    """
    创建网络
    :param in_channels: 输入通道数
    :param out_channels: 输出通道数
    # :param net_name: 网络类型，可选 unet | unet_resnet18/34/50/101/152 |unet_resnext50_32x4d | deeplabv3p
    :param net_name: 网络类型，可选 unet | unet_resnet34
    """
    if net_name == 'unet':
        net = unet(in_channels, out_channels)
    elif net_name == 'unet_resnet34':
        net = unet_resnet('resnet34', in_channels, out_channels)
    elif net_name == 'unet_resnet50':
        net = unet_resnet('resnet50', in_channels, out_channels)
    elif net_name == 'unet_resnet101':
        net = unet_resnet('resnet101', in_channels, out_channels)    
    elif net_name == 'resnext50_32x4d':
        net = unet_resnet('resnext50_32x4d', in_channels, out_channels)
    else:
        raise ValueError('Not supported net_name: {}'.format(net_name))

    return net

def create_loss(predicts: torch.Tensor, labels: torch.Tensor, num_classes):
    """
    创建loss
    @param predicts: shape=(n, c, h, w)
    @param labels: shape=(n, h, w) or shape=(n, 1, h, w)
    @param num_classes: int should equal to channels of predicts
    @return: loss, mean_iou
    """
    # permute to (n, h, w, c)
    predicts = predicts.permute((0, 2, 3, 1))
    # reshape to (-1, num_classes)  每个像素在每种分类上都有一个概率
    predicts = predicts.reshape((-1, num_classes))
    ##print(predicts.shape)
    ##print(labels.flatten().shape)
    # BCE with DICE
    bce_loss = F.cross_entropy(predicts, labels.flatten(), reduction='mean')  # 函数内会自动做softmax
    
    # 将labels做one_hot处理，得到的形状跟predicts相同
    labels_one_hot = utils.make_one_hot(labels.reshape((-1, 1)), num_classes)
    dice_loss = utils.DiceLoss()(predicts, labels_one_hot.to(labels.device))  # torch没有原生的，从老师给的代码里拿过来用
    #loss = bce_loss + dice_loss
    loss = bce_loss
    ious = compute_iou(predicts, labels.reshape((-1, 1)), num_classes)
    return loss, torch.mean(ious)

def compute_iou(predicts, labels, num_classes):
    """
    计算iou
    @param predicts: shape=(-1, classes)
    @param labels: shape=(-1, 1)
    """
    ious = torch.zeros(num_classes)
    predicts = F.softmax(predicts, dim=1)
    predicts = torch.argmax(predicts, dim=1, keepdim=True)
    for i in range(num_classes):
        intersect = torch.sum((predicts == i) * (labels == i))
        area = torch.sum(predicts == i) + torch.sum(labels == i) - intersect
        ious[i] = intersect / (area + 1e-6)
    return ious


In [None]:
"""
@description: 执行训练
"""


"""
import
"""
#from config import ConfigTrain
import utils
from os.path import join as pjoin
import pandas as pd
import numpy as np
import cv2
import torch
import time

"""
main
"""
from tqdm import tqdm
if __name__ == '__main__':
    cfg = ConfigTrain()
    print('Pick device: ', cfg.DEVICE)
    device = torch.device(cfg.DEVICE)

    # 网络
    print('Generating net: ', cfg.NET_NAME)
    net = create_net(3, cfg.NUM_CLASSES, net_name=cfg.NET_NAME)
    if cfg.PRETRAIN:  # 加载预训练权重
        print('Load pretrain weights: ', cfg.PRETRAINED_WEIGHTS)
        net.load_state_dict(torch.load(cfg.PRETRAINED_WEIGHTS, map_location='cpu'))
    net.to(device)
    # 优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=cfg.BASE_LR) 

    # 训练数据生成器
    print('Preparing data... batch_size: {}, image_size: {}, crop_offset: {}'.format(cfg.BATCH_SIZE, cfg.IMAGE_SIZE, cfg.HEIGHT_CROP_OFFSET))
    df_train = pd.read_csv(pjoin(cfg.DATA_LIST_ROOT, 'train_split.csv'))
    data_generator = train_data_generator(np.array(df_train['image']),
                                                np.array(df_train['label']),
                                                cfg.BATCH_SIZE, cfg.IMAGE_SIZE, cfg.HEIGHT_CROP_OFFSET)

    # 训练
    print('Let us train ...')
    log_iters = 1  # 多少次迭代打印一次log
    epoch_size = int(len(df_train) / cfg.BATCH_SIZE)  # 一个轮次包含的迭代次数
    ##trn_loss_hist = []
    ##iou_hist = []
    loss_plot = []
    iou_plot = []
    for epoch in range(cfg.EPOCH_BEGIN, cfg.EPOCH_NUM):
        epoch_loss = 0.0
        epoch_miou = 0.0
        last_epoch_miou = 0.0
        prev_time = time.time()
        for iteration in tqdm(range(1 , epoch_size + 1)):
            images, labels, images_filename = next(data_generator)
            images = images.to(device)
            labels = labels.to(device)

            lr = utils.ajust_learning_rate(optimizer, cfg.LR_STRATEGY, epoch, iteration-1, epoch_size)

            predicts = net(images)

            optimizer.zero_grad()

            # create loss
            cross_loss, mean_iou = utils.create_loss(predicts, labels, cfg.NUM_CLASSES)
            #iou = utils.iou(predicts, labels, 3,ignore=255, per_image=True)
            predicts =  torch.nn.functional.softmax(predicts,dim=1)
            #f_loss = focal_loss(predicts,labels)
            
            loss_lovasz_softmax = utils.lovasz_softmax(predicts, labels)
            loss = cross_loss + loss_lovasz_softmax

            epoch_loss += loss.item()
            epoch_miou += mean_iou.item()

            print("[Epoch-%d Iter-%d] LR: %.4f: iter loss: %.3f, iter iou: %.3f, epoch loss: %.3f, epoch iou: %.3f,  time cost: %.3f s"
                % (epoch, iteration, lr, loss.item(), mean_iou.item(), epoch_loss / iteration, epoch_miou / iteration, time.time() - prev_time))
            prev_time = time.time()

            # if mean_iou.item() < last_epoch_miou * cfg.SUSPICIOUS_RATE:
            #   ## TO DO: define log file or create a log file##
            #     with open(cfg.LOG_SUSPICIOUS_FILES, 'a+') as f:
            #         for filename in images_filename:
            #             f.write("{}\n".format(filename))
            #         f.flush()

            # last_epoch_miou = epoch_miou / iteration
            
            loss.backward()
            loss_plot.append(loss.item())
            iou_plot.append(mean_iou.item())
            optimizer.step()

        torch.save(net.state_dict(), 
                    pjoin(cfg.WEIGHTS_SAVE_ROOT, "weights_ep_%d_%.3f_%.3f.pth" 
                            % (epoch, epoch_loss / epoch_size, epoch_miou / epoch_size)))
    



Pick device:  cuda
Generating net:  unet_resnet101
Preparing data... batch_size: 4, image_size: (768, 256), crop_offset: 690
Let us train ...


  0%|          | 0/1874 [00:00<?, ?it/s]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-1] LR: 0.0010: iter loss: 1.119, iter iou: 0.402, epoch loss: 1.119, epoch iou: 0.402,  time cost: 8.679 s


  0%|          | 1/1874 [00:10<5:36:13, 10.77s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-2] LR: 0.0010: iter loss: 1.169, iter iou: 0.353, epoch loss: 1.144, epoch iou: 0.378,  time cost: 9.140 s


  0%|          | 2/1874 [00:19<5:03:34,  9.73s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-3] LR: 0.0010: iter loss: 1.006, iter iou: 0.460, epoch loss: 1.098, epoch iou: 0.405,  time cost: 9.001 s


  0%|          | 3/1874 [00:28<4:53:07,  9.40s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-4] LR: 0.0010: iter loss: 0.970, iter iou: 0.478, epoch loss: 1.066, epoch iou: 0.423,  time cost: 16.074 s


  0%|          | 4/1874 [00:44<6:15:03, 12.03s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-5] LR: 0.0010: iter loss: 0.902, iter iou: 0.499, epoch loss: 1.033, epoch iou: 0.438,  time cost: 11.890 s


  0%|          | 5/1874 [00:56<6:13:13, 11.98s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-6] LR: 0.0010: iter loss: 0.823, iter iou: 0.623, epoch loss: 0.998, epoch iou: 0.469,  time cost: 11.518 s


  0%|          | 6/1874 [01:08<6:08:11, 11.83s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-7] LR: 0.0010: iter loss: 0.808, iter iou: 0.536, epoch loss: 0.971, epoch iou: 0.479,  time cost: 10.701 s


  0%|          | 7/1874 [01:18<5:56:34, 11.46s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-8] LR: 0.0010: iter loss: 0.787, iter iou: 0.529, epoch loss: 0.948, epoch iou: 0.485,  time cost: 10.317 s


  0%|          | 8/1874 [01:29<5:45:04, 11.10s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-9] LR: 0.0010: iter loss: 0.723, iter iou: 0.591, epoch loss: 0.923, epoch iou: 0.497,  time cost: 12.981 s


  0%|          | 9/1874 [01:42<6:03:15, 11.69s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-10] LR: 0.0010: iter loss: 0.722, iter iou: 0.496, epoch loss: 0.903, epoch iou: 0.497,  time cost: 13.317 s


  1%|          | 10/1874 [01:55<6:18:46, 12.19s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-11] LR: 0.0010: iter loss: 0.669, iter iou: 0.620, epoch loss: 0.882, epoch iou: 0.508,  time cost: 11.333 s


  1%|          | 11/1874 [02:06<6:10:22, 11.93s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-12] LR: 0.0010: iter loss: 0.646, iter iou: 0.653, epoch loss: 0.862, epoch iou: 0.520,  time cost: 9.384 s


  1%|          | 12/1874 [02:16<5:46:09, 11.15s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-13] LR: 0.0010: iter loss: 0.632, iter iou: 0.650, epoch loss: 0.844, epoch iou: 0.530,  time cost: 11.089 s


  1%|          | 13/1874 [02:27<5:45:20, 11.13s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-14] LR: 0.0010: iter loss: 0.597, iter iou: 0.621, epoch loss: 0.827, epoch iou: 0.537,  time cost: 10.533 s


  1%|          | 14/1874 [02:37<5:39:30, 10.95s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-15] LR: 0.0010: iter loss: 0.620, iter iou: 0.553, epoch loss: 0.813, epoch iou: 0.538,  time cost: 12.143 s


  1%|          | 15/1874 [02:50<5:50:31, 11.31s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-16] LR: 0.0010: iter loss: 0.594, iter iou: 0.608, epoch loss: 0.799, epoch iou: 0.542,  time cost: 10.393 s


  1%|          | 16/1874 [03:00<5:41:43, 11.04s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-17] LR: 0.0010: iter loss: 0.561, iter iou: 0.657, epoch loss: 0.785, epoch iou: 0.549,  time cost: 11.494 s


  1%|          | 17/1874 [03:11<5:45:45, 11.17s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-18] LR: 0.0010: iter loss: 0.562, iter iou: 0.676, epoch loss: 0.773, epoch iou: 0.556,  time cost: 13.053 s


  1%|          | 18/1874 [03:25<6:03:07, 11.74s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-19] LR: 0.0010: iter loss: 0.618, iter iou: 0.530, epoch loss: 0.765, epoch iou: 0.554,  time cost: 11.032 s


  1%|          | 19/1874 [03:36<5:56:22, 11.53s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-20] LR: 0.0010: iter loss: 0.527, iter iou: 0.678, epoch loss: 0.753, epoch iou: 0.561,  time cost: 10.273 s


  1%|          | 20/1874 [03:46<5:44:28, 11.15s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-21] LR: 0.0010: iter loss: 0.581, iter iou: 0.542, epoch loss: 0.745, epoch iou: 0.560,  time cost: 9.168 s


  1%|          | 21/1874 [03:55<5:26:02, 10.56s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-22] LR: 0.0010: iter loss: 0.561, iter iou: 0.563, epoch loss: 0.736, epoch iou: 0.560,  time cost: 11.336 s


  1%|          | 22/1874 [04:06<5:33:04, 10.79s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-23] LR: 0.0010: iter loss: 0.513, iter iou: 0.634, epoch loss: 0.726, epoch iou: 0.563,  time cost: 10.871 s


  1%|          | 23/1874 [04:17<5:33:41, 10.82s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-24] LR: 0.0010: iter loss: 0.517, iter iou: 0.582, epoch loss: 0.718, epoch iou: 0.564,  time cost: 11.501 s


  1%|▏         | 24/1874 [04:29<5:39:47, 11.02s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-25] LR: 0.0010: iter loss: 0.503, iter iou: 0.559, epoch loss: 0.709, epoch iou: 0.564,  time cost: 11.348 s


  1%|▏         | 25/1874 [04:40<5:42:37, 11.12s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-26] LR: 0.0010: iter loss: 0.508, iter iou: 0.542, epoch loss: 0.701, epoch iou: 0.563,  time cost: 9.948 s


  1%|▏         | 26/1874 [04:50<5:31:36, 10.77s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-27] LR: 0.0010: iter loss: 0.507, iter iou: 0.631, epoch loss: 0.694, epoch iou: 0.565,  time cost: 10.828 s


  1%|▏         | 27/1874 [05:01<5:32:03, 10.79s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-28] LR: 0.0010: iter loss: 0.472, iter iou: 0.709, epoch loss: 0.686, epoch iou: 0.570,  time cost: 9.735 s


  1%|▏         | 28/1874 [05:11<5:22:10, 10.47s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-29] LR: 0.0010: iter loss: 0.466, iter iou: 0.632, epoch loss: 0.679, epoch iou: 0.573,  time cost: 9.566 s


  2%|▏         | 29/1874 [05:20<5:13:39, 10.20s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-30] LR: 0.0010: iter loss: 0.470, iter iou: 0.619, epoch loss: 0.672, epoch iou: 0.574,  time cost: 10.514 s


  2%|▏         | 30/1874 [05:31<5:16:18, 10.29s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-31] LR: 0.0010: iter loss: 0.453, iter iou: 0.716, epoch loss: 0.665, epoch iou: 0.579,  time cost: 9.744 s


  2%|▏         | 31/1874 [05:40<5:11:04, 10.13s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-32] LR: 0.0010: iter loss: 0.464, iter iou: 0.658, epoch loss: 0.658, epoch iou: 0.581,  time cost: 11.000 s


  2%|▏         | 32/1874 [05:51<5:18:57, 10.39s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-33] LR: 0.0010: iter loss: 0.482, iter iou: 0.570, epoch loss: 0.653, epoch iou: 0.581,  time cost: 12.034 s


  2%|▏         | 33/1874 [06:03<5:33:54, 10.88s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-34] LR: 0.0010: iter loss: 0.461, iter iou: 0.658, epoch loss: 0.647, epoch iou: 0.583,  time cost: 12.580 s


  2%|▏         | 34/1874 [06:16<5:49:24, 11.39s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-35] LR: 0.0010: iter loss: 0.432, iter iou: 0.655, epoch loss: 0.641, epoch iou: 0.585,  time cost: 10.151 s


  2%|▏         | 35/1874 [06:26<5:37:41, 11.02s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-36] LR: 0.0010: iter loss: 0.460, iter iou: 0.612, epoch loss: 0.636, epoch iou: 0.586,  time cost: 10.266 s


  2%|▏         | 36/1874 [06:36<5:30:39, 10.79s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-37] LR: 0.0010: iter loss: 0.441, iter iou: 0.647, epoch loss: 0.631, epoch iou: 0.588,  time cost: 10.308 s


  2%|▏         | 37/1874 [06:47<5:26:08, 10.65s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-38] LR: 0.0010: iter loss: 0.427, iter iou: 0.703, epoch loss: 0.626, epoch iou: 0.591,  time cost: 9.048 s


  2%|▏         | 38/1874 [06:56<5:11:06, 10.17s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-39] LR: 0.0010: iter loss: 0.431, iter iou: 0.645, epoch loss: 0.621, epoch iou: 0.592,  time cost: 8.950 s


  2%|▏         | 39/1874 [07:05<4:59:51,  9.80s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-40] LR: 0.0010: iter loss: 0.420, iter iou: 0.685, epoch loss: 0.616, epoch iou: 0.594,  time cost: 7.618 s


  2%|▏         | 40/1874 [07:12<4:39:35,  9.15s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-41] LR: 0.0010: iter loss: 0.417, iter iou: 0.672, epoch loss: 0.611, epoch iou: 0.596,  time cost: 10.607 s


  2%|▏         | 41/1874 [07:23<4:52:47,  9.58s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-42] LR: 0.0010: iter loss: 0.395, iter iou: 0.746, epoch loss: 0.606, epoch iou: 0.600,  time cost: 10.381 s


  2%|▏         | 42/1874 [07:33<4:59:54,  9.82s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-43] LR: 0.0010: iter loss: 0.392, iter iou: 0.731, epoch loss: 0.601, epoch iou: 0.603,  time cost: 12.842 s


  2%|▏         | 43/1874 [07:46<5:27:31, 10.73s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-44] LR: 0.0010: iter loss: 0.394, iter iou: 0.684, epoch loss: 0.596, epoch iou: 0.605,  time cost: 9.892 s


  2%|▏         | 44/1874 [07:56<5:19:35, 10.48s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-45] LR: 0.0010: iter loss: 0.420, iter iou: 0.637, epoch loss: 0.592, epoch iou: 0.605,  time cost: 10.424 s


  2%|▏         | 45/1874 [08:06<5:18:52, 10.46s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-46] LR: 0.0010: iter loss: 0.423, iter iou: 0.672, epoch loss: 0.588, epoch iou: 0.607,  time cost: 9.987 s


  2%|▏         | 46/1874 [08:16<5:14:28, 10.32s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-47] LR: 0.0010: iter loss: 0.403, iter iou: 0.672, epoch loss: 0.584, epoch iou: 0.608,  time cost: 9.604 s


  3%|▎         | 47/1874 [08:26<5:07:43, 10.11s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-48] LR: 0.0010: iter loss: 0.399, iter iou: 0.687, epoch loss: 0.581, epoch iou: 0.610,  time cost: 10.685 s


  3%|▎         | 48/1874 [08:37<5:12:46, 10.28s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-49] LR: 0.0010: iter loss: 0.386, iter iou: 0.690, epoch loss: 0.577, epoch iou: 0.612,  time cost: 9.167 s


  3%|▎         | 49/1874 [08:46<5:02:32,  9.95s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-50] LR: 0.0010: iter loss: 0.422, iter iou: 0.676, epoch loss: 0.573, epoch iou: 0.613,  time cost: 11.875 s


  3%|▎         | 50/1874 [08:58<5:19:51, 10.52s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-51] LR: 0.0010: iter loss: 0.393, iter iou: 0.698, epoch loss: 0.570, epoch iou: 0.614,  time cost: 10.083 s


  3%|▎         | 51/1874 [09:08<5:15:43, 10.39s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-52] LR: 0.0010: iter loss: 0.391, iter iou: 0.648, epoch loss: 0.567, epoch iou: 0.615,  time cost: 11.438 s


  3%|▎         | 52/1874 [09:19<5:25:06, 10.71s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-53] LR: 0.0010: iter loss: 0.390, iter iou: 0.674, epoch loss: 0.563, epoch iou: 0.616,  time cost: 9.802 s


  3%|▎         | 53/1874 [09:29<5:16:41, 10.43s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-54] LR: 0.0010: iter loss: 0.372, iter iou: 0.733, epoch loss: 0.560, epoch iou: 0.618,  time cost: 9.508 s


  3%|▎         | 54/1874 [09:39<5:08:09, 10.16s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-55] LR: 0.0010: iter loss: 0.378, iter iou: 0.706, epoch loss: 0.556, epoch iou: 0.620,  time cost: 12.267 s


  3%|▎         | 55/1874 [09:51<5:27:05, 10.79s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-56] LR: 0.0010: iter loss: 0.401, iter iou: 0.666, epoch loss: 0.554, epoch iou: 0.621,  time cost: 9.459 s


  3%|▎         | 56/1874 [10:00<5:14:50, 10.39s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-57] LR: 0.0010: iter loss: 0.367, iter iou: 0.720, epoch loss: 0.550, epoch iou: 0.623,  time cost: 10.123 s


  3%|▎         | 57/1874 [10:10<5:12:14, 10.31s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-58] LR: 0.0010: iter loss: 0.354, iter iou: 0.775, epoch loss: 0.547, epoch iou: 0.625,  time cost: 9.561 s


  3%|▎         | 58/1874 [10:20<5:05:14, 10.09s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-59] LR: 0.0010: iter loss: 0.361, iter iou: 0.712, epoch loss: 0.544, epoch iou: 0.627,  time cost: 9.063 s


  3%|▎         | 59/1874 [10:29<4:55:50,  9.78s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-60] LR: 0.0010: iter loss: 0.378, iter iou: 0.700, epoch loss: 0.541, epoch iou: 0.628,  time cost: 7.601 s


  3%|▎         | 60/1874 [10:37<4:35:52,  9.13s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-61] LR: 0.0010: iter loss: 0.375, iter iou: 0.662, epoch loss: 0.538, epoch iou: 0.628,  time cost: 8.934 s


  3%|▎         | 61/1874 [10:46<4:34:01,  9.07s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-62] LR: 0.0010: iter loss: 0.414, iter iou: 0.615, epoch loss: 0.536, epoch iou: 0.628,  time cost: 9.087 s


  3%|▎         | 62/1874 [10:55<4:34:02,  9.07s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-63] LR: 0.0010: iter loss: 0.379, iter iou: 0.676, epoch loss: 0.534, epoch iou: 0.629,  time cost: 11.644 s


  3%|▎         | 63/1874 [11:06<4:57:05,  9.84s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-64] LR: 0.0010: iter loss: 0.373, iter iou: 0.688, epoch loss: 0.531, epoch iou: 0.630,  time cost: 12.493 s


  3%|▎         | 64/1874 [11:19<5:20:56, 10.64s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-65] LR: 0.0010: iter loss: 0.386, iter iou: 0.643, epoch loss: 0.529, epoch iou: 0.630,  time cost: 10.394 s


  3%|▎         | 65/1874 [11:29<5:18:34, 10.57s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-66] LR: 0.0010: iter loss: 0.355, iter iou: 0.650, epoch loss: 0.526, epoch iou: 0.630,  time cost: 11.109 s


  4%|▎         | 66/1874 [11:40<5:23:19, 10.73s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-67] LR: 0.0010: iter loss: 0.362, iter iou: 0.649, epoch loss: 0.524, epoch iou: 0.631,  time cost: 9.368 s


  4%|▎         | 67/1874 [11:50<5:10:48, 10.32s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-68] LR: 0.0010: iter loss: 0.362, iter iou: 0.753, epoch loss: 0.522, epoch iou: 0.632,  time cost: 9.131 s


  4%|▎         | 68/1874 [11:59<4:59:52,  9.96s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-69] LR: 0.0010: iter loss: 0.360, iter iou: 0.702, epoch loss: 0.519, epoch iou: 0.633,  time cost: 11.294 s


  4%|▎         | 69/1874 [12:10<5:11:43, 10.36s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-70] LR: 0.0010: iter loss: 0.357, iter iou: 0.680, epoch loss: 0.517, epoch iou: 0.634,  time cost: 9.371 s


  4%|▎         | 70/1874 [12:20<5:02:42, 10.07s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-71] LR: 0.0010: iter loss: 0.385, iter iou: 0.636, epoch loss: 0.515, epoch iou: 0.634,  time cost: 10.427 s


  4%|▍         | 71/1874 [12:30<5:05:44, 10.17s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-72] LR: 0.0010: iter loss: 0.370, iter iou: 0.655, epoch loss: 0.513, epoch iou: 0.634,  time cost: 10.696 s


  4%|▍         | 72/1874 [12:41<5:10:14, 10.33s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-73] LR: 0.0010: iter loss: 0.333, iter iou: 0.716, epoch loss: 0.511, epoch iou: 0.636,  time cost: 12.752 s


  4%|▍         | 73/1874 [12:53<5:31:51, 11.06s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-74] LR: 0.0010: iter loss: 0.406, iter iou: 0.627, epoch loss: 0.509, epoch iou: 0.635,  time cost: 11.292 s


  4%|▍         | 74/1874 [13:05<5:33:50, 11.13s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-75] LR: 0.0010: iter loss: 0.357, iter iou: 0.740, epoch loss: 0.507, epoch iou: 0.637,  time cost: 10.273 s


  4%|▍         | 75/1874 [13:15<5:25:59, 10.87s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-76] LR: 0.0010: iter loss: 0.345, iter iou: 0.755, epoch loss: 0.505, epoch iou: 0.638,  time cost: 10.561 s


  4%|▍         | 76/1874 [13:26<5:22:56, 10.78s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-77] LR: 0.0010: iter loss: 0.341, iter iou: 0.752, epoch loss: 0.503, epoch iou: 0.640,  time cost: 10.749 s


  4%|▍         | 77/1874 [13:36<5:22:30, 10.77s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-78] LR: 0.0010: iter loss: 0.379, iter iou: 0.667, epoch loss: 0.501, epoch iou: 0.640,  time cost: 11.119 s


  4%|▍         | 78/1874 [13:47<5:25:32, 10.88s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-79] LR: 0.0010: iter loss: 0.338, iter iou: 0.763, epoch loss: 0.499, epoch iou: 0.642,  time cost: 9.936 s


  4%|▍         | 79/1874 [13:57<5:16:52, 10.59s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-80] LR: 0.0010: iter loss: 0.338, iter iou: 0.762, epoch loss: 0.497, epoch iou: 0.643,  time cost: 9.146 s


  4%|▍         | 80/1874 [14:07<5:03:50, 10.16s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-81] LR: 0.0010: iter loss: 0.323, iter iou: 0.765, epoch loss: 0.495, epoch iou: 0.645,  time cost: 9.452 s


  4%|▍         | 81/1874 [14:16<4:57:10,  9.94s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-82] LR: 0.0010: iter loss: 0.371, iter iou: 0.629, epoch loss: 0.493, epoch iou: 0.645,  time cost: 10.178 s


  4%|▍         | 82/1874 [14:26<4:59:13, 10.02s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-83] LR: 0.0010: iter loss: 0.342, iter iou: 0.762, epoch loss: 0.492, epoch iou: 0.646,  time cost: 7.677 s


  4%|▍         | 83/1874 [14:34<4:38:01,  9.31s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-84] LR: 0.0010: iter loss: 0.330, iter iou: 0.716, epoch loss: 0.490, epoch iou: 0.647,  time cost: 8.674 s


  4%|▍         | 84/1874 [14:42<4:32:10,  9.12s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-85] LR: 0.0010: iter loss: 0.333, iter iou: 0.766, epoch loss: 0.488, epoch iou: 0.648,  time cost: 10.419 s


  5%|▍         | 85/1874 [14:53<4:43:41,  9.51s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-86] LR: 0.0010: iter loss: 0.359, iter iou: 0.694, epoch loss: 0.486, epoch iou: 0.649,  time cost: 10.339 s


  5%|▍         | 86/1874 [15:03<4:50:52,  9.76s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-87] LR: 0.0010: iter loss: 0.375, iter iou: 0.649, epoch loss: 0.485, epoch iou: 0.649,  time cost: 10.707 s


  5%|▍         | 87/1874 [15:14<4:59:08, 10.04s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-88] LR: 0.0010: iter loss: 0.348, iter iou: 0.716, epoch loss: 0.484, epoch iou: 0.650,  time cost: 9.129 s


  5%|▍         | 88/1874 [15:23<4:50:45,  9.77s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-89] LR: 0.0010: iter loss: 0.324, iter iou: 0.767, epoch loss: 0.482, epoch iou: 0.651,  time cost: 10.344 s


  5%|▍         | 89/1874 [15:33<4:55:46,  9.94s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-90] LR: 0.0010: iter loss: 0.327, iter iou: 0.712, epoch loss: 0.480, epoch iou: 0.652,  time cost: 10.054 s


  5%|▍         | 90/1874 [15:43<4:56:36,  9.98s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-91] LR: 0.0010: iter loss: 0.329, iter iou: 0.756, epoch loss: 0.478, epoch iou: 0.653,  time cost: 10.143 s


  5%|▍         | 91/1874 [15:54<4:57:56, 10.03s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-92] LR: 0.0010: iter loss: 0.344, iter iou: 0.710, epoch loss: 0.477, epoch iou: 0.653,  time cost: 9.154 s


  5%|▍         | 92/1874 [16:03<4:50:01,  9.77s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-93] LR: 0.0010: iter loss: 0.308, iter iou: 0.751, epoch loss: 0.475, epoch iou: 0.654,  time cost: 8.702 s


  5%|▍         | 93/1874 [16:11<4:40:19,  9.44s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-94] LR: 0.0010: iter loss: 0.344, iter iou: 0.671, epoch loss: 0.474, epoch iou: 0.655,  time cost: 8.738 s


  5%|▌         | 94/1874 [16:20<4:33:53,  9.23s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-95] LR: 0.0010: iter loss: 0.359, iter iou: 0.673, epoch loss: 0.473, epoch iou: 0.655,  time cost: 9.299 s


  5%|▌         | 95/1874 [16:30<4:34:20,  9.25s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-96] LR: 0.0010: iter loss: 0.307, iter iou: 0.763, epoch loss: 0.471, epoch iou: 0.656,  time cost: 10.829 s


  5%|▌         | 96/1874 [16:40<4:48:10,  9.72s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-97] LR: 0.0010: iter loss: 0.333, iter iou: 0.725, epoch loss: 0.469, epoch iou: 0.657,  time cost: 9.730 s


  5%|▌         | 97/1874 [16:50<4:48:05,  9.73s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-98] LR: 0.0010: iter loss: 0.312, iter iou: 0.752, epoch loss: 0.468, epoch iou: 0.658,  time cost: 9.633 s


  5%|▌         | 98/1874 [17:00<4:47:07,  9.70s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-99] LR: 0.0010: iter loss: 0.369, iter iou: 0.669, epoch loss: 0.467, epoch iou: 0.658,  time cost: 9.615 s


  5%|▌         | 99/1874 [17:09<4:46:10,  9.67s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-100] LR: 0.0010: iter loss: 0.330, iter iou: 0.733, epoch loss: 0.465, epoch iou: 0.658,  time cost: 9.958 s


  5%|▌         | 100/1874 [17:19<4:48:29,  9.76s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-101] LR: 0.0010: iter loss: 0.313, iter iou: 0.752, epoch loss: 0.464, epoch iou: 0.659,  time cost: 9.652 s


  5%|▌         | 101/1874 [17:29<4:47:26,  9.73s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-102] LR: 0.0010: iter loss: 0.309, iter iou: 0.721, epoch loss: 0.462, epoch iou: 0.660,  time cost: 10.129 s


  5%|▌         | 102/1874 [17:39<4:50:45,  9.85s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-103] LR: 0.0010: iter loss: 0.370, iter iou: 0.657, epoch loss: 0.461, epoch iou: 0.660,  time cost: 9.483 s


  5%|▌         | 103/1874 [17:49<4:47:32,  9.74s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-104] LR: 0.0010: iter loss: 0.300, iter iou: 0.745, epoch loss: 0.460, epoch iou: 0.661,  time cost: 9.577 s


  6%|▌         | 104/1874 [17:58<4:45:53,  9.69s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-105] LR: 0.0010: iter loss: 0.292, iter iou: 0.761, epoch loss: 0.458, epoch iou: 0.662,  time cost: 8.940 s


  6%|▌         | 105/1874 [18:07<4:39:06,  9.47s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-106] LR: 0.0010: iter loss: 0.329, iter iou: 0.708, epoch loss: 0.457, epoch iou: 0.662,  time cost: 9.045 s


  6%|▌         | 106/1874 [18:16<4:35:09,  9.34s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-107] LR: 0.0010: iter loss: 0.318, iter iou: 0.741, epoch loss: 0.456, epoch iou: 0.663,  time cost: 9.650 s


  6%|▌         | 107/1874 [18:26<4:37:48,  9.43s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-108] LR: 0.0010: iter loss: 0.270, iter iou: 0.807, epoch loss: 0.454, epoch iou: 0.664,  time cost: 9.315 s


  6%|▌         | 108/1874 [18:35<4:36:35,  9.40s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-109] LR: 0.0010: iter loss: 0.340, iter iou: 0.684, epoch loss: 0.453, epoch iou: 0.664,  time cost: 10.935 s


  6%|▌         | 109/1874 [18:46<4:49:56,  9.86s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-110] LR: 0.0010: iter loss: 0.343, iter iou: 0.707, epoch loss: 0.452, epoch iou: 0.665,  time cost: 10.688 s


  6%|▌         | 110/1874 [18:57<4:57:08, 10.11s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-111] LR: 0.0010: iter loss: 0.347, iter iou: 0.692, epoch loss: 0.451, epoch iou: 0.665,  time cost: 11.605 s


  6%|▌         | 111/1874 [19:08<5:10:12, 10.56s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-112] LR: 0.0010: iter loss: 0.281, iter iou: 0.770, epoch loss: 0.450, epoch iou: 0.666,  time cost: 9.909 s


  6%|▌         | 112/1874 [19:18<5:04:17, 10.36s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-113] LR: 0.0010: iter loss: 0.378, iter iou: 0.662, epoch loss: 0.449, epoch iou: 0.666,  time cost: 9.552 s


  6%|▌         | 113/1874 [19:28<4:57:00, 10.12s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-114] LR: 0.0010: iter loss: 0.353, iter iou: 0.689, epoch loss: 0.448, epoch iou: 0.666,  time cost: 10.551 s


  6%|▌         | 114/1874 [19:38<5:00:36, 10.25s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-115] LR: 0.0010: iter loss: 0.301, iter iou: 0.747, epoch loss: 0.447, epoch iou: 0.667,  time cost: 11.470 s


  6%|▌         | 115/1874 [19:50<5:11:09, 10.61s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-116] LR: 0.0010: iter loss: 0.296, iter iou: 0.750, epoch loss: 0.446, epoch iou: 0.668,  time cost: 11.564 s


  6%|▌         | 116/1874 [20:01<5:19:22, 10.90s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-117] LR: 0.0010: iter loss: 0.356, iter iou: 0.714, epoch loss: 0.445, epoch iou: 0.668,  time cost: 9.798 s


  6%|▌         | 117/1874 [20:11<5:09:32, 10.57s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-118] LR: 0.0010: iter loss: 0.346, iter iou: 0.704, epoch loss: 0.444, epoch iou: 0.668,  time cost: 9.724 s


  6%|▋         | 118/1874 [20:21<5:01:57, 10.32s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-119] LR: 0.0010: iter loss: 0.309, iter iou: 0.754, epoch loss: 0.443, epoch iou: 0.669,  time cost: 9.865 s


  6%|▋         | 119/1874 [20:31<4:57:46, 10.18s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-120] LR: 0.0010: iter loss: 0.304, iter iou: 0.766, epoch loss: 0.442, epoch iou: 0.670,  time cost: 10.840 s


  6%|▋         | 120/1874 [20:42<5:03:23, 10.38s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-121] LR: 0.0010: iter loss: 0.298, iter iou: 0.761, epoch loss: 0.440, epoch iou: 0.671,  time cost: 10.506 s


  6%|▋         | 121/1874 [20:52<5:04:22, 10.42s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-122] LR: 0.0010: iter loss: 0.365, iter iou: 0.675, epoch loss: 0.440, epoch iou: 0.671,  time cost: 12.030 s


  7%|▋         | 122/1874 [21:04<5:18:16, 10.90s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-123] LR: 0.0010: iter loss: 0.308, iter iou: 0.741, epoch loss: 0.439, epoch iou: 0.671,  time cost: 11.935 s


  7%|▋         | 123/1874 [21:16<5:27:08, 11.21s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-124] LR: 0.0010: iter loss: 0.356, iter iou: 0.706, epoch loss: 0.438, epoch iou: 0.671,  time cost: 9.791 s


  7%|▋         | 124/1874 [21:26<5:14:37, 10.79s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-125] LR: 0.0010: iter loss: 0.265, iter iou: 0.794, epoch loss: 0.437, epoch iou: 0.672,  time cost: 9.103 s


  7%|▋         | 125/1874 [21:35<4:59:39, 10.28s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-126] LR: 0.0010: iter loss: 0.302, iter iou: 0.750, epoch loss: 0.436, epoch iou: 0.673,  time cost: 10.414 s


  7%|▋         | 126/1874 [21:45<5:00:39, 10.32s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-127] LR: 0.0010: iter loss: 0.279, iter iou: 0.786, epoch loss: 0.434, epoch iou: 0.674,  time cost: 10.243 s


  7%|▋         | 127/1874 [21:56<4:59:50, 10.30s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-128] LR: 0.0010: iter loss: 0.323, iter iou: 0.737, epoch loss: 0.434, epoch iou: 0.674,  time cost: 11.473 s


  7%|▋         | 128/1874 [22:07<5:09:51, 10.65s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-129] LR: 0.0010: iter loss: 0.342, iter iou: 0.706, epoch loss: 0.433, epoch iou: 0.675,  time cost: 9.402 s


  7%|▋         | 129/1874 [22:16<4:58:50, 10.28s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-130] LR: 0.0010: iter loss: 0.346, iter iou: 0.706, epoch loss: 0.432, epoch iou: 0.675,  time cost: 10.827 s


  7%|▋         | 130/1874 [22:27<5:03:26, 10.44s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-131] LR: 0.0010: iter loss: 0.322, iter iou: 0.741, epoch loss: 0.431, epoch iou: 0.675,  time cost: 10.060 s


  7%|▋         | 131/1874 [22:37<4:59:58, 10.33s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-132] LR: 0.0010: iter loss: 0.372, iter iou: 0.673, epoch loss: 0.431, epoch iou: 0.675,  time cost: 8.837 s


  7%|▋         | 132/1874 [22:46<4:46:48,  9.88s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-133] LR: 0.0010: iter loss: 0.310, iter iou: 0.735, epoch loss: 0.430, epoch iou: 0.676,  time cost: 9.700 s


  7%|▋         | 133/1874 [22:56<4:45:08,  9.83s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-134] LR: 0.0010: iter loss: 0.237, iter iou: 0.809, epoch loss: 0.429, epoch iou: 0.677,  time cost: 9.918 s


  7%|▋         | 134/1874 [23:06<4:45:44,  9.85s/it]

torch.Size([786432, 2])
torch.Size([786432])
[Epoch-0 Iter-135] LR: 0.0010: iter loss: 0.347, iter iou: 0.693, epoch loss: 0.428, epoch iou: 0.677,  time cost: 9.500 s


  7%|▋         | 135/1874 [23:15<4:42:37,  9.75s/it]

torch.Size([786432, 2])
torch.Size([786432])


In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.suptitle('dice loss + cross-entropy training')
plt.subplot(1, 2, 1)
plt.plot(loss_plot)
plt.ylabel('loss')
plt.xlabel('iteration')

plt.subplot(1, 2, 2)
plt.plot(iou_plot)
plt.ylabel('Image-IoU (%)')
plt.xlabel('iteration')

In [None]:
torch.save(net.state_dict(), "result.pt")

In [None]:
torch.save(net.state_dict(), 
                    pjoin(cfg.WEIGHTS_SAVE_ROOT, "weights_ep_%d_%.3f_%.3f.pth" 
                            % (epoch, epoch_loss / epoch_size, epoch_miou / epoch_size)))