In [1]:
!pip install torch-summary

Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5


In [2]:
import sys
sys.path.append('../input/siim-acr-pneumothorax-segmentation/stage_2_images')
from mask_functions import rle2mask

In [3]:
import numpy as np
import pandas as pd
import cv2 as cv
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
from tqdm import tqdm
import pydicom
from glob import glob
import os

In [4]:
def squash(x, dim = -1):
    square_norm = torch.sum(x ** 2, dim = -1, keepdims = True)
    return square_norm / (1 + square_norm) * x / (torch.sqrt(square_norm) + 1e-6)

class PrimaryCapsule(nn.Module):
    def __init__(self, in_channels, out_channels, cap_dim, kernel_size = 5, stride = 2, padding = 2):
        super().__init__()
        self.cap_dim = cap_dim
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding)
    def forward(self, x):
        x = self.conv(x)
        number_capsules = x.shape[1] // self.cap_dim
        x = x.view(-1, number_capsules, self.cap_dim, *(x.shape[2:4]))
        x = x.permute(0, 3, 4, 1, 2)
        return squash(x)
class ConvCapsule(nn.Module):
    def __init__(self, input_shape, in_capsules, in_cap_dim, out_capsules, out_cap_dim, kernel_size, stride, padding = 0, iterations = 3):
        super().__init__()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.padding = padding
        self.input_shape = input_shape
        self.in_capsules = in_capsules
        self.in_cap_dim = in_cap_dim
        self.out_capsules = out_capsules
        self.out_cap_dim = out_cap_dim
        self.iterations = iterations
        self.kernel_size = kernel_size
        self.stride = stride
        self.W = nn.Parameter(torch.randn(out_capsules * out_cap_dim, in_cap_dim, kernel_size, kernel_size, device = device))
        self.device = device
        nn.init.xavier_normal_(self.W)

    def forward(self, x):
        in_shape = x.shape
        x = x.permute(0, 3, 4, 1, 2)
        x = x.reshape(-1, self.in_cap_dim, *self.input_shape)
        x = F.conv2d(x, self.W, stride = self.stride, padding = self.padding)
        x = x.view(in_shape[0], in_shape[-2], self.out_capsules, self.out_cap_dim, x.shape[2], x.shape[3])
        u_hat = x.permute(0, 1, 4, 5, 2, 3)

        b = torch.zeros(u_hat.shape[:-1]).to(self.device)
        for _ in range(self.iterations - 1):
            c = F.softmax(b, dim = -2)
            s = torch.sum(u_hat * c.unsqueeze(-1), dim = 1)
            v = squash(s)
            b = b + torch.sum(u_hat * v.unsqueeze(1), dim = -1)
        c = F.softmax(b, dim = -2)
        s = torch.sum(u_hat * c.unsqueeze(-1), dim = 1)
        v = squash(s)
        return v
class ConvTranposeCapsule(nn.Module):
    def __init__(self, input_shape, in_capsules, in_cap_dim, out_capsules, out_cap_dim, kernel_size, stride, padding = 0, iterations = 3):
        super().__init__()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.padding = padding
        self.input_shape = input_shape
        self.in_capsules = in_capsules
        self.in_cap_dim = in_cap_dim
        self.out_capsules = out_capsules
        self.out_cap_dim = out_cap_dim
        self.iterations = iterations
        self.kernel_size = kernel_size
        self.stride = stride
#         self.W = nn.Parameter(torch.randn(in_cap_dim, out_capsules * out_cap_dim, kernel_size, kernel_size).to(device))
        self.tconv = nn.ConvTranspose2d(in_cap_dim, out_capsules * out_cap_dim, kernel_size = kernel_size, padding = padding, stride = stride)
        self.device = device
