Skip to content

Commit

Permalink
Merge pull request #266 from dhpitt/preprocessor
Browse files Browse the repository at this point in the history
Fixes to make `DataProcessor` code doc build
  • Loading branch information
JeanKossaifi committed Nov 16, 2023
2 parents 1e312db + 02f71ac commit 240b817
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 55 deletions.
52 changes: 2 additions & 50 deletions examples/checkpoint_FNO_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@

trainer = Trainer(model=model, n_epochs=20,
device=device,
data_processor=data_processor,
callbacks=[
OutputEncoderCallback(output_encoder),
CheckpointCallback(save_dir='./new_checkpoints',
resume_from_dir='./checkpoints/ep_10')
],
Expand All @@ -117,52 +117,4 @@
optimizer=optimizer,
scheduler=scheduler,
regularizer=False,
training_loss=train_loss)
# %%
# Plot the prediction, and compare with the ground-truth
# Note that we trained on a very small resolution for
# a very small number of epochs
# In practice, we would train at larger resolution, on many more samples.
#
# However, for practicity, we created a minimal example that
# i) fits in just a few Mb of memory
# ii) can be trained quickly on CPU
#
# In practice we would train a Neural Operator on one or multiple GPUs

test_samples = test_loaders[32].dataset

fig = plt.figure(figsize=(7, 7))
for index in range(3):
data = test_samples[index]
# Input x
x = data['x']
# Ground-truth
y = data['y']
# Model prediction
out = model(x.unsqueeze(0))

ax = fig.add_subplot(3, 3, index*3 + 1)
ax.imshow(x[0], cmap='gray')
if index == 0:
ax.set_title('Input x')
plt.xticks([], [])
plt.yticks([], [])

ax = fig.add_subplot(3, 3, index*3 + 2)
ax.imshow(y.squeeze())
if index == 0:
ax.set_title('Ground-truth y')
plt.xticks([], [])
plt.yticks([], [])

ax = fig.add_subplot(3, 3, index*3 + 3)
ax.imshow(out.squeeze().detach().numpy())
if index == 0:
ax.set_title('Model prediction')
plt.xticks([], [])
plt.yticks([], [])

fig.suptitle('Inputs, ground-truth output and prediction.', y=0.98)
plt.tight_layout()
fig.show()
training_loss=train_loss)
2 changes: 2 additions & 0 deletions examples/plot_FNO_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
test_batch_sizes=[32, 32],
positional_encoding=True
)
data_processor = data_processor.to(device)


# %%
Expand Down Expand Up @@ -112,6 +113,7 @@
fig = plt.figure(figsize=(7, 7))
for index in range(3):
data = test_samples[index]
data = data_processor.preprocess(data, batched=False)
# Input x
x = data['x']
# Ground-truth
Expand Down
1 change: 1 addition & 0 deletions examples/plot_UNO_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
fig = plt.figure(figsize=(7, 7))
for index in range(3):
data = test_samples[index]
data = data_processor.preprocess(data, batched=False)
# Input x
x = data['x']
# Ground-truth
Expand Down
3 changes: 2 additions & 1 deletion examples/plot_darcy_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Training samples are 16x16 and we load testing samples at both
# 16x16 and 32x32 (to test resolution invariance).

train_loader, test_loaders, output_encoder = load_darcy_flow_small(
train_loader, test_loaders, data_processor = load_darcy_flow_small(
n_train=100, batch_size=4,
test_resolutions=[16, 32], n_tests=[50, 50], test_batch_sizes=[4, 2],
)
Expand Down Expand Up @@ -50,6 +50,7 @@
index = 0

data = train_dataset[index]
data = data_processor.preprocess(data, batched=False)
x = data['x']
y = data['y']
fig = plt.figure(figsize=(7, 7))
Expand Down
10 changes: 6 additions & 4 deletions neuralop/datasets/data_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def to(self, device):
self.device = device
return self

def preprocess(self, data_dict):
def preprocess(self, data_dict, batched=True):
x = data_dict['x'].to(self.device)
y = data_dict['y'].to(self.device)

if self.in_normalizer is not None:
x = self.in_normalizer.transform(x)
if self.positional_encoding is not None:
x = self.positional_encoding(x)
x = self.positional_encoding(x, batched=batched)
if self.out_normalizer is not None and self.train:
y = self.out_normalizer.transform(y)

Expand Down Expand Up @@ -122,7 +122,7 @@ def wrap(self, model):
self.model = model
return self

def preprocess(self, data_dict):
def preprocess(self, data_dict, batched=True):
"""
Preprocess data assuming that if encoder exists, it has
encoded all data during data loading
Expand All @@ -133,14 +133,16 @@ def preprocess(self, data_dict):
data_dict: dict
dictionary keyed with 'x', 'y' etc
represents one batch of data input to a model
batched: bool
whether the first dimension of 'x', 'y' represents batching
"""
data_dict = {k:v.to(self.device) for k,v in data_dict.items() if torch.is_tensor(v)}
x,y = data_dict['x'], data_dict['y']
if self.in_normalizer:
x = self.in_normalizer.transform(x)
y = self.out_normalizer.transform(y)
if self.positional_encoding is not None:
x = self.positional_encoding(x)
x = self.positional_encoding(x, batched=batched)
data_dict['x'],data_dict['y'] = self.patcher.patch(x,y)
return data_dict

Expand Down

0 comments on commit 240b817

Please sign in to comment.