#transform

In [None]:
import math
import numbers
import random
import warnings
from collections.abc import Sequence
from typing import Tuple, List, Optional

import torch
from PIL import Image
from torch import Tensor

try:
    import accimage
except ImportError:
    accimage = None

from torchvision.transforms import functional as F

__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
           "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
           "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
           "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
           "RandomPerspective", "RandomErasing", "GaussianBlur"]

_pil_interpolation_to_str = {
    Image.NEAREST: 'PIL.Image.NEAREST',
    Image.BILINEAR: 'PIL.Image.BILINEAR',
    Image.BICUBIC: 'PIL.Image.BICUBIC',
    Image.LANCZOS: 'PIL.Image.LANCZOS',
    Image.HAMMING: 'PIL.Image.HAMMING',
    Image.BOX: 'PIL.Image.BOX',
}


class Compose:

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img1, label):
        for t in self.transforms:
            img1,  label = t(img1,  label)
        return img1,  label

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class ToTensor:

    def __call__(self, pic):
       
        return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'


class PILToTensor:

    def __call__(self, pic):
        return F.pil_to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'


class ConvertImageDtype(torch.nn.Module):  
    def __init__(self, dtype: torch.dtype) -> None:
        super().__init__()
        self.dtype = dtype

    def forward(self, image):
        return F.convert_image_dtype(image, self.dtype)


class ToPILImage:
    def __init__(self, mode=None):
        self.mode = mode

    def __call__(self, pic):
        return F.to_pil_image(pic, self.mode)

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        if self.mode is not None:
            format_string += 'mode={0}'.format(self.mode)
        format_string += ')'
        return format_string


class Normalize(torch.nn.Module):
    def __init__(self, mean, std, inplace=False):
        super().__init__()
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def forward(self, tensor: Tensor) -> Tensor:
        return F.normalize(tensor, self.mean, self.std, self.inplace)

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class Resize(torch.nn.Module):

    def __init__(self, size, interpolation=Image.BILINEAR):
        super().__init__()
        if not isinstance(size, (int, Sequence)):
            raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
        if isinstance(size, Sequence) and len(size) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
        self.size = size
        self.interpolation = interpolation

    def forward(self, img):
        return F.resize(img, self.size, self.interpolation)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)


class Scale(Resize):
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                      "please use transforms.Resize instead.")
        super(Scale, self).__init__(*args, **kwargs)


class CenterCrop(torch.nn.Module):

    def __init__(self, size):
        super().__init__()
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

    def forward(self, img):
        return F.center_crop(img, self.size)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)


class Pad(torch.nn.Module):

    def __init__(self, padding, fill=0, padding_mode="constant"):
        super().__init__()
        if not isinstance(padding, (numbers.Number, tuple, list)):
            raise TypeError("Got inappropriate padding arg")

        if not isinstance(fill, (numbers.Number, str, tuple)):
            raise TypeError("Got inappropriate fill arg")

        if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
            raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

        if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
            raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

        self.padding = padding
        self.fill = fill
        self.padding_mode = padding_mode

    def forward(self, img):
        return F.pad(img, self.padding, self.fill, self.padding_mode)

    def __repr__(self):
        return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
            format(self.padding, self.fill, self.padding_mode)


class Lambda:
    def __init__(self, lambd):
        if not callable(lambd):
            raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__)))
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

    def __repr__(self):
        return self.__class__.__name__ + '()'


class RandomTransforms:

    def __init__(self, transforms):
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence")
        self.transforms = transforms

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class RandomApply(torch.nn.Module):

    def __init__(self, transforms, p=0.5):
        super().__init__()
        self.transforms = transforms
        self.p = p

    def forward(self, img):
        if self.p < torch.rand(1):
            return img
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += '\n    p={}'.format(self.p)
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class RandomOrder(RandomTransforms):
    def __call__(self, img):
        order = list(range(len(self.transforms)))
        random.shuffle(order)
        for i in order:
            img = self.transforms[i](img)
        return img


class RandomChoice(RandomTransforms):

    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)


class RandomCrop(torch.nn.Module):
    @staticmethod
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
        w, h = F._get_image_size(img)
        th, tw = output_size

        if h + 1 < th or w + 1 < tw:
            raise ValueError(
                "Required crop size {} is larger then input image size {}".format((th, tw), (h, w))
            )

        if w == tw and h == th:
            return 0, 0, h, w

        i = torch.randint(0, h - th + 1, size=(1, )).item()
        j = torch.randint(0, w - tw + 1, size=(1, )).item()
        return i, j, th, tw

    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
        super().__init__()

        self.size = tuple(_setup_size(
            size, error_msg="Please provide only two dimensions (h, w) for size."
        ))

        self.padding = padding
        self.pad_if_needed = pad_if_needed
        self.fill = fill
        self.padding_mode = padding_mode

    def forward(self, img):
        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)

        width, height = F._get_image_size(img)
        # pad the width if needed
        if self.pad_if_needed and width < self.size[1]:
            padding = [self.size[1] - width, 0]
            img = F.pad(img, padding, self.fill, self.padding_mode)
        # pad the height if needed
        if self.pad_if_needed and height < self.size[0]:
            padding = [0, self.size[0] - height]
            img = F.pad(img, padding, self.fill, self.padding_mode)

        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

    def __repr__(self):
        return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding)