#         nn.init.xavier_normal_(self.W)

    def forward(self, x):
        in_shape = x.shape
        x = x.permute(0, 3, 4, 1, 2)
        x = x.reshape(-1, self.in_cap_dim, *self.input_shape)
        x = self.tconv(x)
        #       x = F.conv_transpose2d(x, self.W, stride = self.stride, padding = self.padding)
        x = x.view(in_shape[0], in_shape[-2], self.out_capsules, self.out_cap_dim, x.shape[2], x.shape[3])
        u_hat = x.permute(0, 1, 4, 5, 2, 3)

        b = torch.zeros(u_hat.shape[:-1]).to(self.device)
        for _ in range(self.iterations - 1):
            c = F.softmax(b, dim = -2)
            s = torch.sum(u_hat * c.unsqueeze(-1), dim = 1)
            v = squash(s)
            b = b + torch.sum(u_hat * v.unsqueeze(1), dim = -1)
        c = F.softmax(b, dim = -2)
        s = torch.sum(u_hat * c.unsqueeze(-1), dim = 1)
        v = squash(s)
        return v

In [5]:
class UpConv(nn.Module):
    def __init__(self, input_shape, scale_factor, in_capsules, in_cap_dim, out_capsules, out_cap_dim, kernel_size, stride, padding, iterations = 3):
        super().__init__()
        self.scale_factor = scale_factor
        new_shape = (input_shape[0] * scale_factor, input_shape[1] * scale_factor)
        self.conv_cap = ConvCapsule(new_shape, in_capsules, in_cap_dim, out_capsules, out_cap_dim, kernel_size = kernel_size, padding = padding, stride = stride, iterations = iterations)
    def forward(self, x):
        in_dim = x.shape[-1]
        in_caps = x.shape[-2]
        h, w = x.shape[1: 3]
        x = x.permute(0, 3, 4, 1, 2)
        x = x.reshape(x.shape[0], -1, h, w)
        x = F.interpolate(x, scale_factor = self.scale_factor)
        x = x.reshape(x.shape[0], in_caps, in_dim, h * self.scale_factor, w * self.scale_factor)
        x = x.permute(0, 3, 4, 1, 2)
        x = self.conv_cap(x)
        return x

In [6]:
def dice_score(inputs, targets, smooth = 1):
    return (2 * (inputs * targets).sum() + smooth) / (inputs.sum() + targets.sum() + smooth)
def dice_loss(inputs, targets, smooth = 1):
    return 1 - (2 * (inputs * targets).sum() + smooth) / (inputs.sum() + targets.sum() + smooth)
def focal_loss(inputs, targets, alpha = 0.8, gamma = 2, reduction = 'none'):
    assert reduction in ['none', 'mean', 'sum']
    loss = F.binary_cross_entropy(inputs, targets, reduction = 'none')
    coeff = (1 - inputs) ** gamma
    loss = coeff * loss
    focal_loss = torch.where(targets == 1, loss, loss * alpha)
    if reduction == "none": return focal_loss
    elif reduction == "mean": return focal_loss.mean()
    elif reduction == "sum": return focal_loss.sum()
def criterion(inputs, targets, factors = [3, 1, 4], smooth = 1, alpha = 0.8, gamma = 2, reduction = 'mean'):
    return F.binary_cross_entropy(inputs, targets) * factors[0] \
            + dice_loss(inputs, targets, smooth = smooth) * factors[1] \
            + focal_loss(inputs, targets, alpha = alpha, gamma = gamma, reduction = reduction) * factors[2]

In [7]:
def get_conv_shape(input_shape, kernel_size, padding, stride):
    def fn(x):
        return (x + 2 * padding - kernel_size) // stride + 1
    return fn(input_shape[0]), fn(input_shape[1])
def get_conv_transpose_shape(input_shape, kernel_size, padding, stride):
    def fn(x):
        return stride * (x - 1) - 2 * padding + kernel_size
    return fn(input_shape[0]), fn(input_shape[1])
def get_upconv_shape(input_shape, scale):
    return input_shape[0] * scale, input_shape[1] * scale
