In [44]:
import numpy as np
import cv2
from pathlib import Path
import os
import glob
from tqdm.notebook import tqdm
import SimpleITK as sitk
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from monai.networks.nets import BasicUNetPlusPlus
from monai.transforms import *

import wandb


'''
1. RETOUCH dataset

    Preprocessing
    - Preprocessed dataset for now

    Training/validation/test
    - split train folder into 75/15/10
    - Use Topcon&Cirrus for validation and testing too? Or use ONLY Topcon&Cirrus for val&test?
        > Paper uses only topcon and cirrus for validation and testing

    SVDNA
    - Create Pytorch Dataset class:
        > define __getitem__ method such that I give it the path to the training set and it returns the SVDNA transformed images

    - when epoch starts, SVDNA is applied to each image
        > for each image, 1/3 chance to choose one of the domains
            > of the n_d images, one is chosen randomly for style transfer
        > for each image, k is sampled between 30 and 50
        > apply SVDNA using sampled image and k
    - Implementation:
        > dict containing spectralis images
        > dict containing cirrus and topcon images

        for epoch in total_epochs:
            source_imgs = spectralis_img
            target_imgs = cirrus_topcon_img

            source_svdna = [svdna(imgs[i], target_imgs):labels[i] for imgs, labels in source_imgs.items()]

            img_dataloader = Dataloader(source_svdna)

    Transformations
    - 


    - Possibilities for datastructure:
        - [{cirrus_img1:cirrus_img1, cirrus_img1_label:cirrus_img1_label}, {cirrus_img2:cirrus_img2, cirrus_img2_label:cirrus_img2_label}, ...]

    
        
https://docs.monai.io/en/stable/networks.html#basicunetplusplus



'''

data_dir = Path.cwd() / 'data' / 'Retouch-preprocessed'



In [2]:
sample_img = '/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/practical/data/RETOUCH/TrainingSet-Release/Cirrus/03a60d9078d35b1488e6030880a29014/reference.mhd'
itk_image = sitk.ReadImage(sample_img)
image_array = sitk.GetArrayViewFromImage(itk_image)

# print the image's dimensions
print(image_array.shape)

# plot the image
for i in range(128):
    if i % 10 == 0:
        pass
        #plt.imshow(image_array[i], cmap='gray')
        #plt.show()

(128, 1024, 512)


In [3]:
name_dir = Path(Path.cwd() / 'data/RETOUCH/TrainingSet-Release/')
train_dir = Path(Path.cwd() / 'data/Retouch-Preprocessed/train')

In [4]:
data = {}
for device in os.listdir(name_dir):
    data[device] = os.listdir(name_dir / device)

for device, vals in data.items():
    print("Device: ", device, "; Number of folders: ", len(vals))

print("Preprocessed Retouch folders: ", len(os.listdir(train_dir)))

Device:  Topcon ; Number of folders:  22
Device:  Spectralis ; Number of folders:  24
Device:  Cirrus ; Number of folders:  24
Preprocessed Retouch folders:  70


In [5]:

def generate_black_images(main_folder, delete_images=False):
    """
    Generate black images for missing files in the label_image folder.

    Args:
        main_folder (str or Path): Path to the main folder.
        delete_images (bool, optional): Flag to delete the generated black images. Defaults to False.
    """

    if not delete_images:
    # Iterate over the subfolders
        for subfolder in sorted(os.listdir(main_folder)):
            subfolder_path = main_folder / subfolder

            # Check if the subfolder contains the 'image' and 'label_image' folders
            if os.path.isdir(subfolder_path) and 'image' in os.listdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                image_folder = subfolder_path / 'image'
                label_folder = subfolder_path / 'label_image'

                # Get the set of filenames in the 'image' folder
                image_files = sorted(set(os.listdir(image_folder)))

                # Get the set of filenames in the 'label_image' folder
                label_files = sorted(set(os.listdir(label_folder)))

                # Find the filenames that are in 'label_image' but not in 'image'
                missing_files = [i for i in image_files if i not in label_files]

                # Create a black image for each missing file
                for file in missing_files:
                    file_path = label_folder / file

                    # find the shape of the input image to create corresponding target
                    file_shape = cv2.imread(str(image_folder / file)).shape
                    black_image = np.zeros(file_shape)

                    # make unique names so files can be deleted again
                    cv2.imwrite(f"{str(file_path)[:-4]}_empty.png", black_image)

    # Delete the generated black images if delete_images flag is True
    if delete_images:
        for subfolder in sorted(os.listdir(main_folder)):
            subfolder_path = main_folder / subfolder
            if os.path.isdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                label_folder = subfolder_path / 'label_image'
                for file in os.listdir(label_folder):
                    file_path = label_folder / file
                    if "_empty" in str(file):
                        os.remove(str(file_path))

