In [65]:
import numpy as np
import monai
import cv2
import torch
from monai.transforms import *
from pathlib import Path
import os
import glob
from tqdm.notebook import tqdm
#import SimpleITK as sitk
import matplotlib.pyplot as plt


'''
1. RETOUCH dataset

    Preprocessing
    - Preprocessed dataset for now

    Training/validation/test
    - split train folder into 75/15/10
    - Use Topcon&Cirrus for validation and testing too? Or use ONLY Topcon&Cirrus for val&test?
        > Paper uses only topcon and cirrus for validation and testing

    SVDNA
    - Create Pytorch Dataset class:
        > define __getitem__ method such that I give it the path to the training set and it returns the SVDNA transformed images

    - when epoch starts, SVDNA is applied to each image
        > for each image, 1/3 chance to choose one of the domains
            > of the n_d images, one is chosen randomly for style transfer
        > for each image, k is sampled between 30 and 50
        > apply SVDNA using sampled image and k
    - Implementation:
        > dict containing spectralis images
        > dict containing cirrus and topcon images

        for epoch in total_epochs:
            source_imgs = spectralis_img
            target_imgs = cirrus_topcon_img

            source_svdna = [svdna(imgs[i], target_imgs):labels[i] for imgs, labels in source_imgs.items()]

            img_dataloader = Dataloader(source_svdna)

    Transformations
    - 


    - Possibilities for datastructure:
        - dict(source:{Cirrus:{images, labels}}, target{Spectralis:images, Topcon:images } )
        - dict(Cirrus:{images, labels}, Spectralis:images, Topcon:images)

    
        
https://docs.monai.io/en/stable/networks.html#basicunetplusplus


First of all, I need to get the data and create a pipeline for using transformations during training.
I need to create a pipeline for the training data and another for the validation data.
I need to get a UNet++ from monai and train it with the training data.
I need to get SVDNA and apply it to the data as the first transformation
Finally, during training, first SVDNA is applied, then any other transformation.



'''

base_dir = Path(os.getcwd())
data_dir = base_dir / 'data' / 'Retouch-preprocessed'

def get_data_dict(data_dir = Path(os.getcwd()) / 'data' / 'Retouch-preprocessed', dataset = 'train'):
    
    dataset_dir = data_dir / dataset
    

get_data_dict()


In [63]:
sample_img = '/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release/Cirrus/03a60d9078d35b1488e6030880a29014/reference.mhd'
itk_image = sitk.ReadImage(sample_img)
image_array = sitk.GetArrayViewFromImage(itk_image)

# print the image's dimensions
print(image_array.shape)

# plot the image
for i in range(128):
    if i % 10 == 0:
        pass
        #plt.imshow(image_array[i], cmap='gray')
        #plt.show()

(128, 1024, 512)


In [50]:
name_dir = Path('/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release/')
train_dir = Path('/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/Retouch-Preprocessed/train/')

In [61]:
data = {}
for device in os.listdir(name_dir):
    data[device] = os.listdir(name_dir / device)

for device, vals in data.items():
    print("Device: ", device, "; Number of folders: ", len(vals))

print("Preprocessed Retouch folders: ", len(os.listdir(train_dir)))

Device:  Topcon ; Number of folders:  22
Device:  Spectralis ; Number of folders:  24
Device:  Cirrus ; Number of folders:  24
Preprocessed Retouch folders:  70


In [136]:

def generate_black_images(main_folder, delete_images=False):
    """
    Generate black images for missing files in the label_image folder.

    Args:
        main_folder (str or Path): Path to the main folder.
        delete_images (bool, optional): Flag to delete the generated black images. Defaults to False.
    """
    main_folder = Path(main_folder)

    if not delete_images:
    # Iterate over the subfolders
        for subfolder in sorted(os.listdir(main_folder)):
            subfolder_path = main_folder / subfolder

            # Check if the subfolder contains the 'image' and 'label_image' folders
            if os.path.isdir(subfolder_path) and 'image' in os.listdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                image_folder = subfolder_path / 'image'
                label_folder = subfolder_path / 'label_image'

                # Get the set of filenames in the 'image' folder
                image_files = sorted(set(os.listdir(image_folder)))

                # Get the set of filenames in the 'label_image' folder
                label_files = sorted(set(os.listdir(label_folder)))

                # Find the filenames that are in 'label_image' but not in 'image'
                missing_files = [i for i in image_files if i not in label_files]

                # Create a black image for each missing file
                for file in missing_files:
                    file_path = label_folder / file

                    # find the shape of the input image to create corresponding target
                    file_shape = cv2.imread(str(image_folder / file)).shape
                    black_image = np.zeros(file_shape)

                    # make unique names so files can be deleted again
                    cv2.imwrite(f"{str(file_path)[:-4]}_empty.png", black_image)

    # Delete the generated black images if delete_images flag is True
    if delete_images:
        for subfolder in sorted(os.listdir(main_folder)):
            subfolder_path = main_folder / subfolder
            if os.path.isdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                label_folder = subfolder_path / 'label_image'
                for file in os.listdir(label_folder):
                    file_path = label_folder / file
                    if "_empty" in str(file):
                        os.remove(str(file_path))


