In [None]:
from glob import glob
import os
import time


import numpy as np
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pandas as pd
from scipy.ndimage.morphology import binary_dilation
import segmentation_models_pytorch as smp
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from tqdm import tqdm

In [None]:
model = smp.Unet(
    encoder_name="efficientnet-b7",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

In [None]:
model

In [None]:
x = torch.randn(3, 1, 512, 512).requires_grad_(True)
y = model(x)

In [None]:
from torchviz import make_dot

make_dot(y, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

In [None]:
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
device

In [None]:
class BoneAgeSegmentationDataset(Dataset):
    def __init__(self, df, transform=None, mean=0.5, std=0.25):
        super(BoneAgeSegmentationDataset, self).__init__()
        self.df = df
        self.transform = transform
        self.mean = mean
        self.std = std
        self.resize = A.Compose([
            #     A.ChannelDropout(p=0.3),
                A.Resize(height=512, width=512)
        ])
        
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, idx, raw=False):
        row = self.df.iloc[idx]
        img = cv2.imread(row['input'], cv2.IMREAD_UNCHANGED)
        mask = cv2.imread(row['label'], cv2.IMREAD_GRAYSCALE)
        resized = self.resize(image=img, mask=mask)
        img, mask = resized['image'], resized['mask']
        if raw:
            return img, mask
        
        if self.transform:
#             img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
            augmented = self.transform(image=img, mask=mask)
            img, mask = augmented['image'], augmented['mask']
        
        img = T.functional.to_tensor(img)
        mask = mask // 255
        mask = torch.Tensor(mask)
        return img, mask

In [None]:
inputs = sorted(os.listdir('input/'))
labels = os.listdir('labels/')
labels = sorted(labels, key=lambda x: int(os.path.splitext(x)[0].split("-")[1]))

df = pd.DataFrame({'input': inputs, 'label': labels}, columns=["input", "label"])
df.input = 'input/' + df.input
df.label = 'labels/' + df.label

In [None]:
train_df, test_df = train_test_split(df, test_size=0.2)
test_df, valid_df = train_test_split(test_df, test_size=0.5)

In [None]:
transform = A.Compose([
#     A.ChannelDropout(p=0.3),
    A.RandomBrightnessContrast(p=0.3),
    A.ColorJitter(p=0.3),
    A.Resize(height=512, width=512)
])

train_dataset = BoneAgeSegmentationDataset(train_df, transform)
valid_dataset = BoneAgeSegmentationDataset(valid_df)
test_dataset = BoneAgeSegmentationDataset(test_df)

In [None]:
batch_size = 8

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)

In [None]:
%matplotlib inline
n_examples = 4

fig, axs = plt.subplots(n_examples, 3, figsize=(20, n_examples*7), constrained_layout=True)
i = 0
for ax in axs:
    while True:
        image, mask = train_dataset.__getitem__(i, raw=True)
#         image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
        i += 1
#         if np.any(mask): 
        ax[0].set_title("MRI images")
        ax[0].imshow(image, cmap='gray')
        ax[1].set_title("Highlited abnormality")
        ax[1].imshow(np.where(mask==0, 0, image), cmap='gray')
        ax[2].imshow(mask, cmap='gray')
        ax[2].set_title("Abnormality mask")
        break

