Skip to content

Commit

Permalink
Move normalization to load sample to save memory
Browse files Browse the repository at this point in the history
  • Loading branch information
danifranco committed Oct 6, 2023
1 parent 4017a05 commit 27120f5
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 221 deletions.
240 changes: 75 additions & 165 deletions data/generators/pair_base_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,189 +451,96 @@ def __init__(self, ndim, X, Y, seed=0, in_memory=True, data_paths=None, da=True,
raise ValueError("Different number of raw and ground truth images ({} vs {}). "
"Please check the data!".format(len(self.data_paths), len(self.data_mask_path)))
self.length = len(self.data_paths)

self.first_no_bin_channel = -1
self.div_Y_on_load_bin_channels = False
self.div_Y_on_load_no_bin_channels = False

# X data analysis
self.X_norm = {}
if norm_custom_mean is not None and norm_custom_std is not None:
nsamples = len(self.data_paths)
else:
self.X = X
if self.Y_provided:
self.Y = Y
self.length = len(self.X)

self.first_no_bin_channel = -1
self.div_Y_on_load_bin_channels = False
self.div_Y_on_load_no_bin_channels = False

# X data analysis
self.X_norm = {}
self.X_norm['type'] = 'not_set_yet'
if norm_custom_mean is not None and norm_custom_std is not None:
if not in_memory:
sam = []
for i in range(len(self.data_paths)):
img, _ = self.load_sample(i)
sam.append(img)
if shape[-1] != img.shape[-1]:
raise ValueError("Channel of the DATA.PATCH_SIZE given {} does not correspond with the loaded image {}. "
"Please, check the channels of the images!".format(shape[-1], img.shape[-1]))
"Please, check the channels of the images!".format(shape[-1], img.shape[-1]))
if not random_crops_in_DA and shape != img.shape:
raise ValueError("Image shape {} does not match provided DATA.PATCH_SIZE {}. If you want to ensure "
"that PATCH_SIZE you have two options: 1) Set IN_MEMORY = True (as the images will be cropped "
"automatically to that DATA.PATCH_SIZE) ; 2) Set DATA.EXTRACT_RANDOM_PATCH = True to extract a patch "
"(if possible) from loaded image".format(img.shape, shape))

