In [1]:
from os import listdir as ls

import matplotlib.pyplot as plt
import cv2

from torch.utils.data import Dataset, DataLoader

from natsort import natsorted

import segmentation_models_pytorch
from timm import create_model

from tqdm.notebook import tqdm

import torch
import nibabel as nib

from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
)

import numpy as np

from segmentation_models_pytorch import Unet, UnetPlusPlus
from segmentation_models_pytorch.losses import DiceLoss

from einops import rearrange

In [2]:
class SegmentationDatasetNumpy(Dataset):
    def __init__(self, train, dataset):
        super().__init__()
        assert train in ['train', 'val']
        assert dataset in ['corona', 'radiopedia', 'all']
        self.train = train
        if self.train == 'train':
            if dataset == 'corona':
                data = np.load('corona_train_segm_dataset.npz')
                self.all_data = data['images']
                self.all_masks = data['masks']
            elif dataset == 'radiopedia':
                data = np.load('radiopedia_train_segm_dataset.npz')
                self.all_data = data['images']
                self.all_masks = data['masks']
            elif dataset == 'all':
                data_corona = np.load('corona_train_segm_dataset.npz')
                data_radio = np.load('radiopedia_train_segm_dataset.npz')
                self.all_data = np.concatenate([data_corona['images'], data_radio['images']], axis=-1)
                self.all_masks = np.concatenate([data_corona['masks'], data_radio['masks']], axis=-1)
        else:
            if dataset == 'corona':
                data = np.load('corona_val_segm_dataset.npz')
                self.all_data = data['images']
                self.all_masks = data['masks']
            elif dataset == 'radiopedia':
                data = np.load('radiopedia_val_segm_dataset.npz')
                self.all_data = data['images']
                self.all_masks = data['masks']
            elif dataset == 'all':
                data_corona = np.load('corona_val_segm_dataset.npz')
                data_radio = np.load('radiopedia_val_segm_dataset.npz')
                self.all_data = np.concatenate([data_corona['images'], data_radio['images']], axis=-1)
                self.all_masks = np.concatenate([data_corona['masks'], data_radio['masks']], axis=-1)

    def __len__(self):
        return self.all_data.shape[-1]
    
    def __getitem__(self, idx):
#         print(idx)
        return self.all_data[...,idx], self.all_masks[...,idx]

In [10]:
class SegmentationDataset(Dataset):
    def __init__(self, train):
        super().__init__()
        assert train in ['train', 'val']
        self.train = train
        self.prefix = '/media/ssd-3t/datasets/course_data/covid_segmentation/'
        all_files = ls(self.prefix+'images')
        corona_files = natsorted([elem for elem in all_files if 'corona' in elem])
        radiopedia_files = natsorted([elem for elem in all_files if 'radio' in elem])
        thr_corona = int(len(corona_files)*0.8)
        thr_radio = int(len(radiopedia_files)*0.8)
        if self.train == 'train':
            self.all_files = corona_files[:thr_corona]
#             self.all_files = radiopedia_files[:thr_radio]
        else:
            self.all_files = corona_files[thr_corona:]
#             self.all_files = radiopedia_files[thr_radio:]
        all_data = None
        all_masks = None
        for elem in self.all_files:
            img_data = nib.load(self.prefix+'images/'+elem).get_fdata().astype('float32')
            img_data = cv2.resize(img_data, (224,224))
            mask_data = nib.load(self.prefix+'masks/'+elem).get_fdata().astype('uint8')
            mask_data = cv2.resize(mask_data, (224,224))
            if all_data is not None:
                all_data = np.concatenate([all_data, img_data], axis=-1)
                all_masks = np.concatenate([all_masks, mask_data], axis=-1)
            else:
                all_data = img_data
                all_masks = mask_data
        self.all_data = all_data
        self.all_masks = all_masks
        
    def __len__(self):
        return self.all_data.shape[-1]
    
    def __getitem__(self, idx):
#         print(idx)
        return self.all_data[...,idx], self.all_masks[...,idx]

In [119]:
val_ds = SegmentationDatasetNumpy('val', 'corona')
train_ds = SegmentationDatasetNumpy('train', 'corona')

val_ds_rad = SegmentationDatasetNumpy('val', 'radiopedia')
train_ds_rad = SegmentationDatasetNumpy('train', 'radiopedia')

