In [1]:
import numpy as np
import cv2
from pathlib import Path
import os
import glob
from tqdm.notebook import tqdm
import SimpleITK as sitk
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from monai.networks.nets import UNet
from monai.transforms import Compose, Randomizable, RandomizableTransform
from monai.transforms import LoadNiftid, AddChanneld, ScaleIntensityd, ToTensord, RandFlipd, RandRotate90d
import wandb


'''
1. RETOUCH dataset

    Preprocessing
    - Preprocessed dataset for now

    Training/validation/test
    - split train folder into 75/15/10
    - Use Topcon&Cirrus for validation and testing too? Or use ONLY Topcon&Cirrus for val&test?
        > Paper uses only topcon and cirrus for validation and testing

    SVDNA
    - Create Pytorch Dataset class:
        > define __getitem__ method such that I give it the path to the training set and it returns the SVDNA transformed images

    - when epoch starts, SVDNA is applied to each image
        > for each image, 1/3 chance to choose one of the domains
            > of the n_d images, one is chosen randomly for style transfer
        > for each image, k is sampled between 30 and 50
        > apply SVDNA using sampled image and k
    - Implementation:
        > dict containing spectralis images
        > dict containing cirrus and topcon images

        for epoch in total_epochs:
            source_imgs = spectralis_img
            target_imgs = cirrus_topcon_img

            source_svdna = [svdna(imgs[i], target_imgs):labels[i] for imgs, labels in source_imgs.items()]

            img_dataloader = Dataloader(source_svdna)

    Transformations
    - 


    - Possibilities for datastructure:
        - [{cirrus_img1:cirrus_img1, cirrus_img1_label:cirrus_img1_label}, {cirrus_img2:cirrus_img2, cirrus_img2_label:cirrus_img2_label}, ...]

    
        
https://docs.monai.io/en/stable/networks.html#basicunetplusplus



'''

base_dir = Path(os.getcwd())
data_dir = base_dir / 'data' / 'Retouch-preprocessed'

def get_data_dict(data_dir = Path(os.getcwd()) / 'data' / 'Retouch-preprocessed', dataset = 'train'):
    
    dataset_dir = data_dir / dataset
    

#get_data_dict()


In [2]:
sample_img = '/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release/Cirrus/03a60d9078d35b1488e6030880a29014/reference.mhd'
itk_image = sitk.ReadImage(sample_img)
image_array = sitk.GetArrayViewFromImage(itk_image)

# print the image's dimensions
print(image_array.shape)

# plot the image
for i in range(128):
    if i % 10 == 0:
        pass
        #plt.imshow(image_array[i], cmap='gray')
        #plt.show()

(128, 1024, 512)


In [3]:
name_dir = Path('/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release/')
train_dir = Path('/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/Retouch-Preprocessed/train/')

In [4]:
data = {}
for device in os.listdir(name_dir):
    data[device] = os.listdir(name_dir / device)

for device, vals in data.items():
    print("Device: ", device, "; Number of folders: ", len(vals))

print("Preprocessed Retouch folders: ", len(os.listdir(train_dir)))

Device:  Topcon ; Number of folders:  22
Device:  Spectralis ; Number of folders:  24
Device:  Cirrus ; Number of folders:  24
Preprocessed Retouch folders:  70


In [5]:

def generate_black_images(main_folder, delete_images=False):
    """
    Generate black images for missing files in the label_image folder.

    Args:
        main_folder (str or Path): Path to the main folder.
        delete_images (bool, optional): Flag to delete the generated black images. Defaults to False.
    """
    main_folder = Path(main_folder)

    if not delete_images:
    # Iterate over the subfolders
        for subfolder in sorted(os.listdir(main_folder)):
            subfolder_path = main_folder / subfolder

            # Check if the subfolder contains the 'image' and 'label_image' folders
            if os.path.isdir(subfolder_path) and 'image' in os.listdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                image_folder = subfolder_path / 'image'
                label_folder = subfolder_path / 'label_image'

                # Get the set of filenames in the 'image' folder
                image_files = sorted(set(os.listdir(image_folder)))

                # Get the set of filenames in the 'label_image' folder
                label_files = sorted(set(os.listdir(label_folder)))

                # Find the filenames that are in 'label_image' but not in 'image'
                missing_files = [i for i in image_files if i not in label_files]

                # Create a black image for each missing file
                for file in missing_files:
                    file_path = label_folder / file

                    # find the shape of the input image to create corresponding target
                    file_shape = cv2.imread(str(image_folder / file)).shape
                    black_image = np.zeros(file_shape)

                    # make unique names so files can be deleted again
                    cv2.imwrite(f"{str(file_path)[:-4]}_empty.png", black_image)

    # Delete the generated black images if delete_images flag is True
    if delete_images:
        for subfolder in sorted(os.listdir(main_folder)):
            subfolder_path = main_folder / subfolder
            if os.path.isdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                label_folder = subfolder_path / 'label_image'
                for file in os.listdir(label_folder):
                    file_path = label_folder / file
                    if "_empty" in str(file):
                        os.remove(str(file_path))


