In [1]:
import os
import datetime
import cv2
from PIL import Image
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data.dataset import Dataset
import torch.distributed as dist
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from functools import partial
from torch.hub import load_state_dict_from_url
import torch.nn.functional as F
import math
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signal
import shutil
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from timm.models.layers import DropPath, to_2tuple, trunc_normal_



In [2]:
# import mmcv
# from mmcv.fileio import FileClient
# # from mmcv.fileio import load as load_file
# from mmcv.parallel import is_module_wrapper
# from mmcv.utils import mkdir_or_exist
# from mmcv.runner import get_dist_info

In [3]:
def load_state_dict(module, state_dict, strict=False, logger=None):
    
    unexpected_keys = []
    all_missing_keys = []
    err_msg = []

    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    # use _load_from_state_dict to enable checkpoint version control
    def load(module, prefix=''):
        # recursively check parallel module in case that the model has a
        # complicated structure, e.g., nn.Module(nn.Module(DDP))
#         if is_module_wrapper(module):
#             module = module.module
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
                                     all_missing_keys, unexpected_keys,
                                     err_msg)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(module)
    load = None  # break load->load reference cycle

    # ignore "num_batches_tracked" of BN layers
    missing_keys = [
        key for key in all_missing_keys if 'num_batches_tracked' not in key
    ]

    if unexpected_keys:
        err_msg.append('unexpected key in source '
                       f'state_dict: {", ".join(unexpected_keys)}\n')
    if missing_keys:
        err_msg.append(
            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')

    rank = 0
    if len(err_msg) > 0 and rank == 0:
        err_msg.insert(
            0, 'The model and loaded state dict do not match exactly\n')
        err_msg = '\n'.join(err_msg)
        if strict:
            raise RuntimeError(err_msg)
        elif logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)


def _load_checkpoint(filename, map_location=None):
    checkpoint = torch.load(filename, map_location=map_location)
    return checkpoint

def load_checkpoint(model,
                    filename,
                    map_location='cpu',
                    strict=False,
                    logger=None):
    
    checkpoint = _load_checkpoint(filename, map_location)
    # OrderedDict is a subclass of dict
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f'No state_dict found in checkpoint file {filename}')
    # get state_dict from checkpoint
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}

    # for MoBY, load model of online branch
    if sorted(list(state_dict.keys()))[0].startswith('encoder'):
        state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}

    # reshape absolute position embedding
    if state_dict.get('absolute_pos_embed') is not None:
        absolute_pos_embed = state_dict['absolute_pos_embed']
        N1, L, C1 = absolute_pos_embed.size()
        N2, C2, H, W = model.absolute_pos_embed.size()
        if N1 != N2 or C1 != C2 or L != H*W:
            logger.warning("Error in loading absolute_pos_embed, pass")
        else:
            state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)

    # interpolate position bias table if needed
    relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
    for table_key in relative_position_bias_table_keys:
        table_pretrained = state_dict[table_key]
        table_current = model.state_dict()[table_key]
        L1, nH1 = table_pretrained.size()
        L2, nH2 = table_current.size()
        if nH1 != nH2:
            logger.warning(f"Error in loading {table_key}, pass")
        else:
            if L1 != L2:
                S1 = int(L1 ** 0.5)
                S2 = int(L2 ** 0.5)
                table_pretrained_resized = F.interpolate(
                     table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
                     size=(S2, S2), mode='bicubic')
                state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)

    # load state_dict
    load_state_dict(model, state_dict, strict, logger)
    return checkpoint

In [4]:
class Mlp(nn.Module):
    
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):

    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    
    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        self.H = None
        self.W = None

    def forward(self, x, mask_matrix):
        
        B, L, C = x.shape
        H, W = self.H, self.W
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class PatchMerging(nn.Module):
    
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):

        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x


class BasicLayer(nn.Module):

    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None,
                 use_checkpoint=False):
        super().__init__()
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x, H, W):

        # calculate attention mask for SW-MSA
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        for blk in self.blocks:
            blk.H, blk.W = H, W
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            Wh, Ww = (H + 1) // 2, (W + 1) // 2
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W


class PatchEmbed(nn.Module):

    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        """Forward function."""
        # padding
        _, _, H, W = x.size()
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

        x = self.proj(x)  # B C Wh Ww
        if self.norm is not None:
            Wh, Ww = x.size(2), x.size(3)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)

        return x

class SwinTransformer(nn.Module):

    def __init__(self,
                 pretrain_img_size=224,
                 patch_size=4,
                 in_chans=3,
                 embed_dim=96,
                 depths=[2, 2, 6, 2],
                 num_heads=[3, 6, 12, 24],
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.2,
                 norm_layer=nn.LayerNorm,
                 ape=False,
                 patch_norm=True,
                 out_indices=(0, 1, 2, 3),
                 frozen_stages=-1,
                 use_checkpoint=False):
        super().__init__()

        self.pretrain_img_size = pretrain_img_size
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.out_indices = out_indices
        self.frozen_stages = frozen_stages

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        # absolute position embedding
        if self.ape:
            pretrain_img_size = to_2tuple(pretrain_img_size)
            patch_size = to_2tuple(patch_size)
            patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]

            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2 ** i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
        self.num_features = num_features