In [138]:
generate_black_images(train_dir, delete_images=False)

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from monai.networks.nets import UNet
from monai.transforms import Compose, Randomizable, RandomizableTransform
from monai.transforms import LoadNiftid, AddChanneld, ScaleIntensityd, ToTensord, RandFlipd, RandRotate90d
import wandb

# Initialize wandb
wandb.init(project="PracticalWorkinAI", name="svdna_reproduction_retouch_only")

# Define the SVDNA function
def svdna(k,target_path,src_path,histo_matching_degree=0.5):

    target_img = Image.open(target_path).convert("L")
    src_img = Image.open(src_path).convert("L")


    resized_target=np.array(target_img.resize((IMAGE_SIZE,IMAGE_SIZE), Image.NEAREST))
    resized_src=np.array(src_img.resize((IMAGE_SIZE,IMAGE_SIZE), Image.NEAREST))

    u_target,s_target,vh_target=np.linalg.svd(resized_target,full_matrices=False)
    u_source,s_source,vh_source=np.linalg.svd(resized_src,full_matrices=False)

    thresholded_singular_target=s_target
    thresholded_singular_target[0:k]=0

    thresholded_singular_source=s_source
    thresholded_singular_source[k:]=0

    target_style=np.array([np.dot(u_target, np.dot(np.diag(thresholded_singular_target), vh_target))])

    content_src=np.array([np.dot(u_source, np.dot(np.diag(thresholded_singular_source), vh_source))])
    content_trgt=resized_target-target_style

    noise_adapted_im=content_src+target_style

    noise_adapted_im_clipped=np.squeeze(noise_adapted_im).clip(0,255).astype(np.uint8)

    transformHist = A.Compose([
        A.HistogramMatching([target_path], blend_ratio=(histo_matching_degree, histo_matching_degree), read_fn=readIm, p=1)
    ])

    image = np.array(Image.open(src_path).resize((IMAGE_SIZE,IMAGE_SIZE)))

    transformed = transformHist(image=noise_adapted_im_clipped)
    svdna_im = transformed["image"]

    return resized_src,resized_target,content_src,np.squeeze(target_style), svdna_im,noise_adapted_im_clipped


# Define the custom dataset class
class OCTDataset(Dataset):
    def __init__(self, data_path, transform=None):

        self.transform = transform
        self.source_data_paths, self.target_data_paths = self.filter_source_domain(data_path)

    def __len__(self):
        return len(self.source_data_paths)

    def __getitem__(self, index):
        source_img = load_source_image(self.source_data_paths[index])
        target_img = load_target_image(self.target_data_paths[index])
        
        if self.transform:
            # Apply transformations dynamically during training
            source_img, target_img = self.transform(source_img, target_img)

        return source_img, target_img

    def filter_source_domain(self, source_domain):
        # go into path:
        # '/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release/'
        # and set the containing folders as keys of a dictionary. The values are the folder names in the respective
        # folder.
        # The created dictionary serves as a tool for filtering the folders that contain the source images.
        # The folder "self.data_path" contains folder names that correspond to the folder names in the dictionary's values.
        # The goal is to return a variable source_data_paths containing the paths to the source domain images, while 
        # the variable target_data_paths contains the paths to the target domain images (Cirrus and Topcon).
        # source_data_paths should be a dictionary with the structure {Spectralis: [path1, path2, ...]}
        # target_data_paths should be a dictionary with the structure {Cirrus: [path1, path2, ...], Topcon: [path1, path2, ...]}
        # The dictionary should be created in the __init__ method of the class.
        # The method should return the source_data_paths and target_data_paths variables.

        named_domain_folder = Path('/Users/moritz/Documents/Master/OPTIMA_Masterarbeit/data/RETOUCH/TrainingSet-Release/')

        domains = os.listdir(named_domain_folder)
        target_domains = domains.remove(source_domain)

        # creates dict e.g. {'cirrus':['path1', 'path2', ...], 'topcon':['path1', 'path2', ...]}
        filter_dict = {domain:os.listdir(named_domain_folder / domain) for domain in os.listdir(named_domain_folder)}


        source_data_paths = {source_domain: [if img in self.data_path]}
        target_data_paths = {}



    def load_source_image(self, source_path):
        # Load the source image
        return source_img

    def load_target_image(self, target_path):
        # Load the target image
        return target_img


    def generate_black_images(self, delete_images=False):
        """
        Generate black images for missing files in the label_image folder.

        Args:
            main_folder (str or Path): Path to the main folder.
            delete_images (bool, optional): Flag to delete the generated black images. Defaults to False.
        """
        main_folder = Path(self.source_data_paths)

        if not delete_images:
        # Iterate over the subfolders
            for subfolder in sorted(os.listdir(main_folder)):
                subfolder_path = main_folder / subfolder

                # Check if the subfolder contains the 'image' and 'label_image' folders
                if os.path.isdir(subfolder_path) and 'image' in os.listdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                    image_folder = subfolder_path / 'image'
                    label_folder = subfolder_path / 'label_image'

                    # Get the set of filenames in the 'image' folder
                    image_files = sorted(set(os.listdir(image_folder)))

                    # Get the set of filenames in the 'label_image' folder
                    label_files = sorted(set(os.listdir(label_folder)))

                    # Find the filenames that are in 'label_image' but not in 'image'
                    missing_files = [i for i in image_files if i not in label_files]

                    # Create a black image for each missing file
                    for file in missing_files:
                        file_path = label_folder / file

                        # find the shape of the input image to create corresponding target
                        file_shape = cv2.imread(str(image_folder / file)).shape
                        black_image = np.zeros(file_shape)

                        # make unique names so files can be deleted again
                        cv2.imwrite(f"{str(file_path)[:-4]}_empty.png", black_image)

        # Delete the generated black images if delete_images flag is True
        if delete_images:
            for subfolder in sorted(os.listdir(main_folder)):
                subfolder_path = main_folder / subfolder
                if os.path.isdir(subfolder_path) and 'label_image' in os.listdir(subfolder_path):
                    label_folder = subfolder_path / 'label_image'
                    for file in os.listdir(label_folder):
                        file_path = label_folder / file
                        if "_empty" in str(file):
                            os.remove(str(file_path))


    