class RandomHorizontalFlip(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, img1,  label):
        if torch.rand(1) < self.p:
            return F.hflip(img1),  F.hflip(label)
        return img1, label

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomVerticalFlip(torch.nn.Module):

    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, img1, label):
        if torch.rand(1) < self.p:
            return F.vflip(img1),  F.vflip(label)
        return img1,  label

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomPerspective(torch.nn.Module):
    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0):
        super().__init__()
        self.p = p
        self.interpolation = interpolation
        self.distortion_scale = distortion_scale
        self.fill = fill

    def forward(self, img):
        if torch.rand(1) < self.p:
            width, height = F._get_image_size(img)
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
            return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
        return img

    @staticmethod
    def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
        """Get parameters for ``perspective`` for a random perspective transform.

        Args:
            width (int): width of the image.
            height (int): height of the image.
            distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.

        Returns:
            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
        """
        half_height = height // 2
        half_width = width // 2
        topleft = [
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
        ]
        topright = [
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
        ]
        botright = [
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
        ]
        botleft = [
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
        ]
        startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
        endpoints = [topleft, topright, botright, botleft]
        return startpoints, endpoints

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomResizedCrop(torch.nn.Module):
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
        super().__init__()
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

        if not isinstance(scale, Sequence):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, Sequence):
            raise TypeError("Ratio should be a sequence")
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
            warnings.warn("Scale and ratio should be of kind (min, max)")

        self.interpolation = interpolation
        self.scale = scale
        self.ratio = ratio

    @staticmethod
    def get_params(
            img: Tensor, scale: List[float], ratio: List[float]
    ) -> Tuple[int, int, int, int]:
        # width, height = F._get_image_size(img)
        width, height = img.size
        area = height * width

        for _ in range(10):
            target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
            log_ratio = torch.log(torch.tensor(ratio))
            aspect_ratio = torch.exp(
                torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
            ).item()

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if 0 < w <= width and 0 < h <= height:
                i = torch.randint(0, height - h + 1, size=(1,)).item()
                j = torch.randint(0, width - w + 1, size=(1,)).item()
                return i, j, h, w

        # Fallback to central crop
        in_ratio = float(width) / float(height)
        if in_ratio < min(ratio):
            w = width
            h = int(round(w / min(ratio)))
        elif in_ratio > max(ratio):
            h = height
            w = int(round(h * max(ratio)))
        else:  # whole image
            w = width
            h = height
        i = (height - h) // 2
        j = (width - w) // 2
        return i, j, h, w

    def forward(self, img1,  label):
        i, j, h, w = self.get_params(img1, self.scale, self.ratio)
        return F.resized_crop(img1, i, j, h, w, self.size, self.interpolation), F.resized_crop(label, i, j, h, w, self.size, self.interpolation),

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
        format_string += ', interpolation={0})'.format(interpolate_str)
        return format_string


class RandomSizedCrop(RandomResizedCrop):
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
                      "please use transforms.RandomResizedCrop instead.")
        super(RandomSizedCrop, self).__init__(*args, **kwargs)


class FiveCrop(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

    def forward(self, img):
        return F.five_crop(img, self.size)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)


class TenCrop(torch.nn.Module):

    def __init__(self, size, vertical_flip=False):
        super().__init__()
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
        self.vertical_flip = vertical_flip

    def forward(self, img):
        return F.ten_crop(img, self.size, self.vertical_flip)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)


class LinearTransformation(torch.nn.Module):

    def __init__(self, transformation_matrix, mean_vector):
        super().__init__()
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError("transformation_matrix should be square. Got " +
                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))

        if mean_vector.size(0) != transformation_matrix.size(0):
            raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
                             " as any one of the dimensions of the transformation_matrix [{}]"
                             .format(tuple(transformation_matrix.size())))

        if transformation_matrix.device != mean_vector.device:
            raise ValueError("Input tensors should be on the same device. Got {} and {}"
                             .format(transformation_matrix.device, mean_vector.device))

        self.transformation_matrix = transformation_matrix
        self.mean_vector = mean_vector

    def forward(self, tensor: Tensor) -> Tensor:
        shape = tensor.shape
        n = shape[-3] * shape[-2] * shape[-1]
        if n != self.transformation_matrix.shape[0]:
            raise ValueError("Input tensor and transformation matrix have incompatible shape." +
                             "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) +
                             "{}".format(self.transformation_matrix.shape[0]))

        if tensor.device.type != self.mean_vector.device.type:
            raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. "
                             "Got {} vs {}".format(tensor.device, self.mean_vector.device))

        flat_tensor = tensor.view(-1, n) - self.mean_vector
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
        tensor = transformed_tensor.view(shape)
        return tensor

    def __repr__(self):
        format_string = self.__class__.__name__ + '(transformation_matrix='
        format_string += (str(self.transformation_matrix.tolist()) + ')')
        format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
        return format_string


class ColorJitter(torch.nn.Module):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        super().__init__()
        self.brightness = self._check_input(brightness, 'brightness')
        self.contrast = self._check_input(contrast, 'contrast')
        self.saturation = self._check_input(saturation, 'saturation')
        self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
                                     clip_first_on_zero=False)

    @torch.jit.unused
    def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError("If {} is a single number, it must be non negative.".format(name))
            value = [center - float(value), center + float(value)]
            if clip_first_on_zero:
                value[0] = max(value[0], 0.0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))

        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
            value = None
        return value

    @staticmethod
    @torch.jit.unused
    def get_params(brightness, contrast, saturation, hue):
        """Get a randomized transform to be applied on image.

        Arguments are same as that of __init__.

        Returns:
            Transform which randomly adjusts brightness, contrast and
            saturation in a random order.
        """
        transforms = []

        if brightness is not None:
            brightness_factor = random.uniform(brightness[0], brightness[1])
            transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))

        if contrast is not None:
            contrast_factor = random.uniform(contrast[0], contrast[1])
            transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))

        if saturation is not None:
            saturation_factor = random.uniform(saturation[0], saturation[1])
            transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))

        if hue is not None:
            hue_factor = random.uniform(hue[0], hue[1])
            transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))

        random.shuffle(transforms)
        transform = Compose(transforms)

        return transform

    def forward(self, img1, label):
        fn_idx = torch.randperm(4)
        for fn_id in fn_idx:
            if fn_id == 0 and self.brightness is not None:
                brightness = self.brightness
                brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
                img1 = F.adjust_brightness(img1, brightness_factor)


            if fn_id == 1 and self.contrast is not None:
                contrast = self.contrast
                contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
                img1 = F.adjust_contrast(img1, contrast_factor)


            if fn_id == 2 and self.saturation is not None:
                saturation = self.saturation
                saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
                img1 = F.adjust_saturation(img1, saturation_factor)


            if fn_id == 3 and self.hue is not None:
                hue = self.hue
                hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
                img1 = F.adjust_hue(img1, hue_factor)


        return img1,  label

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += 'brightness={0}'.format(self.brightness)
        format_string += ', contrast={0}'.format(self.contrast)
        format_string += ', saturation={0}'.format(self.saturation)
        format_string += ', hue={0})'.format(self.hue)
        return format_string


class RandomRotation(torch.nn.Module):
    def __init__(self, degrees, resample=False, expand=False, center=None, fill=None):
        super().__init__()
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))

        if center is not None:
            _check_sequence_input(center, "center", req_sizes=(2, ))

        self.center = center

        self.resample = resample
        self.expand = expand
        self.fill = fill

    @staticmethod
    def get_params(degrees: List[float]) -> float:
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
        return angle

    def forward(self, img1, label):
        angle = self.get_params(self.degrees)
        return F.rotate(img1, angle, self.resample, self.expand, self.center, self.fill),  F.rotate(label, angle, self.resample, self.expand, self.center, self.fill)



    def __repr__(self):
        format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
        format_string += ', resample={0}'.format(self.resample)
        format_string += ', expand={0}'.format(self.expand)
        if self.center is not None:
            format_string += ', center={0}'.format(self.center)
        if self.fill is not None:
            format_string += ', fill={0}'.format(self.fill)
        format_string += ')'
        return format_string