class CapUnetNet(nn.Module):
    def __init__(self, image_shape, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size = 5, padding = 2)
        self.relu1 = nn.ReLU()
        self.cap1 = PrimaryCapsule(1, 16, 16, kernel_size = 5, stride = 1, padding = 2)
        self.cap2 = ConvCapsule(image_shape, 1, 16, 2, 16, kernel_size = 5, stride = 2, padding = 2, iterations = 1)
        out_shape2 = get_conv_shape(image_shape, kernel_size = 5, padding = 2, stride = 2)
        self.cap2_1 = ConvCapsule(out_shape2, 2, 16, 4, 16, kernel_size = 3, padding = 1, stride = 1)
        self.cap3 = ConvCapsule(out_shape2, 4, 16, 4, 32, kernel_size = 3, padding = 1, stride = 2)
        out_shape3 = get_conv_shape(out_shape2, kernel_size = 3, padding = 1, stride = 2)
        self.cap3_1 = ConvCapsule(out_shape3, 4, 32, 8, 32, kernel_size = 3, padding = 1, stride = 1)
        self.cap4 = ConvCapsule(out_shape3, 8, 32, 8, 64, kernel_size = 3, padding = 1, stride = 2)
        out_shape4 = get_conv_shape(out_shape3, kernel_size = 3, padding = 1, stride = 2)

        self.up_cap4 = UpConv(out_shape4, 2, 8, 64, 8, 32, kernel_size = 3, stride = 1, padding = 1)
        out_up3 = get_upconv_shape(out_shape4, 2)
        self.up_cap4_1 = ConvCapsule(out_up3, 16, 32, 4, 32, kernel_size = 3, padding = 1, stride = 1)
        self.up_cap3 = UpConv(out_up3, 2, 4, 32, 4, 16, kernel_size = 3, stride = 1, padding = 1)
        out_up2 = get_upconv_shape(out_up3, 2)
        self.up_cap3_1 = ConvCapsule(out_up2, 8, 16, 2, 16, kernel_size = 3, padding = 1, stride = 1)
        self.up_cap2 = UpConv(out_up2, 2, 2, 16, 1, 16, kernel_size = 3, padding = 1, stride = 1)
        out_up1 = get_upconv_shape(out_up2, 2)
        self.re_cap = ConvCapsule(out_up1, 2, 16, 1, 16, kernel_size = 3, padding = 1, stride = 1)
        self.up_tconv = nn.Conv2d(16, 8, kernel_size = 3, padding = 1)
        self.up_relu = nn.ReLU()
        self.re_conv = nn.Conv2d(8, 1, kernel_size = 1)


    def forward(self, x):
#         x = self.conv1(x)
#         x = self.relu1(x)
        out_cap1 = self.cap1(x)
        out_down2 = self.cap2(out_cap1)
        out_down2_1 = self.cap2_1(out_down2)
        out_down3 = self.cap3(out_down2_1)
        out_down3_1 = self.cap3_1(out_down3)
        out_down4 = self.cap4(out_down3_1)

        out_up3 = self.up_cap4(out_down4)
        out_cat3 = torch.cat((out_down3_1, out_up3), dim = -2)
        out_up3_1 = self.up_cap4_1(out_cat3)
        out_up2 = self.up_cap3(out_up3_1)
        out_cat2 = torch.cat((out_down2_1, out_up2), dim = -2)
        out_up2_1 = self.up_cap3_1(out_cat2)
        out_up1 = self.up_cap2(out_up2_1)
        out = torch.cat((out_cap1, out_up1), dim = -2)
        out = self.re_cap(out)
        out = out_up1.view(*out_up1.shape[:-2], -1)
        out = out.permute(0, 3, 1, 2)
        out = self.up_tconv(out)
        out = self.up_relu(out)
        out = self.re_conv(out)
        return torch.sigmoid(out)

In [8]:
summary(CapUnetNet((512, 512), 1), (1, 512, 512))

