In [1]:
import torch
from neuralop.models import FNO, FNO1d, FNO2d, FNO3d
from neuralop.utils import count_model_params

1D

```
input : batch_size, in_channels , height
output: batch_size, out_channels, height
```

In [2]:
fno = FNO(
    n_modes=[16],
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4
)

In [3]:
inputs = torch.rand((5, 1, 9))
fno(inputs).shape

torch.Size([5, 257, 9])

In [4]:
count_model_params(fno)

161185

In [5]:
fno1d = FNO1d(
    n_modes_height=16,
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4
)

In [6]:
inputs = torch.rand((5, 1, 9))
fno1d(inputs).shape

torch.Size([5, 257, 9])

In [7]:
count_model_params(fno1d)

161185

2D

```
input : batch_size, in_channels , height, width
output: batch_size, out_channels, height, width
```

In [20]:
fno = FNO(
    n_modes=[16, 16],
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4
)

In [21]:
inputs = torch.rand((5, 1, 9, 8))
fno(inputs).shape

torch.Size([5, 257, 9, 8])

In [22]:
count_model_params(fno)

1267105

In [23]:
fno2d = FNO2d(
    n_modes_height=16,
    n_modes_width=16,
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4
)

In [24]:
inputs = torch.rand((5, 1, 9, 8))
fno2d(inputs).shape

torch.Size([5, 257, 9, 8])

In [25]:
count_model_params(fno2d)

1267105

3D

```
input : batch_size, in_channels , height, width, depth
output: batch_size, out_channels, height, width, depth
```

In [28]:
fno = FNO(
    n_modes=[16, 16, 16],
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4
)

In [29]:
inputs = torch.rand((5, 1, 9, 8, 7))
fno(inputs).shape

torch.Size([5, 257, 9, 8, 7])

In [30]:
count_model_params(fno)

18961825

In [31]:
fno3d = FNO3d(
    n_modes_height=16,
    n_modes_width=16,
    n_modes_depth=16,
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4
)

In [32]:
inputs = torch.rand((5, 1, 9, 8, 7))
fno3d(inputs).shape

torch.Size([5, 257, 9, 8, 7])

In [33]:
count_model_params(fno3d)

18961825