class RandomAffine(torch.nn.Module):
    def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, fillcolor=0):
        super().__init__()
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))

        if translate is not None:
            _check_sequence_input(translate, "translate", req_sizes=(2, ))
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
            _check_sequence_input(scale, "scale", req_sizes=(2, ))
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
            self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
        else:
            self.shear = shear

        self.resample = resample
        self.fillcolor = fillcolor

    @staticmethod
    def get_params(
            degrees: List[float],
            translate: Optional[List[float]],
            scale_ranges: Optional[List[float]],
            shears: Optional[List[float]],
            img_size: List[int]
    ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
        if translate is not None:
            max_dx = float(translate[0] * img_size[0])
            max_dy = float(translate[1] * img_size[1])
            tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
            ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
            translations = (tx, ty)
        else:
            translations = (0, 0)

        if scale_ranges is not None:
            scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
        else:
            scale = 1.0

        shear_x = shear_y = 0.0
        if shears is not None:
            shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
            if len(shears) == 4:
                shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())

        shear = (shear_x, shear_y)

        return angle, translations, scale, shear

    def forward(self, img):
        img_size = F._get_image_size(img)

        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
        return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)

    def __repr__(self):
        s = '{name}(degrees={degrees}'
        if self.translate is not None:
            s += ', translate={translate}'
        if self.scale is not None:
            s += ', scale={scale}'
        if self.shear is not None:
            s += ', shear={shear}'
        if self.resample > 0:
            s += ', resample={resample}'
        if self.fillcolor != 0:
            s += ', fillcolor={fillcolor}'
        s += ')'
        d = dict(self.__dict__)
        d['resample'] = _pil_interpolation_to_str[d['resample']]
        return s.format(name=self.__class__.__name__, **d)


class Grayscale(torch.nn.Module):

    def __init__(self, num_output_channels=1):
        super().__init__()
        self.num_output_channels = num_output_channels

    def forward(self, img):
        return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)

    def __repr__(self):
        return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)


class RandomGrayscale(torch.nn.Module):
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p

    def forward(self, img):
        num_output_channels = F._get_image_num_channels(img)
        if torch.rand(1) < self.p:
            return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={0})'.format(self.p)


class RandomErasing(torch.nn.Module):
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
        super().__init__()
        if not isinstance(value, (numbers.Number, str, tuple, list)):
            raise TypeError("Argument value should be either a number or str or a sequence")
        if isinstance(value, str) and value != "random":
            raise ValueError("If value is str, it should be 'random'")
        if not isinstance(scale, (tuple, list)):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, (tuple, list)):
            raise TypeError("Ratio should be a sequence")
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
            warnings.warn("Scale and ratio should be of kind (min, max)")
        if scale[0] < 0 or scale[1] > 1:
            raise ValueError("Scale should be between 0 and 1")
        if p < 0 or p > 1:
            raise ValueError("Random erasing probability should be between 0 and 1")

        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
        self.inplace = inplace

    @staticmethod
    def get_params(
            img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
    ) -> Tuple[int, int, int, int, Tensor]:
        img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
        area = img_h * img_w

        for _ in range(10):
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
            aspect_ratio = torch.empty(1).uniform_(ratio[0], ratio[1]).item()

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
            if not (h < img_h and w < img_w):
                continue

            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]

            i = torch.randint(0, img_h - h + 1, size=(1, )).item()
            j = torch.randint(0, img_w - w + 1, size=(1, )).item()
            return i, j, h, w, v

        # Return original image
        return 0, 0, img_h, img_w, img

    def forward(self, img):
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
                value = [self.value, ]
            elif isinstance(self.value, str):
                value = None
            elif isinstance(self.value, tuple):
                value = list(self.value)
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
                    "{} (number of input channels)".format(img.shape[-3])
                )

            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
            return F.erase(img, x, y, h, w, v, self.inplace)
        return img


class GaussianBlur(torch.nn.Module):

    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        super().__init__()
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")

        if isinstance(sigma, numbers.Number):
            if sigma <= 0:
                raise ValueError("If sigma is a single number, it must be positive.")
            sigma = (sigma, sigma)
        elif isinstance(sigma, Sequence) and len(sigma) == 2:
            if not 0. < sigma[0] <= sigma[1]:
                raise ValueError("sigma values should be positive and of the form (min, max).")
        else:
            raise ValueError("sigma should be a single number or a list/tuple with length 2.")

        self.sigma = sigma

    @staticmethod
    def get_params(sigma_min: float, sigma_max: float) -> float:
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()

    def forward(self, img: Tensor) -> Tensor:
        """
        Args:
            img (PIL Image or Tensor): image to be blurred.

        Returns:
            PIL Image or Tensor: Gaussian blurred image
        """
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])

    def __repr__(self):
        s = '(kernel_size={}, '.format(self.kernel_size)
        s += 'sigma={})'.format(self.sigma)
        return self.__class__.__name__ + s


def _setup_size(size, error_msg):
    if isinstance(size, numbers.Number):
        return int(size), int(size)

    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]

    if len(size) != 2:
        raise ValueError(error_msg)

    return size


def _check_sequence_input(x, name, req_sizes):
    msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
    if not isinstance(x, Sequence):
        raise TypeError("{} should be a sequence of length {}.".format(name, msg))
    if len(x) not in req_sizes:
        raise ValueError("{} should be sequence of length {}.".format(name, msg))


def _setup_angle(x, name, req_sizes=(2, )):
    if isinstance(x, numbers.Number):
        if x < 0:
            raise ValueError("If {} is a single number, it must be positive.".format(name))
        x = [-x, x]
    else:
        _check_sequence_input(x, name, req_sizes)

    return [float(d) for d in x]


#utils

In [None]:
import torch
import logging
import torch.nn as nn
import numpy as np
from skimage import measure
from torch._utils import _accumulate
from torch import randperm
from scipy.ndimage import morphology


def random_split(dataset, lengths, inds=None, israndom=True):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    if israndom:
        indices = randperm(sum(lengths)).tolist()
        print(indices)
    else:
        indices = inds

    return [torch.utils.data.Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]


