In [1]:
import torch
import numpy as np

import modulus.sym
from modulus.sym.hydra import to_absolute_path, instantiate_arch, ModulusConfig, to_yaml
from modulus.sym.solver import Solver
from modulus.sym.domain import Domain
from modulus.sym.models.fully_connected import FullyConnectedArch
from modulus.sym.models.fourier_net import FourierNetArch
from modulus.sym.models.deeponet import DeepONetArch
from modulus.sym.domain.constraint.continuous import DeepONetConstraint
from modulus.sym.domain.validator.discrete import GridValidator
from modulus.sym.dataset.discrete import DictGridDataset

from modulus.sym.key import Key
import sys
sys.path.append('../')

from deeponet.preprocess import preprocess
from deeponet.modules import MioBranchNet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# [datasets]
# load training data
DATA_PATH = "/root/projects/deeponet/data/g0.1-sigma_a3-sigma_t6.npz"
BRANCH_KEYS = ["sigma", "boundary"]
TRUNK_KEYS = ["phase_coords"]
LABEL_KEY = "psi_label"

np_data = dict(np.load(DATA_PATH, allow_pickle=True))

np_data = preprocess(np_data, BRANCH_KEYS, TRUNK_KEYS, LABEL_KEY, repeat=2)


In [4]:
arch_1 = FullyConnectedArch(
        [Key("sigma", size=3200)], [Key("b1", size=128)], layer_size=128, nr_layers=2
    )
arch_2 = FullyConnectedArch(
    [Key("boundary", size=1920)], [Key("b2", size=128)], layer_size=128, nr_layers=2
)
branch_net = MioBranchNet([arch_1, arch_2], [Key("branch", size=128)])
# branch_net = FullyConnectedArch(
#     [Key("sigma", size=3200)],
#     [Key("branch", size=128)],
#     layer_size=128,
#     nr_layers=2,
# )

trunk_net = FourierNetArch(
    input_keys=[Key("phase_coords", 4)],
    output_keys=[Key("trunk", 128)],
    nr_layers=4,
    layer_size=128,
    frequencies=("axis", [i for i in range(5)]),
)

deeponet = DeepONetArch(
    output_keys=[Key("psi")],
    branch_net=branch_net,
    trunk_net=trunk_net,
)

nodes = [deeponet.make_node("deepo")]

In [6]:
domain = Domain()

data = DeepONetConstraint.from_numpy(
    nodes=nodes,
    invar={k: np_data[k] for k in BRANCH_KEYS + TRUNK_KEYS},
    outvar={"psi": np_data[LABEL_KEY]},
    batch_size=16,
)
domain.add_constraint(data, "data")
# [constraint]
print(domain.get_saveable_models())

[DeepONetArch(
  (branch_net): MioBranchNet()
  (trunk_net): FourierNetArch(
    (fourier_layer_params): FourierLayer()
    (fc): FullyConnectedArchCore(
      (layers): ModuleList(
        (0): FCLayer(
          (linear): WeightNormLinear(in_features=84, out_features=128, bias=True)
        )
        (1): FCLayer(
          (linear): WeightNormLinear(in_features=128, out_features=128, bias=True)
        )
        (2): FCLayer(
          (linear): WeightNormLinear(in_features=128, out_features=128, bias=True)
        )
        (3): FCLayer(
          (linear): WeightNormLinear(in_features=128, out_features=128, bias=True)
        )
      )
      (final_layer): FCLayer(
        (linear): Linear(in_features=128, out_features=128, bias=True)
      )
    )
  )
  (branch_linear): Identity()
  (trunk_linear): Identity()
  (output_linear): Linear(in_features=128, out_features=1, bias=False)
)]
