-
Notifications
You must be signed in to change notification settings - Fork 0
/
core.py
249 lines (200 loc) · 8.71 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""Some utility functions for working with PyTorch"""
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb.
# %% auto 0
__all__ = ['set_seed', 'pil_to_tensor', 'tensor_to_pil', 'iterate_modules', 'tensor_stats_df', 'get_torch_device',
'denorm_img_tensor', 'move_data_to_device', 'ImageDataset', 'compute_mean_std']
# %% ../nbs/00_core.ipynb 3
# Import necessary modules from the standard library
from pathlib import Path # For working with file paths
import logging # For logging messages
import hashlib
import random
# Disable logging warnings
logging.disable(logging.WARNING)
import numpy as np # For working with arrays
from PIL import Image # For working with images
import torch # PyTorch module for deep learning
from torchvision import transforms # PyTorch module for image transformations
# %% ../nbs/00_core.ipynb 5
def set_seed(seed: int, # The seed value to be set for all random number generators.
deterministic: bool = False # If True, uses deterministic algorithms in PyTorch where possible for reproducibility, at the cost of performance.
) -> None:
"""
Sets the seed for generating random numbers in PyTorch, NumPy, and Python's random module.
This function is used for reproducibility in stochastic operations, e.g. shuffling in data loaders,
random initializations in neural networks, etc.
Note: The deterministic flag does not guarantee complete reproducibility. Operations which rely on CUDA might
still produce non-deterministic results.
"""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.use_deterministic_algorithms(deterministic)
# %% ../nbs/00_core.ipynb 8
def pil_to_tensor(img:Image, # The input PIL image.
mean=[0.485, 0.456, 0.406], # The mean values for normalization.
std=[0.229, 0.224, 0.225] # The standard deviation values for normalization.
):
"""
Converts a PIL image to a normalized and batched PyTorch tensor.
Returns:
The normalized and batched tensor.
"""
return transforms.Normalize(mean, std)(transforms.ToTensor()(img))[None]
# %% ../nbs/00_core.ipynb 16
def tensor_to_pil(tensor: torch.Tensor # The tensor to be converted
):
"""
Convert a tensor to a PIL image.
Returns:
img (PIL.Image): The PIL image
"""
# Remove the first dimension if the tensor has 4 dimensions
if len(tensor.shape) == 4: tensor.squeeze_(0)
# Use the ToPILImage() function from the transforms module to convert the tensor to a PIL image
return transforms.ToPILImage()(tensor)
# %% ../nbs/00_core.ipynb 20
def iterate_modules(module: torch.nn.Module): # A PyTorch module that contains child modules to be iterated over.
"""
A generator function that yields the children and grandchildren of a PyTorch module.
"""
for child in module.children():
yield child
yield from iterate_modules(child)
# %% ../nbs/00_core.ipynb 23
import pandas as pd
# %% ../nbs/00_core.ipynb 24
def tensor_stats_df(tensor, # Input tensor for which statistics are to be calculated.
attrs = ["mean", "std", "min", "max"], # List of statistics to be calculated.
shape=True): # If True, include shape of the tensor in the output.
"""
Calculate and return statistics of a given tensor as a pandas DataFrame.
"""
attr_dict = {attr: getattr(tensor, attr)().item() for attr in attrs}
if shape: attr_dict["shape"] = tensor.shape
return pd.DataFrame.from_dict(attr_dict, orient='index')
# %% ../nbs/00_core.ipynb 27
def get_torch_device():
"""
This function returns the device to be used for PyTorch computations.
Returns:
str: "mps" if Metal Performance Shaders (MPS) for MacOS is available,
"cuda" if CUDA is available,
"cpu" otherwise
"""
device = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
return device
# %% ../nbs/00_core.ipynb 30
def denorm_img_tensor(img_tensor:torch.Tensor, # The tensor representing the normalized image.
mean:list, # The mean values used for normalization.
std:list): # The standard deviation values used for normalization.
"""
Denormalize an image tensor.
Returns:
torch.Tensor: The tensor representing the denormalized image.
"""
# Convert the mean and standard deviation values to tensors
mean_tensor = torch.Tensor(mean).view(1,1,-1).permute(2, 0, 1)
std_tensor = torch.Tensor(std).view(1,1,-1).permute(2, 0, 1)
# Denormalize the image tensor
return img_tensor*std_tensor+mean_tensor
# %% ../nbs/00_core.ipynb 34
def move_data_to_device(data, # Data to move to the device.
device:torch.device # The PyTorch device to move the data to.
): # Moved data with the same structure as the input but residing on the specified device.
"""
Recursively move data to the specified device.
This function takes a data structure (could be a tensor, list, tuple, or dictionary)
and moves all tensors within the structure to the given PyTorch device.
"""
# If the data is a tuple, iterate through its elements and move each to the device.
if isinstance(data, tuple):
return tuple(move_data_to_device(d, device) for d in data)
# If the data is a list, iterate through its elements and move each to the device.
if isinstance(data, list):
return list(move_data_to_device(d, device) for d in data)
# If the data is a dictionary, iterate through its key-value pairs and move each value to the device.
elif isinstance(data, dict):
return {k: move_data_to_device(v, device) for k, v in data.items()}
# If the data is a tensor, directly move it to the device.
elif isinstance(data, torch.Tensor):
return data.to(device)
# If the data type is not a tensor, list, tuple, or dictionary, it remains unchanged.
else:
return data
# %% ../nbs/00_core.ipynb 38
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
from typing import List
# %% ../nbs/00_core.ipynb 39
class ImageDataset(Dataset):
"""
A PyTorch Dataset for RGB images.
"""
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform or transforms.ToTensor()
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Open image and ensure it's in RGB format
img = Image.open(self.image_paths[idx]).convert('RGB')
# Apply transformation
img = self.transform(img)
return img
# %% ../nbs/00_core.ipynb 40
def compute_mean_std(image_paths:List[Path], # List of image file paths.
batch_size:int=32, # Number of images to process in a batch.
num_workers:int=0, # Number of subprocesses to use for data loading.
image_size:int=224, # Size to resize images to.
transform:transforms.Compose=None # Torchvision transforms to apply to the images.
)->dict: # Dictionary containing 'mean' and 'std' values.
"""
Computes the mean and standard deviation of images provided in image_paths.
"""
if not transform:
# Define transformation (without normalization)
transform = transforms.Compose([
transforms.Resize((image_size, image_size)), # Resize images
transforms.ToTensor(),
])
# Create custom dataset
dataset = ImageDataset(image_paths, transform=transform)
# Create DataLoader
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False
)
# Initialize accumulators
mean = torch.zeros(3)
std = torch.zeros(3)
total_pixels = 0
# Iterate over DataLoader
for data in tqdm(loader):
# data shape: (batch_size, channels, height, width)
batch_samples = data.size(0)
# Flatten height and width
data = data.view(batch_samples, data.size(1), -1)
total_pixels += batch_samples * data.size(2)
mean += data.sum(dim=[0, 2])
std += (data ** 2).sum(dim=[0, 2])
# Finalize mean and std computation
mean /= total_pixels
std = torch.sqrt(std / total_pixels - mean ** 2)
result = {'mean': mean.cpu().tolist(), 'std': std.cpu().tolist()}
# Explicitly delete large objects
del loader, dataset, data, mean, std
return result