In [1]:
import numpy as np
import torch
import tree

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()

True

In [5]:
import sys
sys.path.append('../')

# from deeponet.mio_dataset import preprocess
from deeponet.mio_dataset import MioDataset
from deeponet.modules import MioBranchNet

In [6]:
DATA_PATH = "/root/projects/deeponet/data/g0.1-sigma_a3-sigma_t6.npz"
BRANCH_KEYS = ['sigma', 'boundary']
TRUNK_KEYS = ['phase_coords']
LABEL_KEY = 'psi_label'

np_data = dict(np.load(DATA_PATH, allow_pickle=True))

In [7]:
dataset = MioDataset(np_data, BRANCH_KEYS, TRUNK_KEYS, 'psi_label', collocation_size=20)
tree.map_structure(lambda x: x.shape, next(dataset))

({'sigma': (3200,), 'boundary': (1920,)},
 {'phase_coords': (20, 4)},
 {'psi_label': (20,)})

In [15]:
dataloaders = torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=0)
inputs = next(iter(dataloaders))

In [19]:
from functools import reduce
inputs = reduce(lambda x, y: {**x, **y}, inputs)
tree.map_structure(lambda x: x.shape, inputs)

{'sigma': torch.Size([2, 3200]),
 'boundary': torch.Size([2, 1920]),
 'phase_coords': torch.Size([2, 20, 4]),
 'psi_label': torch.Size([2, 20])}

In [20]:
from modulus.sym.models.fully_connected import FullyConnectedArch
from modulus.sym.models.fourier_net import FourierNetArch
from modulus.sym.models.deeponet import DeepONetArch
from modulus.sym.key import Key
import torch

arch_1 = FullyConnectedArch(
    [Key("sigma", size=3200)], [Key("b1", size=128)], layer_size=128, nr_layers=2
)
arch_2 = FullyConnectedArch(
    [Key("boundary", size=1920)], [Key("b2", size=128)], layer_size=128, nr_layers=2
)
branch_net = MioBranchNet([arch_1, arch_2], [Key("branch", size=128)])
# arch_3 = FullyConnectedArch(
#     [Key("x3", size=2)], [Key("b3", size=2)], layer_size=64, nr_layers=2
# )

trunk_net = FourierNetArch(
        input_keys=[Key("phase_coords", size=4)],
        output_keys=[Key("trunk", 128)],
    )

deeponet = DeepONetArch(
        output_keys=[Key("u")],
        branch_net=branch_net,
        trunk_net=trunk_net,
    )

model = deeponet.make_node("deeponet")
# model = branch_net.make_node("branch_net")
# input = {"sigma": torch.randn(16, 3200), "boundary": torch.randn(16, 1920)}
# inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
# input.update({"x": torch.randn(16, 2)})
output = model.evaluate(inputs)

RuntimeError: Tensors must have same number of dimensions: got 2 and 3

In [6]:
output

{'u': tensor([[ 9.1977e-05],
         [-1.8999e-05],
         [-1.2630e-04],
         [ 9.6460e-05],
         [-5.8578e-05],
         [ 2.8761e-04],
         [ 2.4340e-05],
         [ 6.7943e-06],
         [-2.2871e-04],
         [-1.6994e-04],
         [-2.5038e-05],
         [ 2.6192e-05],
         [ 7.3439e-05],
         [-4.8227e-06],
         [-1.5742e-05],
         [ 7.2662e-05]], grad_fn=<SplitWithSizesBackward0>)}

In [7]:
print(deeponet)

DeepONetArch(
  (branch_net): MioBranchNet(
    (0): FullyConnectedArch(
      (_impl): FullyConnectedArchCore(
        (layers): ModuleList(
          (0): FCLayer(
            (linear): WeightNormLinear(in_features=3200, out_features=128, bias=True)
          )
          (1): FCLayer(
            (linear): WeightNormLinear(in_features=128, out_features=128, bias=True)
          )
        )
        (final_layer): FCLayer(
          (linear): Linear(in_features=128, out_features=128, bias=True)
        )
      )
    )
    (1): FullyConnectedArch(
      (_impl): FullyConnectedArchCore(
        (layers): ModuleList(
          (0): FCLayer(
            (linear): WeightNormLinear(in_features=1920, out_features=128, bias=True)
          )
          (1): FCLayer(
            (linear): WeightNormLinear(in_features=128, out_features=128, bias=True)
          )
        )
        (final_layer): FCLayer(
          (linear): Linear(in_features=128, out_features=128, bias=True)
        )
      )
  

In [13]:
list_1 = [torch.randn(16, 3200), torch.randn(16, 3200), torch.randn(16, 3200)]

In [22]:
torch.randn(16, 3200).view(-1).shape

torch.Size([51200])