def expand_as_one_hot(input, C, ignore_index=None):
    """
    Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
    It is assumed that the batch dimension is present.
    Args:
        input (torch.Tensor): 3D/4D input image
        C (int): number of channels/labels
        ignore_index (int): ignore index to be kept during the expansion
    Returns:
        4D/5D output torch.Tensor (NxCxSPATIAL)
    """
    assert input.dim() == 4

    # expand the input tensor to Nx1xSPATIAL before scattering
    input = input.unsqueeze(1)
    # create output tensor shape (NxCxSPATIAL)
    shape = list(input.size())
    shape[1] = C

    if ignore_index is not None:
        # create ignore_index mask for the result
        mask = input.expand(shape) == ignore_index
        # clone the src tensor and zero out ignore_index in the input
        input = input.clone()
        input[input == ignore_index] = 0
        # scatter to get the one-hot tensor
        result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
        # bring back the ignore_index in the result
        result[mask] = ignore_index
        return result
    else:
        # scatter to get the one-hot tensor
        return torch.zeros(shape).to(input.device).scatter_(1, input, 1)


def random_split(dataset, lengths, inds=None, israndom=True):
    r"""
    Randomly split a data into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input data!")

    if israndom:
        indices = randperm(sum(lengths)).tolist()
        print(indices)
    else:
        indices = inds

    return [torch.utils.data.Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


def logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger


def iou_score(output, target):
    smooth = 1e-5
    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().round().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5
    target_ = target > 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()

    return (intersection + smooth) / (union + smooth)


def dice_coeff(output, target):
    smooth = 1e-5

    output = torch.sigmoid(output).view(-1).data.cpu().numpy()
    target = target.view(-1).data.cpu().numpy()
    intersection = (output * target).sum()

    return (2. * intersection + smooth) / \
        (output.sum() + target.sum() + smooth)


class Evaluator(object):
    def __init__(self, num_class):
        self.num_class = num_class
        self.confusion_matrix = np.zeros((self.num_class,)*2)

    def Precision(self):
        precision = np.diag(self.confusion_matrix)[0]/self.confusion_matrix[:, 0].sum()
        return precision

    def Recall(self):
        recall = np.diag(self.confusion_matrix)[0]/self.confusion_matrix[0, :].sum()
        return recall

    def Specificity(self):
        specificity = np.diag(self.confusion_matrix)[1]/(self.confusion_matrix[0, 1]+self.confusion_matrix[1, 0])
        return specificity

    def F1score(self):
        prec = self.Precision()
        rec = self.Recall()
        f1_score = (2*prec*rec)/(prec+rec)
        return f1_score

    def F2score(self):
        prec = self.Precision()
        rec = self.Recall()
        f2_score = (5*prec*rec)/(4*prec+rec)
        return f2_score

    def Intersection_over_Union(self):
        iou = np.diag(self.confusion_matrix)[0]/(self.confusion_matrix[0,0]+self.confusion_matrix[1,0]+self.confusion_matrix[0,1])
        return iou

    def Pixel_Accuracy(self):
        Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return Acc

    def Pixel_Accuracy_Class(self):
        Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        Acc = np.nanmean(Acc)
        return Acc

    def Mean_Intersection_over_Union(self):
        MIoU = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))
        MIoU = np.nanmean(MIoU)
        return MIoU

    def Frequency_Weighted_Intersection_over_Union(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))

        FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
        return FWIoU

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        print(confusion_matrix)
        return confusion_matrix

    def add_batch(self, gt_image, pre_image):
        assert gt_image.shape == pre_image.shape
        self.confusion_matrix += self._generate_matrix(gt_image, pre_image)

    def reset(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)


def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
    decay = decay_rate ** (epoch // decay_epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] *= decay


def clip_gradient(optimizer, grad_clip):
    """
    For calibrating misalignment gradient via cliping gradient technique
    :param optimizer:
    :param grad_clip:
    :return:
    """
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)


def universal3Dlargestregion(deal):
    """找到3D丈量最大连通域,输出为值为1的mask.deal:输入的3D张量"""
    labels = measure.label(deal, connectivity=3)  # 找白色区域的8连通域，并给予每个连通域标号，connectivity为ndarry的维数，三维数组故为3
    jj = measure.regionprops(labels)  # 这里是取得labels的属性，属性有许多
    save_indexs = []
    num = labels.max()  # 找白色部分的连通域有几个
    print('白色区域数量', num)
    del_array = np.array([0] * (num + 1))
    for k in range(num):  # 这里是找最大的那个白色连通域的标号
        if k == 0:
            initial_area = jj[0].area
            save_index = 1  # 初始保留第一个连通域
            if save_index not in save_indexs:
                save_indexs.append(save_index)
        else:
            k_area = jj[k].area  # 将元组转换成array
            if initial_area < k_area:
                initial_area = k_area
                save_index = k + 1  # python从0开始，而连通域标记是从1开始
                if save_index not in save_indexs:
                    save_indexs.append(save_index)
    print('save_index: ', save_indexs)
    del_array[save_indexs[-2]] = 1
    del_array[save_indexs[-1]] = 1
    del_mask = del_array[labels]
    return del_mask


def measureimg(o_img,t_num=1):
    p_img=np.zeros_like(o_img)
    # temp_img=morphology.binary_dilation(o_img.astype("bool"),iterations=2)
    testa1 = measure.label(o_img.astype("bool"))
    props = measure.regionprops(testa1)
    numPix = []
    for ia in range(len(props)):
        numPix += [props[ia].area]
    # print(numPix)
    # 像素最多的连通区域及其指引
    for i in range(0,t_num):
        index = numPix.index(max(numPix)) + 1
        p_img[testa1 == index]=o_img[testa1 == index]
        numPix[index-1]=0
    return p_img

#mutiscaleUnet

In [None]:
import torch
from torch import nn
import math
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention2(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention2, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)*x

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=3):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        y = torch.cat([avg_out, max_out], dim=1)
        y = self.conv1(y)
        return self.sigmoid(y) * x

class CBAM(nn.Module):
    def __init__(self,in_channels):
        super(CBAM, self).__init__()

        self.ca = ChannelAttention2(in_channels, 16)
        self.sa = SpatialAttention(7)

    def forward(self, x):
        cx = self.ca(x)
        out = self.sa(cx)

        return out



class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_block, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x


class conv_block2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_block2, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

        self.conv1x1 = nn.Conv2d(out_channels * 3, out_channels, kernel_size=1)
        self.conv1x2 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        # print("here:",x1.shape)
        x2 = self.conv2(x1)
        # print("here1:", x1.shape)
        x3 = self.conv3(x2)
        # print("here2:", x1.shape)
        x_out = torch.cat([x1, x2, x3], dim=1)
        # print("here3:", x1.shape)
        x_res = self.conv1x1(x_out)
        # print("here4:", x1.shape)
        x_out = self.conv1x2(x)
        return x_res+x_out

class conv_block4(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_block4, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

        self.conv1x1 = nn.Conv2d(out_channels * 3, out_channels, kernel_size=1)
        self.conv1x2 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.attention = CBAM2(out_channels * 3)

    def forward(self, x):
        x1 = self.conv1(x)
        # print("here:",x1.shape)
        x2 = self.conv2(x1)
        # print("here1:", x1.shape)
        x3 = self.conv3(x2)
        # print("here2:", x1.shape)
        x_out = torch.cat([x1, x2, x3], dim=1)
        # print("here3:", x1.shape)
        x_out1 = self.attention(x_out)
        x_res = self.conv1x1(x_out1)
        # print("here4:", x1.shape)
        x_out = self.conv1x2(x)
        return x_res+x_out

class conv_block3(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_block3, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),)

    def forward(self, x):
        x = self.conv(x)
        return x

class MultiScaleFeatureFusionConvBlock3d(nn.Module):
    '''
    多尺度特征融合模块+卷积模块
    '''

    def __init__(self, in_channels, out_channels):
        super(MultiScaleFeatureFusionConvBlock3d, self).__init__()

        self.out_channels_split = out_channels // 4
        self.conv1_in = nn.Conv3d(in_channels, out_channels, [1, 1, 1], stride=1, padding=0)
        self.conv3_addition_2 = conv_block3(self.out_channels_split, self.out_channels_split)
        self.conv3_addition_3 = conv_block3(self.out_channels_split, self.out_channels_split)
        self.conv3_addition_4 = conv_block3(self.out_channels_split, self.out_channels_split)
        self.conv1_out = nn.Conv3d(out_channels, out_channels, [1, 1, 1], stride=1, padding=0)

        self.conv3 = conv_block3(out_channels, out_channels)

    def forward(self, x):
        f = self.conv1_in(x)
        f_1 = f[:, 0: self.out_channels_split, :, :, :]
        f_2 = self.conv3_addition_2(f[:, self.out_channels_split: 2 * self.out_channels_split, :, :, :])
        f_3 = self.conv3_addition_3(f[:, 2 * self.out_channels_split: 3 * self.out_channels_split, :, :, :] + f_2)
        f_4 = self.conv3_addition_4(f[:, 3 * self.out_channels_split: 4 * self.out_channels_split, :, :, :] + f_3)
        fusion = f_1 + f_2 + f_3 + f_4
        f_1 = f_1 + fusion
        f_2 = f_2 + fusion
        f_3 = f_3 + fusion
        f_4 = f_4 + fusion
        return self.conv3(self.conv1_out(torch.cat((f_1, f_2, f_3, f_4), dim=1)))


class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()
        n = 32
        filters = [n, n*2, n*4, n*8, n*16]
        self.conv0 = conv_block2(in_channels, filters[0])
        self.conv1 = conv_block2(filters[0], filters[1])
        self.conv2 = conv_block2(filters[1], filters[2])
        self.conv3 = conv_block2(filters[2], filters[3])
        # self.conv4 = conv_block(filters[3], filters[4])

        # self.conv5 = conv_block(filters[4], filters[3])
        self.conv6 = conv_block2(filters[3], filters[2])
        self.conv7 = conv_block2(filters[2], filters[1])
        self.conv8 = conv_block2(filters[1], filters[0])

        self.maxpool1 = nn.MaxPool3d(kernel_size=[2, 2, 2])
        self.maxpool2 = nn.MaxPool3d(kernel_size=[2, 2, 2])
        self.maxpool3 = nn.MaxPool3d(kernel_size=[2, 2, 2])
        # self.maxpool4 = nn.MaxPool3d(2)

        # self.transconv4 = nn.ConvTranspose3d(filters[4], filters[4] // 2, kernel_size=2, stride=2)
        self.transconv3 = nn.ConvTranspose3d(filters[3], filters[3] // 2, kernel_size=2, stride=2)
        self.transconv2 = nn.ConvTranspose3d(filters[2], filters[2] // 2, kernel_size=2, stride=2)
        self.transconv1 = nn.ConvTranspose3d(filters[1], filters[1] // 2, kernel_size=2, stride=2)

        self.conv1x1 = nn.Conv3d(filters[0], out_channels, kernel_size=1)

    def forward(self, x):
        # encoder #
        x = self.conv0(x)
        print(x.shape)
        x_cat1 = x
        x = self.maxpool1(x)

        x = self.conv1(x)
        print(x.shape)
        x_cat2 = x
        x = self.maxpool2(x)

        x = self.conv2(x)
        print(x.shape)
        x_cat3 = x
        x = self.maxpool3(x)

        x = self.conv3(x)
        print(x.shape)
        # x_cat4 = x
        # x = self.maxpool4(x)
        #
        # x = self.conv4(x)

        # decoder #
        # x_trans4 = self.transconv4(x)
        # x = torch.cat([x_cat4, x_trans4], dim=1)
        # x = self.conv5(x)

        x_trans3 = self.transconv3(x)
        print(x_trans3.shape)

        x = torch.cat([x_cat3, x_trans3], dim=1)
        x = self.conv6(x)
        print(x.shape)
        x_trans2 = self.transconv2(x)
        x = torch.cat([x_cat2, x_trans2], dim=1)
        x = self.conv7(x)
        print(x.shape)
        x_trans1 = self.transconv1(x)
        x = torch.cat([x_cat1, x_trans1], dim=1)
        x = self.conv8(x)
        print(x.shape)
        output = self.conv1x1(x)

        return output
##################################################################
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM2(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM2, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out


In [None]:
import torch
from torch.autograd import Variable

img = torch.randn(1,32,128,128)
img = img.cuda()

net = UNet(1,2).cuda()

#out = net(img)

#print(out.size())

In [None]:
print(net)

UNet(
  (conv0): conv_block2(
    (conv1): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (conv2): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (conv3): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (conv1x1): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))
    (conv1x2): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (conv1): conv_block2(
    (conv1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64

#UNet

In [None]:
import torch
from torch import nn

# class conv_block(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(conv_block, self).__init__()
#
#         self.conv = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True))
#
#     def forward(self, x):
#         x = self.conv(x)
#         return x


class trans_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(trans_block, self).__init__()

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()
        n = 32
        filters = [n, n*2, n*4, n*8, n*16]
        self.conv0 = conv_block4(in_channels, filters[0])
        self.conv1 = conv_block4(filters[0], filters[1])
        self.conv2 = conv_block4(filters[1], filters[2])
        self.conv3 = conv_block4(filters[2], filters[3])
        self.conv4 = conv_block4(filters[3], filters[4])

        self.conv5 = conv_block4(filters[4], filters[3])
        self.conv6 = conv_block4(filters[3], filters[2])
        self.conv7 = conv_block4(filters[2], filters[1])
        self.conv8 = conv_block4(filters[1], filters[0])

        self.maxpool1 = nn.MaxPool2d(2)
        self.maxpool2 = nn.MaxPool2d(2)
        self.maxpool3 = nn.MaxPool2d(2)
        self.maxpool4 = nn.MaxPool2d(2)

        self.transconv4 = nn.ConvTranspose2d(filters[4], filters[4] // 2, kernel_size=2, stride=2)
        self.transconv3 = nn.ConvTranspose2d(filters[3], filters[3] // 2, kernel_size=2, stride=2)
        self.transconv2 = nn.ConvTranspose2d(filters[2], filters[2] // 2, kernel_size=2, stride=2)
        self.transconv1 = nn.ConvTranspose2d(filters[1], filters[1] // 2, kernel_size=2, stride=2)

        self.conv1x1 = nn.Conv2d(filters[0], out_channels, kernel_size=1)

        # Initialize convolutions' parameters
        self.init_conv2d()

    def init_conv2d(self):
        """
        Initialize convolution parameters.
        """
        for c in self.children():
            if isinstance(c, nn.Conv2d):
                nn.init.xavier_uniform_(c.weight)
                nn.init.constant_(c.bias, 0.)

    def forward(self, x):
        # encoder #
        x = self.conv0(x)
        x_cat1 = x
        x = self.maxpool1(x)

        x = self.conv1(x)
        x_cat2 = x
        x = self.maxpool2(x)

        x = self.conv2(x)
        x_cat3 = x
        x = self.maxpool3(x)

        x = self.conv3(x)
        x_cat4 = x
        x = self.maxpool4(x)

        x = self.conv4(x)

        # decoder #
        x_trans4 = self.transconv4(x)
        x = torch.cat([x_cat4, x_trans4], dim=1)
        x = self.conv5(x)

        x_trans3 = self.transconv3(x)
        x = torch.cat([x_cat3, x_trans3], dim=1)
        x = self.conv6(x)

        x_trans2 = self.transconv2(x)
        x = torch.cat([x_cat2, x_trans2], dim=1)
        x = self.conv7(x)

        x_trans1 = self.transconv1(x)
        x = torch.cat([x_cat1, x_trans1], dim=1)
        x = self.conv8(x)

        output = self.conv1x1(x)

        return output


#utils

In [None]:
import torch
import logging
import torch.nn as nn
import numpy as np
from skimage import measure
from torch._utils import _accumulate
from torch import randperm
from scipy.ndimage import morphology


def random_split(dataset, lengths, inds=None, israndom=True):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    if israndom:
        indices = randperm(sum(lengths)).tolist()
        print(indices)
    else:
        indices = inds

    return [torch.utils.data.Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]


def expand_as_one_hot(input, C, ignore_index=None):
    """
    Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
    It is assumed that the batch dimension is present.
    Args:
        input (torch.Tensor): 3D/4D input image
        C (int): number of channels/labels
        ignore_index (int): ignore index to be kept during the expansion
    Returns:
        4D/5D output torch.Tensor (NxCxSPATIAL)
    """
    assert input.dim() == 4

    # expand the input tensor to Nx1xSPATIAL before scattering
    input = input.unsqueeze(1)
    # create output tensor shape (NxCxSPATIAL)
    shape = list(input.size())
    shape[1] = C

    if ignore_index is not None:
        # create ignore_index mask for the result
        mask = input.expand(shape) == ignore_index
        # clone the src tensor and zero out ignore_index in the input
        input = input.clone()
        input[input == ignore_index] = 0
        # scatter to get the one-hot tensor
        result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
        # bring back the ignore_index in the result
        result[mask] = ignore_index
        return result
    else:
        # scatter to get the one-hot tensor
        return torch.zeros(shape).to(input.device).scatter_(1, input, 1)


def random_split(dataset, lengths, inds=None, israndom=True):
    r"""
    Randomly split a data into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input data!")

    if israndom:
        indices = randperm(sum(lengths)).tolist()
        print(indices)
    else:
        indices = inds

    return [torch.utils.data.Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


def logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger


def iou_score(output, target):
    smooth = 1e-5
    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().round().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5
    target_ = target > 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()

    return (intersection + smooth) / (union + smooth)


def dice_coeff(output, target):
    smooth = 1e-5

    output = torch.sigmoid(output).view(-1).data.cpu().numpy()
    target = target.view(-1).data.cpu().numpy()
    intersection = (output * target).sum()

    return (2. * intersection + smooth) / \
        (output.sum() + target.sum() + smooth)


class Evaluator(object):
    def __init__(self, num_class):
        self.num_class = num_class
        self.confusion_matrix = np.zeros((self.num_class,)*2)

    def Precision(self):
        precision = np.diag(self.confusion_matrix)[0]/self.confusion_matrix[:, 0].sum()
        return precision

    def Recall(self):
        recall = np.diag(self.confusion_matrix)[0]/self.confusion_matrix[0, :].sum()
        return recall

    def Specificity(self):
        specificity = np.diag(self.confusion_matrix)[1]/(self.confusion_matrix[0, 1]+self.confusion_matrix[1, 0])
        return specificity

    def F1score(self):
        prec = self.Precision()
        rec = self.Recall()
        f1_score = (2*prec*rec)/(prec+rec)
        return f1_score

    def F2score(self):
        prec = self.Precision()
        rec = self.Recall()
        f2_score = (5*prec*rec)/(4*prec+rec)
        return f2_score

    def Intersection_over_Union(self):
        iou = np.diag(self.confusion_matrix)[0]/(self.confusion_matrix[0,0]+self.confusion_matrix[1,0]+self.confusion_matrix[0,1])
        return iou

    def Pixel_Accuracy(self):
        Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return Acc

    def Pixel_Accuracy_Class(self):
        Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        Acc = np.nanmean(Acc)
        return Acc

    def Mean_Intersection_over_Union(self):
        MIoU = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))
        MIoU = np.nanmean(MIoU)
        return MIoU

    def Frequency_Weighted_Intersection_over_Union(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))

        FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
        return FWIoU

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        print(confusion_matrix)
        return confusion_matrix

    def add_batch(self, gt_image, pre_image):
        assert gt_image.shape == pre_image.shape
        self.confusion_matrix += self._generate_matrix(gt_image, pre_image)

    def reset(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)