In [6]:
#generate_black_images(train_dir, delete_images=False)

In [40]:

'''
Below class should just be a function that does the following:
sort data and output one large list of dicts for all source images and one for all target images. Transforms are applied in the monai or torch
dataset classes. they don't need to be created.
'''


class OCTDataset(Dataset):
    '''
    Custon dataset object accomplishes the task of bringing the training data into the right format for further processing.
    '''
    
    def __init__(self, data_path: str, transform=None, generate_empty_labels=False, source_domain='Spectralis'):

        self.data_path = Path(data_path)

        self.transform = transform

        if generate_empty_labels:
            self.generate_black_images(data_path) # check if this misses arguments
                                                  # change method so it only creates black images for source domain

        self.dict_domain_images_sorted = self.filter_source_domain(data_path, source_domain)

        self.source_domain_dict = self.dict_domain_images_sorted[source_domain]
        self.target_domain_dict = [img_dict for domain in self.dict_domain_images_sorted if domain != source_domain for img_dict in self.dict_domain_images_sorted[domain]]


    def __len__(self):
        return len(self.source_data_paths)

    def __getitem__(self, index):
        source_img = self.source_domain_dict[index]
        target_img = self.target_domain_dict[index]
        
        if self.transform:
            # Apply transformations dynamically during training
            pass
            #source_img, target_img = self.transform(source_img, target_img)

        return source_img, target_img


    def filter_source_domain(self, data_path, source_domain):
        '''
        data_path: Path to the training set folder where all images are not sorted by domains.
        source_domain: The source domain for the upcoming SVDNA process.

        Returns: a dictionary containing three lists of dictionaries of the following structure:
                    {source domain: [{img: img1, label: label1}, {img: img2, label: label2}, ...], 
                    target domain 1: [{img: img1}, {img: img2}, ...],
                    target domain 2: [{img: img1}, {img: img2}, ...]}
        '''
        
        named_domain_folder = Path.cwd() / 'data/RETOUCH/TrainingSet-Release'

        domains = os.listdir(named_domain_folder)

        # creates dict e.g. {'cirrus':['path1', 'path2', ...], 'topcon':['path1', 'path2', ...]}
        img_folders_sorted_by_domain = {domain:os.listdir(named_domain_folder / domain) for domain in domains}


        # restructure source data into a list of dictionaries, where each dictionary has keys img and label

        unsorted_img_folders_training_set = os.listdir(data_path)

        domains_dict = {}

        for domain in domains:
            
            list_of_dicts_images = []
            
            for img_folder in unsorted_img_folders_training_set:

                if img_folder in img_folders_sorted_by_domain[domain]:
                    
                    subfolders = os.listdir(data_path / img_folder)
                    unsorted_img_folders_training_set.remove(img_folder)

                    if 'image' in subfolders and 'label_image' in subfolders:
                        
                        if domain == source_domain:
            
                            sliced_images = sorted(os.listdir(data_path / img_folder / 'image'))
                            sliced_labels = sorted(os.listdir(data_path / img_folder / 'label_image'))
                            
                            for i in range(len(sliced_images)):
                                if (sliced_images[i] == sliced_labels[i]) or (sliced_images[i][:-4] + '_empty.png' == sliced_labels[i]):
                                    list_of_dicts_images.append(
                                        {'img': str(data_path / img_folder / 'image' / sliced_images[i]), 'label': str(data_path / img_folder / 'label_image' / sliced_labels[i])}
                                        )

                                else:
                                    print(f"Image {img_folder}/image/{sliced_images[i]} has no corresponding label image. Skipping image. \nTake a look at 'generate_black_images' method.") 
                                    continue
                        
                        else:

                            sliced_images = sorted(os.listdir(data_path / img_folder / 'image'))

                            for i in range(len(sliced_images)):
                                    list_of_dicts_images.append({'img': str(data_path / img_folder / 'image' / sliced_images[i])})
                        
            domains_dict[domain] = list_of_dicts_images
                                        
        return domains_dict

    def generate_black_images(self, main_folder, delete_images=False):
        """
        Generate black images for missing files in the label_image folder.

        Args:
            main_folder (str or Path): Path to the main folder.
            delete_images (bool, optional): Flag to delete the generated black images. Defaults to False.
        """

        if not delete_images:
        # Iterate over the subfolders
            for subfolder in sorted(os.listdir(main_folder)):
                subfolder_path = main_folder / subfolder

                # Check if the subfolder contains the 'image' and 'label_image' folders
                if os.path.isdir(subfolder_path) and 'image' in os.listdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                    image_folder = subfolder_path / 'image'
                    label_folder = subfolder_path / 'label_image'

                    # Get the set of filenames in the 'image' folder
                    image_files = sorted(set(os.listdir(image_folder)))

                    # Get the set of filenames in the 'label_image' folder
                    label_files = sorted(set(os.listdir(label_folder)))

                    # Find the filenames that are in 'label_image' but not in 'image'
                    missing_files = [i for i in image_files if i not in label_files]

                    # Create a black image for each missing file
                    for file in missing_files:
                        file_path = label_folder / file

                        # find the shape of the input image to create corresponding target
                        file_shape = cv2.imread(str(image_folder / file)).shape
                        black_image = np.zeros(file_shape)

                        # make unique names so files can be deleted again
                        cv2.imwrite(f"{str(file_path)[:-4]}_empty.png", black_image)

        # Delete the generated black images if delete_images flag is True
        if delete_images:
            for subfolder in sorted(os.listdir(main_folder)):
                subfolder_path = main_folder / subfolder
                if os.path.isdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                    label_folder = subfolder_path / 'label_image'
                    for file in os.listdir(label_folder):
                        file_path = label_folder / file
                        if "_empty" in str(file):
                            os.remove(str(file_path))

    def delete_generated_labels(self):
        # Delete the generated black images
        self.generate_black_images(delete_images=True)

    