Layer (type:depth-idx)                   Output Shape              Param #
├─PrimaryCapsule: 1-1                    [-1, 512, 512, 1, 16]     --
|    └─Conv2d: 2-1                       [-1, 16, 512, 512]        416
├─ConvCapsule: 1-2                       [-1, 256, 256, 2, 16]     12,800
├─ConvCapsule: 1-3                       [-1, 256, 256, 4, 16]     9,216
├─ConvCapsule: 1-4                       [-1, 128, 128, 4, 32]     18,432
├─ConvCapsule: 1-5                       [-1, 128, 128, 8, 32]     73,728
├─ConvCapsule: 1-6                       [-1, 64, 64, 8, 64]       147,456
├─UpConv: 1-7                            [-1, 128, 128, 8, 32]     --
|    └─ConvCapsule: 2-2                  [-1, 128, 128, 8, 32]     147,456
├─ConvCapsule: 1-8                       [-1, 128, 128, 4, 32]     36,864
├─UpConv: 1-9                            [-1, 256, 256, 4, 16]     --
|    └─ConvCapsule: 2-3                  [-1, 256, 256, 4, 16]     18,432
├─ConvCapsule: 1-10                      [-1, 256, 

Layer (type:depth-idx)                   Output Shape              Param #
├─PrimaryCapsule: 1-1                    [-1, 512, 512, 1, 16]     --
|    └─Conv2d: 2-1                       [-1, 16, 512, 512]        416
├─ConvCapsule: 1-2                       [-1, 256, 256, 2, 16]     12,800
├─ConvCapsule: 1-3                       [-1, 256, 256, 4, 16]     9,216
├─ConvCapsule: 1-4                       [-1, 128, 128, 4, 32]     18,432
├─ConvCapsule: 1-5                       [-1, 128, 128, 8, 32]     73,728
├─ConvCapsule: 1-6                       [-1, 64, 64, 8, 64]       147,456
├─UpConv: 1-7                            [-1, 128, 128, 8, 32]     --
|    └─ConvCapsule: 2-2                  [-1, 128, 128, 8, 32]     147,456
├─ConvCapsule: 1-8                       [-1, 128, 128, 4, 32]     36,864
├─UpConv: 1-9                            [-1, 256, 256, 4, 16]     --
|    └─ConvCapsule: 2-3                  [-1, 256, 256, 4, 16]     18,432
├─ConvCapsule: 1-10                      [-1, 256, 

In [9]:
class XrayDataset(Dataset):
    def __init__(self, paths, df, target_shape = (512, 512), mode = 'train', transforms = None):
        super().__init__()
        self.paths = paths
        self.df = df
        self.target_shape = target_shape
        self.mode = mode
        self.transforms = transforms
        self.data = self.prepare_data()
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        info = self.data[idx]
        path = info['path']
        dcmdata = pydicom.dcmread(path)
        image = dcmdata.pixel_array
        if self.target_shape is not None:
            image = cv.resize(image, self.target_shape[::-1], interpolation = cv.INTER_CUBIC)
        if self.mode == "test":
            image = np.expand_dims(image, 0) / 255.
            image = image.astype(np.float32)
            return torch.from_numpy(image)
        encoded_masks = info['masks']
        mask = np.zeros(self.target_shape)
        for encoded_mask in encoded_masks:
            if encoded_mask != "-1":
                _mask = rle2mask(encoded_mask, 1024, 1024).T
                _mask = cv.resize(_mask, self.target_shape[::-1], interpolation = cv.INTER_CUBIC)
                mask[_mask == 255] = 255
        if self.transforms:
            aug = self.transforms(image = image, mask = mask)
            image = aug['image']
            mask = aug['mask']        
        image = np.expand_dims(image, 0) / 255.
        mask = np.expand_dims(mask, 0) / 255.
        image = image.astype(np.float32)
        mask = mask.astype(np.float32)
        return torch.from_numpy(image), torch.from_numpy(mask)

    def prepare_data(self):
        data = []
        image_ids = self.df['ImageId'].unique()
        for image_id in tqdm(image_ids):
            index = list(filter(lambda x: image_id in self.paths[x], range(len(self.paths))))
            if len(index) == 0:
                continue
            index = index[0]
            path = self.paths[index]
            all_chests = self.df[self.df["ImageId"] == image_id]
            encode_rois = []
            for _, row in all_chests.iterrows():
                encode_rois.append(row[" EncodedPixels"])
            data.append({
                'image_id': image_id,
                'path': path,
                'masks': encode_rois
            })
        return data

In [10]:
paths = glob(os.path.join('..', 'input', 'siim-train-test', 'dicom-images-train', '*', '*', '*.dcm'))

In [11]:
df = pd.read_csv("../input/siim-train-test/train-rle.csv")

In [12]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INIT_LR = 1e-5
BATCH_SIZE = 2
TRAIN_RATE = 0.75
EPOCHS = 20

In [13]:
train_paths, val_paths = train_test_split(paths, train_size = TRAIN_RATE)

In [14]:
train_dataset = XrayDataset(train_paths, df)
val_dataset = XrayDataset(val_paths, df, mode = 'val')

100%|██████████| 12047/12047 [01:16<00:00, 157.85it/s]
100%|██████████| 12047/12047 [00:21<00:00, 556.44it/s]


In [15]:
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE)

In [16]:
model = CapUnetNet((512, 512), 1).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr = INIT_LR)