def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
    decay = decay_rate ** (epoch // decay_epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] *= decay


def clip_gradient(optimizer, grad_clip):
    """
    For calibrating misalignment gradient via cliping gradient technique
    :param optimizer:
    :param grad_clip:
    :return:
    """
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)


def universal3Dlargestregion(deal):
    """找到3D丈量最大连通域,输出为值为1的mask.deal:输入的3D张量"""
    labels = measure.label(deal, connectivity=3)  # 找白色区域的8连通域，并给予每个连通域标号，connectivity为ndarry的维数，三维数组故为3
    jj = measure.regionprops(labels)  # 这里是取得labels的属性，属性有许多
    save_indexs = []
    num = labels.max()  # 找白色部分的连通域有几个
    print('白色区域数量', num)
    del_array = np.array([0] * (num + 1))
    for k in range(num):  # 这里是找最大的那个白色连通域的标号
        if k == 0:
            initial_area = jj[0].area
            save_index = 1  # 初始保留第一个连通域
            if save_index not in save_indexs:
                save_indexs.append(save_index)
        else:
            k_area = jj[k].area  # 将元组转换成array
            if initial_area < k_area:
                initial_area = k_area
                save_index = k + 1  # python从0开始，而连通域标记是从1开始
                if save_index not in save_indexs:
                    save_indexs.append(save_index)
    print('save_index: ', save_indexs)
    del_array[save_indexs[-2]] = 1
    del_array[save_indexs[-1]] = 1
    del_mask = del_array[labels]
    return del_mask