# Define the randomizable transformation class for dynamic transformations
class RandomTransform(RandomizableTransform):
    def __call__(self, source_img, target_img):
        self.randomize()
        # Implement your dynamic random transformations
        # Apply random transformations to both source and target images
        #return transformed_source_img, transformed_target_img




# Define the UNet model
class SegmentationUNet(nn.Module):
    def __init__(self):
        super(SegmentationUNet, self).__init__()
        # Define your UNet architecture using MONAI

    def forward(self, x):
        # Implement your forward pass logic
        return x
    



In [41]:
training_ready = False

if training_ready:

    # Initialize wandb
    wandb.init(project="PracticalWorkinAI", name="svdna_reproduction_retouch_only")

    # Instantiate the model, loss function, and optimizer
    model = SegmentationUNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Define the dataset paths
    source_data_paths = [...]  # List of paths to source domain data
    target_data_paths = [...]  # List of paths to target domain data

    # Create the dataset and dataloader
    transform = Compose([
        RandomTransform(),
        svdna,  # Apply SVDNA function
        ToTensord()
    ])

    dataset = OCTDataset(source_data_paths, target_data_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    # Training loop
    num_epochs = 10
    for epoch in range(num_epochs):
        for batch_idx, (source_img, target_img) in enumerate(dataloader):
            optimizer.zero_grad()

            # Forward pass
            output = model(source_img)

            # Compute loss
            loss = criterion(output, target_img)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Log loss to wandb
            wandb.log({"Loss": loss.item(), "Epoch": epoch, "Batch": batch_idx})

            # Print progress
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item()}')

    # Save the trained model if needed
    torch.save(model.state_dict(), 'segmentation_unet.pth')

    # Finish wandb run
    wandb.finish()

In [None]:

from typing import Hashable, Optional, Union, Sequence
import scipy.ndimage

from typing import *
from monai.transforms import *
from monai.config.type_definitions import KeysCollection