#         self.norm = norm_layer(self.num_features)
#         self.avgpool = nn.AdaptiveAvgPool1d(1)
#         self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        # add a norm layer for each output
        for i_layer in out_indices:
            layer = norm_layer(num_features[i_layer])
            layer_name = f'norm{i_layer}'
            self.add_module(layer_name, layer)

        self._freeze_stages()

    def _freeze_stages(self):
        if self.frozen_stages >= 0:
            self.patch_embed.eval()
            for param in self.patch_embed.parameters():
                param.requires_grad = False

        if self.frozen_stages >= 1 and self.ape:
            self.absolute_pos_embed.requires_grad = False

        if self.frozen_stages >= 2:
            self.pos_drop.eval()
            for i in range(0, self.frozen_stages - 1):
                m = self.layers[i]
                m.eval()
                for param in m.parameters():
                    param.requires_grad = False

    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """

        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        if isinstance(pretrained, str):
            self.apply(_init_weights)
            load_checkpoint(self, pretrained, strict=False, logger=None)
        elif pretrained is None:
            self.apply(_init_weights)
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        """Forward function."""
        x = self.patch_embed(x)

        Wh, Ww = x.size(2), x.size(3)
        if self.ape:
            # interpolate the position embedding to the corresponding size
            absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
        else:
            x = x.flatten(2).transpose(1, 2)
        x = self.pos_drop(x)

        outs = []
        for i in range(self.num_layers):
            layer = self.layers[i]
            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)

            if i in self.out_indices:
                norm_layer = getattr(self, f'norm{i}')
                x_out = norm_layer(x_out)

                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
                outs.append(out)

        return tuple(outs)

    def train(self, mode=True):
        """Convert the model into training mode while keep layers freezed."""
        super(SwinTransformer, self).train(mode)
        self._freeze_stages()


In [5]:
backbone = SwinTransformer(frozen_stages=4)

In [6]:
# backbone.init_weights('/home/ubuntu/MyFiles/swin_tiny_patch4_window7_224.pth')
backbone.init_weights('e://毕业论文/swin_tiny_patch4_window7_224.pth')

The model and loaded state dict do not match exactly

unexpected key in source state_dict: norm.weight, norm.bias, head.weight, head.bias, layers.0.blocks.1.attn_mask, layers.1.blocks.1.attn_mask, layers.2.blocks.1.attn_mask, layers.2.blocks.3.attn_mask, layers.2.blocks.5.attn_mask

missing keys in source state_dict: norm0.weight, norm0.bias, norm1.weight, norm1.bias, norm2.weight, norm2.bias, norm3.weight, norm3.bias



In [7]:
class PPM(nn.ModuleList):
    def __init__(self, pool_sizes, in_channels=768, out_channels=256):
        super(PPM, self).__init__()
        self.pool_sizes = pool_sizes
        self.in_channels = in_channels
        self.out_channels = out_channels
        for pool_size in pool_sizes:
            self.append(
                nn.Sequential(
                    nn.AdaptiveMaxPool2d(pool_size),
                    nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1),
                )
            )     
            
    def forward(self, x):
        out_puts = []
        for ppm in self:
            ppm_out = nn.functional.interpolate(ppm(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True)
            out_puts.append(ppm_out)
        return out_puts
 
    
class PPMHEAD(nn.Module):
    def __init__(self, in_channels=768, out_channels=256, pool_sizes = [1, 2, 3, 6],num_classes=2):
        super(PPMHEAD, self).__init__()
        self.pool_sizes = pool_sizes
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.psp_modules = PPM(self.pool_sizes, self.in_channels, self.out_channels)
        self.final = nn.Sequential(
            nn.Conv2d(self.in_channels + len(self.pool_sizes)*self.out_channels, self.out_channels, kernel_size=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
        )
        
    def forward(self, x):
        out = self.psp_modules(x)
        out.append(x)
        out = torch.cat(out, 1)
        out = self.final(out)
        return out
 
class FPNHEAD(nn.Module):
    def __init__(self, channels=768, out_channels=256):
        super(FPNHEAD, self).__init__()
        self.PPMHead = PPMHEAD(in_channels=channels, out_channels=out_channels)
        
        self.Conv_fuse1 = nn.Sequential(
            nn.Conv2d(channels//2, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.Conv_fuse1_ = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.Conv_fuse2 = nn.Sequential(
            nn.Conv2d(channels//4, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )    
        self.Conv_fuse2_ = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.Conv_fuse3 = nn.Sequential(
            nn.Conv2d(channels//8, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ) 
        self.Conv_fuse3_ = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
        self.fuse_all = nn.Sequential(
            nn.Conv2d(out_channels*4, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.conv_x1 = nn.Conv2d(out_channels, out_channels, 1)
 
    def forward(self, input_fpn):
        # b, 768,16,16
        x1 = self.PPMHead(input_fpn[-1])
         #b, 256,32,32
        x = nn.functional.interpolate(x1, size=(x1.size(2)*2, x1.size(3)*2),mode='bilinear', align_corners=True)
        x = self.conv_x1(x) + self.Conv_fuse1(input_fpn[-2])  #b,256,32,32
        x2 = self.Conv_fuse1_(x)
        
        x = nn.functional.interpolate(x2, size=(x2.size(2)*2, x2.size(3)*2),mode='bilinear', align_corners=True)
        x = x + self.Conv_fuse2(input_fpn[-3]) #b,256, 64, 64
        x3 = self.Conv_fuse2_(x)  
 
        x = nn.functional.interpolate(x3, size=(x3.size(2)*2, x3.size(3)*2),mode='bilinear', align_corners=True)
        x = x + self.Conv_fuse3(input_fpn[-4]) #b,256,128,128
        x4 = self.Conv_fuse3_(x)
 
        x1 = F.interpolate(x1, x4.size()[-2:],mode='bilinear', align_corners=True)
        x2 = F.interpolate(x2, x4.size()[-2:],mode='bilinear', align_corners=True)
        x3 = F.interpolate(x3, x4.size()[-2:],mode='bilinear', align_corners=True)
 
        x = self.fuse_all(torch.cat([x1, x2, x3, x4], 1)) #b,256*4,128,128
        
        return x
    
class UPerNet(nn.Module):
    def __init__(self, num_classes, backbone):
        super(UPerNet, self).__init__()
        self.num_classes = num_classes
        self.backbone = backbone
        self.in_channels = 768
        self.channels = 256
        self.decoder = FPNHEAD()
        self.cls_seg = nn.Sequential(
            nn.Conv2d(self.channels, self.num_classes, kernel_size=3, padding=1),
        )
        
    def forward(self, x):
        x = self.backbone(x) 
        x = self.decoder(x)
        
        x = nn.functional.interpolate(x, size=(x.size(2)*4, x.size(3)*4),mode='bilinear', align_corners=True)
        x = self.cls_seg(x)
        return x
    
    def unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True

In [8]:
def fast_hist(a, b, n):
    #--------------------------------------------------------------------------------#
    #   a是转化成一维数组的标签，形状(H×W,)；b是转化成一维数组的预测结果，形状(H×W,)
    #--------------------------------------------------------------------------------#
    k = (a >= 0) & (a < n)
    #--------------------------------------------------------------------------------#
    #   np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数，返回值形状(n, n)
    #   返回中，写对角线上的为分类正确的像素点
    #--------------------------------------------------------------------------------#
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)  

def per_class_iu(hist):
    return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) 

def per_class_PA_Recall(hist):
    return np.diag(hist) / np.maximum(hist.sum(1), 1) 

def per_class_Precision(hist):
    return np.diag(hist) / np.maximum(hist.sum(0), 1) 

def per_Accuracy(hist):
    return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) 

def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None):  
    print('Num classes', num_classes)  
    #-----------------------------------------#
    #   创建一个全是0的矩阵，是一个混淆矩阵
    #-----------------------------------------#
    hist = np.zeros((num_classes, num_classes))  #2*2
    
    #------------------------------------------------#
    #   获得验证集标签路径列表，方便直接读取
    #   获得验证集图像分割结果路径列表，方便直接读取
    #------------------------------------------------#
    gt_imgs     = [os.path.join(gt_dir, x + "_segmentation.png") for x in png_name_list]  
    pred_imgs   = [os.path.join(pred_dir, x + ".png") for x in png_name_list]  

    #------------------------------------------------#
    #   读取每一个（图片-标签）对
    #------------------------------------------------#
    for ind in range(len(gt_imgs)): 
        #------------------------------------------------#
        #   读取一张图像分割结果，转化成numpy数组
        #------------------------------------------------#
        pred = np.array(Image.open(pred_imgs[ind]))  #0为   1为背景
        #------------------------------------------------#
        #   读取一张对应的标签，转化成numpy数组
        #------------------------------------------------#
        png = np.array(Image.open(gt_imgs[ind])) #0为背景， 255为目标
        label  = np.zeros_like(png)
        label[png <= 127.5] = 1

        # 如果图像分割结果与标签的大小不一样，这张图片就不计算
        if len(label.flatten()) != len(pred.flatten()):  
            print(
                'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(
                    len(label.flatten()), len(pred.flatten()), gt_imgs[ind],
                    pred_imgs[ind]))
            continue

        #------------------------------------------------#
        #   对一张图片计算21×21的hist矩阵，并累加
        #------------------------------------------------#
        hist += fast_hist(label.flatten(), pred.flatten(), num_classes)  
        # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值
        if name_classes is not None and ind > 0 and ind % 10 == 0: 
            print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format(
                    ind, 
                    len(gt_imgs),
                    100 * np.nanmean(per_class_iu(hist)),
                    100 * np.nanmean(per_class_PA_Recall(hist)),
                    100 * per_Accuracy(hist)
                )
            )
    #------------------------------------------------#
    #   计算所有验证集图片的逐类别mIoU值
    #------------------------------------------------#
    IoUs        = per_class_iu(hist)
    PA_Recall   = per_class_PA_Recall(hist)
    Precision   = per_class_Precision(hist)
    #------------------------------------------------#
    #   逐类别输出一下mIoU值
    #------------------------------------------------#
    if name_classes is not None:
        for ind_class in range(num_classes):
            print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \
                + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2)))

    #-----------------------------------------------------------------#
    #   在所有验证集图像上求所有类别平均的mIoU值，计算时忽略NaN值
    #-----------------------------------------------------------------#
    print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2)))  
    hist = np.array(hist)
    return np.array(hist, np.int), IoUs, PA_Recall, Precision


In [9]:
class LossHistory():
    def __init__(self, log_dir, model, input_shape, val_loss_flag=True):
        self.log_dir        = log_dir
        self.val_loss_flag  = val_loss_flag

        self.losses         = []
        if self.val_loss_flag:
            self.val_loss   = []
        
        os.makedirs(self.log_dir)
        self.writer     = SummaryWriter(self.log_dir)
        try:
            dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
            self.writer.add_graph(model, dummy_input)
        except:
            pass

    def append_loss(self, epoch, loss, val_loss = None):
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        self.losses.append(loss)
        if self.val_loss_flag:
            self.val_loss.append(val_loss)
        
        with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
            f.write(str(loss))
            f.write("\n")
        if self.val_loss_flag:
            with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
                f.write(str(val_loss))
                f.write("\n")
            
        self.writer.add_scalar('loss', loss, epoch)
        if self.val_loss_flag:
            self.writer.add_scalar('val_loss', val_loss, epoch)
            
        self.loss_plot()

    def loss_plot(self):
        iters = range(len(self.losses))

        plt.figure()
        plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
        if self.val_loss_flag:
            plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
            
        try:
            if len(self.losses) < 25:
                num = 5
            else:
                num = 15
            
            plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
            if self.val_loss_flag:
                plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
        except:
            pass

        plt.grid(True)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend(loc="upper right")

        plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))

        plt.cla()
        plt.close("all")
        
class EvalCallback():
    def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \
            miou_out_path=".temp_miou_out", eval_flag=True, period=1):
        super(EvalCallback, self).__init__()
        
        self.net                = net
        self.input_shape        = input_shape
        self.num_classes        = num_classes
        self.image_ids          = image_ids
        self.dataset_path       = dataset_path
        self.log_dir            = log_dir
        self.cuda               = cuda
        self.miou_out_path      = miou_out_path
        self.eval_flag          = eval_flag
        self.period             = period
        
        self.image_ids          = [image_id.split()[0][:-4] for image_id in image_ids]
        self.mious      = [0]
        self.epoches    = [0]
        if self.eval_flag:
            with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
                f.write(str(0))
                f.write("\n")

    def get_miou_png(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        orininal_h  = np.array(image).shape[0]
        orininal_w  = np.array(image).shape[1]
        #---------------------------------------------------------#
        #   给图像增加灰条，实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data = np.array(image_data)
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            pr = self.net(images)[0]  #num_class, 512, 512
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
            #--------------------------------------#
            #   将灰条部分截取掉
            #--------------------------------------#
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
            #---------------------------------------------------#
            #   进行图片的resize
            #---------------------------------------------------#
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = pr.argmax(axis=-1)
    
        image = Image.fromarray(np.uint8(pr))
#         image = Image.fromarray(np.uint8(pr*255))
        return image
    
    def on_epoch_end(self, epoch, model_eval):
        if epoch % self.period == 0 and self.eval_flag:
            self.net    = model_eval
#             jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "Images"), name + ".jpg"))
#             png         = Image.open(os.path.join(os.path.join(self.dataset_path, "Labels"), name + "_segmentation.png"))
#             gt_dir      = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/")
            gt_dir      = os.path.join(self.dataset_path, "Labels")
            pred_dir    = os.path.join(self.miou_out_path, 'detection-results')
            if not os.path.exists(self.miou_out_path):
                os.makedirs(self.miou_out_path)
            if not os.path.exists(pred_dir):
                os.makedirs(pred_dir)
            print("Get miou.")
            for image_id in tqdm(self.image_ids):
                #-------------------------------#
                #   从文件中读取图像
                #-------------------------------#
                image_path  = os.path.join(self.dataset_path, "Images/"+image_id+".jpg")
                image       = Image.open(image_path)
                #------------------------------#
                #   获得预测txt
                #------------------------------#
                image       = self.get_miou_png(image)
                image.save(os.path.join(pred_dir, image_id + ".png"))
                        
            print("Calculate miou.")
            _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None)  # 执行计算mIoU的函数
            temp_miou = np.nanmean(IoUs) * 100

            self.mious.append(temp_miou)
            self.epoches.append(epoch)

            with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
                f.write(str(temp_miou))
                f.write("\n")
            
            plt.figure()
            plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou')

            plt.grid(True)
            plt.xlabel('Epoch')
            plt.ylabel('Miou')
            plt.title('A Miou Curve')
            plt.legend(loc="upper right")

            plt.savefig(os.path.join(self.log_dir, "epoch_miou.png"))
            plt.cla()
            plt.close("all")

            print("Get miou done.")
            shutil.rmtree(self.miou_out_path)


In [10]:
def cvtColor(image):
    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
        return image 
    else:
        image = image.convert('RGB')
        return image 
def preprocess_input(image):
    image /= 255.0
    return image

def resize_image(image, size):
    iw, ih  = image.size
    w, h    = size

    scale   = min(w/iw, h/ih)
    nw      = int(iw*scale)
    nh      = int(ih*scale)

    image   = image.resize((nw,nh), Image.BICUBIC)
    new_image = Image.new('RGB', size, (128,128,128))
    new_image.paste(image, ((w-nw)//2, (h-nh)//2))

    return new_image, nw, nh

class UnetDataset(Dataset):
    def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
        super(UnetDataset, self).__init__()
        self.annotation_lines   = annotation_lines
        self.length             = len(annotation_lines)
        self.input_shape        = input_shape
        self.num_classes        = num_classes
        self.train              = train
        self.dataset_path       = dataset_path

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        annotation_line = self.annotation_lines[index]
        name            = annotation_line.split()[0][:-4]

        #-------------------------------#
        #   从文件中读取图像
        #-------------------------------#
        jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "Images"), name + ".jpg"))
        png         = Image.open(os.path.join(os.path.join(self.dataset_path, "Labels"), name + "_segmentation.png"))
        #-------------------------------#
        #   数据增强
        #-------------------------------#
#         jpg, png    = self.get_random_data(jpg, png, self.input_shape, random = self.train)
        jpg, png    = self.get_random_data(jpg, png, self.input_shape, random = False)

        jpg         = np.array(jpg)
        jpg         = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
        png         = np.array(png)
        #-------------------------------------------------------#
        #   这里的标签处理方式和普通voc的处理方式不同
        #   将小于127.5的像素点设置为目标像素点。
        #-------------------------------------------------------#
        modify_png  = np.zeros_like(png)
        modify_png[png <= 127.5] = 1
        seg_labels  = modify_png
        seg_labels  = np.eye(self.num_classes + 1)[seg_labels.reshape([-1])]
        seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
#         seg_labels  = np.eye(self.num_classes)[seg_labels.reshape([-1])]
#         seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes))

        return jpg, modify_png, seg_labels

    def rand(self, a=0, b=1):
        return np.random.rand() * (b - a) + a

    def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):
        image   = cvtColor(image)
        label   = Image.fromarray(np.array(label))
        #------------------------------#
        #   获得图像的高宽与目标高宽
        #------------------------------#
        iw, ih  = image.size
        h, w    = input_shape

        if not random:
            iw, ih  = image.size
            scale   = min(w/iw, h/ih)
            nw      = int(iw*scale)
            nh      = int(ih*scale)

            image       = image.resize((nw,nh), Image.BICUBIC)
            new_image   = Image.new('RGB', [w, h], (128,128,128))
            new_image.paste(image, ((w-nw)//2, (h-nh)//2))

            label       = label.resize((nw,nh), Image.NEAREST)
            new_label   = Image.new('L', [w, h], (0))
            new_label.paste(label, ((w-nw)//2, (h-nh)//2))
            return new_image, new_label

        #------------------------------------------#
        #   对图像进行缩放并且进行长和宽的扭曲
        #------------------------------------------#
        new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
        scale = self.rand(0.25, 2)
        if new_ar < 1:
            nh = int(scale*h)
            nw = int(nh*new_ar)
        else:
            nw = int(scale*w)
            nh = int(nw/new_ar)
        image = image.resize((nw,nh), Image.BICUBIC)
        label = label.resize((nw,nh), Image.NEAREST)
        
        #------------------------------------------#
        #   翻转图像
        #------------------------------------------#
        flip = self.rand()<.5
        if flip: 
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
            label = label.transpose(Image.FLIP_LEFT_RIGHT)
        
        #------------------------------------------#
        #   将图像多余的部分加上灰条
        #------------------------------------------#
        dx = int(self.rand(0, w-nw))
        dy = int(self.rand(0, h-nh))
        new_image = Image.new('RGB', (w,h), (128,128,128))
        new_label = Image.new('L', (w,h), (0))
        new_image.paste(image, (dx, dy))
        new_label.paste(label, (dx, dy))
        image = new_image
        label = new_label
        image = np.array(image)
        image_data      = np.array(image, np.uint8)
        #---------------------------------#
        #   对图像进行色域变换
        #   计算色域变换的参数
        #---------------------------------#
        r               = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
        #---------------------------------#
        #   将图像转到HSV上
        #---------------------------------#
        hue, sat, val   = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
        dtype           = image_data.dtype
        #---------------------------------#
        #   应用变换
        #---------------------------------#
        x       = np.arange(0, 256, dtype=r.dtype)
        lut_hue = ((x * r[0]) % 180).astype(dtype)
        lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
        lut_val = np.clip(x * r[2], 0, 255).astype(dtype)

        image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
        image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
        
        return image_data, label

# DataLoader中collate_fn使用
def unet_dataset_collate(batch):
    images      = []
    pngs        = []
    seg_labels  = []
    for img, png, labels in batch:
        images.append(img)
        pngs.append(png)
        seg_labels.append(labels)
    images      = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
    pngs        = torch.from_numpy(np.array(pngs)).long() #batchsize, 512,512,    1表示背景
    seg_labels  = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)#batchsize, 512,512, num_classes+1
    return images, pngs, seg_labels


In [11]:
# Cuda = True
Cuda = False
distributed     = False
sync_bn         = False
fp16            = False
num_classes = 2
pretrained  = True
model_path  = ""
input_shape = [224, 224]
Init_Epoch          = 0
Freeze_Epoch        = 25
Freeze_batch_size   = 2
UnFreeze_Epoch      = 50
Unfreeze_batch_size = 2
Freeze_Train        = True
Init_lr             = 1e-4
Min_lr              = Init_lr * 0.01
optimizer_type      = "adam"
momentum            = 0.9
weight_decay        = 0
lr_decay_type       = 'step'
save_period         = 5
save_dir            = 'logs'
eval_flag           = True
eval_period         = 5
# VOCdevkit_path  = '/home/ubuntu/MyFiles/Skin_Datasets/train'
VOCdevkit_path  = 'e://毕业论文/Skin_Datasets/train'
dice_loss       = True
focal_loss      = False
cls_weights     = np.ones([num_classes], np.float32)
device          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank      = 0

In [12]:
with open('total_lines_3.txt','r')as f:
    total_lines = f.readlines()
total_lines = [line.strip() for line in total_lines]
# train_lines = total_lines[:1998]
# val_lines = total_lines[1998:1998+348]
train_lines = total_lines[:1998]
val_lines = total_lines[1998:1998+348]
num_train   = len(train_lines)
num_val     = len(val_lines)

In [13]:
model = UPerNet(num_classes=num_classes,backbone=backbone).train()

In [14]:
time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
log_dir         = os.path.join(save_dir, "loss_" + str(time_str))
loss_history    = LossHistory(log_dir, model, input_shape=input_shape)
scaler = None

  if W % self.patch_size[1] != 0:
  if H % self.patch_size[0] != 0:
  Hp = int(np.ceil(H / self.window_size)) * self.window_size
  Wp = int(np.ceil(W / self.window_size)) * self.window_size
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ..\aten\src\ATen\native\BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
  assert L == H * W, "input feature has wrong size"
  B = int(windows.shape[0] / (H * W / window_size / window_size))
  if pad_r > 0 or pad_b > 0:
  assert L == H * W, "input feature has wrong size"
  pad_input = (H % 2 == 1) or (W % 2 == 1)
  if pad_input:


In [15]:
model_train     = model.train()
if Cuda:
    model_train = torch.nn.DataParallel(model)
    cudnn.benchmark = True
    model_train = model_train.cuda()

In [16]:
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
    def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
        if iters <= warmup_total_iters:
            # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
            lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
        elif iters >= total_iters - no_aug_iter:
            lr = min_lr
        else:
            lr = min_lr + 0.5 * (lr - min_lr) * (
                1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
            )
        return lr

    def step_lr(lr, decay_rate, step_size, iters):
        if step_size < 1:
            raise ValueError("step_size must above 1.")
        n       = iters // step_size
        out_lr  = lr * decay_rate ** n
        return out_lr

    if lr_decay_type == "cos":
        warmup_total_iters  = min(max(warmup_iters_ratio * total_iters, 1), 3)
        warmup_lr_start     = max(warmup_lr_ratio * lr, 1e-6)
        no_aug_iter         = min(max(no_aug_iter_ratio * total_iters, 1), 15)
        func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
    else:
        decay_rate  = (min_lr / lr) ** (1 / (step_num - 1))
        step_size   = total_iters / step_num
        func = partial(step_lr, lr, decay_rate, step_size)

    return func

In [17]:
UnFreeze_flag = False
# if Freeze_Train:
#     model.freeze_backbone()
batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
epoch_step      = num_train // batch_size
epoch_step_val  = num_val // batch_size
nbs = 16
lr_limit_max    = 1e-4 if optimizer_type == 'adam' else 1e-1
lr_limit_min    = 1e-4 if optimizer_type == 'adam' else 5e-4
Init_lr_fit     = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
Min_lr_fit      = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
optimizer = {
    'adam'  : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay),
    'sgd'   : optim.SGD(model.parameters(), Init_lr_fit, momentum = momentum, nesterov=True, weight_decay = weight_decay)
}[optimizer_type]
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
train_dataset   = UnetDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path)
val_dataset     = UnetDataset(val_lines, input_shape, num_classes, False, VOCdevkit_path)
train_sampler   = None
val_sampler     = None
shuffle         = True
# gen             = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, pin_memory=True,
#                             drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler)
# gen_val         = DataLoader(val_dataset  , shuffle = shuffle, batch_size = batch_size, pin_memory=True, 
#                             drop_last = True, collate_fn = unet_dataset_collate, sampler=val_sampler)
gen             = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, pin_memory=False,
                            drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler)
gen_val         = DataLoader(val_dataset  , shuffle = shuffle, batch_size = batch_size, pin_memory=False, 
                            drop_last = True, collate_fn = unet_dataset_collate, sampler=val_sampler)
eval_callback   = EvalCallback(model, input_shape, num_classes, val_lines, VOCdevkit_path, log_dir, Cuda, \
                                eval_flag=eval_flag, period=eval_period)

In [18]:
def CE_Loss(inputs, target, cls_weights, num_classes=2):
    n, c, h, w = inputs.size()
    nt, ht, wt = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)

    temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    temp_target = target.view(-1)  #1为背景

    CE_loss  = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target)
    return CE_loss

def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice loss
    #--------------------------------------------#
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss

def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice系数
    #--------------------------------------------#
    temp_inputs = torch.gt(temp_inputs, threhold).float()
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    score = torch.mean(score)
    return score

def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0):
    total_loss      = 0
    total_f_score   = 0

    val_loss        = 0
    val_f_score     = 0

    if local_rank == 0:
        print('Start Train')
        pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
    model_train.train()
    for iteration, batch in enumerate(gen):
        if iteration >= epoch_step: 
            break
        imgs, pngs, labels = batch
        with torch.no_grad():
            weights = torch.from_numpy(cls_weights)
            if cuda:
                imgs    = imgs.cuda(local_rank)
                pngs    = pngs.cuda(local_rank)
                labels  = labels.cuda(local_rank)
                weights = weights.cuda(local_rank)

        optimizer.zero_grad()
        if not fp16:
            #----------------------#
            #   前向传播
            #----------------------#
            outputs = model_train(imgs)
            #----------------------#
            #   损失计算
            #----------------------#
#             if focal_loss:
#                 loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
#             else:
            loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)

            if dice_loss:
                main_dice = Dice_loss(outputs, labels)
                loss      = loss + main_dice

            with torch.no_grad():
                #-------------------------------#
                #   计算f_score
                #-------------------------------#
                _f_score = f_score(outputs, labels)

            loss.backward()
            optimizer.step()
        else:
            from torch.cuda.amp import autocast
            with autocast():
                #----------------------#
                #   前向传播
                #----------------------#
                outputs = model_train(imgs)
                #----------------------#
                #   损失计算
                #----------------------#
                if focal_loss:
                    loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
                else:
                    loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)

                if dice_loss:
                    main_dice = Dice_loss(outputs, labels)
                    loss      = loss + main_dice

                with torch.no_grad():
                    #-------------------------------#
                    #   计算f_score
                    #-------------------------------#
                    _f_score = f_score(outputs, labels)

            #----------------------#
            #   反向传播
            #----------------------#
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        total_loss      += loss.item()
        total_f_score   += _f_score.item()
        
        if local_rank == 0:
            pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), 
                                'f_score'   : total_f_score / (iteration + 1),
                                'lr'        : get_lr(optimizer)})
            pbar.update(1)

    if local_rank == 0:
        pbar.close()
        print('Finish Train')
        print('Start Validation')
        pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)

    model_train.eval()
    for iteration, batch in enumerate(gen_val):
        if iteration >= epoch_step_val:
            break
        imgs, pngs, labels = batch
        with torch.no_grad():
            weights = torch.from_numpy(cls_weights)
            if cuda:
                imgs    = imgs.cuda(local_rank)
                pngs    = pngs.cuda(local_rank)
                labels  = labels.cuda(local_rank)
                weights = weights.cuda(local_rank)

            #----------------------#
            #   前向传播
            #----------------------#
            outputs = model_train(imgs)
            #----------------------#
            #   损失计算
            #----------------------#
#             if focal_loss:
#                 loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
#             else:
            loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)

            if dice_loss:
                main_dice = Dice_loss(outputs, labels)
                loss  = loss + main_dice
            #-------------------------------#
            #   计算f_score
            #-------------------------------#
            _f_score    = f_score(outputs, labels)

            val_loss    += loss.item()
            val_f_score += _f_score.item()
            
        if local_rank == 0:
            pbar.set_postfix(**{'val_loss'  : val_loss / (iteration + 1),
                                'f_score'   : val_f_score / (iteration + 1),
                                'lr'        : get_lr(optimizer)})
            pbar.update(1)
            
    if local_rank == 0:
        pbar.close()
        print('Finish Validation')
        loss_history.append_loss(epoch + 1, total_loss/ epoch_step, val_loss/ epoch_step_val)
        eval_callback.on_epoch_end(epoch + 1, model_train)
        print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
        print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
        
        #-----------------------------------------------#
        #   保存权值
        #-----------------------------------------------#
        if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
            torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth'%((epoch + 1), total_loss / epoch_step, val_loss / epoch_step_val)))

#         if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
#             print('Save best model to best_epoch_weights.pth')
#             torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
            
        torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))


In [19]:
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
    lr = lr_scheduler_func(epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [19]:
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
    lr = lr_scheduler_func(epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

for epoch in range(Init_Epoch, UnFreeze_Epoch):
    #---------------------------------------#
    #   如果模型有冻结学习部分
    #   则解冻，并设置参数
    #---------------------------------------#
    if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
        batch_size = Unfreeze_batch_size

        #-------------------------------------------------------------------#
        #   判断当前batch_size，自适应调整学习率
        #-------------------------------------------------------------------#
        lr_limit_max    = 1e-4 if optimizer_type == 'adam' else 1e-1
        lr_limit_min    = 1e-4 if optimizer_type == 'adam' else 5e-4
        Init_lr_fit     = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
        Min_lr_fit      = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
        #---------------------------------------#
        #   获得学习率下降的公式
        #---------------------------------------#
        lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)

        model.unfreeze_backbone()

        epoch_step      = num_train // batch_size
        epoch_step_val  = num_val // batch_size

        if epoch_step == 0:
            raise ValueError("数据集过小，无法继续进行训练，请扩充数据集。")


        gen             = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, pin_memory=True,
                                    drop_last = True, collate_fn = unet_dataset_collate, sampler=train_sampler)
        gen_val         = DataLoader(val_dataset  , shuffle = shuffle, batch_size = batch_size, pin_memory=True, 
                                    drop_last = True, collate_fn = unet_dataset_collate, sampler=val_sampler)

        UnFreeze_flag = True

    set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
    fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, 
                epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank)

if local_rank == 0:
    loss_history.writer.close()
#     fit_one_epoch_no_val(model_train, model, loss_history, optimizer, epoch, epoch_step, gen, UnFreeze_Epoch, Cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank)

Start Train


Epoch 1/50: 100%|██████████| 999/999 [14:51<00:00,  1.12it/s, f_score=0.893, lr=0.0001, total_loss=0.278]


Finish Train
Start Validation


Epoch 1/50: 100%|██████████| 174/174 [02:31<00:00,  1.15it/s, f_score=0.913, lr=0.0001, val_loss=0.214]


Finish Validation
Epoch:1/50
Total Loss: 0.278 || Val Loss: 0.214 
Start Train


Epoch 2/50: 100%|██████████| 999/999 [14:06<00:00,  1.18it/s, f_score=0.922, lr=0.0001, total_loss=0.199]


Finish Train
Start Validation


Epoch 2/50: 100%|██████████| 174/174 [02:14<00:00,  1.29it/s, f_score=0.924, lr=0.0001, val_loss=0.195]


Finish Validation
Epoch:2/50
Total Loss: 0.199 || Val Loss: 0.195 
Start Train


Epoch 3/50: 100%|██████████| 999/999 [13:59<00:00,  1.19it/s, f_score=0.933, lr=0.0001, total_loss=0.171]


Finish Train
Start Validation


Epoch 3/50: 100%|██████████| 174/174 [02:16<00:00,  1.28it/s, f_score=0.93, lr=0.0001, val_loss=0.182] 


Finish Validation
Epoch:3/50
Total Loss: 0.171 || Val Loss: 0.182 
Start Train


Epoch 4/50: 100%|██████████| 999/999 [15:40<00:00,  1.06it/s, f_score=0.937, lr=0.0001, total_loss=0.161]


Finish Train
Start Validation


Epoch 4/50: 100%|██████████| 174/174 [02:33<00:00,  1.13it/s, f_score=0.93, lr=0.0001, val_loss=0.183] 


Finish Validation
Epoch:4/50
Total Loss: 0.161 || Val Loss: 0.183 
Start Train


Epoch 5/50: 100%|██████████| 999/999 [15:00<00:00,  1.11it/s, f_score=0.939, lr=0.0001, total_loss=0.153]


Finish Train
Start Validation


Epoch 5/50: 100%|██████████| 174/174 [02:30<00:00,  1.16it/s, f_score=0.928, lr=0.0001, val_loss=0.19] 


Finish Validation
Get miou.


100%|██████████| 348/348 [04:39<00:00,  1.24it/s]


Calculate miou.
Num classes 2
===> mIoU: 84.88; mPA: 90.08; Accuracy: 94.34
Get miou done.
Epoch:5/50
Total Loss: 0.153 || Val Loss: 0.190 
Start Train


Epoch 6/50: 100%|██████████| 999/999 [17:01<00:00,  1.02s/it, f_score=0.945, lr=5.99e-5, total_loss=0.135]


Finish Train
Start Validation


Epoch 6/50: 100%|██████████| 174/174 [02:26<00:00,  1.19it/s, f_score=0.934, lr=5.99e-5, val_loss=0.176]


Finish Validation
Epoch:6/50
Total Loss: 0.135 || Val Loss: 0.176 
Start Train


Epoch 7/50: 100%|██████████| 999/999 [15:09<00:00,  1.10it/s, f_score=0.95, lr=5.99e-5, total_loss=0.123] 


Finish Train
Start Validation


Epoch 7/50: 100%|██████████| 174/174 [02:36<00:00,  1.11it/s, f_score=0.931, lr=5.99e-5, val_loss=0.178]


Finish Validation
Epoch:7/50
Total Loss: 0.123 || Val Loss: 0.178 
Start Train


Epoch 8/50: 100%|██████████| 999/999 [15:49<00:00,  1.05it/s, f_score=0.952, lr=5.99e-5, total_loss=0.117] 


Finish Train
Start Validation


Epoch 8/50: 100%|██████████| 174/174 [02:39<00:00,  1.09it/s, f_score=0.93, lr=5.99e-5, val_loss=0.179] 


Finish Validation
Epoch:8/50
Total Loss: 0.117 || Val Loss: 0.179 
Start Train


Epoch 9/50: 100%|██████████| 999/999 [15:17<00:00,  1.09it/s, f_score=0.954, lr=5.99e-5, total_loss=0.111]


Finish Train
Start Validation


Epoch 9/50: 100%|██████████| 174/174 [02:32<00:00,  1.14it/s, f_score=0.936, lr=5.99e-5, val_loss=0.173]


Finish Validation
Epoch:9/50
Total Loss: 0.111 || Val Loss: 0.173 
Start Train


Epoch 10/50: 100%|██████████| 999/999 [15:03<00:00,  1.11it/s, f_score=0.956, lr=5.99e-5, total_loss=0.105] 


Finish Train
Start Validation


Epoch 10/50: 100%|██████████| 174/174 [02:30<00:00,  1.16it/s, f_score=0.931, lr=5.99e-5, val_loss=0.193]


Finish Validation
Get miou.


100%|██████████| 348/348 [03:48<00:00,  1.52it/s]


Calculate miou.
Num classes 2
===> mIoU: 85.47; mPA: 91.52; Accuracy: 94.44
Get miou done.
Epoch:10/50
Total Loss: 0.105 || Val Loss: 0.193 
Start Train


Epoch 11/50: 100%|██████████| 999/999 [15:27<00:00,  1.08it/s, f_score=0.961, lr=3.59e-5, total_loss=0.0933]


Finish Train
Start Validation


Epoch 11/50: 100%|██████████| 174/174 [02:13<00:00,  1.30it/s, f_score=0.933, lr=3.59e-5, val_loss=0.194]


Finish Validation
Epoch:11/50
Total Loss: 0.093 || Val Loss: 0.194 
Start Train


Epoch 12/50: 100%|██████████| 999/999 [15:15<00:00,  1.09it/s, f_score=0.963, lr=3.59e-5, total_loss=0.0905]


Finish Train
Start Validation


Epoch 12/50: 100%|██████████| 174/174 [02:30<00:00,  1.16it/s, f_score=0.934, lr=3.59e-5, val_loss=0.191]


Finish Validation
Epoch:12/50
Total Loss: 0.090 || Val Loss: 0.191 
Start Train


Epoch 13/50: 100%|██████████| 999/999 [14:54<00:00,  1.12it/s, f_score=0.964, lr=3.59e-5, total_loss=0.0852]


Finish Train
Start Validation


Epoch 13/50: 100%|██████████| 174/174 [02:36<00:00,  1.11it/s, f_score=0.933, lr=3.59e-5, val_loss=0.2]  


Finish Validation
Epoch:13/50
Total Loss: 0.085 || Val Loss: 0.200 
Start Train


Epoch 14/50: 100%|██████████| 999/999 [15:48<00:00,  1.05it/s, f_score=0.966, lr=3.59e-5, total_loss=0.0791]


Finish Train
Start Validation


Epoch 14/50: 100%|██████████| 174/174 [02:44<00:00,  1.06it/s, f_score=0.933, lr=3.59e-5, val_loss=0.204]


Finish Validation
Epoch:14/50
Total Loss: 0.079 || Val Loss: 0.204 
Start Train


Epoch 15/50: 100%|██████████| 999/999 [15:10<00:00,  1.10it/s, f_score=0.968, lr=3.59e-5, total_loss=0.0762]


Finish Train
Start Validation


Epoch 15/50: 100%|██████████| 174/174 [02:30<00:00,  1.16it/s, f_score=0.935, lr=3.59e-5, val_loss=0.2]  


Finish Validation
Get miou.


100%|██████████| 348/348 [03:59<00:00,  1.45it/s]


Calculate miou.
Num classes 2
===> mIoU: 86.48; mPA: 93.06; Accuracy: 94.76
Get miou done.
Epoch:15/50
Total Loss: 0.076 || Val Loss: 0.200 
Start Train


Epoch 16/50: 100%|██████████| 999/999 [15:27<00:00,  1.08it/s, f_score=0.969, lr=2.15e-5, total_loss=0.0729]


Finish Train
Start Validation


Epoch 16/50: 100%|██████████| 174/174 [02:21<00:00,  1.23it/s, f_score=0.932, lr=2.15e-5, val_loss=0.205]


Finish Validation
Epoch:16/50
Total Loss: 0.073 || Val Loss: 0.205 
Start Train


Epoch 17/50: 100%|██████████| 999/999 [15:11<00:00,  1.10it/s, f_score=0.971, lr=2.15e-5, total_loss=0.0687]


Finish Train
Start Validation


Epoch 17/50: 100%|██████████| 174/174 [02:35<00:00,  1.12it/s, f_score=0.934, lr=2.15e-5, val_loss=0.216]


Finish Validation
Epoch:17/50
Total Loss: 0.069 || Val Loss: 0.216 
Start Train


Epoch 18/50: 100%|██████████| 999/999 [15:59<00:00,  1.04it/s, f_score=0.972, lr=2.15e-5, total_loss=0.065] 


Finish Train
Start Validation


Epoch 18/50: 100%|██████████| 174/174 [02:51<00:00,  1.01it/s, f_score=0.934, lr=2.15e-5, val_loss=0.21] 


Finish Validation
Epoch:18/50
Total Loss: 0.065 || Val Loss: 0.210 
Start Train


Epoch 19/50: 100%|██████████| 999/999 [17:08<00:00,  1.03s/it, f_score=0.973, lr=2.15e-5, total_loss=0.0642]


Finish Train
Start Validation


Epoch 19/50: 100%|██████████| 174/174 [02:50<00:00,  1.02it/s, f_score=0.932, lr=2.15e-5, val_loss=0.216]


Finish Validation
Epoch:19/50
Total Loss: 0.064 || Val Loss: 0.216 
Start Train


Epoch 20/50: 100%|██████████| 999/999 [15:49<00:00,  1.05it/s, f_score=0.973, lr=2.15e-5, total_loss=0.0635]


Finish Train
Start Validation


Epoch 20/50: 100%|██████████| 174/174 [02:44<00:00,  1.06it/s, f_score=0.927, lr=2.15e-5, val_loss=0.222]


Finish Validation
Get miou.


100%|██████████| 348/348 [04:13<00:00,  1.37it/s]


Calculate miou.
Num classes 2
===> mIoU: 85.4; mPA: 91.34; Accuracy: 94.43
Get miou done.
Epoch:20/50
Total Loss: 0.064 || Val Loss: 0.222 
Start Train


Epoch 21/50: 100%|██████████| 999/999 [15:36<00:00,  1.07it/s, f_score=0.976, lr=1.29e-5, total_loss=0.0581]


Finish Train
Start Validation


Epoch 21/50: 100%|██████████| 174/174 [02:16<00:00,  1.27it/s, f_score=0.934, lr=1.29e-5, val_loss=0.228]


Finish Validation
Epoch:21/50
Total Loss: 0.058 || Val Loss: 0.228 
Start Train


Epoch 22/50: 100%|██████████| 999/999 [14:50<00:00,  1.12it/s, f_score=0.976, lr=1.29e-5, total_loss=0.0569]


Finish Train
Start Validation


Epoch 22/50: 100%|██████████| 174/174 [02:30<00:00,  1.16it/s, f_score=0.932, lr=1.29e-5, val_loss=0.232]


Finish Validation
Epoch:22/50
Total Loss: 0.057 || Val Loss: 0.232 
Start Train


Epoch 23/50: 100%|██████████| 999/999 [15:51<00:00,  1.05it/s, f_score=0.977, lr=1.29e-5, total_loss=0.0546]


Finish Train
Start Validation


Epoch 23/50: 100%|██████████| 174/174 [02:33<00:00,  1.13it/s, f_score=0.933, lr=1.29e-5, val_loss=0.235]


Finish Validation
Epoch:23/50
Total Loss: 0.055 || Val Loss: 0.235 
Start Train


Epoch 24/50: 100%|██████████| 999/999 [17:11<00:00,  1.03s/it, f_score=0.976, lr=1.29e-5, total_loss=0.0565]


Finish Train
Start Validation


Epoch 24/50: 100%|██████████| 174/174 [02:45<00:00,  1.05it/s, f_score=0.931, lr=1.29e-5, val_loss=0.228]


Finish Validation
Epoch:24/50
Total Loss: 0.057 || Val Loss: 0.228 
Start Train


Epoch 25/50: 100%|██████████| 999/999 [15:41<00:00,  1.06it/s, f_score=0.977, lr=1.29e-5, total_loss=0.0538]


Finish Train
Start Validation


Epoch 25/50: 100%|██████████| 174/174 [02:31<00:00,  1.15it/s, f_score=0.933, lr=1.29e-5, val_loss=0.224]


Finish Validation
Get miou.


100%|██████████| 348/348 [04:05<00:00,  1.42it/s]


Calculate miou.
Num classes 2
===> mIoU: 85.79; mPA: 91.28; Accuracy: 94.63
Get miou done.
Epoch:25/50
Total Loss: 0.054 || Val Loss: 0.224 
Start Train


Epoch 26/50: 100%|██████████| 999/999 [15:10<00:00,  1.10it/s, f_score=0.978, lr=7.74e-6, total_loss=0.051] 


Finish Train
Start Validation


Epoch 26/50: 100%|██████████| 174/174 [02:23<00:00,  1.21it/s, f_score=0.93, lr=7.74e-6, val_loss=0.242] 


Finish Validation
Epoch:26/50
Total Loss: 0.051 || Val Loss: 0.242 
Start Train


Epoch 27/50: 100%|██████████| 999/999 [15:01<00:00,  1.11it/s, f_score=0.979, lr=7.74e-6, total_loss=0.0504]


Finish Train
Start Validation


Epoch 27/50: 100%|██████████| 174/174 [02:31<00:00,  1.15it/s, f_score=0.933, lr=7.74e-6, val_loss=0.237]


Finish Validation
Epoch:27/50
Total Loss: 0.050 || Val Loss: 0.237 
Start Train


Epoch 28/50: 100%|██████████| 999/999 [16:14<00:00,  1.02it/s, f_score=0.979, lr=7.74e-6, total_loss=0.0494]


Finish Train
Start Validation


Epoch 28/50: 100%|██████████| 174/174 [02:37<00:00,  1.11it/s, f_score=0.93, lr=7.74e-6, val_loss=0.243] 


Finish Validation
Epoch:28/50
Total Loss: 0.049 || Val Loss: 0.243 
Start Train


Epoch 29/50: 100%|██████████| 999/999 [15:36<00:00,  1.07it/s, f_score=0.979, lr=7.74e-6, total_loss=0.0491]


Finish Train
Start Validation


Epoch 29/50: 100%|██████████| 174/174 [02:30<00:00,  1.16it/s, f_score=0.93, lr=7.74e-6, val_loss=0.242] 


Finish Validation
Epoch:29/50
Total Loss: 0.049 || Val Loss: 0.242 
Start Train


Epoch 30/50: 100%|██████████| 999/999 [17:27<00:00,  1.05s/it, f_score=0.979, lr=7.74e-6, total_loss=0.0484]


Finish Train
Start Validation


Epoch 30/50: 100%|██████████| 174/174 [02:33<00:00,  1.13it/s, f_score=0.932, lr=7.74e-6, val_loss=0.253]


Finish Validation
Get miou.


100%|██████████| 348/348 [03:56<00:00,  1.47it/s]


Calculate miou.
Num classes 2
===> mIoU: 85.22; mPA: 90.62; Accuracy: 94.43
Get miou done.
Epoch:30/50
Total Loss: 0.048 || Val Loss: 0.253 
Start Train


Epoch 31/50: 100%|██████████| 999/999 [15:00<00:00,  1.11it/s, f_score=0.98, lr=4.64e-6, total_loss=0.0474] 


Finish Train
Start Validation


Epoch 31/50: 100%|██████████| 174/174 [02:22<00:00,  1.22it/s, f_score=0.932, lr=4.64e-6, val_loss=0.245]


Finish Validation
Epoch:31/50
Total Loss: 0.047 || Val Loss: 0.245 
Start Train


Epoch 32/50: 100%|██████████| 999/999 [15:10<00:00,  1.10it/s, f_score=0.98, lr=4.64e-6, total_loss=0.0466] 


Finish Train
Start Validation


Epoch 32/50: 100%|██████████| 174/174 [02:35<00:00,  1.12it/s, f_score=0.931, lr=4.64e-6, val_loss=0.258]


Finish Validation
Epoch:32/50
Total Loss: 0.047 || Val Loss: 0.258 
Start Train


Epoch 33/50: 100%|██████████| 999/999 [14:59<00:00,  1.11it/s, f_score=0.981, lr=4.64e-6, total_loss=0.0459]


Finish Train
Start Validation


Epoch 33/50: 100%|██████████| 174/174 [02:35<00:00,  1.12it/s, f_score=0.931, lr=4.64e-6, val_loss=0.253]


Finish Validation
Epoch:33/50
Total Loss: 0.046 || Val Loss: 0.253 
Start Train


Epoch 34/50: 100%|██████████| 999/999 [15:14<00:00,  1.09it/s, f_score=0.98, lr=4.64e-6, total_loss=0.0464] 


Finish Train
Start Validation


Epoch 34/50: 100%|██████████| 174/174 [02:36<00:00,  1.11it/s, f_score=0.931, lr=4.64e-6, val_loss=0.256]


Finish Validation
Epoch:34/50
Total Loss: 0.046 || Val Loss: 0.256 
Start Train


Epoch 35/50: 100%|██████████| 999/999 [14:48<00:00,  1.12it/s, f_score=0.98, lr=4.64e-6, total_loss=0.0471] 


Finish Train
Start Validation


Epoch 35/50: 100%|██████████| 174/174 [02:43<00:00,  1.06it/s, f_score=0.933, lr=4.64e-6, val_loss=0.262]


Finish Validation
Get miou.


100%|██████████| 348/348 [04:09<00:00,  1.39it/s]


Calculate miou.
Num classes 2
===> mIoU: 85.44; mPA: 91.11; Accuracy: 94.48
Get miou done.
Epoch:35/50
Total Loss: 0.047 || Val Loss: 0.262 
Start Train


Epoch 36/50: 100%|██████████| 999/999 [15:26<00:00,  1.08it/s, f_score=0.981, lr=2.78e-6, total_loss=0.0453]


Finish Train
Start Validation


Epoch 36/50: 100%|██████████| 174/174 [02:12<00:00,  1.31it/s, f_score=0.931, lr=2.78e-6, val_loss=0.263]


Finish Validation
Epoch:36/50
Total Loss: 0.045 || Val Loss: 0.263 
Start Train


Epoch 37/50: 100%|██████████| 999/999 [15:30<00:00,  1.07it/s, f_score=0.981, lr=2.78e-6, total_loss=0.045] 


Finish Train
Start Validation


Epoch 37/50: 100%|██████████| 174/174 [02:55<00:00,  1.01s/it, f_score=0.931, lr=2.78e-6, val_loss=0.258]


Finish Validation
Epoch:37/50
Total Loss: 0.045 || Val Loss: 0.258 
Start Train


Epoch 38/50: 100%|██████████| 999/999 [15:05<00:00,  1.10it/s, f_score=0.981, lr=2.78e-6, total_loss=0.0452]


Finish Train
Start Validation


Epoch 38/50: 100%|██████████| 174/174 [02:36<00:00,  1.11it/s, f_score=0.931, lr=2.78e-6, val_loss=0.252]


Finish Validation
Epoch:38/50
Total Loss: 0.045 || Val Loss: 0.252 
Start Train


Epoch 39/50: 100%|██████████| 999/999 [15:36<00:00,  1.07it/s, f_score=0.981, lr=2.78e-6, total_loss=0.0445]


Finish Train
Start Validation


Epoch 39/50: 100%|██████████| 174/174 [02:28<00:00,  1.17it/s, f_score=0.933, lr=2.78e-6, val_loss=0.258]


Finish Validation
Epoch:39/50
Total Loss: 0.044 || Val Loss: 0.258 
Start Train


Epoch 40/50: 100%|██████████| 999/999 [17:22<00:00,  1.04s/it, f_score=0.981, lr=2.78e-6, total_loss=0.0444]


Finish Train
Start Validation


Epoch 40/50: 100%|██████████| 174/174 [02:31<00:00,  1.15it/s, f_score=0.932, lr=2.78e-6, val_loss=0.259]


Finish Validation
Get miou.


100%|██████████| 348/348 [03:53<00:00,  1.49it/s]


Calculate miou.
Num classes 2
===> mIoU: 85.25; mPA: 90.73; Accuracy: 94.44
Get miou done.
Epoch:40/50
Total Loss: 0.044 || Val Loss: 0.259 
Start Train


Epoch 41/50: 100%|██████████| 999/999 [16:14<00:00,  1.02it/s, f_score=0.981, lr=1.67e-6, total_loss=0.0443]


Finish Train
Start Validation


Epoch 41/50: 100%|██████████| 174/174 [02:16<00:00,  1.27it/s, f_score=0.931, lr=1.67e-6, val_loss=0.259]


Finish Validation
Epoch:41/50
Total Loss: 0.044 || Val Loss: 0.259 
Start Train


Epoch 42/50: 100%|██████████| 999/999 [15:23<00:00,  1.08it/s, f_score=0.982, lr=1.67e-6, total_loss=0.0429]


Finish Train
Start Validation


Epoch 42/50: 100%|██████████| 174/174 [02:33<00:00,  1.13it/s, f_score=0.932, lr=1.67e-6, val_loss=0.25] 


Finish Validation
Epoch:42/50
Total Loss: 0.043 || Val Loss: 0.250 
Start Train


Epoch 43/50: 100%|██████████| 999/999 [14:53<00:00,  1.12it/s, f_score=0.981, lr=1.67e-6, total_loss=0.0438]


Finish Train
Start Validation


Epoch 43/50: 100%|██████████| 174/174 [02:32<00:00,  1.14it/s, f_score=0.932, lr=1.67e-6, val_loss=0.254]


Finish Validation
Epoch:43/50
Total Loss: 0.044 || Val Loss: 0.254 
Start Train


Epoch 44/50: 100%|██████████| 999/999 [16:04<00:00,  1.04it/s, f_score=0.981, lr=1.67e-6, total_loss=0.0437]


Finish Train
Start Validation


Epoch 44/50: 100%|██████████| 174/174 [02:56<00:00,  1.01s/it, f_score=0.932, lr=1.67e-6, val_loss=0.262]


Finish Validation
Epoch:44/50
Total Loss: 0.044 || Val Loss: 0.262 
Start Train


Epoch 45/50: 100%|██████████| 999/999 [15:14<00:00,  1.09it/s, f_score=0.982, lr=1.67e-6, total_loss=0.0435]


Finish Train
Start Validation


Epoch 45/50: 100%|██████████| 174/174 [02:38<00:00,  1.10it/s, f_score=0.935, lr=1.67e-6, val_loss=0.25] 


Finish Validation
Get miou.


100%|██████████| 348/348 [04:33<00:00,  1.27it/s]


Calculate miou.
Num classes 2
===> mIoU: 85.3; mPA: 90.79; Accuracy: 94.45
Get miou done.
Epoch:45/50
Total Loss: 0.043 || Val Loss: 0.250 
Start Train


Epoch 46/50: 100%|██████████| 999/999 [16:37<00:00,  1.00it/s, f_score=0.982, lr=1e-6, total_loss=0.0432]


Finish Train
Start Validation


Epoch 46/50: 100%|██████████| 174/174 [02:43<00:00,  1.06it/s, f_score=0.929, lr=1e-6, val_loss=0.258]


Finish Validation
Epoch:46/50
Total Loss: 0.043 || Val Loss: 0.258 
Start Train


Epoch 47/50: 100%|██████████| 999/999 [15:43<00:00,  1.06it/s, f_score=0.982, lr=1e-6, total_loss=0.043] 


Finish Train
Start Validation


Epoch 47/50: 100%|██████████| 174/174 [02:39<00:00,  1.09it/s, f_score=0.928, lr=1e-6, val_loss=0.254]


Finish Validation
Epoch:47/50
Total Loss: 0.043 || Val Loss: 0.254 
Start Train


Epoch 48/50: 100%|██████████| 999/999 [18:21<00:00,  1.10s/it, f_score=0.982, lr=1e-6, total_loss=0.0426]


Finish Train
Start Validation


Epoch 48/50: 100%|██████████| 174/174 [03:00<00:00,  1.04s/it, f_score=0.933, lr=1e-6, val_loss=0.259]


Finish Validation
Epoch:48/50
Total Loss: 0.043 || Val Loss: 0.259 
Start Train


Epoch 49/50: 100%|██████████| 999/999 [16:34<00:00,  1.00it/s, f_score=0.982, lr=1e-6, total_loss=0.043] 


Finish Train
Start Validation


Epoch 49/50: 100%|██████████| 174/174 [02:32<00:00,  1.14it/s, f_score=0.933, lr=1e-6, val_loss=0.261]


Finish Validation
Epoch:49/50
Total Loss: 0.043 || Val Loss: 0.261 
Start Train


Epoch 50/50: 100%|██████████| 999/999 [15:26<00:00,  1.08it/s, f_score=0.982, lr=1e-6, total_loss=0.0424]


Finish Train
Start Validation


Epoch 50/50: 100%|██████████| 174/174 [02:35<00:00,  1.12it/s, f_score=0.931, lr=1e-6, val_loss=0.264]


Finish Validation
Get miou.


100%|██████████| 348/348 [04:22<00:00,  1.33it/s]


Calculate miou.
Num classes 2
===> mIoU: 85.16; mPA: 90.77; Accuracy: 94.38
Get miou done.
Epoch:50/50
Total Loss: 0.042 || Val Loss: 0.264 


In [20]:
test_lines = total_lines[1998+348:]
# test_dataset = UnetDataset(test_lines, input_shape, num_classes, False, VOCdevkit_path)
# test_sampler = None
# gen_test = DataLoader(test_dataset  , shuffle = shuffle, batch_size = batch_size, pin_memory=True, 
#                                     drop_last = True, collate_fn = unet_dataset_collate, sampler=test_sampler)
class GetTestResult():
    def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \
            miou_out_path=".temp_miou_out", eval_flag=True, period=1):
        super(GetTestResult, self).__init__()
        
        self.net                = net
        self.input_shape        = input_shape
        self.num_classes        = num_classes
        self.image_ids          = image_ids
        self.dataset_path       = dataset_path
        self.log_dir            = log_dir
        self.cuda               = cuda
        self.miou_out_path      = miou_out_path
        self.eval_flag          = eval_flag
        self.period             = period
        
        self.image_ids          = [image_id.split()[0][:-4] for image_id in image_ids]
        self.mious      = [0]
        self.epoches    = [0]
#         if self.eval_flag:
#             with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
#                 f.write(str(0))
#                 f.write("\n")

    def get_miou_png(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        orininal_h  = np.array(image).shape[0]
        orininal_w  = np.array(image).shape[1]
        #---------------------------------------------------------#
        #   给图像增加灰条，实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data = np.array(image_data)
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            pr = self.net(images)[0]  #num_class, 512, 512
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
            #--------------------------------------#
            #   将灰条部分截取掉
            #--------------------------------------#
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
            #---------------------------------------------------#
            #   进行图片的resize
            #---------------------------------------------------#
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = pr.argmax(axis=-1)
    
        image = Image.fromarray(np.uint8(pr))
#         image = Image.fromarray(np.uint8(pr*255))
        return image
    
    def get_result(self, model_eval):
        self.net    = model_eval
#             jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "Images"), name + ".jpg"))
#             png         = Image.open(os.path.join(os.path.join(self.dataset_path, "Labels"), name + "_segmentation.png"))
#             gt_dir      = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/")
        gt_dir      = os.path.join(self.dataset_path, "Labels")
        pred_dir    = '/home/ubuntu/user_space/detection-results'
        if not os.path.exists(self.miou_out_path):
            os.makedirs(self.miou_out_path)
        if not os.path.exists(pred_dir):
            os.makedirs(pred_dir)
        print("Get miou.")
        for image_id in tqdm(self.image_ids):
            #-------------------------------#
            #   从文件中读取图像
            #-------------------------------#
            image_path  = os.path.join(self.dataset_path, "Images/"+image_id+".jpg")
            image       = Image.open(image_path)
            #------------------------------#
            #   获得预测txt
            #------------------------------#
            image       = self.get_miou_png(image)
            image.save(os.path.join(pred_dir, image_id + ".png"))

        print("Calculate miou.")
        _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None)  # 执行计算mIoU的函数
        temp_miou = np.nanmean(IoUs) * 100
        print("Get predict image.")
        for image_id in tqdm(self.image_ids):
            image_path = os.path.join(pred_dir, image_id + ".png")
            image = Image.open(image_path)
            result = Image.fromarray(np.uint8((1-np.array(image))*255))
            result.save(image_path)
    
GetResult = GetTestResult(model, input_shape, num_classes, test_lines, VOCdevkit_path, log_dir, Cuda, \
                                eval_flag=eval_flag, period=eval_period)
GetResult.get_result(model_train)

Get miou.


100%|██████████| 348/348 [04:08<00:00,  1.40it/s]


Calculate miou.
Num classes 2
===> mIoU: 86.62; mPA: 91.48; Accuracy: 94.85
Get predict image.


100%|██████████| 348/348 [01:08<00:00,  5.10it/s]


In [20]:
model.load_state_dict(torch.load('d://swin-tran-total-line3/ep050-loss0.042-val_loss0.264.pth',map_location=torch.device('cpu')))

<All keys matched successfully>

In [21]:
# test_lines = total_lines[1998+348:]
# test_dataset = UnetDataset(test_lines, input_shape, num_classes, False, VOCdevkit_path)
# test_sampler = None
# gen_test = DataLoader(test_dataset  , shuffle = shuffle, batch_size = batch_size, pin_memory=True, 
#                                     drop_last = True, collate_fn = unet_dataset_collate, sampler=test_sampler)
data_path = 'D://360Downloads//ISIC2018_Task1-2_Test_Input'
test_lines = os.listdir(data_path)[1:-1]
class GetTestResult():
    def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \
            miou_out_path=".temp_miou_out", eval_flag=True, period=1):
        super(GetTestResult, self).__init__()
        
        self.net                = net
        self.input_shape        = input_shape
        self.num_classes        = num_classes
        self.image_ids          = image_ids
        self.dataset_path       = dataset_path
        self.log_dir            = log_dir
        self.cuda               = cuda
        self.miou_out_path      = miou_out_path
        self.eval_flag          = eval_flag
        self.period             = period
        
        self.image_ids          = [image_id.split()[0][:-4] for image_id in image_ids]
        self.mious      = [0]
        self.epoches    = [0]
#         if self.eval_flag:
#             with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
#                 f.write(str(0))
#                 f.write("\n")

    def get_miou_png(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        orininal_h  = np.array(image).shape[0]
        orininal_w  = np.array(image).shape[1]
        #---------------------------------------------------------#
        #   给图像增加灰条，实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data = np.array(image_data)
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            pr = self.net(images)[0]  #num_class, 512, 512
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
            #--------------------------------------#
            #   将灰条部分截取掉
            #--------------------------------------#
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
            #---------------------------------------------------#
            #   进行图片的resize
            #---------------------------------------------------#
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = pr.argmax(axis=-1)
    
        image = Image.fromarray(np.uint8(pr))
#         image = Image.fromarray(np.uint8(pr*255))
        return image
    
    def get_result(self, model_eval):
        self.net    = model_eval
#             jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "Images"), name + ".jpg"))
#             png         = Image.open(os.path.join(os.path.join(self.dataset_path, "Labels"), name + "_segmentation.png"))
#             gt_dir      = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/")
        gt_dir      = 'e://毕业论文/ISIC2018_Task1_Test_GroundTruth'
#         pred_dir    = '/home/ubuntu/user_space/detection-results'
        pred_dir = 'd://swin-tran-total-line3/test-results50'
        if not os.path.exists(self.miou_out_path):
            os.makedirs(self.miou_out_path)
        if not os.path.exists(pred_dir):
            os.makedirs(pred_dir)
        print("Get miou.")
        for image_id in tqdm(self.image_ids):
            #-------------------------------#
            #   从文件中读取图像
            #-------------------------------#
            image_path  = os.path.join(self.dataset_path, image_id+'.jpg')
            image       = Image.open(image_path)
            #------------------------------#
            #   获得预测txt
            #------------------------------#
            image       = self.get_miou_png(image)
            image.save(os.path.join(pred_dir, image_id + ".png"))

        print("Calculate miou.")
        _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None)  # 执行计算mIoU的函数
        temp_miou = np.nanmean(IoUs) * 100
        print("Get predict image.")
        for image_id in tqdm(self.image_ids):
            image_path = os.path.join(pred_dir, image_id + ".png")
            image = Image.open(image_path)
            result = Image.fromarray(np.uint8((1-np.array(image))*255))
            result.save(image_path)
    
GetResult = GetTestResult(model, input_shape, num_classes, test_lines, data_path, log_dir, Cuda, \
                                eval_flag=eval_flag, period=eval_period)
GetResult.get_result(model_train)

  0%|                                                                                         | 0/1000 [00:00<?, ?it/s]

Get miou.


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [14:08<00:00,  1.18it/s]


Calculate miou.
Num classes 2


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return np.array(hist, np.int), IoUs, PA_Recall, Precision
  0%|                                                                                         | 0/1000 [00:00<?, ?it/s]

===> mIoU: 83.06; mPA: 90.24; Accuracy: 92.24
Get predict image.


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:42<00:00,  9.74it/s]


In [21]:
test_lines = total_lines[1998+348:]
# test_dataset = UnetDataset(test_lines, input_shape, num_classes, False, VOCdevkit_path)
# test_sampler = None
# gen_test = DataLoader(test_dataset  , shuffle = shuffle, batch_size = batch_size, pin_memory=True, 
#                                     drop_last = True, collate_fn = unet_dataset_collate, sampler=test_sampler)
class GetTestResult():
    def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \
            miou_out_path=".temp_miou_out", eval_flag=True, period=1):
        super(GetTestResult, self).__init__()
        
        self.net                = net
        self.input_shape        = input_shape
        self.num_classes        = num_classes
        self.image_ids          = image_ids
        self.dataset_path       = dataset_path
        self.log_dir            = log_dir
        self.cuda               = cuda
        self.miou_out_path      = miou_out_path
        self.eval_flag          = eval_flag
        self.period             = period
        
        self.image_ids          = [image_id.split()[0][:-4] for image_id in image_ids]
        self.mious      = [0]
        self.epoches    = [0]
#         if self.eval_flag:
#             with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
#                 f.write(str(0))
#                 f.write("\n")

    def get_miou_png(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像，防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        orininal_h  = np.array(image).shape[0]
        orininal_w  = np.array(image).shape[1]
        #---------------------------------------------------------#
        #   给图像增加灰条，实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data = np.array(image_data)
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            pr = self.net(images)[0]  #num_class, 512, 512
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
            #--------------------------------------#
            #   将灰条部分截取掉
            #--------------------------------------#
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
            #---------------------------------------------------#
            #   进行图片的resize
            #---------------------------------------------------#
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
            #---------------------------------------------------#
            #   取出每一个像素点的种类
            #---------------------------------------------------#
            pr = pr.argmax(axis=-1)
    
        image = Image.fromarray(np.uint8(pr))
#         image = Image.fromarray(np.uint8(pr*255))
        return image
    
    def get_result(self, model_eval):
        self.net    = model_eval
#             jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "Images"), name + ".jpg"))
#             png         = Image.open(os.path.join(os.path.join(self.dataset_path, "Labels"), name + "_segmentation.png"))
#             gt_dir      = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/")
        gt_dir      = os.path.join(self.dataset_path, "Labels")
#         pred_dir    = '/home/ubuntu/user_space/detection-results'
        pred_dir = 'd://swin-tran-total-line3/detetion-results25'
        if not os.path.exists(self.miou_out_path):
            os.makedirs(self.miou_out_path)
        if not os.path.exists(pred_dir):
            os.makedirs(pred_dir)
        print("Get miou.")
        for image_id in tqdm(self.image_ids):
            #-------------------------------#
            #   从文件中读取图像
            #-------------------------------#
            image_path  = os.path.join(self.dataset_path, "Images/"+image_id+".jpg")
            image       = Image.open(image_path)
            #------------------------------#
            #   获得预测txt
            #------------------------------#
            image       = self.get_miou_png(image)
            image.save(os.path.join(pred_dir, image_id + ".png"))

        print("Calculate miou.")
        _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None)  # 执行计算mIoU的函数
        temp_miou = np.nanmean(IoUs) * 100
        print("Get predict image.")
        for image_id in tqdm(self.image_ids):
            image_path = os.path.join(pred_dir, image_id + ".png")
            image = Image.open(image_path)
            result = Image.fromarray(np.uint8((1-np.array(image))*255))
            result.save(image_path)
    
GetResult = GetTestResult(model, input_shape, num_classes, test_lines, VOCdevkit_path, log_dir, Cuda, \
                                eval_flag=eval_flag, period=eval_period)
GetResult.get_result(model_train)

  0%|                                                                                          | 0/348 [00:00<?, ?it/s]

Get miou.


100%|████████████████████████████████████████████████████████████████████████████████| 348/348 [06:16<00:00,  1.08s/it]


Calculate miou.
Num classes 2


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return np.array(hist, np.int), IoUs, PA_Recall, Precision
  0%|                                                                                          | 0/348 [00:00<?, ?it/s]

===> mIoU: 87.16; mPA: 91.91; Accuracy: 95.06
Get predict image.


100%|████████████████████████████████████████████████████████████████████████████████| 348/348 [00:53<00:00,  6.46it/s]


In [37]:
name          = total_lines[1998+360].split()[0][:-4]

In [38]:
input_shape   = [224, 224]
image         = Image.open(os.path.join(os.path.join(train_dataset.dataset_path, "Images"), name + ".jpg"))
label         = Image.open(os.path.join(os.path.join(train_dataset.dataset_path, "Labels"), name + "_segmentation.png"))
# jpg, png    = train_dataset.get_random_data(jpg, png, input_shape, random = False)
# jpg         = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
# png         = np.array(png)
image         = eval_callback.get_miou_png(image)
result        = Image.fromarray(np.uint8((1-np.array(image))*255))

Get miou.


100%|██████████| 348/348 [06:56<00:00,  1.20s/it]


Calculate miou.
Num classes 2
===> mIoU: 82.76; mPA: 91.35; Accuracy: 92.0
Get predict image.


100%|██████████| 348/348 [01:59<00:00,  2.90it/s]
