In [None]:
import SimpleITK as sitk
import vtk
import os
import numpy as np
from vtk.util import numpy_support
from PIL import Image, ImageOps

import tensorflow as tf
import os
import matplotlib.pyplot as plt

from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img


In [None]:
def read_raw_image(file_path):
    raw_data = np.fromfile(file_path, dtype=np.uint16)
    nb_slices = raw_data.shape[0] / (256*256)
    image = raw_data.reshape((int(nb_slices), 256, 256))
    normalized_image = ((image / np.max(image))) * 255
    normalized_image = normalized_image.astype(np.uint8)
    return normalized_image

def read_vtk(file_path):
    reader = vtk.vtkDataSetReader()
    reader.SetFileName(file_path)
    reader.Update()


    poly_data = reader.GetOutput()
    dims = poly_data.GetDimensions()

    return poly_data, dims

def vtk_to_numpy(poly_data, dims):
    num_slices = dims[1]
    points = poly_data.GetPointData()
    vtk_array = points.GetScalars()
    numpy_data = numpy_support.vtk_to_numpy(vtk_array)

    # on reforme l'image pour qu'elle soit en 3D et pour qu'on puisse avoir
    # accès a chaque coupe
    numpy_data = np.reshape(numpy_data, (dims[2],dims[1],dims[0]))
    numpy_array = []

    for i in range(num_slices):
        # Obtenir la tranche actuelle en coupe coronnale
        # *255 pour rescale image et que le pixel de valeur "1" soit dessiné en blanc
        slice_data = numpy_data[:,i,:]*255
        array_coronal = np.array(slice_data)
        numpy_array.append(array_coronal)
    
    return numpy_array


In [None]:
def save_slices_as_pngs(image, segmentation, dims, file_index, output_folder):
    os.makedirs(output_folder + "_img/images", exist_ok=True)
    os.makedirs(output_folder + "_mask/mask", exist_ok=True)

    segmentation_array = vtk_to_numpy(segmentation, dims)

    for i in range(image.shape[1]):
        # On récupère chaques images en coupe coronnale
        segmented_mask = np.array(segmentation_array[i])

        if not (np.all(segmented_mask == 0)) :
            # On convertit les résultat en image de taille (256,256)
            image_slice = Image.fromarray(image[:,i,:]).resize([256,256])
            # Supprimez la ligne en dessous si vous voulez pas égaliser les images
            image_slice = ImageOps.equalize(image_slice)
            image_slice.save(os.path.join(output_folder + "_img/images", f'ID_{str(file_index+i).zfill(6)}.png'))
            image_mask = Image.fromarray(segmented_mask.astype(np.uint8)).resize([256,256])
            image_mask.save(os.path.join(output_folder + "_mask/mask", f'ID_{str(file_index+i).zfill(6)}.png'))

In [None]:
def image_process(rep_path, output_folder):
    image = ''
    segmentation = ''
    dims = ''
    i = 0
    for case_path in os.listdir(rep_path):
        case = os.path.join(rep_path, case_path)

        if (os.path.isdir(case)):
            # On rename au cas ou il y'aurait un pb avec le caracere espace
            case_ = case.replace(" ", "_")
            os.rename(case, case_)
            case = case_
            for t_path in os.listdir(case):
                t = os.path.join(case, t_path)
                for file in os.listdir(t):
                    link = os.path.join(t, file)
                    if (file.endswith(".raw")):
                        image = read_raw_image(link)
                    if (file.endswith(".vtk")):
                        segmentation, dims = read_vtk(link)
                if (segmentation != '' and dims != ''):
                    save_slices_as_pngs(image, segmentation, dims, i*256, output_folder)
                    i += 1
                    image = segmentation = dims = ''
                    

In [None]:
rep_path = "./Segmentation_diaphragm/"
output_folder = "data"

image_process(rep_path, output_folder)

In [None]:


images_directory = './data_img/'
masks_directory = './data_mask/'

data_gen_args = dict(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=90,
    width_shift_range=0.2,
    height_shift_range=0.2,
    zoom_range=0.2,
    shear_range=0.2
)

# Création des générateurs d'images et de masques
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

seed = 1
image_generator = image_datagen.flow_from_directory(
    images_directory,
    class_mode=None,
    seed=seed
)

mask_generator = mask_datagen.flow_from_directory(
    masks_directory,
    class_mode=None,
    seed=seed
)

train_generator = zip(image_generator, mask_generator)

num_iterations = 150

# Boucle sur plusieurs itérations pour obtenir différentes images et masques augmentés
for i in range(num_iterations):
    augmented_images, augmented_masks = next(train_generator)

    for j in range(augmented_images.shape[0]):
        img = array_to_img(augmented_images[j])
        mask = array_to_img(augmented_masks[j])

        img.save(os.path.join(images_directory + "images/", f"{i * augmented_images.shape[0] + j}_aug.png"))
        mask.save(os.path.join(masks_directory + "mask/", f"{i * augmented_masks.shape[0] + j}_aug.png"))

    