labels_mapping = {255: 0,
                  0: 1,
                  80: 2,
                  160: 3}

class LayerPositionToProbabilityMap(MapTransform):
    def __init__(self, keys: Sequence, target_size,target_keys: Sequence = None):
        super().__init__(keys)
        if target_keys is None:
            self.target_keys = keys
        self.target_keys = target_keys
        self.target_size = target_size
    
    def smoothing_function(self, mask):
        mean = 0
        std = 0.5
        scale = 1 / (std * np.sqrt(2 * np.pi))
        return scale * np.exp(-(mask - mean)**2 / (2 * std**2))
        #return -(mask - mean)**2 / (2 * std**2)
        
    def __call__(self, data):
        for i, key in enumerate(self.keys):
            layer = data[key]
            column = np.arange(self.target_size[1])
            column = np.expand_dims(column, 1)
            mask = np.repeat(column, self.target_size[0], axis=1)
            mask = np.expand_dims(mask, 0)
            mask = np.repeat(mask, layer.shape[0], axis=0)
            #layer = np.expand_dims(layer, 1)
            mask = mask - layer
            #mask = np.ones_like(mask)
            mask = mask.astype(np.float32)
            
            mask = self.smoothing_function(mask)
            mask = mask / np.expand_dims(mask.sum(axis=1), axis=2)
            data[self.target_keys[i]] = mask

            
        return data
    

class CropImages(MapTransform):

    def __init__(self, keys: KeysCollection, source_key : str, crop_size, crop_allowance=10, allow_missing_keys: bool = False) -> None:
        super().__init__(keys, allow_missing_keys)
        self.source_key = source_key
        self.crop_size = crop_size
        self.crop_allowance = crop_allowance


    def __call__(self, data):


        d = dict(data)
        for ki, key in enumerate(self.keys):
            preliminary = data[self.source_key]
            min_val = max(preliminary.min() - self.crop_allowance, 0)
            img = data[key]
            img_crop = img[:, min_val:(min_val + self.crop_size), :]

            d[key] = img_crop
        return d

class CropImages(MapTransform):

    def __init__(self, keys: KeysCollection, source_key : str, crop_size, crop_allowance=10, allow_missing_keys: bool = False) -> None:
        super().__init__(keys, allow_missing_keys)
        self.source_key = source_key
        self.crop_size = crop_size
        self.crop_allowance = crop_allowance


    def __call__(self, data):


        d = dict(data)
        for ki, key in enumerate(self.keys):
            preliminary = data[self.source_key]
            min_val = max(preliminary.min() - self.crop_allowance, 0)
            img = data[key]
            img_crop = img[:, min_val:(min_val + self.crop_size), :]

            d[key] = img_crop
        return d

class CropValImages(MapTransform):

    def __init__(self, keys: KeysCollection, source_key : str, crop_size, crop_allowance=10, allow_missing_keys: bool = False) -> None:
        super().__init__(keys, allow_missing_keys)
        self.source_key = source_key
        self.crop_size = crop_size
        self.crop_allowance = crop_allowance
        self.positions = [0, 1, 3, 5, 7]

    def __call__(self, data):


        d = dict(data)
        for ki, key in enumerate(self.keys):
            crop_id = data[self.source_key]
            img = data[key]
            img_crop = img[:, :, (self.positions[crop_id]*100):(self.positions[crop_id]*100 + self.crop_size)]

            d[key] = img_crop
        return d

class BilateralFilter(MapTransform):

    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
        super().__init__(keys, allow_missing_keys)


    def __call__(self, data):
        d = dict(data)
        for ki, key in enumerate(self.keys):
            img = data[key]
            #img_filter = np.expand_dims(cv2.bilateralFilter(img[0], 10, 50, 50), 0)
            img_filter = img
            d[key] = img_filter
        return d

