In [1]:
import torch
from neuralop.models import TFNO, TFNO1d, TFNO2d, TFNO3d
from neuralop.utils import count_model_params

1D

```
input : batch_size, in_channels , height
output: batch_size, out_channels, height
```

In [2]:
fno = TFNO(
    n_modes=[16],
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4,
    factorization='tucker',
    implementation='factorized',
    rank = 0.5
)

In [3]:
inputs = torch.rand((5, 1, 9))
fno(inputs).shape

torch.Size([5, 257, 9])

In [4]:
count_model_params(fno)

122385

In [5]:
fno1d = TFNO1d(
    n_modes_height=16,
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4,
    factorization='tucker',
    implementation='factorized',
    rank = 0.5
)

In [6]:
inputs = torch.rand((5, 1, 9))
fno1d(inputs).shape

torch.Size([5, 257, 9])

In [7]:
count_model_params(fno1d)

122385

2D

```
input : batch_size, in_channels , height, width
output: batch_size, out_channels, height, width
```

In [8]:
fno = TFNO(
    n_modes=[16, 16],
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4,
    factorization='tucker',
    implementation='factorized',
    rank = 0.5
)

In [9]:
inputs = torch.rand((5, 1, 9, 8))
fno(inputs).shape

torch.Size([5, 257, 9, 8])

In [10]:
count_model_params(fno)

710049

In [11]:
fno2d = TFNO2d(
    n_modes_height=16,
    n_modes_width=16,
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4,
    factorization='tucker',
    implementation='factorized',
    rank = 0.5
)

In [12]:
inputs = torch.rand((5, 1, 9, 8))
fno2d(inputs).shape

torch.Size([5, 257, 9, 8])

In [13]:
count_model_params(fno2d)

710049

3D

```
input : batch_size, in_channels , height, width, depth
output: batch_size, out_channels, height, width, depth
```

In [14]:
fno = TFNO(
    n_modes=[16, 16, 16],
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4,
    factorization='tucker',
    implementation='factorized',
    rank = 0.5
)

In [15]:
inputs = torch.rand((5, 1, 9, 8, 7))
fno(inputs).shape

torch.Size([5, 257, 9, 8, 7])

In [16]:
count_model_params(fno)

9940449

In [17]:
fno3d = TFNO3d(
    n_modes_height=16,
    n_modes_width=16,
    n_modes_depth=16,
    hidden_channels=32,
    in_channels=1,
    out_channels=257,
    lifting_channels=256,
    projection_channels=256,
    n_layers=4,
    factorization='tucker',
    implementation='factorized',
    rank = 0.5
)

In [18]:
inputs = torch.rand((5, 1, 9, 8, 7))
fno3d(inputs).shape

torch.Size([5, 257, 9, 8, 7])

In [19]:
count_model_params(fno3d)

9940449