# Configs

In [None]:
import copy
import os
import random
import time
from collections import defaultdict

import matplotlib.image as image
import numpy as np
import pandas as pd
import pydicom
import torch
import torch.optim as optim
from scipy.special import comb
from skimage import io
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms

In [None]:
try:
    os.mkdir("../log_dir/ss")
    os.mkdir("../models/ss_models")
except FileExistsError:
    pass

In [None]:
class Config:
    def __init__(self,
                 data_augmentation=True,
                 nonlinear_rate=0.5,
                 paint_rate=0.6,
                 outpaint_rate=0.4,
                 flip_rate=0.5,
                 local_rate=0.4,
                 load_saved_model_path='../models/pretrained/ss_pretrained.pth.tar',
                 original_dataset_dir='../data/Dataset',
                 preprocessed_dataset_dir='../data/Preprocessed_Dataset/',
                 labels_file='../data/labels.csv',
                 label_type='PVWM',
                 preprocess=True,
                 begin_at_slice=6,
                 num_slices=7,
                 image_crop_size=128
                 ):
        self.data_augmentation = data_augmentation
        self.nonlinear_rate = nonlinear_rate
        self.paint_rate = paint_rate
        self.outpaint_rate = outpaint_rate
        self.inpaint_rate = 1.0 - self.outpaint_rate
        self.flip_rate = flip_rate
        self.local_rate = local_rate
        self.load_saved_model_path = load_saved_model_path
        self.original_dataset_dir = original_dataset_dir
        self.preprocessed_dataset_dir = preprocessed_dataset_dir
        self.labels_file = labels_file
        self.label_type = label_type
        self.preprocess = preprocess
        self.begin_at_slice = begin_at_slice
        self.num_slices = num_slices
        self.image_crop_size = image_crop_size

    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")


config = Config()
config.display()

## Augmentations

In [None]:
def bernstein_poly(i, n, t):
    """
     The Bernstein polynomial of n, i as a function of t
    """
    return comb(n, i) * (t ** (n - i)) * (1 - t) ** i


def bezier_curve(points, nTimes=1000):
    """
       Given a set of control points, return the
       bezier curve defined by the control points.

       Control points should be a list of lists, or list of tuples
       such as [ [1,1], 
                 [2,3], 
                 [4,5], ..[Xn, Yn] ]
        nTimes is the number of time steps, defaults to 1000

        See http://processingjs.nihongoresources.com/bezierinfo/
    """

    nPoints = len(points)
    xPoints = np.array([p[0] for p in points])
    yPoints = np.array([p[1] for p in points])

    t = np.linspace(0.0, 1.0, nTimes)

    polynomial_array = np.array([bernstein_poly(i, nPoints - 1, t) for i in range(0, nPoints)])
    xvals = np.dot(xPoints, polynomial_array)
    yvals = np.dot(yPoints, polynomial_array)
    return xvals, yvals


def data_augmentation(x, y, prob=0.5):
    # augmentation by flipping
    cnt = 1
    while random.random() < prob and cnt > 0:
        degree = random.choice([0, 1])
        # print('augmentation')
        x = np.flip(x, axis=degree)
        y = np.flip(y, axis=degree)
        cnt = cnt - 1
    return x, y


def nonlinear_transformation(x, prob=0.5):
    if random.random() >= prob:
        return x
    points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]]
    xpoints = [p[0] for p in points]
    ypoints = [p[1] for p in points]
    xvals, yvals = bezier_curve(points, nTimes=100)
    if random.random() < 0.5:
        # Half change to get flip
        xvals = np.sort(xvals)
    else:
        xvals, yvals = np.sort(xvals), np.sort(yvals)
    nonlinear_x = 255 * np.interp(np.true_divide(x, 255), xvals, yvals)
    # print('nonlinear_transformation')
    return nonlinear_x