class ConvertToMultiChannelMasks(MapTransform):
    """
    Convert labels of ong image into 0,1,2,3
    0 - background
    1 - RNFL
    2 - GCIPL
    3 - Choroid
    """

    def __init__(self, keys: KeysCollection, target_keys: List[str], allow_missing_keys: bool = False) -> None:
        super().__init__(keys, allow_missing_keys)
        self.target_keys = target_keys


    def __call__(self, data):


        d = dict(data)
        for ki, key in enumerate(self.keys):

            mask = data[key]
            h, w= mask.shape

            mask_new = np.zeros((len(labels_mapping), h, w),dtype=np.uint8)
                      

            for i, (k, v) in enumerate(labels_mapping.items()):
                mask_new[i][mask==k] = 1

            d[self.target_keys[ki]] = mask_new[1:]
        return d

class GetMaskPositions(MapTransform):
    def __init__(self, keys: KeysCollection, target_keys: List[str], allow_missing_keys: bool = False) -> None:
        super().__init__(keys, allow_missing_keys)
        self.target_keys = target_keys


    def __call__(self, data):


        d = dict(data)
        for ki, key in enumerate(self.keys):

            mask = data[key]
            num, h, w = mask.shape

            mask_positions = np.zeros((num * 2, w),dtype=np.float32)
                      

            for i in range(num):
                mask_positions[i*2] = np.argmax(mask[i], axis=0)
                mask_positions[i*2 + 1] = h - np.argmax(np.flip(mask[i], axis=0), axis=0)

            # The first and the second masks share a common border
            removal_mask = np.ones(len(mask_positions), dtype=bool)
            removal_mask[1] = False
            mask_positions = mask_positions[removal_mask]
            mask_positions = np.expand_dims(mask_positions, 1)
            d[self.target_keys[ki]] = mask_positions
            d["invalid_masks"] = np.ones_like(mask_positions)
        return d


class ConvertToMultiChannelGOALS(MapTransform):
    """
    Convert labels of ong image into 0,1,2,3
    0 - background
    1 - RNFL
    2 - GCIPL
    3 - Choroid
    """

    def __call__(self, data):


        d = dict(data)
        for key in self.keys:

            mask = data[key]

            mask_new = np.zeros_like(mask,dtype=np.uint8)

            for k,v in labels_mapping.items():
                mask_new[mask==k] = v

            d[key] = mask_new
        return d


transforms = Compose([
    LoadImaged(keys=['image','segmentation']),
    Lambdad(keys=['image','segmentation'], func = lambda x: x.transpose()[0:1]),
    #Lambdad(keys=['image','segmentation'], func = lambda x: np.expand_dims(x, 0)),
    #AddChanneld(keys=['image','segmentation']),
    RandZoomd(keys=["image", "segmentation"], mode=["area", "nearest-exact"], prob=0.3, min_zoom=1.3, max_zoom=1.3),
    Resized(keys=["image", "segmentation"], mode=["area", "nearest-exact"], spatial_size=[-1, 400]), # We first only resize horizontally, for the correct image width
    RandFlipd(keys=["image", "segmentation"], spatial_axis=1, prob=0.3),
    RandHistogramShiftd(keys=["image"], prob=0.3),
    RandAffined(keys=["image", "segmentation"], prob=0.3, shear_range=[(-0.7, 0.7), (0.0, 0.0)], translate_range=[(-300, 100), (0, 0)], mode=["bilinear", "nearest"], padding_mode="zeros"),
    Lambdad(keys=['image','segmentation'], func = lambda x: x[0, ...]),
    ConvertToMultiChannelMasks(keys=['segmentation'], target_keys=["masks"]),
    GetMaskPositions(keys=['masks'], target_keys=["mask_positions"]), #We get the layer position, but on the original height
    #AddChanneld(keys=['image','segmentation']),
    Resized(keys=["image", "segmentation", "masks"], mode=["area", "nearest-exact", "nearest-exact"], spatial_size=[400, 400]),
    Lambdad(keys=['mask_positions'], func = lambda x: x * 400 / 800), #We scale down the positions to have more accurate positions
    #Lambdad(keys=['image'], func = lambda x: np.clip((x - x.mean()) / x.std(), -1, 1)),
    Lambdad(keys=['image'], func = lambda x: 2*(x - x.min()) / (x.max() - x.min()) - 1 ),
    LayerPositionToProbabilityMap(["mask_positions"], target_size=(400,400), target_keys=["mask_probability_map"])
])


In [42]:
ds = OCTDataset(data_dir)

In [43]:
ds[0]

IndexError: list index out of range