In [1]:
import os
from glob import glob
import yaml
from tqdm import tqdm
from ipywidgets import interact
from collections import OrderedDict

from monai.transforms import(
    Compose,
    AddChanneld,
    LoadImage,
    Resized,
    ToTensord,
    Spacingd,
    ScaleIntensityRanged,
    CropForegroundd,
    Rotate90d,
    apply_transforms
)
from monai.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.nn.functional as f
import matplotlib.pyplot as plt
import numpy as np

In [2]:
class UNet(nn.Module):

    def __init__(self, in_channels=1, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [112]:
config_path = 'config.yaml'
with open(config_path, 'r') as config_file:
    config = yaml.safe_load(config_file)

In [113]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, images_directory, masks_directory, transform = None):
        self.images_directory = sorted(glob(images_directory + '/*'))
        self.masks_directory = sorted(glob(masks_directory + '/*'))
        self.img_size = 512
        self.lung_2d, self.mask_2d = self.get_lists()

        
    def get_lists(self):
        lung_2d = []
        mask_2d = []

        for i in range(len(self.images_directory)):

            lung_3d = LoadImage(image_only=True, ensure_channel_first=False, simple_keys=True)(self.images_directory[i])
            mask_3d = LoadImage(image_only=True, ensure_channel_first=False, simple_keys=True)(self.masks_directory[i])

            for j in (pbar := tqdm(range(lung_3d.shape[2]))):
                lung_2d.append(lung_3d[:,:,j])
                mask_2d.append(mask_3d[:,:,j])

        return lung_2d, mask_2d
            
    def __len__(self):
        return len(self.lung_2d)
    
    def __getitem__(self, idx):
    
        
        return {'image': self.lung_2d[idx], 'label': self.mask_2d[idx]}
    

In [114]:
params = config['augmentation_staff']
lungs_path = os.path.join(params['dataset_path'], 'ct_scans_tr')
masks_path = os.path.join(params['dataset_path'], 'lung_mask')
train_dataset = MyDataset(lungs_path, masks_path)

100%|██████████████████████████████████████| 301/301 [00:00<00:00, 20394.58it/s]
100%|██████████████████████████████████████| 200/200 [00:00<00:00, 29836.77it/s]
100%|██████████████████████████████████████| 200/200 [00:00<00:00, 13189.22it/s]
100%|██████████████████████████████████████| 270/270 [00:00<00:00, 23306.00it/s]
100%|████████████████████████████████████████| 290/290 [00:00<00:00, 301.96it/s]
100%|██████████████████████████████████████| 213/213 [00:00<00:00, 26000.78it/s]
100%|██████████████████████████████████████| 249/249 [00:00<00:00, 13683.71it/s]
100%|██████████████████████████████████████| 301/301 [00:00<00:00, 17248.25it/s]
100%|██████████████████████████████████████| 256/256 [00:00<00:00, 34721.96it/s]
100%|██████████████████████████████████████| 301/301 [00:00<00:00, 13020.28it/s]
100%|████████████████████████████████████████| 39/39 [00:00<00:00, 28612.53it/s]
100%|██████████████████████████████████████| 418/418 [00:00<00:00, 28540.11it/s]
100%|███████████████████████

In [115]:
transforms_for_ploting  = Compose(
        [
            AddChanneld(keys=["image", "label"]),
            Rotate90d(keys=["image", "label"], k=1, ),
            Spacingd(keys=['image', 'label'], pixdim=eval(params['pixdim']), mode=('bilinear', 'nearest')),
            ScaleIntensityRanged(keys=["image"], a_min=params['window_lvl'] - params['window_width']/2,
                                 a_max=params['window_lvl'] + params['window_width']/2,
                                 b_min=0.0, b_max=1.0, clip=True), 
            CropForegroundd(keys=["image", "label"], source_key="image"),    
            Resized(keys=["image", "label"], spatial_size = eval(params['img_size'])),
        ]
        )
def show_layer(layer):
    patient = train_dataset[layer]
    patient = transforms_for_ploting(patient)
    lung = patient['image'].squeeze()
    mask = patient['label'].squeeze()
    fig = plt.figure(figsize=(18, 15))
    plt.subplot(1, 2, 1)
    plt.imshow(lung, cmap='bone')
    plt.axis('off')
    plt.title('Original')

    plt.subplot(1, 2, 2)
#     plt.imshow(lung, cmap='bone')
    plt.imshow(mask,  alpha = 0.5,cmap='nipy_spectral')
    plt.axis('off')
    plt.title('With mask')

interact(show_layer, layer=(0, len(train_dataset)-1))

interactive(children=(IntSlider(value=1759, description='layer', max=3519), Output()), _dom_classes=('widget-i…

<function __main__.show_layer(layer)>

In [116]:
train_transforms = Compose(
        [
            AddChanneld(keys=["image", "label"]),
            Rotate90d(keys=["image", "label"], k=1),
            Spacingd(keys=['image', 'label'], pixdim=eval(params['pixdim']), mode=('bilinear', 'nearest')),
            ScaleIntensityRanged(keys=["image"], a_min=params['window_lvl'] - params['window_width']/2,
                                 a_max=params['window_lvl'] + params['window_width']/2,
                                 b_min=0.0, b_max=1.0, clip=True), 
            CropForegroundd(keys=["image", "label"], source_key="image"),    
            Resized(keys=["image", "label"], spatial_size = eval(params['img_size'])),
            ToTensord(keys=["image", "label"]),

        ]
    )

In [117]:
train_dataset_for_dataloader = [patient for patient in train_dataset]
transformed_dataset = Dataset(data=train_dataset_for_dataloader, transform=train_transforms)
train_loader = DataLoader(transformed_dataset, num_workers = 1, batch_size=params['batch_size'], shuffle = True)

8

In [118]:
class DiceLoss(nn.Module):

    def __init__(self):
        super(DiceLoss, self).__init__()
        self.smooth = 1.0

    def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        y_pred = y_pred[:, 0].contiguous().view(-1)
        y_true = y_true[:, 0].contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dsc = (2. * intersection + self.smooth) / (
            y_pred.sum() + y_true.sum() + self.smooth
        )
        return 1. - dsc

In [119]:
params_train = config['train_staff']
if params_train['device'] == 'mps':
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
elif params_train['device'] == 'cuda':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet()
model.to(device)
dc_loss = DiceLoss()
dc_loss = dc_loss.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas = (0.9, 0.999))
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.9)

In [120]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [121]:
count_parameters(model)

7762465

In [104]:
epochs = params_train['networks']['epochs']
loss_epochs_list = []
acc_epochs_list = []

for epoch in range(epochs):
    loss_val = 0
    acc_val = 0
    loop = tqdm(train_loader)
    for sample in loop:
        lung, label = sample['image'], sample['label']
        lung = lung.to(device)
        label = label.to(device)
        
        optimizer.zero_grad()
        pred = model(lung)
        
        loss = dc_loss(pred, label)
        loss.backward()
        loss_val += loss.item()
        
        acc_current = 1 - dc_loss(pred.cpu().float(), label.cpu().float())
        acc_val += acc_current
        optimizer.step()
        loop.set_description_str(f'loss = {loss.item()}   acc = {acc_current}')
        
    scheduler.step()
    torch.save(model.state_dict(), 'model_weights')
        
    

  t = cls([], dtype=storage.dtype, device=storage.device)
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  0%|                                                    | 0/38 [00:04<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 7.51 GB, other allocations: 1.80 GB, max allowed: 9.07 GB). Tried to allocate 256 bytes on shared pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).