In [12]:
#generate_black_images(train_dir, delete_images=False)

In [14]:

class OCTDataset(Dataset):
    def __init__(self, data_path: str, transform=None, generate_empty_labels=False, source_domain='Spectralis'):

        self.data_path = Path(data_path)

        self.transform = transform
        self.source_domain = source_domain

        if generate_empty_labels:
            self.generate_black_images(data_path) # check if this misses arguments
                                                  # change method so it only creates black images for source domain

        self.source_data_paths, self.target_data_paths = self.filter_source_domain(data_path, source_domain=source_domain)

    def __len__(self):
        return len(self.source_data_paths)

    def __getitem__(self, index):
        source_img = load_source_image(self.source_data_paths[index])
        target_img = load_target_image(self.target_data_paths[index])
        
        if self.transform:
            # Apply transformations dynamically during training
            source_img, target_img = self.transform(source_img, target_img)

        return source_img, target_img

    def filter_source_domain(self, source_domain):
        # go into path:
        # '/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release/'
        # and set the containing folders as keys of a dictionary. The values are the folder names in the respective
        # folder.
        # The result should be a data structure holding each domain name and the image folder names that stem from that domain.
        # e.g. {'cirrus':['path1', 'path2', ...], 'topcon':['path1', 'path2', ...]}
        # the method returns the image folder names for the source domain and the target domains separately

        named_domain_folder = Path('/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release')
        target_domains = os.listdir(named_domain_folder)
        target_domains.remove(source_domain)

        # creates dict e.g. {'cirrus':['path1', 'path2', ...], 'topcon':['path1', 'path2', ...]}
        filter_dict = {domain:os.listdir(named_domain_folder / domain) for domain in os.listdir(named_domain_folder)}

        source_data_paths = {source_domain: [img for img in filter_dict[source_domain]]}
        target_data_paths = {target_domain: [img for img in filter_dict[target_domain]] for target_domain in target_domains}

        return source_data_paths, target_data_paths



    def load_source_image(self, source_path):
        # Load the source image
        return source_img

    def load_target_image(self, target_path):
        # Load the target image
        return target_img


    def generate_black_images(self, main_folder, delete_images=False):
        """
        Generate black images for missing files in the label_image folder.

        Args:
            main_folder (str or Path): Path to the main folder.
            delete_images (bool, optional): Flag to delete the generated black images. Defaults to False.
        """

        if not delete_images:
        # Iterate over the subfolders
            for subfolder in sorted(os.listdir(main_folder)):
                subfolder_path = main_folder / subfolder

                # Check if the subfolder contains the 'image' and 'label_image' folders
                if os.path.isdir(subfolder_path) and 'image' in os.listdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                    image_folder = subfolder_path / 'image'
                    label_folder = subfolder_path / 'label_image'

                    # Get the set of filenames in the 'image' folder
                    image_files = sorted(set(os.listdir(image_folder)))

                    # Get the set of filenames in the 'label_image' folder
                    label_files = sorted(set(os.listdir(label_folder)))

                    # Find the filenames that are in 'label_image' but not in 'image'
                    missing_files = [i for i in image_files if i not in label_files]

                    # Create a black image for each missing file
                    for file in missing_files:
                        file_path = label_folder / file

                        # find the shape of the input image to create corresponding target
                        file_shape = cv2.imread(str(image_folder / file)).shape
                        black_image = np.zeros(file_shape)

                        # make unique names so files can be deleted again
                        cv2.imwrite(f"{str(file_path)[:-4]}_empty.png", black_image)

        # Delete the generated black images if delete_images flag is True
        if delete_images:
            for subfolder in sorted(os.listdir(main_folder)):
                subfolder_path = main_folder / subfolder
                if os.path.isdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                    label_folder = subfolder_path / 'label_image'
                    for file in os.listdir(label_folder):
                        file_path = label_folder / file
                        if "_empty" in str(file):
                            os.remove(str(file_path))

    def delete_generated_labels(self):
        # Delete the generated black images
        self.generate_black_images(delete_images=True)

    


# Define the randomizable transformation class for dynamic transformations
class RandomTransform(RandomizableTransform):
    def __call__(self, source_img, target_img):
        self.randomize()
        # Implement your dynamic random transformations
        # Apply random transformations to both source and target images
        return transformed_source_img, transformed_target_img




# Define the UNet model
class SegmentationUNet(nn.Module):
    def __init__(self):
        super(SegmentationUNet, self).__init__()
        # Define your UNet architecture using MONAI

    def forward(self, x):
        # Implement your forward pass logic
        return x
    



NameError: name 'Dataset' is not defined

In [None]:
training_ready = False