def measureimg(o_img,t_num=1):
    p_img=np.zeros_like(o_img)
    # temp_img=morphology.binary_dilation(o_img.astype("bool"),iterations=2)
    testa1 = measure.label(o_img.astype("bool"))
    props = measure.regionprops(testa1)
    numPix = []
    for ia in range(len(props)):
        numPix += [props[ia].area]
    # print(numPix)
    # 像素最多的连通区域及其指引
    for i in range(0,t_num):
        index = numPix.index(max(numPix)) + 1
        p_img[testa1 == index]=o_img[testa1 == index]
        numPix[index-1]=0
    return p_img

#loss

In [None]:
import torch
from torch import nn
from torch.nn.functional import one_hot


class WCEDCELoss(nn.Module):
    def __init__(self, num_classes=4, inter_weights=0.5, intra_weights=None, device='cuda'):
        super(WCEDCELoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss(weight=intra_weights)
        self.num_classes = num_classes
        self.intra_weights = intra_weights
        self.inter_weights = inter_weights
        self.device = device

    def dice_loss(self, prediction, target, weights):
        """Calculating the dice loss
        Args:
            prediction = predicted image
            target = Targeted image
        Output:
            dice_loss"""
        smooth = 1e-5

        prediction = torch.softmax(prediction, dim=1)
        batchsize = target.size(0)
        num_classes = target.size(1)
        prediction = prediction.view(batchsize, num_classes, -1)
        target = target.view(batchsize, num_classes, -1)

        intersection = (prediction * target)

        dice = (2. * intersection.sum(2) + smooth) / (prediction.sum(2) + target.sum(2) + smooth)
        # print('dice: ', dice)
        dice_loss = 1 - dice.sum(0) / batchsize
        weighted_dice_loss = dice_loss * weights

        # print(dice_loss, weighted_dice_loss)
        return weighted_dice_loss.mean()

    def forward(self, pred, label):
        """Calculating the loss and metrics
            Args:
                prediction = predicted image
                target = Targeted image
                metrics = Metrics printed
                bce_weight = 0.5 (default)
            Output:
                loss : dice loss of the epoch """
        cel = self.ce_loss(pred, label)
        label_onehot = one_hot(label, num_classes=self.num_classes).permute(0, 3, 1, 2).contiguous()

        if self.intra_weights == None:
            intra_weights = torch.zeros([self.num_classes]).to(self.device)
            for item in range(self.num_classes):
                intra_weights[item] = len(label.view(-1)) / (len(label[label == item].view(-1)) + 1e-5)
        else:
            intra_weights = self.intra_weights
        # print('weights: ', intra_weights)
        dicel = self.dice_loss(pred, label_onehot, intra_weights)
        # print('ce: ', cel, 'dicel: ', dicel)
        loss = cel * self.inter_weights + dicel * (1 - self.inter_weights)

        return loss

In [None]:
wcedceloss = WCEDCELoss()
label = torch.randint(low=0, high=4, size=[2, 224, 224]).cuda()
print(one_hot(label, 4).shape)
prediction = torch.randn([2, 4, 224, 224]).cuda()
loss = wcedceloss(pred=prediction, label=label)
print('loss: ', loss)

torch.Size([2, 224, 224, 4])
loss:  tensor(2.3606, device='cuda:0')


#main

In [None]:
from torch.utils.data.dataset import Dataset
import os
from torchvision.transforms import *
from PIL import Image
import random
import torch
import numpy as np
from torch.utils.data import DataLoader


In [None]:
def transform():
    return Compose([
        Resize([512, 512], Image.NEAREST),
        ToTensor()
    ])


def label_transform():
    return Compose([
        Resize([512, 512], Image.NEAREST),
    ])

def augmentation_transform():
    return Compose([
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        RandomRotation(degrees=180),
        ColorJitter(),
        RandomResizedCrop(size=512)
    ])

class Test1Dataset(Dataset):
    def __init__(self, image_dir = '', stage = 'train', augmentation = False):
        super(Test1Dataset, self).__init__()
        self.stage = stage
        self.augmentation = augmentation

        image_path = image_dir +'/2d_images'
        label_path = image_dir +'/2d_masks'

        self.flair_image_paths = [os.path.join(image_path, x)
                                  for x in os.listdir(os.path.join(image_path))

                                  if x.endswith('.tif')]



        self.label_paths = [os.path.join(label_path, x)
                            for x in os.listdir(label_path)

                            if x.endswith('tif')]


        # print(self.flair_image_paths)
        # print(self.label_paths)

        self.transform = transform()
        self.label_transform = label_transform()
        self.augmentation_transform = augmentation_transform()

    def __getitem__(self, index):
        flair_image = Image.open(self.flair_image_paths[index])

        label = Image.open(self.label_paths[index])

        if self.augmentation and self.stage == 'train':
            factor = random.choice([0,1])
            if factor:
                flair_image, label = self.augmentation_transform(flair_image,label)

        flair_image = self.transform(flair_image)

        label = self.label_transform(label)
        # image = torch.cat([flair_image, t1_image, t1ce_image, t2_image])
        image = flair_image
        label = torch.from_numpy(np.array(label)).float().unsqueeze(0)

        label[label > 1] = 1

        return image, label

    def __len__(self):
        return len(self.label_paths)

In [None]:
dataset_path = '/content/drive/MyDrive/dataset'

dataset = Test1Dataset(image_dir = dataset_path,augmentation = False)
print(dataset)
dataloader = DataLoader(dataset, batch_size = 2, shuffle = True, num_workers = 0)

print(len(dataset))
#for image, label in dataloader:
#  print(image.shape, label.shape)

<__main__.Test1Dataset object at 0x7f37e0e4ca10>
267


  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim
import time
import os
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import logging
import torch.nn as nn
import numpy as np
from skimage import measure
from torch._utils import _accumulate
from torch import randperm
from scipy.ndimage import morphology

def logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger


In [None]:
Num_epochs = 10
Num_batchsize = 3
Learningrate = 0.0003
device = torch.device("cuda" )
print(device)
Imagedir =  '/content/drive/MyDrive/dataset/'

Architecture = ['Unet']
architecture = Architecture[0]

AUGMENTATION = False
WCEDCELOSS = False
Deepsupervision = False

############################load data#########################
set1 = Test1Dataset(image_dir=Imagedir , stage='train', augmentation=AUGMENTATION)


train_set = set1

train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=Num_batchsize, shuffle=True, pin_memory=True)