val_ds_all = SegmentationDatasetNumpy('val', 'all')
train_ds_all = SegmentationDatasetNumpy('train', 'all')

In [120]:
mean = train_ds_all.all_data.mean()
std = train_ds_all.all_data.std()

In [11]:
val_ds = SegmentationDataset('val')
train_ds = SegmentationDataset('train')


In [12]:
np.savez_compressed('corona_train_segm_dataset', images=train_ds.all_data, masks=train_ds.all_masks)
np.savez_compressed('corona_val_segm_dataset', images=val_ds.all_data, masks=val_ds.all_masks)

In [123]:
# train_ds.all_data = train_ds.all_data[...,150:164]
# train_ds.all_masks = train_ds.all_masks[...,150:164]
# val_ds.all_data = train_ds.all_data
# val_ds.all_masks = train_ds.all_masks

In [124]:
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=8)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=8)

train_loader_rad = DataLoader(train_ds_rad, batch_size=16, shuffle=True, num_workers=8)
val_loader_rad = DataLoader(val_ds_rad, batch_size=16, shuffle=False, num_workers=8)

train_loader_all = DataLoader(train_ds_all, batch_size=16, shuffle=True, num_workers=8)
val_loader_all = DataLoader(val_ds_all, batch_size=16, shuffle=False, num_workers=8)

model = Unet(encoder_name='resnet50', encoder_weights=None, in_channels=1, classes=1)
model = model.cuda()

criterion = torch.nn.BCEWithLogitsLoss()
criterion = DiceLoss('binary')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [125]:
def calculate_dice(pred, y):
    return 2*((pred>0)*y).sum(axis=(1,2,3))/((pred>0).sum(axis=(1,2,3)) + y.sum(axis=(1,2,3))+1e-3)

def dice_loss(pred, y):
    probs = torch.sigmoid(pred)
    return 1 - (2*probs.int()*y).sum()/(probs.sum()+y.sum()+1e-3)

def train(train_loader):
    model.train()
    model.zero_grad()
    train_loss = 0
    for batch in tqdm(train_loader):
        x, y = batch
        x = (x - mean)/std
        x = x.unsqueeze(1).cuda()
        y = y.unsqueeze(1).cuda().float()
        pred = model(x)
        loss = criterion(pred, y)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    return train_loss/len(train_loader)
        
def val(val_loader):
    model.eval()
    val_loss = 0
    val_dice = 0
    non_zeros_y = 0
    for batch in tqdm(val_loader):
        x, y = batch
        x = (x - mean)/std
        x = x.unsqueeze(1).cuda()
        y = y.unsqueeze(1).cuda().float()
        with torch.no_grad():
            pred = model(x)
        loss = criterion(pred, y)
        val_loss += loss.item()
        d = calculate_dice(pred, y)
        val_dice += d.sum()
        non_zeros_y += (y.sum(axis=[1,2,3]) != 0).sum()
    return val_dice/non_zeros_y, val_loss/len(val_loader), pred, y

In [128]:
train_losses = list()
val_losses = list()
val_dices = list()

# criterion = dice_loss
dice, val_loss, pred, y = val(val_loader)
# val_losses.append(val_loss)
# val_dices.append(dice)
print(dice)
for i in range(20):
    tr_loss = train(train_loader_all)
    train_losses.append(tr_loss)
    dice, val_loss, pred, y = val(val_loader)
    val_losses.append(val_loss)
    val_dices.append(dice)
    print(i, val_dices[i])
    d, _, _, _ = val(val_loader_rad)
    print(d)

  0%|          | 0/35 [00:00<?, ?it/s]

tensor(5.2980e-06, device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

0 tensor(0.6923, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

1 tensor(0.7022, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

2 tensor(0.6529, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

3 tensor(0.7150, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

4 tensor(0.6790, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

5 tensor(0.6976, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

6 tensor(0.6975, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

7 tensor(0.6842, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

8 tensor(0.7015, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

9 tensor(0.7063, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

10 tensor(0.6918, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

11 tensor(0.7058, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

12 tensor(0.6779, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

13 tensor(0.6886, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

14 tensor(0.7025, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

15 tensor(0.6891, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

16 tensor(0.7176, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

17 tensor(0.7092, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

18 tensor(0.6881, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

19 tensor(0.7127, device='cuda:0')


  0%|          | 0/9 [00:00<?, ?it/s]

tensor(0., device='cuda:0')
