In [2]:
import torch
from torch import nn

In [3]:
class Sine(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)


class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim, num_layers):
        super().__init__()
        self.d_in = nn.Linear(in_dim, hidden_dim)
        lin = [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)]
        self.linear_layers = nn.ModuleList(lin)
        self.d_out = nn.Linear(hidden_dim, out_dim)
        self.activation = Sine()

    def forward(self, x):
        """
        x      : [batch_size, in_dim]
        
        output : [batch_size, out_dim]
        """

        x = self.activation(self.d_in(x))
        for l in self.linear_layers:
            x = self.activation(l(x))
        x = self.d_out(x)
        return x

In [4]:
import torch
from torchinfo import summary
from neuralop.models import UNO

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [34]:
model = UNO(
            in_channels=3,
            out_channels=64,
            hidden_channels=64,
            lifting_channels=64,
            projection_channels=64,
            n_layers = 4,
            uno_n_modes = [[8,8],
                           [8,8],
                           [8,8],
                           [8,8]],
            uno_out_channels = [64,
                                64,
                                64,
                                64],
            uno_scalings =  [[0.5,0.5],
                            [0.5,0.5],
                            [0.5,0.5],
                            [0.5,0.5]]
        ).to(device)

In [35]:
bs = 1
summary(model, input_size=(bs, 3, 512, 256))

Layer (type:depth-idx)                   Output Shape              Param #
UNO                                      [1, 64, 32, 16]           --
├─MLP: 1-1                               [1, 64, 512, 256]         --
│    └─ModuleList: 2-1                   --                        --
│    │    └─Conv2d: 3-1                  [1, 64, 512, 256]         256
│    │    └─Conv2d: 3-2                  [1, 64, 512, 256]         4,160
├─ModuleList: 1-6                        --                        (recursive)
│    └─FNOBlocks: 2-2                    [1, 64, 256, 128]         --
│    │    └─ModuleList: 3-3              --                        4,096
│    │    └─SpectralConv: 3-4            [1, 64, 256, 128]         327,744
├─ModuleDict: 1-5                        --                        (recursive)
│    └─Conv2d: 2-3                       [1, 64, 256, 128]         4,096
├─ModuleList: 1-6                        --                        (recursive)
│    └─FNOBlocks: 2-4                    [1

In [36]:
outputs = model(torch.randn(bs, 3, 512, 256).to(device))
outputs.shape

torch.Size([1, 64, 32, 16])

In [38]:
3 * 512 * 256

393216

In [37]:
64 * 32 * 16

32768

In [40]:
(3 * 512 * 256) / (64 * 32 * 16)

12.0

In [28]:
class DeepONet(nn.Module):
    def __init__(self, trunk_in_dim, out_dim, latent_dim, hidden_dim, num_layers):
        super().__init__()
        self.branch_inc = UNO(
            in_channels=3,
            out_channels=64,
            hidden_channels=64,
            lifting_channels=64,
            projection_channels=64,
            n_layers = 4,
            uno_n_modes = [[8,8],
                           [8,8],
                           [8,8],
                           [8,8]],
            uno_out_channels = [64,
                                64,
                                64,
                                64],
            uno_scalings =  [[0.5,0.5],
                            [0.5,0.5],
                            [0.5,0.5],
                            [0.5,0.5],]
        )
        self.branch_layer = nn.Linear(64*32*16, latent_dim)
        self.trunk_layer = MLP(trunk_in_dim, latent_dim, hidden_dim, num_layers)
        self.d_out = nn.Linear(latent_dim, out_dim)
        self.activation = Sine()
    
    def forward(self, bc, x):
        """
        bc     : [batch_size, 3, 512, 256]
        x      : [batch_size, batch_coords, trunk_in_dim]

        output : [batch_size, batch_coords, out_dim]
        """
        branch_latent = self.branch_inc(bc)
        branch_latent = torch.flatten(branch_latent, 1)
        branch_latent = self.branch_layer(branch_latent)
        trunk_latent = self.trunk_layer(x)
        latent = branch_latent[:, None, :] * trunk_latent
        output = self.d_out(self.activation(latent))
        return output

In [50]:
don = DeepONet(3, 3, 8192, 256, 4).to(device)
summary(don, input_type=[(bs, 3, 512, 256), (bs, 1000, 3)])

Layer (type:depth-idx)                                  Param #
DeepONet                                                --
├─UNO: 1-1                                              --
│    └─MLP: 2-1                                         --
│    │    └─ModuleList: 3-1                             4,416
│    └─ModuleList: 2-2                                  --
│    │    └─FNOBlocks: 3-2                              331,840
│    │    └─FNOBlocks: 3-3                              331,840
│    │    └─FNOBlocks: 3-4                              331,840
│    │    └─FNOBlocks: 3-5                              663,616
│    │    └─FNOBlocks: 3-6                              663,616
│    └─ModuleDict: 2-3                                  --
│    │    └─Conv2d: 3-7                                 4,096
│    │    └─Conv2d: 3-8                                 4,096
│    └─MLP: 2-4                                         --
│    │    └─ModuleList: 3-9                             6,240
├─Linear: 1-2 

In [51]:
don(torch.randn(bs, 3, 512, 256).to(device), torch.randn(bs, 1000, 3).to(device)).shape

torch.Size([1, 1000, 3])