############################load the net#####################


inet = UNet(in_channels=1, out_channels=2)
print("#parameters:", sum(param.numel() for param in inet.parameters()))

inet = inet.to(device)
##################loss function and optimization##############

if WCEDCELOSS:

    criterion = WCEDCELoss(intra_weights=torch.tensor([1., 5.]).to(device), device=device, inter_weights=0.5)
else:
    criterion = torch.nn.CrossEntropyLoss()

optimizer = optim.AdamW(inet.parameters(), lr=Learningrate, betas=(0.9, 0.999))

#######Train the net###################3

results = {'loss': [], 'dice': [], 'iou': [], 'val_loss': [], 'val_dice': [], 'val_iou': []}



for epoch in range(1, Num_epochs + 1):
    epochresults = {'loss': [], 'dice': [], 'iou': [], 'val_loss': [], 'val_dice': [], 'val_iou': []}
    inet.train()
    for iteration, data in enumerate(train_loader):
        image, label = data
        #print("*********************",image.shape,label.shape)
        image = image.to(device)
        label = label.to(device)
        optimizer.zero_grad()

        if Deepsupervision:
            pred, pred1, pred2, pred3, pred4 = inet(image)
            loss0 = criterion(pred, label.squeeze(1).long())
            loss1 = criterion(pred1, F.interpolate(label, scale_factor=1. / 2., mode='bilinear').squeeze(1).long())
            loss2 = criterion(pred2, F.interpolate(label, scale_factor=1. / 4., mode='bilinear').squeeze(1).long())
            loss3 = criterion(pred3, F.interpolate(label, scale_factor=1. / 8., mode='bilinear').squeeze(1).long())
            loss4 = criterion(pred4, F.interpolate(label, scale_factor=1. / 16., mode='bilinear').squeeze(1).long())
            loss = 0.4 * loss0 + 0.3 * loss1 + 0.2 * loss2 + 0.05 * loss3 + 0.05 * loss4
        else:
            pred = inet(image.float())
            #print("#######",pred.shape)
            loss = criterion(pred, label.squeeze(1).long())

        loss.backward()
        optimizer.step()
        ########loss of each iteration########
        if iteration % 100 == 0:
            print("Train: Epoch/Epoches {}/{}\t"
                        "iteration/iterations {}/{}\t"
                        "loss {:.3f}".format(epoch, Num_epochs, iteration, len(train_loader), loss.item()))

        epochresults['loss'].append(loss.item())

    results['loss'].append(np.mean(epochresults['loss']))# 每个epoch的迭代loss的平均结果保存

    ######Average loss of each epoch########
    print("Average: Epoch/Epoches {}/{}\t"
                "train epoch loss {:.3f}\n".format(epoch, Num_epochs, np.mean(epochresults['loss'])))
    #######save lastest epoch modle#######
    net_model_path = '/content/drive/MyDrive/Models'
    
    torch.save(inet.state_dict(), net_model_path + "/net_best_epoch_%d.pth" % epoch)


