In [12]:
import os
import numpy as np
import pickle
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from monai.transforms import EnsureChannelFirst, LoadImage, Compose, ScaleIntensity
from monai.data import ArrayDataset
from torch.utils.data import DataLoader

import polarTransform

In [13]:
transformer = Compose([LoadImage(image_only=True),
                       EnsureChannelFirst(),
                       ScaleIntensity()])

train_image_path = "data/REFUGE2/Train/Images/"
train_dm_path = "data/REFUGE2/Train/Disc_Masks/"
test_image_path = "data/REFUGE2/Test/Images/"
test_dm_path = "data/REFUGE2/Test/Disc_Masks/"
val_image_path = "data/REFUGE2/Validation/Images/"
val_dm_path = "data/REFUGE2/Validation/Disc_Masks/"

train_data = ArrayDataset(img=sorted([train_image_path + file for file in os.listdir(train_image_path)]),
                          img_transform=transformer,
                          seg=sorted([train_dm_path + file for file in os.listdir(train_dm_path)]),
                          seg_transform=transformer)

train_dataloader = DataLoader(train_data,
                              batch_size=1,
                              shuffle=False,
                              num_workers=2)

val_data = ArrayDataset(img=sorted([val_image_path + file for file in os.listdir(val_image_path)]),
                        img_transform=transformer,
                        seg=sorted([val_dm_path + file for file in os.listdir(val_dm_path)]),
                        seg_transform=transformer)

val_dataloader = DataLoader(val_data,
                            batch_size=1,
                            shuffle=False,
                            num_workers=2)

test_data = ArrayDataset(img=sorted([test_image_path + file for file in os.listdir(test_image_path)]),
                         img_transform=transformer,
                         seg=sorted([test_dm_path + file for file in os.listdir(test_dm_path)]),
                         seg_transform=transformer)

test_dataloader = DataLoader(test_data,
                             batch_size=1,
                             shuffle=False,
                             num_workers=2)

In [14]:
new_path = "data_test/REFUGE2/Validation/"
old_path = "data/REFUGE2/Validation/"
names = sorted(os.listdir(old_path + "Images/"))

if not os.path.exists(new_path + "Disc_Masks"):
    os.makedirs(new_path + "Disc_Masks")
    
if not os.path.exists(new_path + "Images"):
    os.makedirs(new_path + "Images")

settings_dict = {}

for j, batch in tqdm(enumerate(val_dataloader)):
    image = np.array(255*batch[0][0].permute(1, 2, 0)).astype(np.uint8)
    mask = np.array(255*batch[1][0].permute(1, 2, 0)).astype(np.uint8)
    cup_points = np.where(mask[:, :, 1] == 255)
    center = (round(cup_points[1].mean()), round(cup_points[0].mean()))
    transformed_mask, settings = polarTransform.convertToPolarImage(mask, 
                                                                center=center, 
                                                                hasColor=True, 
                                                                radiusSize=mask.shape[0], 
                                                                angleSize=mask.shape[1])
    
    transformed_mask = transformed_mask.astype(np.uint8)
    new_mask = Image.fromarray(transformed_mask)
    new_mask.save(new_path + "Disc_Masks/" + names[j])
    
    transformed_image = settings.convertToPolarImage(image)
    new_image = Image.fromarray(transformed_image)
    new_image.save(new_path + "Images/" + names[j])
    
    settings_dict[names[j]] = settings
    

0it [00:00, ?it/s]

In [15]:
with open(new_path + "settings.pickle", "wb") as handle:
    pickle.dump(settings_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)