In [1]:
import torch
from torchinfo import summary
from neuralop.models import UNO
from rtmag.train.model import UNO as un

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
model = un(
            in_channels = 3,
            out_channels = [256, 256, 256],
            
            # lifting
            hidden_channels = 256,

            # mlp hidden layer neurons
            lifting_channels = 256,
            projection_channels = 256,

            # number of layers
            n_layers = 8,

            uno_out_channels = [128,
                                64,
                                32,
                                16,
                                16,
                                32,
                                64,
                                128],
                                
            uno_scalings = [[0.5,0.5],
                            [0.5,0.5],
                            [0.5,0.5],
                            [0.5,0.5],
                            [2.0,2.0],
                            [2.0,2.0],
                            [2.0,2.0],
                            [2.0,2.0]],

            uno_n_modes = [[ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8]],
            
        ).to(device)

In [4]:
bs = 1
summary(model, input_size=(bs, 3, 256, 512))

Layer (type:depth-idx)                   Output Shape              Param #
UNO                                      [1, 256, 256, 512, 3]     --
├─MLP: 1-1                               [1, 256, 256, 512]        --
│    └─ModuleList: 2-1                   --                        --
│    │    └─Conv2d: 3-1                  [1, 256, 256, 512]        1,024
│    │    └─Conv2d: 3-2                  [1, 256, 256, 512]        65,792
├─ModuleList: 1-10                       --                        (recursive)
│    └─FNOBlocks: 2-2                    [1, 128, 128, 256]        --
│    │    └─ModuleList: 3-3              --                        32,768
│    │    └─SpectralConv: 3-4            [1, 128, 128, 256]        2,621,568
├─ModuleDict: 1-9                        --                        (recursive)
│    └─Conv2d: 2-3                       [1, 128, 128, 256]        16,384
├─ModuleList: 1-10                       --                        (recursive)
│    └─FNOBlocks: 2-4               

In [5]:
x = torch.randn(1, 3, 256, 512).to(device)
y = model(x)

In [6]:
y.shape

torch.Size([1, 256, 256, 512, 3])