def local_pixel_shuffling(x, prob=0.5):
    # print(x)
    if random.random() >= prob:
        return x
    image_temp = copy.deepcopy(x)
    orig_image = copy.deepcopy(x)
    img_deps, img_rows, img_cols = x.shape
    num_block = 800
    for _ in range(num_block):
        block_noise_size_x = random.randint(1, img_rows // 20)
        block_noise_size_y = random.randint(1, img_cols // 20)
        # block_noise_size_z = random.randint(1, img_deps//10)
        noise_x = random.randint(0, img_rows - block_noise_size_x)
        noise_y = random.randint(0, img_cols - block_noise_size_y)
        # noise_z = random.randint(0, img_deps-block_noise_size_z)
        window = orig_image[0, noise_x:noise_x + block_noise_size_x,
                 noise_y:noise_y + block_noise_size_y]
        # print(window)
        window = window.flatten()
        np.random.shuffle(window)
        window = window.reshape((1, block_noise_size_x,
                                 block_noise_size_y))
        image_temp[0, noise_x:noise_x + block_noise_size_x,
        noise_y:noise_y + block_noise_size_y] = window
    local_shuffling_x = image_temp
    # print('local_shuffling')
    return local_shuffling_x


def image_in_painting(x):
    img_deps, img_rows, img_cols = x.shape
    cnt = 5
    while cnt > 0 and random.random() < 0.95:
        block_noise_size_x = random.randint(img_rows // 20, img_rows // 10)
        block_noise_size_y = random.randint(img_cols // 20, img_cols // 10)
        # block_noise_size_z = random.randint(img_deps//6, img_deps//3)
        noise_x = random.randint(3, img_rows - block_noise_size_x - 3)
        noise_y = random.randint(3, img_cols - block_noise_size_y - 3)
        # noise_z = random.randint(3, img_deps-block_noise_size_z-3)

        x[0,
        noise_x:noise_x + block_noise_size_x,
        noise_y:noise_y + block_noise_size_y] = np.full((block_noise_size_x, block_noise_size_y),
                                                        np.random.rand(1)[0] * 255)
    # print('inpaint')    
    return x


def image_out_painting(x):
    img_deps, img_rows, img_cols, = x.shape
    image_temp = copy.deepcopy(x)
    x = np.full((x.shape[0], x.shape[1], x.shape[2]), np.random.rand(1)[0] * 255)
    block_noise_size_x = img_rows - random.randint(3 * img_rows // 12, 4 * img_rows // 12)
    block_noise_size_y = img_cols - random.randint(3 * img_cols // 12, 4 * img_cols // 12)
    noise_x = random.randint(3, img_rows - block_noise_size_x - 3)
    noise_y = random.randint(3, img_cols - block_noise_size_y - 3)
    x[0,
    noise_x:noise_x + block_noise_size_x,
    noise_y:noise_y + block_noise_size_y, ] = image_temp[0, noise_x:noise_x + block_noise_size_x,
                                              noise_y:noise_y + block_noise_size_y]
    cnt = 4
    while cnt > 0 and random.random() < 0.95:
        block_noise_size_x = img_rows - random.randint(3 * img_rows // 12, 4 * img_rows // 12)
        block_noise_size_y = img_cols - random.randint(3 * img_cols // 12, 4 * img_cols // 12)
        # block_noise_size_z = img_deps - random.randint(3*img_deps//7, 4*img_deps//7)
        noise_x = random.randint(3, img_rows - block_noise_size_x - 3)
        noise_y = random.randint(3, img_cols - block_noise_size_y - 3)
        # noise_z = random.randint(3, img_deps-block_noise_size_z-3)
        x[:,
        noise_x:noise_x + block_noise_size_x,
        noise_y:noise_y + block_noise_size_y] = image_temp[:, noise_x:noise_x + block_noise_size_x,
                                                noise_y:noise_y + block_noise_size_y]
    # print('out_paint')

    return x


## Dataloader and Self-supervised dataset

In [None]:
class SelfSupervisedDataset(Dataset):
    """Dementia Dataset"""

    def __init__(self, original_dataset_dir, preprocessed_dataset_dir, labels_file, begin_at_slice=6, num_slices=7,
                 image_crop_size=128,
                 label_type='PVWM', preprocess=True,
                 transform=None):
        """
        Args:
            param original_dataset_dir (string): Path to the directory with all the original images.
            param preprocessed_dataset_dir (string): Path to the directory with all the preprocessed images,
            param labels_file (string): Path to the csv file with PVWM and DWM labels.
            begin_at_slice (int): The first slice number to use. This is because we are interested in middle slices.
            num_slices (int): Number of slices from each MRI series to use.
            image_crop_size (int): The height and width of a center patch extracted from each original image.
            label_type (string): Assumes values 'PVWM', 'DWM'.
            preprocess (boolean): Whether to preprocess samples.
            transform (callable, optional): Optional transform to be applied on a sample.
        """

        assert label_type in ['PVWM', 'DWM'], "Invalid label type {}, label type must be one of ['PVWM', 'DWM']".format(
            label_type)

        self.original_dataset_dir = original_dataset_dir
        self.preprocessed_dataset_dir = preprocessed_dataset_dir
        self.labels = pd.read_csv(labels_file, dtype={"ID": str})
        self.begin_at_slice = begin_at_slice
        self.num_slices = num_slices
        self.image_size = image_crop_size
        self.label_type = label_type
        self.preprocess = preprocess
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):

        slices = self.get_original_slices(idx)

        if self.preprocess:
            preprocessed_slices = self.get_preprocessed_slices(idx)
            slices['slices'] = torch.cat([slices['slices'], preprocessed_slices['slices']], dim=0)
            slices['label'] = torch.cat([slices['label'], preprocessed_slices['label']], dim=0)
        sample = [slices['slices'], slices['label']]
        return sample

    def get_original_slices(self, idx):

        slices = torch.zeros((self.num_slices, 1, self.image_size, self.image_size))
        slices_label = torch.zeros((self.num_slices, 1, self.image_size, self.image_size))

        patient_id = self.labels.loc[idx, 'ID'] # patient id in the dataset
        patient_dir = os.path.join(self.original_dataset_dir, patient_id)

        all_slices = sorted(os.listdir(patient_dir))
        middle_slices = all_slices[self.begin_at_slice: self.begin_at_slice + self.num_slices]

        for i, slice_file in enumerate(middle_slices):
            slice_file_path = os.path.join(patient_dir, slice_file)
            slice_data = pydicom.dcmread(slice_file_path).pixel_array.astype(float)
            slice_data *= 255.0 / np.max(slice_data)  # normalize all images to [0-255]
            slice_data = slice_data.astype(np.uint8)

            if self.transform:
                slice_data = (self.transform(np.expand_dims(slice_data, 2))).numpy() * 255
            x = copy.deepcopy(slice_data)
            x, slice_data = data_augmentation(x, slice_data, config.flip_rate)
            x = local_pixel_shuffling(x, prob=config.local_rate)
            x = nonlinear_transformation(x, config.nonlinear_rate)
            if random.random() < config.paint_rate:
                if random.random() < config.inpaint_rate:
                    x = image_in_painting(x)
                else:
                    x = image_out_painting(x)

            x_train_tensor = torch.from_numpy(np.true_divide(x, 255)).unsqueeze(0)
            x_label_tensor = torch.from_numpy(np.true_divide(slice_data, 255)).unsqueeze(0)

            slices[i] = x_train_tensor
            slices_label[i] = x_label_tensor
        sample = {'slices': slices, 'label': slices_label}
        return sample

    def get_preprocessed_slices(self, idx):
        slices = torch.zeros((self.num_slices, 1, self.image_size, self.image_size))
        slices_label = torch.zeros((self.num_slices, 1, self.image_size, self.image_size))

        patient_id = self.labels.loc[idx, 'ID'] # patient id in the dataset
        patient_dir = os.path.join(self.original_dataset_dir, patient_id)

        all_slices = sorted(os.listdir(patient_dir))
        middle_slices = all_slices[self.begin_at_slice: self.begin_at_slice + self.num_slices]

        for i, slice_file in enumerate(middle_slices):
            slice_file_path = os.path.join(patient_dir, slice_file)

            slice_data = io.imread(slice_file_path)
            slice_data = (slice_data * (255.0 / np.max(slice_data)))  # normalize all images to [0-255]
            slice_data = slice_data.astype(np.uint8)

            if self.transform:
                slice_data = (self.transform(np.expand_dims(slice_data, 2))).numpy() * 255

            x = copy.deepcopy(slice_data)
            x, slice_data = data_augmentation(x, slice_data, config.flip_rate)
            x = local_pixel_shuffling(x, prob=config.local_rate)
            x = nonlinear_transformation(x, config.nonlinear_rate)
            if random.random() < config.paint_rate:
                if random.random() < config.inpaint_rate:
                    x = image_in_painting(x)
                else:
                    x = image_out_painting(x)

            x_train_tensor = torch.from_numpy(np.true_divide(x, 255)).unsqueeze(0)
            x_label_tensor = torch.from_numpy(np.true_divide(slice_data, 255)).unsqueeze(0)

            slices[i] = x_train_tensor
            slices_label[i] = x_label_tensor
        sample = {'slices': slices, 'label': slices_label}
        return sample

In [None]:
transform = transforms.Compose([transforms.ToPILImage(),
                                transforms.RandomRotation(degrees=45),
                                transforms.CenterCrop(128),
                                transforms.ToTensor()])


In [None]:
dataset = SelfSupervisedDataset(original_dataset_dir=config.original_dataset_dir,
                                preprocessed_dataset_dir=config.preprocessed_dataset_dir,
                                labels_file=config.labels_file,
                                label_type=config.label_type,
                                preprocess=config.preprocess,
                                begin_at_slice=config.begin_at_slice,
                                num_slices=config.num_slices,
                                image_crop_size=config.image_crop_size,
                                transform=transform
                                )

In [None]:
train_split = 0.9
batch_size = 4

dataset_size = len(dataset)
train_size = int(np.floor(train_split * dataset_size))
validation_size = dataset_size - train_size

train_set, val_set = torch.utils.data.random_split(dataset, [train_size, validation_size])

train_samples_weights = torch.from_numpy(np.true_divide(np.ones(len(train_set)), len(train_set)))
print(train_samples_weights)
train_sampler = WeightedRandomSampler(weights=train_samples_weights, num_samples=len(train_samples_weights),
                                      replacement=True)

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, sampler=train_sampler),
    'val': DataLoader(val_set, batch_size=batch_size, )
}


## Sava data

In [None]:
images, labels = next(iter(dataloaders["train"]))

In [None]:
os.mkdir('../log_dir/ss/check2')
os.mkdir('../log_dir/ss/check2/image')
os.mkdir('../log_dir/ss/check2/label')
num = 0
for i in range(images.shape[0]):
    for j in range(images.shape[1]):
        print(i,j)
        image.imsave(os.path.join('../log_dir/ss/check2/image/' ,'{}.png'.format(num)), images[i][j][0], cmap='gray')
        image.imsave(os.path.join('../log_dir/ss/check2/label/' ,'{}.png'.format(num)), labels[i][j][0], cmap='gray')
        num += 1

In [None]:
images.shape

In [None]:
labels.shape


# Model

## Resnet18

In [None]:
__all__ = ['ResNet', 'resnet18']

import torch
from torch import nn
from torch.hub import load_state_dict_from_url
from torch.nn import init

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=2, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 2:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=2, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 2.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=2, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 2
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 2, layers[0])
        self.layer2 = self._make_layer(block, 4, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 8, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 16, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride, downsample, self.groups,
                        self.base_width, previous_dilation, norm_layer)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]r======
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
    """ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)


In [None]:
base_model = resnet18(pretrained=False)

print(list(base_model.children()))

## ResnetUnet

In [None]:

def init_weights(m):
    if type(m) == nn.Linear:
        init.kaiming_normal_(m.weight, mode='fan_in')
        m.bias.data.fill_(0.01)


def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):

    def __init__(self, n_class):
        super(ResNetUNet, self).__init__()

        self.base_model = resnet18(pretrained=False)

        self.base_layers = list(self.base_model.children())
        self.layer0 = nn.Sequential(*self.base_layers[0:3])  # size=(N, 64, x.H/2, x.W/2)
        self.layer0.apply(init_weights)

        self.layer0_1x1 = convrelu(2, 2, 1, 0)
        self.layer0_1x1.apply(init_weights)
        self.layer1 = nn.Sequential(*self.base_layers[3:5])  # size=(N, 64, x.H/4, x.W/4)
        self.layer1.apply(init_weights)

        self.layer1_1x1 = convrelu(2, 2, 1, 0)
        self.layer1_1x1.apply(init_weights)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2.apply(init_weights)

        self.layer2_1x1 = convrelu(4, 4, 1, 0)
        self.layer2_1x1.apply(init_weights)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3.apply(init_weights)

        self.layer3_1x1 = convrelu(8, 8, 1, 0)
        self.layer3_1x1.apply(init_weights)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4.apply(init_weights)

        self.layer4_1x1 = convrelu(16, 16, 1, 0)
        self.layer4_1x1.apply(init_weights)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(8 + 16, 16, 3, 1)
        self.conv_up2 = convrelu(4 + 16, 8, 3, 1)
        self.conv_up1 = convrelu(2 + 8, 8, 3, 1)
        self.conv_up0 = convrelu(2 + 8, 4, 3, 1)

        self.conv_up3.apply(init_weights)
        self.conv_up2.apply(init_weights)
        self.conv_up1.apply(init_weights)
        self.conv_up0.apply(init_weights)

        self.conv_original_size0 = convrelu(1, 2, 3, 1)
        self.conv_original_size1 = convrelu(2, 2, 3, 1)
        self.conv_original_size2 = convrelu(2 + 4, 2, 3, 1)

        self.conv_original_size0.apply(init_weights)
        self.conv_original_size1.apply(init_weights)
        self.conv_original_size2.apply(init_weights)

        self.conv_last = nn.Conv2d(2, n_class, 1)
        init.kaiming_normal_(self.conv_last.weight, mode='fan_in')

    def forward(self, x, decode=True):
        x_original = self.conv_original_size0(x)
        x_original = self.conv_original_size1(x_original)

        layer0, layer1, layer2, layer3, layer4 = self.encode(x)

        if decode:
            out = self.decode(x_original, layer0, layer1, layer2, layer3, layer4)
            return out

        return layer4

    def encode(self, x):
        layer0 = self.layer0(x)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        return layer0, layer1, layer2, layer3, layer4

    def decode(self, x_original, layer0, layer1, layer2, layer3, layer4):
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out


In [None]:
base_model = ResNetUNet(1)

list(base_model.children())

# Train

In [None]:
def save_checkpoint(state,
                    filename='../models/ss_models/self_supervised_model.pth.tar'):
    torch.save(state, filename)

In [None]:


def calc_loss(pred, target, metrics, bce_weight=0.5):
    # bce = F.binary_cross_entropy_with_logits(pred, target)

    # pred = torch.sigmoid(pred)
    # dice = dice_loss(pred, target)

    # loss = bce * bce_weight + dice * (1 - bce_weight)

    # metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    # metrics['dice'] += dice.data.cpu().numpy() * target.size(0)

    loss = nn.MSELoss()
    output = loss(pred, target)
    # print(output)
    metrics['loss'] += output.data.cpu().numpy() * target.size(0)
    # return loss
    return output


def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))


def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        since = time.time()
        # uuu = 0
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode
            metrics = defaultdict(float)
            epoch_samples = 0

            for inputs, labels in dataloaders[phase]:
                inputs = ((inputs.reshape(-1, 1, 128, 128)).float()).to(device)
                labels = ((labels.reshape(-1, 1, 128, 128)).float()).to(device)
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    print("inputs.shape", inputs.shape)
                    outputs = model(inputs.float())
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples
            scheduler.step(epoch_loss)

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'best_score': epoch_loss,
                    'optimizer': optimizer.state_dict(),
                })

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = 1

model = ResNetUNet(n_class=1).to(device)
model.float()

# freeze backbone layers
# Comment out to finetune further
# for l in model.base_layers:
#     for param in l.parameters():
#         param.requires_grad = False

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

lr_scheduler = ReduceLROnPlateau(optimizer_ft, "min", factor=0.95, patience=30, verbose=False)

model = train_model(model, optimizer_ft, lr_scheduler, num_epochs=1000)

#Test

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(config.load_saved_model_path)
checkpoint = torch.load(config.load_saved_model_path, map_location=device)
resnet_unet = ResNetUNet(n_class=1).to(device)
resnet_unet.load_state_dict(checkpoint['state_dict'])
print(torch.load(config.load_saved_model_path, map_location=device)['epoch'])
print(torch.load(config.load_saved_model_path, map_location=device)['best_score'])

In [None]:
import shutil

try:
    shutil.rmtree('../log_dir/check')
except FileNotFoundError:
    pass

In [None]:
model.eval()
test_loader = DataLoader(val_set, batch_size=2, )

inputs, labels = next(iter(test_loader))
os.mkdir('../log_dir/check')
os.mkdir('../log_dir/check/image')
os.mkdir('../log_dir/check/label')
os.mkdir('../log_dir/check/reconstruction')
num = 0
for i in range(inputs.shape[0]):
    for j in range(inputs.shape[1]):
        print(i, j)
        image.imsave(os.path.join('../log_dir/check/image/', '{}.png'.format(num)), inputs[i][j][0], cmap='gray')
        image.imsave(os.path.join('../log_dir/check/label/', '{}.png'.format(num)), labels[i][j][0], cmap='gray')
        num += 1
inputs = ((inputs.reshape(-1, 1, 128, 128)).float()).to(device)
labels = ((labels.reshape(-1, 1, 128, 128)).float()).to(device)
pred = model(inputs)
pred = torch.sigmoid(pred)
pred = pred.data.cpu().numpy()
print(pred.shape)
for i in range(len(pred)):
    image.imsave(os.path.join('../log_dir/check/reconstruction', '{}.png'.format(i)), pred[i][0], cmap='gray')