if training_ready:

    # Initialize wandb
    wandb.init(project="PracticalWorkinAI", name="svdna_reproduction_retouch_only")

    # Instantiate the model, loss function, and optimizer
    model = SegmentationUNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Define the dataset paths
    source_data_paths = [...]  # List of paths to source domain data
    target_data_paths = [...]  # List of paths to target domain data

    # Create the dataset and dataloader
    transform = Compose([
        RandomTransform(),
        svdna,  # Apply SVDNA function
        ToTensord()
    ])

    dataset = OCTDataset(source_data_paths, target_data_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    # Training loop
    num_epochs = 10
    for epoch in range(num_epochs):
        for batch_idx, (source_img, target_img) in enumerate(dataloader):
            optimizer.zero_grad()

            # Forward pass
            output = model(source_img)

            # Compute loss
            loss = criterion(output, target_img)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Log loss to wandb
            wandb.log({"Loss": loss.item(), "Epoch": epoch, "Batch": batch_idx})

            # Print progress
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item()}')

    # Save the trained model if needed
    torch.save(model.state_dict(), 'segmentation_unet.pth')

    # Finish wandb run
    wandb.finish()

In [19]:
source_domain = 'Spectralis'
data_path = '/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/Retouch-Preprocessed/train/'
named_domain_folder = Path('/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release/')
filter_dict = {domain:os.listdir(named_domain_folder / domain) for domain in os.listdir(named_domain_folder)}
source_data_paths = {source_domain: [img for img in filter_dict[source_domain] if img in os.listdir(data_path)]}

In [20]:
source_data_paths

{'Spectralis': ['f520e52af71723efed0d6fecf934075b',
  '4223bd8c7002828a266d03b666b5c492',
  'dccf9863a1914de55b78a98f0791f605',
  'fe02f982b78218ab05c755c01c7876b1',
  'be0a143d6181d7f4f366fe2c0cc23075',
  '7501081e3e7577af524c6f7703d8d538',
  '11f08e45fcac03ea64b7138754117435',
  'c7e339f972f58b5388fbeb13adc670eb',
  '83e7b9339019b549c87830f90c560fea',
  '7096ef65c25c1c87068ed9cdf2b73a59',
  'f4b90ca25b223e4598c32caab90dc4aa',
  'bd14b46237f99db413f29d4797f246e1',
  '8844486c9f1b8952d0988d0acbea86c8',
  '92a92b80b8a4c48dc4775dd60819cf3b',
  'af3868172d16615c9556cacabcc80d66',
  '7b2607e057592d507c4ec4732bae64c2',
  '4a8a81b1c06072385738775dccdc7942',
  '76ebaa858a59427f392d97ed4b894f6d',
  '8945696b67c38140f30e99c0050aa925',
  'ecf4cb44944f518c82a4f108b94b0dfd',
  'd1ad979857c1877496207f3d4e5f5c52',
  '88bab0b300b19bdac3dc431cb3507a2a',
  '7456e703d778e17f29e42f97615f4a68',
  '7a05c267fdde4c819b19eb30da70d387']}

In [12]:
d = [{'img':[1,2,3], 'label':[1,2,3]}, {'img':3, 'label':4}, {'img':5, 'label':6}, {'img':7, 'label':8}]
d[1].values()

dict_values([3, 4])

In [13]:
# When applying SVDNA, be careful that not the same target images are chosen every epoch.
Path(Path.cwd() / 'data/RETOUCH/TrainingSet-Release')

PosixPath('/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/practical/data/RETOUCH/TrainingSet-Release')

In [36]:
def filter_source_domain(source_domain):
    # go into path:
    # '/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release/'
    # and set the containing folders as keys of a dictionary. The values are the folder names in the respective
    # folder.
    # The created dictionary serves as a tool for filtering the folders that contain the source images.
    # The folder "self.data_path" contains folder names that correspond to the folder names in the dictionary's values.
    # The goal is to return a variable source_data_paths containing the paths to the source domain images, while 
    # the variable target_data_paths contains the paths to the target domain images (Cirrus and Topcon).
    # source_data_paths should be a dictionary with the structure {Spectralis: [path1, path2, ...]}
    # target_data_paths should be a dictionary with the structure {Cirrus: [path1, path2, ...], Topcon: [path1, path2, ...]}
    # The dictionary should be created in the __init__ method of the class.
    # The method should return the source_data_paths and target_data_paths variables.

    named_domain_folder = Path('/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release')
    target_domains = os.listdir(named_domain_folder)
    target_domains.remove(source_domain)

    # creates dict e.g. {'cirrus':['path1', 'path2', ...], 'topcon':['path1', 'path2', ...]}
    filter_dict = {domain:os.listdir(named_domain_folder / domain) for domain in os.listdir(named_domain_folder)}

    source_data_paths = {source_domain: [img for img in filter_dict[source_domain]]}
    target_data_paths = {target_domain: [img for img in filter_dict[target_domain]] for target_domain in target_domains}

    return source_data_paths, target_data_paths
    

filter_source_domain('Spectralis')

[22, 24]


In [38]:
!ls

SVDNA_demo_updated.ipynb   main.ipynb
[34mdata[m[m                       main.py
layer_transformation(1).py unet_training_dict.py
