<a href="https://colab.research.google.com/github/luigiantonelli/Lightweight-Conditional-Swin-U-Net-for-Medical-Image-reconstruction-and-segmentation/blob/main/luigi_antonelli_thesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Author: Luigi Antonelli

Implementation in PyTorch (and PyTorch Lightning) of the Deep Learning architectures introduced in the Master Thesis "Lightweight Conditional Swin U-Net for Medical Image reconstruction and segmentation.

Note: the notebook is not ready to run since it has been modified to hide details about my personal file system and my Weights & Biases credentials.

Autore: Luigi Antonelli

Implementazione in PyTorch (e PyTorch Lightning) delle architetture di Deep Learning introdotte nella tesi di Laurea Magistrale "Lightweight Conditional Swin U-Net for Medical Image reconstruction and segmentation".

Nota: il notebook non è pronto per l'esecuzione poiché è stato modificato per nascondere i dettagli del mio file system personale e le mie credenziali di Weights & Biases.

# Imports and installations

In [None]:
!pip install pytorch-lightning==2.0.0 --quiet
!pip install torchmetrics --quiet
!pip install torchvision --quiet
!pip install gdown==4.5.4 --no-cache-dir --quiet
!pip install einops --quiet
!pip install fastmri --quiet
!pip install h5py --quiet
!pip install nibabel --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m715.6/715.6 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.1/58.1 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.4/101.4 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for runstats (setup.py) ... [?25l[?25hdone


In [None]:
import os
import glob
import math
import pickle
import random
from typing import *
from datetime import datetime
import warnings
import nibabel as nib
import pandas as pd
import h5py
from scipy.ndimage import zoom
import gdown
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
from copy import deepcopy
#from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip, RandomResizedCrop, RandomRotation
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, random_split
import multiprocessing as mp

import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import fastmri
from fastmri.data import transforms as fastmriT

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed_everything(10, workers=True)


INFO:lightning_fabric.utilities.seed:Global seed set to 10


10

# fastMRI dataset

In [None]:
class Mask:
    def __init__(self, acceleration: int):

        self.__acceleration = acceleration

    @property
    def acceleration(self):
        return self.__acceleration

    def __call__(self, shape: Tuple) -> np.ndarray:
        if len(shape) == 3:
            n_sample, n_row, n_col = shape
        elif len(shape) == 2:
            n_sample = 1
            n_row, n_col = shape
        else:
            raise ValueError("Shape should have 2 or 3 dimensions")

        return self.generate_mask(n_sample, n_row, n_col)

    def generate_mask(self, n_sample: int,n_row: int, n_col: int) -> np.ndarray:
        raise NotImplementedError

In [None]:
class RectangularEquispacedMask(Mask):
    def __init__(
        self,
        acceleration: int,
        central_ratio: int):

        super().__init__(acceleration)

        self.__central_ratio = central_ratio

    @property
    def central_ratio(self):
        return self.__central_ratio

    def generate_mask(
        self,
        n_sample: int,
        n_row: int,
        n_col: int,
        ) -> np.ndarray:

        base_mask = np.zeros((n_row, n_col))

        n_lines = int(np.floor(n_col/self.acceleration))
        n_low = int(np.floor(n_lines*self.central_ratio))
        if n_low % 2 != 0:
            n_low += 1

        start_low = int(n_col/2 - n_low/2)
        end_low = int(n_col/2 + n_low/2)
        base_mask[:, start_low:end_low] = 1

        n_high = n_lines - n_low
        if n_high > 0:
            high_tab = int(np.around(n_col-n_lines)/(n_high+2))
            for h in range(int(n_high/2)):
                temp = int(high_tab*(h+1)+h)
                base_mask[:, start_low-temp-1] = 1
                base_mask[:, end_low+temp] = 1

        if np.ones((n_row, n_col)).sum()/base_mask.sum() < self.acceleration:
            warnings.warn("WARNING: Mask with wrong (less) acceleration")

        if n_sample == 1:
            return base_mask

        return np.repeat(base_mask[np.newaxis, :, :], n_sample, axis=0)


In [None]:
from typing import Tuple

def normalize(
    input_array: np.ndarray
    ):

    if len(input_array.shape) < 2:
        raise ValueError("Dimension must be at least 2")

    result = input_array.copy()
    len_shape = len(result.shape)
    result -= result.min(axis=(len_shape-2, len_shape-1), keepdims=True)
    _max = result.max(axis=(len_shape-2, len_shape-1), keepdims=True)
    _max[_max==0] = np.finfo(np.float64).tiny
    result /= _max

    return result

def apply_mask(
    k_space: np.ndarray,
    mask: np.ndarray
    ):

    if k_space.shape != mask.shape:
        raise ValueError("k_space and mask must have the same shape")

    return np.multiply(k_space, mask)

def check_dark(
    img: np.ndarray,
    dark_threshold: float = 0.25
    ):

    rows = [0, img.shape[0]]
    cols = [0, img.shape[1]]
    rows_bool = np.max(img, axis=1) > dark_threshold
    cols_bool = np.max(img, axis=0) > dark_threshold

    for i, b in enumerate(rows_bool):
        if b:
            rows[0] = i
            break
    for i, b in enumerate(np.flip(rows_bool)):
        if b:
            # rows[1] = img.shape[0] - i
            rows[1] = i
            break
    for i, b in enumerate(cols_bool):
        if b:
            cols[0] = i
            break
    for i, b in enumerate(np.flip(cols_bool)):
        if b:
            # cols[1] = img.shape[1] - i
            cols[1] = i
            break

    return (rows, cols)

def check_dark_square(
    img: np.ndarray,
    dark_threshold: float = 0.25
    ):

    rows, cols = check_dark(img, dark_threshold)

    return (int(min(rows)), int(min(cols)))

In [None]:
class fastMRIDataset(Dataset):
    def __init__(self, preprocessed_data_path: str,
        original_data_path: str = None, mask: Mask = None, dark_threshold = 0.25, keyword: str = 'train',
        target_size : int = 256,
        max_proc: int = 1,
        train: bool = False,
        test: bool = False,
        max_vlms: int = 20,
        pd: bool = False,
        pdfs: bool = False):

        super().__init__()
        self.original_data_path = original_data_path
        self.preprocessed_data_path = preprocessed_data_path
        self.images = []
        self.reconstructed_images = []
        self.__max_vlms = max_vlms // max_proc
        self.__pd = pd
        self.__pdfs = pdfs
        self.__target_size = target_size
        self.__dark_threshold = dark_threshold
        self.__test = test
        self.__train = train

        if self.original_data_path is not None:

            preprocess = True

            h5_files = [os.path.join(self.original_data_path, f) for f in os.listdir(self.original_data_path) if '.h5' in f]
            if len(h5_files) == 0:
                raise Exception("No valid h5 files found")

            if os.path.isdir(self.preprocessed_data_path):
                temp = [os.path.join(self.preprocessed_data_path, f) for f in os.listdir(self.preprocessed_data_path) if '.npz' in f]

                if len(temp) > 0:
                    preprocess = False
                    warnings.warn("WARNING: Preprocessed file already present")

            else:
                os.mkdir(self.preprocessed_data_path)

            if preprocess:

                if not mask:
                    mask = RectangularEquispacedMask(acceleration=4, central_ratio=0.8)

                if max_proc == 1:
                    self._generate_samples(h5_files, mask)
                else:
                    linspace = np.linspace(start=0, stop=len(h5_files), num=max_proc+1, dtype=int)
                    # tqdm from https://stackoverflow.com/questions/66208601/tqdm-and-multiprocessing-python
                    lock = mp.Manager().Lock()
                    with mp.Pool(processes=max_proc) as pool:
                        for i in range(max_proc):
                            pool.apply_async(self._generate_samples, args=(h5_files[linspace[i]:linspace[i+1]], mask, i+1, lock))

                        pool.close()
                        pool.join()

        self.preprocessed_data_list = [os.path.join(self.preprocessed_data_path, f) for f in os.listdir(self.preprocessed_data_path) if '.npz' in f]

        if len(self.preprocessed_data_list) == 0:
            raise Exception("No valid npz (preprocesses) files found")

    @property
    def keyword(self):
        return self.__keyword

    @property
    def target_size(self):
        return self.__target_size

    @property
    def dark_threshold(self):
        return self.__dark_threshold

    @property
    def train(self):
        return self.__train

    @property
    def test(self):
        return self.__test

    @property
    def preprocessed_data_list(self):
        return self.__preprocessed_data_list

    @preprocessed_data_list.setter
    def preprocessed_data_list(self, value):
        self.__preprocessed_data_list = value



    def _idxs_dark_slices(self, vlm: np.ndarray):
        rows_bools = np.sum((np.max(vlm, axis=2) > self.__dark_threshold), axis=1)
        cols_bools = np.sum((np.max(vlm, axis=1) > self.__dark_threshold), axis=1)
        slices_list = [True if r/vlm.shape[1] > 0.10 or c/vlm.shape[2] > 0.10 else False for r, c in zip(rows_bools, cols_bools)]

        start_idx = 0
        end_idx = vlm.shape[0] - 1

        for i in range(vlm.shape[0]):
            if slices_list[i]:
                start_idx = i
                break
        for i in range(vlm.shape[0]-1, -1, -1):
            if slices_list[i]:
                end_idx = i
                break

        return [start_idx, end_idx]

    def _crop_pad_2_square(
        self,
        img: np.ndarray,
        img_1: np.ndarray
    ):

        result = img.copy()
        result_1 = img_1.copy()

        h, w = result.shape
        if h == w:
            return (result, result_1)

        max_h_cut = result.shape[-2]
        max_w_cut = result.shape[-1]
        max_h_cut, max_w_cut = check_dark_square(result)

        if h > w:
            h_cut = int((h-w)/2)
            if h_cut <= max_h_cut:
                return (result[h_cut:-h_cut,:], result_1[h_cut:-h_cut,:])

            w_pad = h_cut
            if max_h_cut != 0 and max_h_cut < int((h - self.__target_size)/2):
                result = result[max_h_cut:-max_h_cut,:]
                result_1 = result_1[max_h_cut:-max_h_cut,:]
                h, w = result.shape
                h_cut = int((h-w)/2)
                w_pad = h_cut
            return (np.pad(result, ((0, 0), (w_pad, w_pad)), mode='constant', constant_values=0),
                    np.pad(result_1, ((0, 0), (w_pad, w_pad)), mode='constant', constant_values=0))
        elif w > h:
            w_cut = int((w-h)/2)
            if w_cut <= max_w_cut:
                return (result[:,w_cut:-w_cut], result_1[:,w_cut:-w_cut])

            h_pad = w_cut
            if max_w_cut != 0 and max_w_cut < int((w - self.__target_size)/2):
                result = result[:,max_w_cut:-max_w_cut]
                result_1 = result_1[:,max_w_cut:-max_w_cut]
                h, w = result.shape
                w_cut = int((w-h)/2)
                h_pad = w_cut
            return (np.pad(result, ((h_pad, h_pad), (0, 0)), mode='constant', constant_values=0),
                    np.pad(result_1, ((h_pad, h_pad), (0, 0)), mode='constant', constant_values=0))

    def _gen_background_mask(
        self,
        sample,
        dark_threshold = 0.25
    ):

        img = torch.tensor(sample)

        bg_mask = torch.zeros(img.size())

        img_bools = img > dark_threshold
        h, w = img_bools.size()

        rows_bools = torch.ones(img_bools.size()) * -1
        cols_bools = torch.ones(img_bools.size()) * -1
        for r in range(h):
            _indxs = img_bools[r, :].argwhere()
            if len(_indxs) > 0:
                rows_bools[r, _indxs[0]:_indxs[-1]] = 0
        for c in range(w):
            _indxs = img_bools[:, c].argwhere()
            if len(_indxs) > 0:
                cols_bools[_indxs[0]:_indxs[-1], c] = 0

        bg_mask[((rows_bools * cols_bools) - 1).bool()] = 1

        return bg_mask.numpy()

    def _generate_samples(
        self,
        h5_files: list,
        mask: Mask,
        position: int = 1,
        lock = None
        ):

        if lock == None:
            lock = mp.Manager().Lock()

        with lock:
            bar = tqdm(
                desc=f'Worker {position}',
                # total=len(h5_files),
                total=self.__max_vlms,
                position=position,
                leave=False
            )

        count = 0
        for f in h5_files:

            hf = h5py.File(f)

            include_sample = True

            if include_sample:
                vlm_k = hf['kspace'][()]

                mask_vlm = mask.generate_mask(*vlm_k.shape)
                masked_vlm_k = apply_mask(vlm_k, mask_vlm)

                _samples = fastmriT.normalize_instance(fastmriT.center_crop(fastmri.complex_abs(fastmri.ifft2c(fastmriT.to_tensor(masked_vlm_k))), (320, 320)), eps=1e-9)[0].numpy().astype(np.float64)
                _labels = fastmriT.normalize_instance(fastmriT.center_crop(fastmri.complex_abs(fastmri.ifft2c(fastmriT.to_tensor(vlm_k))), (320, 320)), eps=1e-9)[0].numpy().astype(np.float64)

                for i, (sample, label) in enumerate(zip(_samples, _labels)):

                    label, sample = self._crop_pad_2_square(label, sample)
                    sample = normalize(zoom(
                        sample,
                        zoom=(self.target_size/sample.shape[0], self.target_size/sample.shape[1])
                    ))
                    label = normalize(zoom(
                        label,
                        zoom=(self.target_size/label.shape[0], self.target_size/label.shape[1])
                    ))

                    sample_bg_mask = self._gen_background_mask(self._gen_background_mask(sample, min(max(sample.mean(), self.dark_threshold // 2), self.dark_threshold)))

                    if sample_bg_mask.sum() / (self.target_size ** 2) > 0.15:

                        out_file_name = f.split('/')[-1].split('.')[0] + '_' + str(i)
                        out_file = os.path.join(self.preprocessed_data_path, out_file_name)

                        if self.test:
                            np.savez_compressed(out_file, sample=sample, label=None, sample_bg_mask=sample_bg_mask)
                        else:
                            np.savez_compressed(out_file, sample=sample, label=label, sample_bg_mask=sample_bg_mask)

                with lock:
                    bar.update(1)

                count += 1

            if count == self.__max_vlms:
                break

        with lock:
            bar.close()

    def __len__(self):
        return len(self.__preprocessed_data_list)

    def __getitem__(self, idx):
        assert idx < self.__len__()

        with np.load(self.__preprocessed_data_list[idx]) as data:
            sample = data['sample']
            label = data['label']
            sample_bg_mask = data['sample_bg_mask']

        return (sample, label, sample_bg_mask)

In [None]:
def FastMRI_collate_fn(
    data_list: List[Tuple[np.ndarray, np.ndarray, np.ndarray]]
    ):

    samples = list()
    labels = list()
    sample_bg_masks = list()

    for sample, label, sample_bg_mask in data_list:

        samples.append(torch.tensor(sample, dtype=torch.float)[None, :, :])
        if label is not None:
            labels.append(torch.tensor(label, dtype=torch.float)[None, :, :])
        else:
            labels.append(None)
        sample_bg_masks.append(torch.tensor(sample_bg_mask, dtype=torch.float)[None, :, :])

    if None in labels:
        return (torch.stack(samples), None, torch.stack(sample_bg_masks))

    return (torch.stack(samples), torch.stack(labels), torch.stack(sample_bg_masks))

In [None]:
class fastMRIDataModule(pl.LightningDataModule):
    def __init__(self, preprocessed_data_path: str, preprocessed_data_path_test: str,
        original_data_path: str = None, original_data_path_test : str = None, mask: Mask = None, dark_threshold = 0.25, keyword: str = 'train',
        target_size : int = 256,
        max_proc: int = 1,
        max_vlms: int = 20,
        pd: bool = False,
        pdfs: bool = False,
        batch_size: int = 32,
        use_validation_set: bool = True
    ):
        super().__init__()
        self.original_data_path = original_data_path
        self.preprocessed_data_path = preprocessed_data_path
        self.original_data_path_test = original_data_path_test
        self.preprocessed_data_path_test = preprocessed_data_path_test
        self.mask = mask
        self.dark_threshold = dark_threshold
        self.target_size = target_size
        self.max_proc = max_proc
        self.max_vlms = max_vlms
        self.pd = pd
        self.pdfs = pd
        self.batch_size = batch_size
        self.use_validation_set = use_validation_set

    def setup(self, stage=None):
        if stage == "fit":
            self.fastmri_train = fastMRIDataset(self.preprocessed_data_path,
                self.original_data_path, self.mask, dark_threshold = self.dark_threshold, keyword = 'train',
                target_size = self.target_size,
                max_proc = self.max_proc,
                train = True,
                max_vlms = self.max_vlms,
                pd = self.pd,
                pdfs = self.pdfs)

            if self.use_validation_set:
                self.fastmri_val = fastMRIDataset(self.preprocessed_data_path_test,
                    self.original_data_path_test, self.mask, dark_threshold = self.dark_threshold, keyword = 'test',
                    target_size = self.target_size,
                    max_proc = self.max_proc,
                    train = True,
                    max_vlms = self.max_vlms,
                    pd = self.pd,
                    pdfs = self.pdfs)

        if stage == "test":
            self.fastmri_test = fastMRIDataset(self.preprocessed_data_path_test,
                self.original_data_path_test, self.mask, dark_threshold = self.dark_threshold, keyword = 'test',
                target_size = self.target_size,
                max_proc = self.max_proc,
                test = True,
                max_vlms = self.max_vlms,
                pd = self.pd,
                pdfs = self.pdfs)

    def train_dataloader(self):
        return DataLoader(self.fastmri_train, batch_size=self.batch_size, collate_fn=FastMRI_collate_fn,
                          shuffle=True, num_workers=2, pin_memory=True, drop_last = True if not self.use_validation_set else False)

    def val_dataloader(self):
        if self.use_validation_set:
            return DataLoader(self.fastmri_val, batch_size=self.batch_size, collate_fn=FastMRI_collate_fn,
                            num_workers=2, pin_memory=True)
        else:
            return None

    def test_dataloader(self):
        return DataLoader(self.fastmri_test, batch_size=self.batch_size, collate_fn=FastMRI_collate_fn,)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        pass

# Liver Tumor Segmentation dataset

In [None]:
from __future__ import print_function
import numpy as np
import os
import glob
import skimage.io as io
import skimage.transform as trans
import sys

In [None]:
# !mkdir ./LiverTumorSegmentation/train_images_preprocessed

In [None]:
# !mkdir ./LiverTumorSegmentation/train_masks_preprocessed

In [None]:
# def preprocess_directory(df):
#     filtered_df = []
#     images_dir = "./LiverTumorSegmentation/train_images_preprocessed"
#     masks_dir = "./LiverTumorSegmentation/train_masks_preprocessed"
#     threshold_values = [0.0, 128.0, 255.0]
#     last_i = 0

#     for i in range(len(df)):
#         img_name = df.iloc[i, 0]
#         mask_name = df.iloc[i, 1]
#         image = io.imread(img_name)
#         mask = io.imread(mask_name)
#         mask_orig = mask.copy()
#         new_mask = torch.zeros(*mask.shape)
#         for idx, t in enumerate(threshold_values):
#             new_mask[(mask >= t - 15) & (mask <= t + 15)] = idx
#         if not torch.all(new_mask == 0):
#             filtered_df.append(df.iloc[i])

#             image_path = os.path.join(images_dir, os.path.basename(img_name))
#             io.imsave(image_path, image)

#             mask_path = os.path.join(masks_dir, os.path.basename(mask_name))
#             io.imsave(mask_path, mask_orig)
#         last_i = i
#         print(f"last index: {last_i}")

#     return filtered_df


# df_filtered = preprocess_directory(df)

In [None]:
# file_list = []
# mask_list = []
# for dirname, _, filenames in os.walk('./LiverTumorSegmentation/train_images_preprocessed'):
#     for filename in filenames:
#         file_list.append(filename)

# for dirname, _, filenames in os.walk('./LiverTumorSegmentation/train_masks_preprocessed'):
#     for filename in filenames:
#         mask_list.append(filename)

# files = pd.DataFrame({"train": file_list, "mask": mask_list})

# filename = []
# for i in range(len(files)):
#     root = "./LiverTumorSegmentation/train_images_preprocessed"
#     path = os.path.join(root, files["train"][i])
#     filename.append(path)

# mask = []
# for i in range(len(files)):
#     root = "./LiverTumorSegmentation/train_masks_preprocessed"
#     path = os.path.join(root, files["train"][i])
#     mask.append(path)

# df = pd.DataFrame(data={"filename": filename, 'mask' : mask})
# df['mask'] = df['mask'].str.split(".").str[0] + "_mask.jpg"

In [None]:
img_size = (256, 256)

class LiverTumorSegmentationDataset(Dataset):
    def __init__(self, dataframe, img_size):
        self.dataframe = dataframe
        self.transform = transforms.Resize(img_size, interpolation=transforms.InterpolationMode.NEAREST_EXACT)
        self.threshold_values = [0.0, 128.0, 255.0]  # 0 is background, 128 is liver and 255 is tumor. Mapped to 0, 1, 2 respectively for class indices

    def threshold_function(self, mask):
        new_mask = torch.zeros(*mask.shape)
        for idx, t in enumerate(self.threshold_values):
            new_mask[(mask >= t - 15) & (mask <= t + 15)] = idx
        return new_mask

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        assert idx < self.__len__()

        img_name = self.dataframe.iloc[idx, 0]  # Assuming the first column is the image filename
        mask_name = self.dataframe.iloc[idx, 1]  # Assuming the second column is the mask filename
        image = io.imread(img_name)
        mask = io.imread(mask_name)

        image = np.transpose(image, (2, 0, 1)).astype(np.float32) / 255.0
        image = torch.tensor(image, dtype=torch.float32)

        mask = self.threshold_function(mask[...,-1])

        image = self.transform(image.unsqueeze(0)).squeeze(0)
        mask = self.transform(mask.unsqueeze(0)).squeeze(0)
        return image, mask.type(torch.long)

In [None]:
class LiverTumorSegmentationDataModule(pl.LightningDataModule):
    def __init__(self, train_dataframe, validation_dataframe = None, test_dataframe = None,
        batch_size: int = 32,
        use_validation_set: bool = True
    ):
        super().__init__()
        self.train_dataframe = train_dataframe
        self.validation_dataframe = validation_dataframe
        self.test_dataframe = test_dataframe
        self.batch_size = batch_size
        self.use_validation_set = use_validation_set

    def setup(self, stage=None):
        if stage == "fit":
            self.livertumorseg_train = LiverTumorSegmentationDataset(dataframe = self.train_dataframe, img_size = img_size)

            if self.use_validation_set and self.validation_dataframe is not None:
                self.livertumorseg_val = LiverTumorSegmentationDataset(dataframe = self.validation_dataframe, img_size = img_size)

        if stage == "test":
            if self.test_dataframe is not None:
                self.livertumorseg_test = LiverTumorSegmentationDataset(dataframe = self.test_dataframe, img_size = img_size)

    def train_dataloader(self):
        return DataLoader(self.livertumorseg_train, batch_size=self.batch_size,
                          shuffle=True, pin_memory=True, drop_last = True if not self.use_validation_set else False)

    def val_dataloader(self):
        if self.use_validation_set:
            return DataLoader(self.livertumorseg_val, batch_size=self.batch_size, pin_memory=True)
        else:
            return None

    def test_dataloader(self):
        return DataLoader(self.livertumorseg_test, batch_size=self.batch_size)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        pass

# Swin Transformer modules

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, input_channels, embedding_dim):
        super(PatchEmbedding, self).__init__()
        self.img_size, self.patch_size = img_size, patch_size

        self.input_channels = input_channels
        self.embedding_dim = embedding_dim

        self.projection = nn.Sequential(
            nn.Conv2d(input_channels, embedding_dim, kernel_size = self.patch_size, stride = self.patch_size),
            Rearrange('b d h w -> b (h w) d'),
            nn.LayerNorm(embedding_dim)
        )

    def forward(self, img):
        batch_size, channels, height, width = img.shape
        assert channels == self.input_channels and height == self.img_size[0] and width == self.img_size[1]
        return self.projection(img)

In [None]:
def window_partition(x, window_size): #divide la sequenza di patch in finestre
    # b, h, w, d = x.shape
    # x = x.view(b, h // window_size, window_size, w // window_size, window_size, d)
    # windows = x.transpose(2,3).contiguous().view(-1, window_size, window_size, d)
    windows = rearrange(x, "b (h ws_0) (w ws_1) d -> (b h w) ws_0 ws_1 d", ws_0 = window_size, ws_1 = window_size)
    return windows

def window_merge(windows, window_size, height, width): #ricompone la sequenza di patch a partire dalle finestre
    h, w = height // window_size, width // window_size
    x = rearrange(windows, "(b h w) ws_0 ws_1 d -> b (h ws_0) (w ws_1) d", h = h, w = w, ws_0 = window_size, ws_1 = window_size)
    return x

In [None]:
def scaled_dot_product_attention(query, key, value, scale, bias = 0, mask = None, dropout_layer = None):
    t = (torch.einsum('b h q d, b h k d -> b h q k', query, key) / scale) + bias
    # print(f"t.shape: {t.shape}")
    if mask is not None:
        b, h, q, _ = query.shape
        k = key.size(2)
        n_w = mask.size(1)
        t = t.view(b // n_w, n_w, h, q, k) #necessary because each window has a different mask and we wouldn't be able to broadcast the batch dimension
        t = t.masked_fill(mask == False, -1e10)
        t = t.view(-1, h, q, k)
    t = F.softmax(t, dim = -1)
    if dropout_layer is not None:
        t = dropout_layer(t)
    return torch.matmul(t, value)

In [None]:
class WindowMultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads, window_size, dropout=0.2):
        super(WindowMultiHeadAttention, self).__init__()
        assert embedding_dim % num_heads == 0
        self.dim_head = embedding_dim // num_heads # Single head dimension
        self.sqrt_q = math.sqrt(self.dim_head)
        self.num_heads = num_heads
        self.window_size = window_size

        self.W_qkv = nn.Linear(embedding_dim, 3*embedding_dim, bias=True)
        self.W_o = nn.Linear(embedding_dim, embedding_dim, bias=True)

        self.dropout = nn.Dropout(dropout)

        self.relative_position_table = nn.Parameter(torch.zeros(((2 * window_size - 1) * (2 * window_size - 1), num_heads)))
        #(2 * M - 1) perché è il numero di possibili posizioni relative su ogni asse: |{-M+1, ..., M-1}| = 2M-1
        #con M = 3: |{-2, -1, 0, 1, 2}| = 5 = 2*3 -1
        #anche se relative_position_index è quadratico si risparmia perché ci sono meno parametri da imparare

        coords_h = torch.arange(window_size)
        coords_w = torch.arange(window_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, W, W
        coords_flatten = torch.flatten(coords, 1)  # 2, W * W
        relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1)  # 2, W * W, W * W
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # W * W, W * W, 2
        relative_coords[:, :, 0] += window_size - 1  # shift to start from 0
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= window_size - 1
        relative_position_index = relative_coords.sum(-1).view(-1)  # W * W, W * W and then flatten
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self, x, mask = None):

        qkv = rearrange(self.W_qkv(x), "b n (qkv n_h d_h) -> qkv b n_h n d_h", qkv = 3, n_h = self.num_heads, d_h = self.dim_head)
        q, k, v = qkv[0], qkv[1], qkv[2]

        relative_position_bias = self.relative_position_table[self.relative_position_index]
        relative_position_bias = rearrange(relative_position_bias,
                                           "(ws_0 ws_1 ws_2 ws_3) n_h -> 1 n_h (ws_0 ws_1) (ws_2 ws_3)",
                                           ws_0 = self.window_size, ws_1 = self.window_size,
                                           ws_2 = self.window_size, ws_3 = self.window_size)

        attention_value = scaled_dot_product_attention(q, k, v, scale = self.sqrt_q,
                                                       bias = relative_position_bias,
                                                       mask = mask, dropout_layer = self.dropout)

        attention_value = rearrange(attention_value, "b h n d_h -> b n (h d_h)", d_h = self.dim_head)
        return self.W_o(attention_value)

In [None]:
class SwinTransformerBlock(nn.Module):
    def __init__(self, embedding_dim, hidden_size, patch_resolution, num_heads, window_size, shift_size=0, dropout=0.2):
        super(SwinTransformerBlock, self).__init__()
        self.height, self.width = patch_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size

        self.norm_layer1 = nn.LayerNorm(embedding_dim)
        self.attention = WindowMultiHeadAttention(embedding_dim, num_heads, window_size, dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm_layer2 = nn.LayerNorm(embedding_dim)
        self.ff = nn.Sequential(
            nn.Linear(embedding_dim, hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, embedding_dim)
        )
        self.dropout2 = nn.Dropout(dropout)


        if self.shift_size > 0:
            img_mask = torch.zeros((self.height, self.width))
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            id = 0
            for h in h_slices:#valido solo se lo shift viene fatto con (-self.shift_size,-self.shift_size)
                for w in w_slices:
                    img_mask[h, w] = id
                    id += 1

            mask_windows = rearrange(img_mask,
                                     "(h_w ws_0) (w_w ws_1) -> (h_w w_w) (ws_0 ws_1)",
                                     ws_0 = self.window_size, ws_1 = self.window_size)
            attention_mask = (mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)) != 0
            # attention_mask is a [num_windows, patches_per_window, patches_per_window] tensor
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(0)
            # attention_mask is a [1, num_windows, 1, patches_per_window, patches_per_window] tensor
            # the ones are for broadcasting the batch size and the number of heads
        else:
            attention_mask = None

        self.register_buffer("attention_mask", attention_mask)

    def forward(self, x):
        n = x.size(1)
        if n != self.height*self.width:
            print(f"error: x.shape = {x.shape}, height = {self.height}, width = {self.width}")
        assert n == self.height*self.width

        z = x
        x = rearrange(self.norm_layer1(x), "b (h w) d -> b h w d", h = self.height)

        if self.shift_size > 0: #for sw-mha
            x = torch.roll(x, shifts = (-self.shift_size, -self.shift_size), dims = (1, 2))

        windows = window_partition(x, self.window_size)
        windows = rearrange(windows, "bhw ws_0 ws_1 d -> bhw (ws_0 ws_1) d", ws_0 = self.window_size)

        attention = rearrange(self.attention(windows, self.attention_mask),
                              "bhw (ws_0 ws_1) d -> bhw ws_0 ws_1 d", ws_0 = self.window_size)

        x = window_merge(attention, self.window_size, self.height, self.width)

        if self.shift_size > 0: #undo shift
            x = torch.roll(x, shifts = (self.shift_size, self.shift_size), dims = (1, 2))

        x = rearrange(x, "b h w d -> b (h w) d")
        x = z + self.dropout1(x)

        return x + self.dropout2(self.ff(self.norm_layer2(x)))

In [None]:
class PatchMerge(nn.Module):
    def __init__(self, patch_resolution, embedding_dim, use_norm_layer = False):
        super(PatchMerge, self).__init__()
        self.height, self.width = patch_resolution
        assert self.height % 2 == 0 and self.width % 2 == 0

        self.projection = nn.Sequential(
            nn.LayerNorm(4*embedding_dim) if use_norm_layer else nn.Identity(),
            nn.Linear(4*embedding_dim, 2*embedding_dim)
        )

    def forward(self, x):
        n = x.size(1)
        assert n == self.height*self.width

        x = rearrange(x, "b (h w) d -> b h w d", h = self.height)
        x0 = x[:, 0::2, 0::2, :]  #b h/2 w/2 d
        x1 = x[:, 1::2, 0::2, :]  #b h/2 w/2 d
        x2 = x[:, 0::2, 1::2, :]  #b h/2 w/2 d
        x3 = x[:, 1::2, 1::2, :]  #b h/2 w/2 d
        x = torch.cat([x0, x1, x2, x3], -1)  #b h/2 w/2 4*d
        x = rearrange(x, "b half_h half_w d4 -> b (half_h half_w) d4")

        return self.projection(x)


class PatchExpand(nn.Module):
    def __init__(self, patch_resolution, embedding_dim, use_norm_layer = False, is_final = False):
        super(PatchExpand, self).__init__()
        self.height, self.width = patch_resolution

        scale = 2 if not is_final else 4
        expanded_dim = 2*embedding_dim if not is_final else 16*embedding_dim
        norm_dim = embedding_dim // scale if not is_final else embedding_dim

        self.expand_and_rearrange = nn.Sequential(
            nn.Linear(embedding_dim, expanded_dim, bias=False),
            Rearrange("b (h w) d -> b h w d", h = self.height),
            Rearrange("b h w (ps_0 ps_1 c) -> b (h ps_0) (w ps_1) c", h = self.height, ps_0 = scale, ps_1 = scale, c = expanded_dim // (scale**2)),
            Rearrange("b new_h new_w c -> b (new_h new_w) c"),
            nn.LayerNorm(norm_dim) if use_norm_layer else nn.Identity(),
        )

    def forward(self, x):
        n = x.size(1)
        assert n == self.height*self.width

        return self.expand_and_rearrange(x)

# Conditional modules

In [None]:
def create_batch_range(t, right_pad_dims = 1):
    b, device = t.shape[0], t.device
    batch_range = torch.arange(b, device = device)
    pad_dims = ((1,) * right_pad_dims)
    return batch_range.reshape(-1, *pad_dims)

def batched_gather(x, indices):
    batch_range = create_batch_range(indices, indices.ndim - 1)
    return x[batch_range, indices]

def route_back(x, routed_tokens, indices):
    batch_range = create_batch_range(routed_tokens)
    x[batch_range, indices] = routed_tokens
    return x

In [None]:
class LightRouter(nn.Module):
    def __init__(self, embedding_dim, k):
        super(LightRouter, self).__init__()
        self.k = k
        self.routing_embeddings = nn.Parameter(torch.randn(embedding_dim))

    def forward(self, x):
        scores = torch.einsum("b n d, d -> b n", x, self.routing_embeddings)
        normalized_scores = F.softmax(scores, dim = -1)

        k = min(self.k, x.size(1))
        values, indices = torch.topk(normalized_scores, dim = -1, k = k)
        return values.unsqueeze(-1), indices, normalized_scores.unsqueeze(-1) #unsqueeze to broadcast the embedding dimension

In [None]:
class IterativeSoftTopkRouter(nn.Module):
    def __init__(self, embedding_dim, k, eps = 2e-2, eps_init = None, eps_decay = 1., iters = 20):
        super(IterativeSoftTopkRouter, self).__init__()
        self.k = k
        self.logk = math.sqrt(k)
        self.routing_embeddings = nn.Parameter(torch.randn(embedding_dim))
        self.eps = eps
        self.eps_init = eps_init
        self.eps_decay = eps_decay
        self.iters = iters

    def forward(self, x):
        scores = torch.einsum("b n d, d -> b n", x, self.routing_embeddings)
        a = 0
        b = -scores

        current_eps = max(self.eps_init if self.eps_init is not None else self.eps, self.eps)

        for _ in range(self.iters):
            sb = ((scores + b) / current_eps)

            a = current_eps * (self.logk - sb.logsumexp(dim = -1, keepdim = True))
            b = -F.relu(scores + a)

            current_eps = max(current_eps * self.eps_decay, self.eps)

        normalized_scores = ((scores + a + b) / current_eps).exp()

        k = min(self.k, x.size(1))
        values, indices = torch.topk(normalized_scores, dim = -1, k = k)
        return values.unsqueeze(-1), indices, normalized_scores.unsqueeze(-1) #unsqueeze to broadcast the embedding dimension

In [None]:
class ConditionalFeedForward(nn.Module):
    def __init__(self, input_dim, k, dropout = 0.2, use_iterative_algorithm = False, eps = 2e-2, eps_init = None, eps_decay = 1., iters = 20):
        super(ConditionalFeedForward, self).__init__()

        self.k = k

        if not use_iterative_algorithm:
            self.router = LightRouter(input_dim, k)
        else:
            self.router= IterativeSoftTopkRouter(input_dim, k, eps, eps_init, eps_decay, iters)

        self.light_ff = nn.Sequential(
            nn.Linear(input_dim, input_dim, bias=True),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(input_dim, input_dim, bias=True)
        )
        self.heavy_ff = nn.Sequential(
            nn.Linear(input_dim, 4*input_dim, bias=True),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4*input_dim, input_dim, bias=True)
        )

    def forward(self, x):
        light_output = self.light_ff(x)

        values, indices, _ = self.router(x)
        x_topk = batched_gather(x, indices)

        heavy_output_topk = values * self.heavy_ff(x_topk)
        heavy_output = route_back(torch.zeros(x.size(), device = x.device), heavy_output_topk, indices)

        return light_output + heavy_output

In [None]:
class ConditionalWindowAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads, window_size, k, dropout=0.2, use_iterative_algorithm = False, eps = 2e-2, eps_init = None, eps_decay = 1., iters = 20):
        super(ConditionalWindowAttention, self).__init__()
        assert embedding_dim % num_heads == 0
        self.k = k

        if not use_iterative_algorithm:
            self.router_q = LightRouter(embedding_dim, k)
            self.router_kv = LightRouter(embedding_dim, k)
        else:
            self.router_q = IterativeSoftTopkRouter(embedding_dim, k, eps, eps_init, eps_decay, iters)
            self.router_kv = IterativeSoftTopkRouter(embedding_dim, k, eps, eps_init, eps_decay, iters)

        self.light_attention = WindowMultiHeadAttention(embedding_dim, num_heads // 2, window_size, dropout)

        self.heavy_W_q = nn.Linear(embedding_dim, embedding_dim)
        self.heavy_W_kv = nn.Linear(embedding_dim, 2*embedding_dim)
        self.num_heads = num_heads
        self.dim_head = embedding_dim // num_heads
        self.sqrt_q = math.sqrt(self.dim_head)
        self.dropout = nn.Dropout(dropout)
        self.window_size = window_size

        self.relative_position_table = nn.Parameter(torch.zeros(((2 * window_size - 1) * (2 * window_size - 1), num_heads)))
        #(2 * M - 1) perché è il numero di possibili posizioni relative su ogni asse: |{-M+1, ..., M-1}| = 2M-1
        #con M = 3: |{-2, -1, 0, 1, 2}| = 5 = 2*3 -1
        #anche se relative_position_index è quadratico si risparmia perché ci sono meno parametri da imparare

        coords_h = torch.arange(window_size)
        coords_w = torch.arange(window_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, W, W
        coords_flatten = torch.flatten(coords, 1)  # 2, W * W
        relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1)  # 2, W * W, W * W
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # W * W, W * W, 2
        relative_coords[:, :, 0] += window_size - 1  # shift to start from 0
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= window_size - 1
        relative_position_index = relative_coords.sum(-1).view(-1)  # W * W, W * W and then flatten
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self, x, mask=None):
        light_output = self.light_attention(x, mask)

        _, _, normalized_scores_q = self.router_q(x)
        values_kv, indices_kv, _ = self.router_kv(x)
        x_topk_kv = values_kv * batched_gather(x, indices_kv)

        #light(x_i, x) + normalized_scores_q * heavy(x_i, values_kv * x) with values_kv set to zero for non routed tokens
        q = rearrange(self.heavy_W_q(x), "b n (n_h d_h) -> b n_h n d_h", n_h = self.num_heads, d_h = self.dim_head)
        kv = rearrange(self.heavy_W_kv(x_topk_kv), "b topk (kv n_h d_h) -> kv b n_h topk d_h", n_h = self.num_heads, d_h = self.dim_head)
        k, v = kv[0], kv[1]

        relative_position_bias = self.relative_position_table[self.relative_position_index]
        relative_position_bias = rearrange(relative_position_bias,
                                           "(ws_0 ws_1 ws_2 ws_3) n_h -> 1 n_h (ws_0 ws_1) (ws_2 ws_3)",
                                           ws_0 = self.window_size, ws_1 = self.window_size,
                                           ws_2 = self.window_size, ws_3 = self.window_size)

        relative_position_bias = repeat(relative_position_bias, "1 ... -> b ...", b = x.size(0))
        indices_kv = repeat(indices_kv, "b topk -> b n_h n topk", n_h = self.num_heads, n = x.size(1))
        relative_position_bias_topk = torch.gather(relative_position_bias, dim = -1, index = indices_kv)


        if mask is None:
            mask_topk = None
        else:
            indices_kv = rearrange(indices_kv, "(b_o n_w) ... -> b_o n_w ...", b_o = x.size(0) // mask.size(1)) #b_o = b // num_windows
            mask = repeat(mask, "1 n_w 1 ... -> b_o n_w n_h ...", b_o = x.size(0) // mask.size(1), n_h = self.num_heads)
            mask_topk = torch.gather(mask, dim = -1, index = indices_kv)

        heavy_output_topk = scaled_dot_product_attention(q, k, v,
                                                         scale = self.sqrt_q, bias = relative_position_bias_topk,
                                                         mask = mask_topk, dropout_layer = self.dropout)

        heavy_output = normalized_scores_q * rearrange(heavy_output_topk, "b h n d_h -> b n (h d_h)", d_h = self.dim_head)
        return light_output + heavy_output

In [None]:
class ConditionalSwinTransformerBlock(nn.Module):
    def __init__(self, embedding_dim, hidden_size, patch_resolution, num_heads, window_size, k, shift_size=0, dropout=0.2, use_iterative_algorithm = False, eps = 2e-2, eps_init = None, eps_decay = 1., iters = 20):
        super(ConditionalSwinTransformerBlock, self).__init__()
        self.height, self.width = patch_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.k = k

        self.norm_layer1 = nn.LayerNorm(embedding_dim)
        self.attention = ConditionalWindowAttention(embedding_dim, num_heads, window_size, k, dropout, use_iterative_algorithm, eps , eps_init, eps_decay, iters)
        self.dropout1 = nn.Dropout(dropout)
        self.norm_layer2 = nn.LayerNorm(embedding_dim)
        self.ff = ConditionalFeedForward(embedding_dim, k, dropout, use_iterative_algorithm, eps , eps_init, eps_decay, iters)
        self.dropout2 = nn.Dropout(dropout)


        if self.shift_size > 0:
            img_mask = torch.zeros((self.height, self.width))
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            id = 0
            for h in h_slices:#valido solo se lo shift viene fatto con (-self.shift_size,-self.shift_size)
                for w in w_slices:
                    img_mask[h, w] = id
                    id += 1

            mask_windows = rearrange(img_mask,
                                     "(h_w ws_0) (w_w ws_1) -> (h_w w_w) (ws_0 ws_1)",
                                     ws_0 = self.window_size, ws_1 = self.window_size)
            attention_mask = (mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)) != 0
            # attention_mask is a [num_windows, patches_per_window, patches_per_window] tensor
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(0)
            # attention_mask is a [1, num_windows, 1, patches_per_window, patches_per_window] tensor
            # the ones are for broadcasting the batch size and the number of heads
        else:
            attention_mask = None

        self.register_buffer("attention_mask", attention_mask)

    def forward(self, x):
        n = x.size(1)
        assert n == self.height*self.width

        z = x
        x = rearrange(self.norm_layer1(x), "b (h w) d -> b h w d", h = self.height)

        if self.shift_size > 0: #for sw-mha
            x = torch.roll(x, shifts = (-self.shift_size, -self.shift_size), dims = (1, 2))

        windows = window_partition(x, self.window_size)
        windows = rearrange(windows, "bhw ws_0 ws_1 d -> bhw (ws_0 ws_1) d", ws_0 = self.window_size)

        attention = rearrange(self.attention(windows, self.attention_mask),
                              "bhw (ws_0 ws_1) d -> bhw ws_0 ws_1 d", ws_0 = self.window_size)

        x = window_merge(attention, self.window_size, self.height, self.width)

        if self.shift_size > 0: #undo shift
            x = torch.roll(x, shifts = (self.shift_size, self.shift_size), dims = (1, 2))

        x = rearrange(x, "b h w d -> b (h w) d")
        x = z + self.dropout1(x)

        return x + self.dropout2(self.ff(self.norm_layer2(x)))

# Swin Transformer U-Nets



In [None]:
class BasicLayer(nn.Module):
    def __init__(self, embedding_dim, patch_resolution, depth, num_heads, window_size,
                 hidden_size, dropout=0.2, downsample=False, upsample=False, use_conditional_blocks = False, use_iterative_algorithm = False, eps = 2e-2, eps_init = None, eps_decay = 1., iters = 20):
        super(BasicLayer, self).__init__()
        assert not (downsample and upsample), "Only one of downsample and upsample can be True"

        if not use_conditional_blocks:
            blocks = [
                SwinTransformerBlock(embedding_dim, hidden_size,
                                    patch_resolution, num_heads, window_size,
                                    shift_size = 0 if i % 2 == 0 else window_size // 2, dropout = dropout)
                for i in range(depth)
            ]
        else:
            blocks = [
                ConditionalSwinTransformerBlock(embedding_dim, hidden_size,
                                    patch_resolution, num_heads, window_size, 7,
                                    shift_size = 0 if i % 2 == 0 else window_size // 2, dropout = dropout, use_iterative_algorithm = use_iterative_algorithm,
                                    eps = eps, eps_init = eps_init, eps_decay = eps_decay, iters = iters)
                for i in range(depth)
            ]

        self.blocks = nn.Sequential(*blocks)

        if downsample:
            self.patch_rearrange = PatchMerge(patch_resolution, embedding_dim, use_norm_layer = True)
        elif upsample:
            self.patch_rearrange = PatchExpand(patch_resolution, embedding_dim, use_norm_layer = True)
        else:
            self.patch_rearrange = nn.Identity()

    def forward(self, x):
        x = self.blocks(x)
        return self.patch_rearrange(x)


In [None]:
class NMSE(nn.Module):
    def __init__(self):
        super(NMSE, self).__init__()

    def forward(self, x, y):
        diff = y - x
        error = torch.div(torch.pow(torch.linalg.norm(diff), 2), torch.pow(torch.linalg.norm(y), 2))
        return error

In [None]:
class SwinTransformer_Unet(pl.LightningModule):
    def __init__(self, img_size, patch_size, input_channels, embedding_dim, hidden_size, window_size, output_channels, dropout = 0.2, norm_layer = nn.LayerNorm,
                 depths = [2, 2, 2], num_heads = [4, 4, 4, 4], use_conditional_blocks = False, learning_rate = 1e-3, epoch = 0,
                 use_iterative_algorithm = False, eps = 2e-2, eps_init = None, eps_decay = 1., iters = 20, fastMRI = True):
        super(SwinTransformer_Unet, self).__init__()
        def to_tuple(x):
            if not isinstance(x, tuple):
                return (x, x)
            return x

        self.wandb_log = {}
        self.epoch = epoch
        self.train_loss = []

        self.learning_rate = learning_rate
        self.img_size, self.patch_size = to_tuple(img_size), to_tuple(patch_size)

        assert self.img_size[0] % self.patch_size[0] == 0 and self.img_size[1] % self.patch_size[1] == 0
        self.patch_resolution = (self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1])
        self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
        self.depths = len(depths)

        self.patch_embeddings = PatchEmbedding(self.img_size, self.patch_size, input_channels, embedding_dim)

        self.encoder = nn.ModuleList([
            BasicLayer(int(embedding_dim * 2**i), (self.patch_resolution[0] // (2**i), self.patch_resolution[1] // (2**i)),
                       depths[i], num_heads[i],
                       window_size, hidden_size, dropout, downsample = True, use_conditional_blocks=use_conditional_blocks,
                       use_iterative_algorithm = use_iterative_algorithm,
                       eps = eps, eps_init = eps_init, eps_decay = eps_decay, iters = iters)
            for i in range(len(depths))
        ])

        self.bottleneck = BasicLayer(int(embedding_dim * 2**len(depths)), (self.patch_resolution[0] // (2**len(depths)), self.patch_resolution[1] // (2**len(depths))), 2,
                                     num_heads[len(depths)], window_size, hidden_size, dropout, use_conditional_blocks=use_conditional_blocks,
                                     use_iterative_algorithm = use_iterative_algorithm,
                                     eps = eps, eps_init = eps_init, eps_decay = eps_decay, iters = iters)

        self.concat_linear = nn.ModuleList([
            nn.Linear(2*int(embedding_dim*2**(len(depths)-i-1)), int(embedding_dim*2**(len(depths)-i-1)))
            for i in range(len(depths))
        ])

        self.decoder = nn.ModuleList([
            PatchExpand((self.patch_resolution[0] // (2**len(depths)), self.patch_resolution[1] // (2**len(depths))),
                        int(embedding_dim * 2**len(depths)), use_norm_layer = True)
            ] + [
            BasicLayer(int(embedding_dim * 2**(len(depths) - 1 - i)), (self.patch_resolution[0] // (2 ** (len(depths) - 1 - i)), self.patch_resolution[1] // (2 ** (len(depths) - 1 - i))),
                    depths[len(depths) - 1 - i], num_heads[len(depths) - 1 - i],
                    window_size, hidden_size, dropout, upsample = (i < len(depths) - 1), use_conditional_blocks=use_conditional_blocks,
                    use_iterative_algorithm = use_iterative_algorithm,
                    eps = eps, eps_init = eps_init, eps_decay = eps_decay, iters = iters)
            for i in range(len(depths))
            ] + [
            PatchExpand(self.patch_resolution, embedding_dim, use_norm_layer = True, is_final = True)
        ])

        self.norm = nn.LayerNorm(int(embedding_dim * 2**len(depths)))
        self.norm_up = nn.LayerNorm(embedding_dim)

        self.output = nn.Conv2d(embedding_dim, output_channels, kernel_size = 1, bias = False)

        self.train_mae = torchmetrics.MeanAbsoluteError()

        self.validation_mae = torchmetrics.MeanAbsoluteError()
        self.validation_psnr = torchmetrics.image.PeakSignalNoiseRatio(data_range=1)
        self.validation_ssim = torchmetrics.image.StructuralSimilarityIndexMeasure(kernel_size=7, data_range=1)
        self.validation_nmse = NMSE()
        self.validation_nmse_errors = []

        self.test_mae = torchmetrics.MeanAbsoluteError()

        if fastMRI:
            self.loss_function = F.mse_loss
            self.train_metrics = torchmetrics.MeanAbsoluteError()
            self.validation_metrics = nn.ModuleList([torchmetrics.MeanAbsoluteError(),
                                       torchmetrics.image.PeakSignalNoiseRatio(data_range=1),
                                       torchmetrics.image.StructuralSimilarityIndexMeasure(kernel_size=7, data_range=1),
                                       NMSE()])
            self.validation_nmse_errors = []
            self.test_metrics = torchmetrics.MeanAbsoluteError()
        else:
            self.loss_function = nn.CrossEntropyLoss()
            self.train_metrics = torchmetrics.Dice(num_classes = 3, ignore_index = 0)
            self.validation_metrics = torchmetrics.Dice(num_classes = 3, ignore_index = 0)
            self.test_metrics = torchmetrics.Dice(num_classes = 3, ignore_index = 0)

        self.fastMRI = fastMRI

    def compute_metrics(self, metrics, pred, target):
        if isinstance(metrics, nn.ModuleList):
            for m in metrics[:-1]:
                m.update(pred, target)
            self.validation_nmse_errors.append(metrics[-1](pred, target))  # NMSE which is built from scratch is put at the end for simplicity
        else:
            metrics.update(pred, target)
        return

    def reset_metrics(self, metrics):
        if isinstance(metrics, nn.ModuleList):
            for m in metrics[:-1]:
                m.reset()
            self.validation_nmse_errors = []
        else:
            metrics.reset()

    def forward_down_and_bottleneck(self, x):
        x = self.patch_embeddings(x)
        x_downsample = []

        for layer in self.encoder:
            x_downsample.append(x)
            x = layer(x)

        x = self.bottleneck(x)

        x = self.norm(x)

        return x, x_downsample

    def forward_up(self, x, x_downsample):
        x = self.decoder[0](x)
        for idx, layer_up in enumerate(self.decoder[1:-1]):
            i = idx+1
            x = torch.cat([x, x_downsample[self.depths-i]],-1)
            x = self.concat_linear[idx](x)
            x = layer_up(x)

        x = self.norm_up(x)
        x = self.decoder[-1](x)

        x = rearrange(x, "b (h w) d -> b d h w", h = 4*self.patch_resolution[0], w = 4*self.patch_resolution[1])
        return self.output(x)


    def forward(self, x):
        x, x_downsample = self.forward_down_and_bottleneck(x)
        return self.forward_up(x, x_downsample)

    def inference(self, x):
        self.eval()
        with torch.no_grad():
            return self(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), self.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8, verbose=True)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch[:2]
        y_pred = self(x)
        if self.fastMRI:
            y_pred = y_pred.clamp(0, 1)
            bg_m = batch[-1]
            y *= bg_m
            y_pred *= bg_m

        loss = self.loss_function(y_pred, y)
        self.train_loss.append(loss)

        if batch_idx % 100 == 0:  # layers set on eval to keep track of performance on training set
            y_pred2 = self.inference(x)
            self.train()
            self.compute_metrics(self.train_metrics, y_pred2, y if self.fastMRI else y.type(torch.uint8))


        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch[:2]
        y_pred = self.inference(x)

        if self.fastMRI:
            y_pred = y_pred.clamp(0, 1)
            bg_m = batch[-1]
            y *= bg_m
            y_pred *= bg_m
        else:
            y = y.type(torch.uint8)
        self.compute_metrics(self.validation_metrics, y_pred, y)


    def test_step(self, batch, batch_idx):
        x, y = batch[:2]
        y_pred = self.inference(x)

        if self.fastMRI:
            y_pred = y_pred.clamp(0, 1)
            bg_m = batch[-1]
            y *= bg_m
            y_pred *= bg_m
        else:
            y = y.type(torch.uint8)
        self.compute_metrics(self.test_metrics, y_pred, y)

    def on_train_epoch_end(self):
        self.wandb_log['train_loss'] = sum(self.train_loss)/len(self.train_loss)
        if self.fastMRI:
            self.wandb_log['train_mae'] = self.train_metrics.compute()
        else:
            self.wandb_log['train_dicescore'] = self.train_metrics.compute()
        self.wandb_log['epoch'] = self.epoch
        self.epoch += 1
        self.train_loss = []
        self.reset_metrics(self.train_metrics)

    def on_validation_epoch_end(self):
        if self.fastMRI:
            self.log('val_mae_epoch', self.validation_metrics[0].compute(),
                    on_step=False, on_epoch=True, prog_bar=True, logger=True)
            self.wandb_log['validation_mae'] = self.validation_metrics[0].compute()
            self.wandb_log['validation_psnr'] = self.validation_metrics[1].compute()
            self.wandb_log['validation_ssim'] = self.validation_metrics[2].compute()
            self.wandb_log['validation_nmse'] = sum(self.validation_nmse_errors)/len(self.validation_nmse_errors)
        else:
            self.log('val_dicescore_epoch', self.validation_metrics.compute(),
                    on_step=False, on_epoch=True, prog_bar=True, logger=True)
            self.wandb_log['validation_dicescore'] = self.validation_metrics.compute()

        self.reset_metrics(self.validation_metrics)
        wandb.log(self.wandb_log)

    def on_test_epoch_end(self):
        self.log('test_mae_epoch', self.test_mae.compute(),
                 on_epoch=True, prog_bar=True, logger=True)
        self.wandb_log['test_mae'] = self.test_mae.compute()
        self.test_mae.reset()


# wandb setup

In [None]:
!pip install wandb --quiet

In [None]:
import wandb

In [None]:
!wandb login API-KEY

# Training fastMRI

In [None]:
root_dir = "./training/checkpoints"
checkpoint_dir = "./training/checkpoints/c_swin_u_net"
#logger_dir = "./training/tensorboard/c_swin_u_net"
EPOCHS = 50
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
EMBEDDING_DIM = 128
HIDDEN_SIZE = 256
# NUM_HEADS = 8
# assert EMBEDDING_DIM % NUM_HEADS == 0

DROP_PROB = 0.2
GRADIENT_CLIP_VAL = 0.5
use_iterative_algorithm = False
eps = 2e-2
eps_init = 4
eps_decay = 0.7
iters = 20
depths=  [2, 2]
num_heads = [4, 8, 16]
# depths= [2]
# num_heads = [4,8]

In [None]:
now = datetime.now().strftime("%H.%M")
checkpoint_callback = ModelCheckpoint(
    dirpath = checkpoint_dir,
    filename=now+'c_swin_u_net{epoch:02d}_{step:06d}_{val_mae_epoch:.3f}',
    save_top_k=7,
    monitor='val_mae_epoch',
    mode='min',
    verbose=True,
    save_last=True
)

#logger = TensorBoardLogger(logger_dir, name="c_swin_u_net")
callbacks = [checkpoint_callback, TQDMProgressBar(refresh_rate=20)]
trainer_input = {
    "default_root_dir": root_dir,
    "accelerator": "auto",
    "devices": 1,
    "log_every_n_steps": 50,
    "val_check_interval": 1.0,
    "gradient_clip_val": GRADIENT_CLIP_VAL,
    "max_epochs": EPOCHS,
    #"logger": logger,
    "callbacks": callbacks,
}

datamodule_input = {
    "preprocessed_data_path": "./singlecoil_train_preprocessed",
    "preprocessed_data_path_test": "./singlecoil_val_preprocessed",
    # "original_data_path": "./singlecoil_train",
    # "original_data_path_test": "./singlecoil_val",
    "mask": RectangularEquispacedMask(acceleration=4, central_ratio=0.8),
    "max_proc": 1,
    "batch_size": BATCH_SIZE,
}
fastmri_dm = fastMRIDataModule(**datamodule_input)

config_wandb = {
    "moby_pretrain": "none",
    "batch_size": BATCH_SIZE,
    "embedding_dim": EMBEDDING_DIM,
    "hidden_size": HIDDEN_SIZE,
    "dropout": 0.2,
    "depths": depths,
    "num_heads": num_heads,
    "learning_rate": LEARNING_RATE,
    "use_conditional_blocks": True,
    "use_iterative_algorithm": use_iterative_algorithm, "eps": eps, "eps_init": eps_init, "eps_decay": eps_decay, "iters": iters
}

In [None]:
wandb.init(project = "fastMRI_training", config = config_wandb)

In [None]:
trainer = Trainer(**trainer_input)
swin_u_net = SwinTransformer_Unet(img_size = 256, patch_size = 4, input_channels = 1, embedding_dim = EMBEDDING_DIM,
                                  hidden_size = HIDDEN_SIZE, window_size = 8, output_channels = 1, dropout = 0.2, norm_layer = nn.LayerNorm,
                                  depths = depths, num_heads = num_heads, learning_rate = LEARNING_RATE, use_conditional_blocks=True,
                                  use_iterative_algorithm = use_iterative_algorithm,
                                  eps = eps, eps_init = eps_init, eps_decay = eps_decay, iters = iters)
trainer.fit(swin_u_net, datamodule=fastmri_dm)

In [None]:
wandb.finish()

In [None]:
fastmri_dm.train_dataloader().__len__()

# Training Liver Tumor Segmentation

In [None]:
root_dir = "./training/checkpoints"
checkpoint_dir = "./training/checkpoints/swin_u_net"
#logger_dir = "./training/tensorboard/c_swin_u_net"
EPOCHS = 30
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
EMBEDDING_DIM = 128
HIDDEN_SIZE = 256
# NUM_HEADS = 8
# assert EMBEDDING_DIM % NUM_HEADS == 0

DROP_PROB = 0.2
GRADIENT_CLIP_VAL = 0.5
use_iterative_algorithm = False
eps = 2e-2
eps_init = 4
eps_decay = 0.7
iters = 20
depths=  [2, 2]
num_heads = [4, 8, 16]
# depths= [2]
# num_heads = [4,8]

In [None]:
now = datetime.now().strftime("%H.%M")
checkpoint_callback = ModelCheckpoint(
    dirpath = checkpoint_dir,
    filename=now+'swin_u_net{epoch:02d}_{step:06d}_{val_dicescore_epoch:.3f}',
    save_top_k=7,
    monitor='val_dicescore_epoch',
    mode='max',
    verbose=True,
    save_last=True
)

#logger = TensorBoardLogger(logger_dir, name="c_swin_u_net")
callbacks = [checkpoint_callback, TQDMProgressBar(refresh_rate=20)]
trainer_input = {
    "default_root_dir": root_dir,
    "accelerator": "auto",
    "devices": 1,
    "log_every_n_steps": 50,
    "val_check_interval": 1.0,
    "gradient_clip_val": GRADIENT_CLIP_VAL,
    "max_epochs": EPOCHS,
    #"logger": logger,
    "callbacks": callbacks,
}

datamodule_input = {
    "train_dataframe": df_train,
    "validation_dataframe": df_val,
    "batch_size": BATCH_SIZE,
}
livertumor_dm = LiverTumorSegmentationDataModule(**datamodule_input)

config_wandb = {
    "batch_size": BATCH_SIZE,
    "embedding_dim": EMBEDDING_DIM,
    "hidden_size": HIDDEN_SIZE,
    "dropout": 0.2,
    "depths": depths,
    "num_heads": num_heads,
    "learning_rate": LEARNING_RATE,
    "use_conditional_blocks": False,
    "use_iterative_algorithm": use_iterative_algorithm, "eps": eps, "eps_init": eps_init, "eps_decay": eps_decay, "iters": iters
}

In [None]:
wandb.init(project = "liver_dataset_training_new", config = config_wandb)

In [None]:
trainer = Trainer(**trainer_input)
swin_u_net = SwinTransformer_Unet(img_size = 256, patch_size = 4, input_channels = 3, embedding_dim = EMBEDDING_DIM,
                                  hidden_size = HIDDEN_SIZE, window_size = 8, output_channels = 3, dropout = 0.2, norm_layer = nn.LayerNorm,
                                  depths = depths, num_heads = num_heads, learning_rate = LEARNING_RATE, use_conditional_blocks=False,
                                  use_iterative_algorithm = use_iterative_algorithm,
                                  eps = eps, eps_init = eps_init, eps_decay = eps_decay, iters = iters, fastMRI = False)
trainer.fit(swin_u_net, datamodule=livertumor_dm)

In [None]:
wandb.finish()