In [None]:
class EarlyStopping():
    """
    Stops training when loss stops decreasing in a PyTorch module.
    """
    def __init__(self, patience:int = 6, min_delta: float = 0, weights_path: str = 'weights.pt'):
        """
        :param patience: number of epochs of non-decreasing loss before stopping
        :param min_delta: minimum difference between best and new loss that is considered
            an improvement
        :paran weights_path: Path to the file that should store the model's weights
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.weights_path = weights_path

    def __call__(self, val_loss: float, model: torch.nn.Module):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            torch.save(model.state_dict(), self.weights_path)
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

    def load_weights(self, model: torch.nn.Module):
        """
        Loads weights of the best model.
        :param model: model to which the weigths should be loaded
        """
        return model.load_state_dict(torch.load(self.weights_path))

In [None]:
def iou_pytorch(predictions: torch.Tensor, labels: torch.Tensor, e: float = 1e-7):
    """Calculates Intersection over Union for a tensor of predictions"""
    predictions = torch.where(predictions > 0.5, 1, 0)
    labels = labels.byte()
    
    intersection = (predictions & labels).float().sum((1, 2))
    union = (predictions | labels).float().sum((1, 2))
    
    iou = (intersection + e) / (union + e)
    return iou

def dice_pytorch(predictions: torch.Tensor, labels: torch.Tensor, e: float = 1e-7):
    """Calculates Dice coefficient for a tensor of predictions"""
    predictions = torch.where(predictions > 0.5, 1, 0)
    labels = labels.byte()
    
    intersection = (predictions & labels).float().sum((1, 2))
    return ((2 * intersection) + e) / (predictions.float().sum((1, 2)) + labels.float().sum((1, 2)) + e)

In [None]:
def BCE_dice(output, target, alpha=0.01):
    bce = torch.nn.functional.binary_cross_entropy(output, target)
    soft_dice = 1 - dice_pytorch(output, target).mean()
    return bce + alpha * soft_dice

In [None]:
model = smp.FPN(
    encoder_name="efficientnet-b7",
    encoder_weights="imagenet",
    in_channels=1,
    classes=1,
    activation='sigmoid',
)
model.to(device);

In [None]:
def training_loop(epochs, model, train_loader, valid_loader, optimizer, loss_fn, lr_scheduler):
    history = {'train_loss': [], 'val_loss': [], 'val_IoU': [], 'val_dice': []}
    early_stopping = EarlyStopping(patience=7)
    
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        
        running_loss = 0
        model.train()
        for i, data in enumerate(tqdm(train_loader)):
            img, mask = data
            img, mask = img.to(device), mask.to(device)
            predictions = model(img)
            predictions = predictions.squeeze(1)
            loss = loss_fn(predictions, mask)
            running_loss += loss.item() * img.size(0)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        model.eval()
        with torch.no_grad():
            running_IoU = 0
            running_dice = 0
            running_valid_loss = 0
            for i, data in enumerate(valid_loader):
                img, mask = data
                img, mask = img.to(device), mask.to(device)
                predictions = model(img)
                predictions = predictions.squeeze(1)
                running_dice += dice_pytorch(predictions, mask).sum().item()
                running_IoU += iou_pytorch(predictions, mask).sum().item()
                loss = loss_fn(predictions, mask)
                running_valid_loss += loss.item() * img.size(0)
        train_loss = running_loss / len(train_loader.dataset)
        val_loss = running_valid_loss / len(valid_loader.dataset)
        val_dice = running_dice / len(valid_loader.dataset)
        val_IoU = running_IoU / len(valid_loader.dataset)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_IoU'].append(val_IoU)
        history['val_dice'].append(val_dice)
        print(f'Epoch: {epoch}/{epochs} | Training loss: {train_loss} | Validation loss: {val_loss} | Validation Mean IoU: {val_IoU} '
         f'| Validation Dice coefficient: {val_dice}')
        
        lr_scheduler.step(val_loss)
        if early_stopping(val_loss, model):
            early_stopping.load_weights(model)
            break
    model.eval()
    return history

In [None]:
loss_fn = BCE_dice
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 100
lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, patience=2,factor=0.2)

history = training_loop(epochs, model, train_loader, valid_loader, optimizer, loss_fn, lr_scheduler)

In [None]:
plt.figure(figsize=(7, 7))
plt.plot(history['train_loss'], label='Training', color='red')
plt.plot(history['val_loss'], label='Validation', color='green')
plt.ylim(0, 0.3)
plt.xlabel('Epoch')
plt.ylabel('Loss (BCE Dice)')
plt.title('Training Performance')
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(7, 7))
plt.ylim(0.7, 1)
plt.plot(history['val_IoU'], label='Validation Mean IoU', color='blue')
plt.plot(history['val_dice'], label='Validation Dice', color='orange')
plt.legend()
plt.xlabel('Epoch')
plt.title('Mean IoU and Dice in Validation')
plt.show()

In [None]:
with torch.no_grad():
    running_IoU = 0
    running_dice = 0
    running_loss = 0
    for i, data in enumerate(test_loader):
        img, mask = data
        img, mask = img.to(device), mask.to(device)
        predictions = model(img)
        predictions = predictions.squeeze(1)
        running_dice += dice_pytorch(predictions, mask).sum().item()
        running_IoU += iou_pytorch(predictions, mask).sum().item()
        loss = loss_fn(predictions, mask)
        running_loss += loss.item() * img.size(0)
    loss = running_loss / len(test_dataset)
    dice = running_dice / len(test_dataset)
    IoU = running_IoU / len(test_dataset)
    
    print(f'Tests: loss: {loss} | Mean IoU: {IoU} | Dice coefficient: {dice}')

In [None]:
%matplotlib inline

width = 3
columns = 2 
n_examples = columns * width

fig, axs = plt.subplots(columns, width, figsize=(4*width , 4*columns), constrained_layout=True)
red_patch = mpatches.Patch(color='red', label='The red data')
fig.legend(loc='upper right',handles=[
    mpatches.Patch(color='red', label='Ground truth'),
    mpatches.Patch(color='green', label='Predicted abnormality')])
i = 0
with torch.no_grad():
    for data in test_loader:
        image, mask = data
        mask = mask[0]
        if not mask.byte().any():
            continue
        image = image.to(device)
        prediction = model(image).to('cpu')[0][0]
        prediction = torch.where(prediction > 0.5, 1, 0)
        prediction_edges = prediction - binary_dilation(prediction)
        ground_truth = mask - binary_dilation(mask)
        image = cv2.cvtColor(image[0].to('cpu').permute(1, 2, 0).numpy(),cv2.COLOR_GRAY2RGB)
#         print(prediction_edges.bool().shape)
        image[ground_truth.bool(), :] = [1, 0, 0]
        image[prediction_edges.bool(), :] = [0, 1, 0]
        axs[i//width][i%width].imshow(image)
        if n_examples == i + 1:
            break
        i += 1

In [None]:
~prediction_edges.bool()

In [None]:
torch.save(model.state_dict(), 'model.pth')