In [4]:
model = UNO(
            in_channels = 1,
            out_channels = 256,

            hidden_channels = 64,
            lifting_channels = 256,
            projection_channels = 256,
            n_layers = 6,

            uno_n_modes = [[8,8, 8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [8,8, 8]],
            uno_out_channels = [64,
                                128,
                                256,
                                256,
                                128,
                                64],
        uno_scalings =  [[0.5,0.5,2.0],
                         [0.5,0.5,2.0],
                         [0.5,0.5,2.0],
                         [2.0,2.0,0.5],
                         [2.0,2.0,0.5],
                         [2.0,2.0,0.5]]
        ).to(device)

In [20]:
model = UNO(
            in_channels = 1,
            out_channels = 256,

            hidden_channels = 64,
            lifting_channels = 256,
            projection_channels = 256,
            n_layers = 6,

            uno_n_modes = [[8,8, 8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [8,8, 8]],
            uno_out_channels = [64,
                                64,
                                128,
                                128,
                                64,
                                64],
        uno_scalings =  [[0.5,0.5,2.0],
                         [0.5,0.5,2.0],
                         [0.5,0.5,2.0],
                         [2.0,2.0,0.5],
                         [2.0,2.0,0.5],
                         [2.0,2.0,0.5]]
        ).to(device)

In [21]:
bs = 1
summary(model, input_size=(bs, 1, 256, 512, 3))

Layer (type:depth-idx)                   Output Shape              Param #
UNO                                      [1, 256, 256, 512, 3]     --
├─MLP: 1-1                               [1, 64, 256, 512, 3]      --
│    └─ModuleList: 2-1                   --                        --
│    │    └─Conv3d: 3-1                  [1, 256, 256, 512, 3]     512
│    │    └─Conv3d: 3-2                  [1, 64, 256, 512, 3]      16,448
├─ModuleList: 1-8                        --                        (recursive)
│    └─FNOBlocks: 2-2                    [1, 64, 128, 256, 6]      --
│    │    └─ModuleList: 3-3              --                        4,096
│    │    └─SpectralConv: 3-4            [1, 64, 128, 256, 6]      2,621,504
├─ModuleDict: 1-7                        --                        (recursive)
│    └─Conv3d: 2-3                       [1, 64, 128, 256, 6]      4,096
├─ModuleList: 1-8                        --                        (recursive)
│    └─FNOBlocks: 2-4                   

: 

In [8]:
model = UNO(
            in_channels = 1,
            out_channels = 256,

            hidden_channels = 32,
            lifting_channels = 256,
            projection_channels = 256,
            n_layers = 6,

            uno_n_modes = [[16,16, 16],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [16,16, 16]],
            uno_out_channels = [32,
                              64,
                              128,
                              128,
                              64,
                             32],
        uno_scalings =  [[0.5,0.5,1.0],
                         [0.5,0.5,1.0],
                         [0.5,0.5,1.0],
                         [2.0,2.0,1.0],
                         [2.0,2.0,1.0],
                         [2.0,2.0,1.0]]
        ).to(device)

In [9]:
bs = 1
summary(model, input_size=(bs, 1, 256, 512, 3))

Layer (type:depth-idx)                   Output Shape              Param #
UNO                                      [1, 256, 256, 512, 3]     --
├─MLP: 1-1                               [1, 32, 256, 512, 3]      --
│    └─ModuleList: 2-1                   --                        --
│    │    └─Conv3d: 3-1                  [1, 256, 256, 512, 3]     512
│    │    └─Conv3d: 3-2                  [1, 32, 256, 512, 3]      8,224
├─ModuleList: 1-8                        --                        (recursive)
│    └─FNOBlocks: 2-2                    [1, 128, 128, 256, 3]     --
│    │    └─ModuleList: 3-3              --                        4,096
│    │    └─SpectralConv: 3-4            [1, 128, 128, 256, 3]     18,874,496
├─ModuleDict: 1-7                        --                        (recursive)
│    └─Conv3d: 2-3                       [1, 128, 128, 256, 3]     16,384
├─ModuleList: 1-8                        --                        (recursive)
│    └─FNOBlocks: 2-4                  

In [3]:
model = UNO(
            hidden_channels = 32,
            in_channels = 1,
            out_channels = 256,
            lifting_channels = 256,
            projection_channels = 256,
            n_layers = 6,

            uno_n_modes = [[16,16, 16],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [ 8, 8,  8],
                            [16,16, 16]],
            uno_out_channels = [32,
                                64,
                                64,
                                64,
                                64,
                                32],
            uno_scalings = [[1.0,1.0,1.0],
                            [0.5,0.5,0.5],
                            [1.0,1.0,1.0],
                            [1.0,1.0,1.0],
                            [2.0,2.0,2.0],
                            [1.0,1.0,1.0]],
        ).to(device)

In [5]:
bs = 1
summary(model, input_size=(bs, 1, 256, 512, 3))

Layer (type:depth-idx)                   Output Shape              Param #
UNO                                      [1, 256, 256, 512, 3]     --
├─MLP: 1-1                               [1, 32, 256, 512, 3]      --
│    └─ModuleList: 2-1                   --                        --
│    │    └─Conv3d: 3-1                  [1, 256, 256, 512, 3]     512
│    │    └─Conv3d: 3-2                  [1, 32, 256, 512, 3]      8,224
├─ModuleList: 1-8                        --                        (recursive)
│    └─FNOBlocks: 2-2                    [1, 32, 256, 512, 3]      --
│    │    └─ModuleList: 3-3              --                        1,024
│    │    └─SpectralConv: 3-4            [1, 32, 256, 512, 3]      4,718,624
├─ModuleDict: 1-7                        --                        (recursive)
│    └─Conv3d: 2-3                       [1, 32, 256, 512, 3]      1,024
├─ModuleList: 1-8                        --                        (recursive)
│    └─FNOBlocks: 2-4                    

In [21]:
256*3

768

In [10]:
model = UNO(
            in_channels = 3,
            out_channels = 256*3,
            
            # lifting
            hidden_channels = 256,

            # mlp hidden layer neurons
            lifting_channels = 256,
            projection_channels = 256,

            # number of layers
            n_layers = 8,

            uno_out_channels = [128,
                                64,
                                32,
                                16,
                                16,
                                32,
                                64,
                                128],
                                
            uno_scalings = [[0.5,0.5],
                            [0.5,0.5],
                            [0.5,0.5],
                            [0.5,0.5],
                            [2.0,2.0],
                            [2.0,2.0],
                            [2.0,2.0],
                            [2.0,2.0]],

            uno_n_modes = [[ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8],
                           [ 8, 8]],
            
        ).to(device)

In [11]:
bs = 1
summary(model, input_size=(bs, 3, 256, 512))

Layer (type:depth-idx)                   Output Shape              Param #
UNO                                      [1, 768, 256, 512]        --
├─MLP: 1-1                               [1, 256, 256, 512]        --
│    └─ModuleList: 2-1                   --                        --
│    │    └─Conv2d: 3-1                  [1, 256, 256, 512]        1,024
│    │    └─Conv2d: 3-2                  [1, 256, 256, 512]        65,792
├─ModuleList: 1-10                       --                        (recursive)
│    └─FNOBlocks: 2-2                    [1, 128, 128, 256]        --
│    │    └─ModuleList: 3-3              --                        32,768
│    │    └─SpectralConv: 3-4            [1, 128, 128, 256]        2,621,568
├─ModuleDict: 1-9                        --                        (recursive)
│    └─Conv2d: 2-3                       [1, 128, 128, 256]        16,384
├─ModuleList: 1-10                       --                        (recursive)
│    └─FNOBlocks: 2-4               

In [12]:
x = torch.rand((bs, 3, 256, 512)).to(device)

In [13]:
y = model(x)

In [14]:
y.shape

torch.Size([1, 768, 256, 512])

In [15]:
y[:, 0:256].shape

torch.Size([1, 256, 256, 512])

In [20]:
cc = torch.stack([y[:, 0:256], y[:, 256:512], y[:, 512:768]], dim=-1)
cc.shape

torch.Size([1, 256, 256, 512, 3])