# Define the randomizable transformation class for dynamic transformations
class RandomTransform(RandomizableTransform):
    def __call__(self, source_img, target_img):
        self.randomize()
        # Implement your dynamic random transformations
        # Apply random transformations to both source and target images
        return transformed_source_img, transformed_target_img

# Define the UNet model
class SegmentationUNet(nn.Module):
    def __init__(self):
        super(SegmentationUNet, self).__init__()
        # Define your UNet architecture using MONAI

    def forward(self, x):
        # Implement your forward pass logic
        return x

# Instantiate the model, loss function, and optimizer
model = SegmentationUNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define the dataset paths
source_data_paths = [...]  # List of paths to source domain data
target_data_paths = [...]  # List of paths to target domain data

# Create the dataset and dataloader
transform = Compose([
    RandomTransform(),
    svdna,  # Apply SVDNA function
    ToTensord()
])

dataset = OCTDataset(source_data_paths, target_data_paths, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for batch_idx, (source_img, target_img) in enumerate(dataloader):
        optimizer.zero_grad()

        # Forward pass
        output = model(source_img)

        # Compute loss
        loss = criterion(output, target_img)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Log loss to wandb
        wandb.log({"Loss": loss.item(), "Epoch": epoch, "Batch": batch_idx})

        # Print progress
        print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item()}')

# Save the trained model if needed
torch.save(model.state_dict(), 'segmentation_unet.pth')

# Finish wandb run
wandb.finish()


['80293194ec21abbfd3da95f83cdbd5ab',
 '93ae688615be155f7207e4c1f45eb725',
 'd40ef678ef3dff06eb3aa3a4f4ae99e6',
 'a953b47afe290923d91a08610799ea42',
 '6a96f7c246314e625bacfb78d18ad09e',
 '70f44050a581f94509fd63bcbd797971',
 '92a0bdb6013c8d0b4770e4e907d3d8cf',
 '8e78a8e83df1029abbd2479ff1d5f92f',
 'a6be6cdf56b8354d72483606af62b8bb',
 '6af50661914d9c03417adbe6eb91ebae',
 '03a60d9078d35b1488e6030880a29014',
 '644b104a47cf811b3828eb4655284991',
 '091e68f597bb8e45ce9478363ef686b3',
 '42533245b1c2cf49d9d5c554dbdee5a0',
 '1e0e71d2acdc57f10ab6712ab87b2ef7',
 '5fef2f4c2adcda3b9e07801f713d4ccf',
 '6107c818602ec2abf340a89b8e225b12',
 '8248fbc73c9381d0bf3e909c7d732f84',
 '8482e2f4d2aae33a5eac5178801df9fb',
 '7291fbbae825c6a9230b8787ee7645d2',
 'b35010885c2127df56140ffbfc3db3e3',
 '4d1f2722e3f9c55b689621d1228526db',
 '3c68f67cd2e2b41afa54bf6059f509d1',
 'e6747bd39c2fc8a907a7193731724eab']