cuda
#parameters: 13852026


  "Argument interpolation should be of type InterpolationMode instead of int. "


Train: Epoch/Epoches 1/10	iteration/iterations 0/89	loss 43.097
Average: Epoch/Epoches 1/10	train epoch loss 4.985

Train: Epoch/Epoches 2/10	iteration/iterations 0/89	loss 1.424
Average: Epoch/Epoches 2/10	train epoch loss 0.739

Train: Epoch/Epoches 3/10	iteration/iterations 0/89	loss 0.886
Average: Epoch/Epoches 3/10	train epoch loss 0.445

Train: Epoch/Epoches 4/10	iteration/iterations 0/89	loss 0.488
Average: Epoch/Epoches 4/10	train epoch loss 0.353

Train: Epoch/Epoches 5/10	iteration/iterations 0/89	loss 0.319
Average: Epoch/Epoches 5/10	train epoch loss 0.437

Train: Epoch/Epoches 6/10	iteration/iterations 0/89	loss 0.244
Average: Epoch/Epoches 6/10	train epoch loss 0.365

Train: Epoch/Epoches 7/10	iteration/iterations 0/89	loss 0.349
Average: Epoch/Epoches 7/10	train epoch loss 0.378

Train: Epoch/Epoches 8/10	iteration/iterations 0/89	loss 0.267
Average: Epoch/Epoches 8/10	train epoch loss 0.324

Train: Epoch/Epoches 9/10	iteration/iterations 0/89	loss 0.232
Average: Epoch/E