Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding flag option to only pad the last dim #185

Merged
merged 12 commits into from
Aug 11, 2023
44 changes: 33 additions & 11 deletions neuralop/layers/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ class DomainPadding(nn.Module):

Parameters
----------
domain_padding : float
domain_padding : float or list
typically, between zero and one, percentage of padding to use
if a list, make sure if matches the dim of (d1, ..., dN)
padding_mode : {'symmetric', 'one-sided'}, optional
whether to pad on both sides, by default 'one-sided'

Notes
-----
This class works for any input resolution, as long as it is in the form
Expand Down Expand Up @@ -39,8 +40,12 @@ def pad(self, x):
"""
resolution = x.shape[2:]

# if domain_padding is list, then to pass on
if isinstance(self.domain_padding, (float, int)):
self.domain_padding = [float(self.domain_padding)]*len(resolution)

assert len(self.domain_padding) == len(resolution), "domain_padding length must match the number of spatial/time dimensions (excluding batch, ch)"

if self.output_scaling_factor is None:
self.output_scaling_factor = [1]*len(resolution)
elif isinstance(self.output_scaling_factor, (float, int)):
Expand All @@ -52,33 +57,50 @@ def pad(self, x):

except KeyError:
padding = [int(round(p*r)) for (p, r) in zip(self.domain_padding, resolution)]

print(f'Padding inputs of {resolution=} with {padding=}, {self.padding_mode}')



# padding is being applied in reverse order (so we must reverse the padding list)
padding = padding[::-1]

output_pad = padding

output_pad = [int(round(i*j)) for (i,j) in zip(self.output_scaling_factor,output_pad)]



# the F.pad(x, padding) funtion pads the tensor 'x' in reverse order of the "padding" list i.e. the last axis of tensor 'x' will be
# padded by the amount mention at the first position of the 'padding' vector.
# The details about F.pad can be found here : https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html

if self.padding_mode == 'symmetric':
# Pad both sides
unpad_indices = (Ellipsis, ) + tuple([slice(p, -p, None) for p in output_pad[::-1] ])
unpad_list = list()
for p in output_pad[::-1]:
if p == 0:
padding_end = None
padding_start = None
else:
padding_end = p
padding_start = -p
unpad_list.append(slice(padding_end, padding_start, None))
unpad_indices = (Ellipsis, ) + tuple(unpad_list)

padding = [i for p in padding for i in (p, p)]

elif self.padding_mode == 'one-sided':
# One-side padding
unpad_indices = (Ellipsis, ) + tuple([slice(None, -p, None) for p in output_pad[::-1]])
unpad_list = list()
for p in output_pad[::-1]:
if p == 0:
padding_start = None
else:
padding_start = -p
unpad_list.append(slice(None, padding_start, None))
unpad_indices = (Ellipsis, ) + tuple(unpad_list)
padding = [i for p in padding for i in (0, p)]
else:
raise ValueError(f'Got {self.padding_mode=}')

self._padding[f'{resolution}'] = padding


padded = F.pad(x, padding, mode='constant')

Expand Down
38 changes: 34 additions & 4 deletions neuralop/layers/tests/test_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,46 @@
import pytest

@pytest.mark.parametrize('mode', ['one-sided', 'symmetric'])
def test_DomainPadding(mode):
out_size = {'one-sided': 12, 'symmetric': 14}
@pytest.mark.parametrize('padding', [0.2, [0.1, 0.2]])
def test_DomainPadding_2d(mode, padding):
if isinstance(padding, float):
out_size = {'one-sided': [12, 12], 'symmetric': [14, 14]}
else:
out_size = {'one-sided': [11, 12], 'symmetric': [12, 14]}

data = torch.randn((2, 3, 10, 10))
padder = DomainPadding(0.2, mode)
padder = DomainPadding(padding, mode)
padded = padder.pad(data)

target_shape = list(padded.shape)
target_shape[-1] = target_shape[-2] = out_size[mode]
# create the target shape from hardcoded out_size
for pad_dim in range(1,3):
target_shape[-pad_dim] = out_size[mode][-pad_dim]
assert list(padded.shape) == target_shape

unpadded = padder.unpad(padded)
assert unpadded.shape == data.shape


@pytest.mark.parametrize('mode', ['one-sided', 'symmetric'])
@pytest.mark.parametrize('padding', [0.2, [0.1, 0, 0.2]])
def test_DomainPadding_3d(mode, padding):
if isinstance(padding, float):
out_size = {'one-sided': [12, 12, 12], 'symmetric': [14, 14, 14]}
else:
out_size = {'one-sided': [11, 10, 12], 'symmetric': [12, 10, 14]}

data = torch.randn((2, 3, 10, 10, 10))
padder = DomainPadding(padding, mode)
padded = padder.pad(data)

target_shape = list(padded.shape)
# create the target shape from hardcoded out_size
for pad_dim in range(1,4):
target_shape[-pad_dim] = out_size[mode][-pad_dim]
assert list(padded.shape) == target_shape

unpadded = padder.unpad(padded)
assert unpadded.shape == data.shape


20 changes: 12 additions & 8 deletions neuralop/models/fno.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from functools import partialmethod

Expand Down Expand Up @@ -44,18 +45,20 @@ class FNO(nn.Module):
By default None, otherwise tanh is used before FFT in the FNO block
use_mlp : bool, optional
Whether to use an MLP layer after each FNO block, by default False
mlp_dropout : float
droupout parameter of MLP layer (default is 0)
mlp_expansion : float
expansion parameter of MLP layer (default is 0.5)
mlp_dropout : float , optional
droupout parameter of MLP layer, by default 0
mlp_expansion : float, optional
expansion parameter of MLP layer, by default 0.5
non_linearity : nn.Module, optional
Non-Linearity module to use, by default F.gelu
norm : F.module, optional
Normalization layer to use, by default None
preactivation : bool, default is False
if True, use resnet-style preactivation
skip : {'linear', 'identity', 'soft-gating'}, optional
Type of skip connection to use, by default 'soft-gating'
fno_skip : {'linear', 'identity', 'soft-gating'}, optional
Type of skip connection to use in fno, by default 'linear'
mlp_skip : {'linear', 'identity', 'soft-gating'}, optional
Type of skip connection to use in mlp, by default 'soft-gating'
separable : bool, default is False
if True, use a depthwise separable spectral convolution
factorization : str or None, {'tucker', 'cp', 'tt'}
Expand Down Expand Up @@ -135,10 +138,11 @@ def __init__(self, n_modes, hidden_channels,
# When updated, change should be reflected in fno blocks
self._incremental_n_modes = incremental_n_modes

if domain_padding is not None and domain_padding > 0:
self.domain_padding = DomainPadding(domain_padding=domain_padding, padding_mode=domain_padding_mode, output_scaling_factor=output_scaling_factor)
if domain_padding is not None and ((isinstance(domain_padding, list) and sum(domain_padding) > 0) or (isinstance(domain_padding, (float, int)) and domain_padding > 0)):
self.domain_padding = DomainPadding(domain_padding=domain_padding, padding_mode=domain_padding_mode, output_scaling_factor=output_scaling_factor)
else:
self.domain_padding = None

self.domain_padding_mode = domain_padding_mode

if output_scaling_factor is not None and not joint_factorization:
Expand Down