sam = np.array(sam)
self.X_norm['type'] = 'custom'
self.X_norm['mean'] = np.mean(sam)
self.X_norm['std'] = np.std(sam)
self.X_norm['orig_dtype'] = img.dtype
del sam
else:
self.X_norm['type'] = 'div'
img, _ = self.load_sample(0)
img, nsteps = norm_range01(img)
self.X_norm.update(nsteps)
if shape[-1] != img.shape[-1]:
raise ValueError("Channel of the patch size given {} does not correspond with the loaded image {}. "
"Please, check the channels of the images!".format(shape[-1], img.shape[-1]))
if not random_crops_in_DA and shape != img.shape:
raise ValueError("Image shape {} does not match provided DATA.PATCH_SIZE {}. If you want to ensure "
else:
self.X_norm['mean'] = np.mean(self.X)
self.X_norm['std'] = np.std(self.X)
self.X_norm['orig_dtype'] = self.X.dtype
self.X_norm['type'] = 'custom'
else:
img, _ = self.load_sample(0)
img, nsteps = norm_range01(img)
self.X_norm.update(nsteps)
if shape[-1] != img.shape[-1]:
raise ValueError("Channel of the patch size given {} does not correspond with the loaded image {}. "
"Please, check the channels of the images!".format(shape[-1], img.shape[-1]))
if not random_crops_in_DA and shape != img.shape:
raise ValueError("Image shape {} does not match provided DATA.PATCH_SIZE {}. If you want to ensure "
"that PATCH_SIZE you have two options: 1) Set IN_MEMORY = True (as the images will be cropped "
"automatically to that DATA.PATCH_SIZE) ; 2) Set DATA.EXTRACT_RANDOM_PATCH = True to extract a patch "
"(if possible) from loaded image".format(img.shape, shape))

self.X_channels = img.shape[-1]
self.Y_channels = img.shape[-1]
self.shape = shape if random_crops_in_DA else img.shape
del img
self.X_norm['type'] = 'div'

# Y data analysis
if self.Y_provided:
found = False
# Loop over a few masks to ensure foreground class is present to decide normalization
for i in range(min(10,len(self.data_mask_path))):
_, mask = self.load_sample(i)
if self.normalizeY == 'as_mask':
# Store wheter all channels of the gt are binary or not (i.e. distance transform channel)
if not found and (mask.dtype is np.dtype(np.float32) or mask.dtype is np.dtype(np.float64)) and instance_problem:
for j in range(mask.shape[-1]):
if len(np.unique(mask[...,j])) > 2:
self.first_no_bin_channel = j
found = True
break

# If found high values divide masks
if self.first_no_bin_channel != -1:
if self.first_no_bin_channel != 0:
if np.max(mask[...,:self.first_no_bin_channel]) > 30: self.div_Y_on_load_bin_channels = True
if np.max(mask[...,self.first_no_bin_channel:]) > 30: self.div_Y_on_load_no_bin_channels = True
else:
if np.max(mask) > 30: self.div_Y_on_load_bin_channels = True
if np.max(mask) > 30: self.div_Y_on_load_no_bin_channels = True
else:
if np.max(mask) > 30: self.div_Y_on_load_bin_channels = True

self.Y_channels = mask.shape[-1]
self.Y_dtype = mask.dtype
del mask
else:
self.X = X
if self.Y_provided:
self.Y = Y
self.Y_channels = Y.shape[-1] if type(Y) != list else Y[0].shape[-1]
else:
self.Y_channels = X.shape[-1] if type(X) != list else X[0].shape[-1]
self.X_channels = X.shape[-1] if type(X) != list else X[0].shape[-1]
self.length = len(self.X)
if random_crops_in_DA:
self.shape = shape
else:
self.shape = X.shape[1:] if type(X) != list else X[0].shape[1:]

# X data analysis and normalization
self.X_norm = {}
if norm_custom_mean is not None and norm_custom_std is not None:
self.X_norm['type'] = 'custom'
self.X_norm['mean'] = np.mean(self.X)
self.X_norm['std'] = np.std(self.X)
self.X_channels = img.shape[-1]
self.shape = shape if random_crops_in_DA else img.shape
del img

self.X = normalize(self.X, self.X_norm['mean'], self.X_norm['std'])
else:
self.X_norm['type'] = 'div'
if type(X) != list:
self.X, normx = norm_range01(self.X)
else:
self.X[0], normx = norm_range01(self.X[0])
for i in range(1,len(self.X)):
self.X[i], _ = norm_range01(self.X[i])
self.X_norm.update(normx)

# Y data analysis
self.first_no_bin_channel = -1
if self.Y_provided:
self.div_Y_on_load_bin_channels = False
self.div_Y_on_load_no_bin_channels = False
# Y data analysis
if self.Y_provided:
found = False
# Loop over a few masks to ensure foreground class is present to decide normalization
n_samples = len(self.data_mask_path) if not in_memory else len(self.Y)
for i in range(n_samples):
_, mask = self.load_sample(i)
if self.normalizeY == 'as_mask':
if (_Y.dtype is np.dtype(np.float32) or _Y.dtype is np.dtype(np.float64)) and instance_problem:
for i in range(_Y.shape[-1]):
if len(np.unique(_Y[...,i])) > 2:
self.first_no_bin_channel = i
# Store wheter all channels of the gt are binary or not (i.e. distance transform channel)
if not found and (mask.dtype is np.dtype(np.float32) or mask.dtype is np.dtype(np.float64)) and instance_problem:
for j in range(mask.shape[-1]):
if len(np.unique(mask[...,j])) > 2:
self.first_no_bin_channel = j
found = True
break

# If found high values divide masks
if self.first_no_bin_channel != -1:
if self.first_no_bin_channel != 0:
self.div_Y_on_load_bin_channels = True if np.max(_Y[...,:self.first_no_bin_channel]) > 30 else False
self.div_Y_on_load_no_bin_channels = True if np.max(_Y[...,self.first_no_bin_channel:]) > 30 else False
else:
self.div_Y_on_load_bin_channels = False
self.div_Y_on_load_no_bin_channels = True if np.max(_Y) > 30 else False
else:
self.div_Y_on_load_bin_channels = True if np.max(_Y) > 30 else False

# Y normalization
if type(Y) != list:
if self.first_no_bin_channel != -1:
if self.div_Y_on_load_bin_channels:
self.Y[...,:self.first_no_bin_channel] = self.Y[...,:self.first_no_bin_channel]/255
if self.div_Y_on_load_no_bin_channels:
if self.first_no_bin_channel != 0:
self.Y[...,self.first_no_bin_channel:] = self.Y[...,self.first_no_bin_channel:]/255
else:
self.Y = self.Y/255
if np.max(mask[...,:self.first_no_bin_channel]) > 30: self.div_Y_on_load_bin_channels = True
if np.max(mask[...,self.first_no_bin_channel:]) > 30: self.div_Y_on_load_no_bin_channels = True
else:
if self.div_Y_on_load_bin_channels: self.Y = self.Y/255
if np.max(mask) > 30: self.div_Y_on_load_bin_channels = True
if np.max(mask) > 30: self.div_Y_on_load_no_bin_channels = True
else:
for i in range(len(self.Y)):
if self.first_no_bin_channel != -1:
if self.div_Y_on_load_bin_channels:
self.Y[i][...,:self.first_no_bin_channel] = self.Y[i][...,:self.first_no_bin_channel]/255
if self.div_Y_on_load_no_bin_channels:
if self.first_no_bin_channel != 0:
self.Y[i][...,self.first_no_bin_channel:] = self.Y[i][...,self.first_no_bin_channel:]/255
else:
self.Y[i] = self.Y[i]/255
else:
if self.div_Y_on_load_bin_channels: self.Y[i] = self.Y[i]/255
self.Y_dtype = self.Y.dtype if type(self.Y) != list else self.Y[0].dtype
elif self.normalizeY == 'as_image':
self.Y_dtype = np.float32
if self.X_norm['type'] == 'div':
if type(X) != list:
self.Y, _ = norm_range01(self.Y)
else:
for i in range(len(self.Y)):
self.Y[i], _ = norm_range01(self.Y[i])
elif self.X_norm['type'] == 'custom':
self.Y = normalize(self.Y, self.X_norm['mean'], self.X_norm['std'])
else:
self.Y_dtype = self.Y.dtype if type(self.Y) != list else self.Y[0].dtype

t = "Training" if not val else "Validation"
if type(X) != list:
print("{} data X normalization - min: {} , max: {} , mean: {} , dtype: {}"
.format(t,np.min(self.X), np.max(self.X), np.mean(self.X), self.X.dtype))
if self.Y_provided:
print("{} data Y normalization - min: {} , max: {} , mean: {} , dtype: {}"
.format(t,np.min(self.Y), np.max(self.Y), np.mean(self.Y), self.Y.dtype))
else:
print("{} data[0] X normalization - min: {} , max: {} , mean: {} , dtype: {}"
.format(t,np.min(self.X[0]), np.max(self.X[0]), np.mean(self.X[0]), self.X[0].dtype))
if self.Y_provided:
print("{} data[0] Y normalization - min: {} , max: {} , mean: {} , dtype: {}"
.format(t,np.min(self.Y[0]), np.max(self.Y[0]), np.mean(self.Y[0]), self.Y[0].dtype))
print("Normalization config used for X: {}".format(self.X_norm))
if self.Y_provided:
print("Normalization config used for Y: {}".format(self.normalizeY))
if np.max(mask) > 30: self.div_Y_on_load_bin_channels = True
if not in_memory:
self.Y_channels = mask.shape[-1]
self.Y_dtype = mask.dtype
del mask

print("Normalization config used for X: {}".format(self.X_norm))
if self.Y_provided:
print("Normalization config used for Y: {}".format(self.normalizeY))

if self.ndim == 2:
resolution = tuple(resolution[i] for i in [1, 0]) # y, x -> x, y
Expand Down Expand Up @@ -732,6 +639,7 @@ def __init__(self, ndim, X, Y, seed=0, in_memory=True, data_paths=None, da=True,
if extra_data_factor > 1:
self.extra_data_factor = extra_data_factor
self.o_indexes = np.concatenate([self.o_indexes]*extra_data_factor)
self.length = self.length*extra_data_factor
else:
self.extra_data_factor = 1

Expand Down Expand Up @@ -828,9 +736,10 @@ def __len__(self):
"""Defines the number of samples per epoch."""
return self.length

def load_sample(self, idx):
def load_sample(self, _idx):
"""Load one data sample given its corresponding index."""
# Choose the data source
idx = _idx % self.length
if self.in_memory:
img = self.X[idx]
img = np.squeeze(img)
Expand All @@ -850,15 +759,16 @@ def load_sample(self, idx):
img = np.squeeze(img)
if self.Y_provided:
mask = np.squeeze(mask)

# X normalization
if self.X_norm:
if self.X_norm['type'] == 'div':
img, _ = norm_range01(img)
elif self.X_norm['type'] == 'custom':
img = normalize(img, self.X_norm['mean'], self.X_norm['std'])

# Y normalization

# X normalization
if self.X_norm['type'] != "not_set_yet":
if self.X_norm['type'] == 'div':
img, _ = norm_range01(img)
elif self.X_norm['type'] == 'custom':
img = normalize(img, self.X_norm['mean'], self.X_norm['std'])

# Y normalization
if self.X_norm['type'] != "not_set_yet":
if self.normalizeY == 'as_mask' and self.Y_provided:
if self.first_no_bin_channel != -1:
if self.div_Y_on_load_bin_channels:
Expand Down
Loading

0 comments on commit 27120f5

Please sign in to comment.