In [17]:
ITERS_PER_EPOCH = len(train_paths) // BATCH_SIZE

In [18]:
for epoch in range(EPOCHS):
    bar = tqdm(range(ITERS_PER_EPOCH))
    for i, (images, masks) in enumerate(train_loader):
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)
        pred_masks = model(images)
        loss = criterion(pred_masks, masks)
        loss.backward()
        optimizer.step()
        bar.set_description(f'epoch {epoch + 1} iter {i + 1} loss {loss.cpu().detach().item()}')
    with torch.no_grad():
        acc_loss = 0
        acc_dice = 0
        count = 0
        for images, masks in tqdm(val_loader):
            images = images.to(torch.float32).to(DEVICE)
            masks  = masks.to(torch.float32).to(DEVICE)
            pred_masks = model(images)
            loss = criterion(pred_masks, masks)
            acc_dice += dice_score(pred_masks, masks)
            acc_loss += loss.cpu().item()
            count += 1
        acc_loss /= count
        acc_dice /= count
        print(f'[VAL] epoch {epoch} loss {acc_loss} dice {acc_dice}')

epoch 1 iter 1 loss 3.8499557971954346:   0%|          | 0/4533 [00:38<?, ?it/s]

RuntimeError: all elements of input should be between 0 and 1

In [19]:
pred_masks

tensor([[[[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]]],


        [[[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]]]], grad_fn=<SigmoidBackward>)

In [22]:
for p in model.parameters():
    if p.requires_grad:
         print(p.name, p.data)

None tensor([[[[-3.5453e-02,  8.0910e-02,  7.8410e-03, -9.3192e-02, -1.6132e-01],
          [-7.6303e-02, -1.4904e-01,  9.8070e-02, -4.7112e-02,  6.1998e-02],
          [ 1.9028e-01,  1.2712e-01, -8.4803e-02,  6.3459e-02,  3.6548e-02],
          [-1.7687e-01, -1.7444e-01,  1.7430e-01,  1.7025e-02, -5.2179e-02],
          [ 6.9308e-02,  3.1164e-02,  1.5894e-01, -6.4666e-02,  9.3455e-03]]],


        [[[ 1.5420e-01,  1.2863e-01,  5.9149e-02, -9.1593e-02, -5.4059e-02],
          [-8.3433e-02, -1.6336e-01,  1.0875e-01,  1.5094e-01, -2.4183e-03],
          [ 1.3701e-01,  7.2726e-02,  4.5911e-02, -5.0948e-02, -2.8501e-03],
          [ 1.3533e-01,  2.6575e-02,  5.4932e-02,  5.5155e-02,  1.3815e-01],
          [ 3.4955e-02,  3.8531e-02, -3.2546e-02, -1.9486e-01, -1.0140e-01]]],


        [[[ 1.2958e-01, -1.2638e-01, -1.4096e-01,  1.6479e-01, -1.8333e-01],
          [ 1.9504e-01,  2.4040e-02, -9.4551e-02, -1.6500e-01, -1.7157e-01],
          [ 3.8113e-02, -1.5854e-01,  2.9875e-02,  9.8179e-02, 

In [40]:
model(images.to(torch.float32).to(DEVICE))

tensor([[[[0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          ...,
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228]]],


        [[[0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          ...,
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228],
          [0.5228, 0.5228, 0.5228,  ..., 0.5228, 0.5228, 0.5228]]]],
       grad_fn=<SigmoidBackward>)