In [3]:
import os
import sys
import warnings

import torch
import numpy as np

import modulus.sym
from modulus.sym.hydra import to_absolute_path, instantiate_arch, ModulusConfig, to_yaml
from modulus.sym.solver import Solver
from modulus.sym.domain import Domain
from modulus.sym.models.fully_connected import FullyConnectedArch
from modulus.sym.models.fourier_net import FourierNetArch
from modulus.sym.models.deeponet import DeepONetArch
from modulus.sym.domain.constraint.continuous import DeepONetConstraint
from modulus.sym.domain.validator.discrete import GridValidator
from modulus.sym.dataset.discrete import DictGridDataset

In [4]:
import torch
torch.cuda.is_available()

False

In [6]:
from typing import Dict, List, Tuple, Union
from modulus.sym.key import Key
from modulus.sym.models.arch import Arch
from torch import Tensor
import torch

class MioBranchNet(Arch):
    def __init__(
        self,
        branch_net_list: List[Arch],
        output_keys: List[Key] = None,
        detach_keys: List[Key] = [],
    ) -> None:
        super().__init__(
            input_keys=[],
            output_keys=output_keys,
            detach_keys=detach_keys,
        )
        self.num_branches = len(branch_net_list)

        self.branch_net_list = branch_net_list
        self.input_keys = [key for b in self.branch_net_list for key in b.input_keys]
        self.input_key_dict = {str(var): var.size for var in self.input_keys}

        branch_index_name = ["branch_slice_index_"+str(i) for i in range(self.num_branches)]
        for i in range(self.num_branches):
            index = self.prepare_slice_index(
                self.input_key_dict, self.branch_net_list[i].input_key_dict.keys()
            )
            self.register_buffer(branch_index_name[i], index, persistent=False)

    def _tensor_forward(self, x: Tensor) -> None:
        _output = [
            self.branch_net_list[i]._tensor_forward(
                self.slice_input(x, getattr(self, "branch_slice_index_"+str(i)), dim=-1)
            )
            for i in range(self.num_branches)
        ]
        stacked_tensor = torch.stack(_output)
        output = torch.prod(stacked_tensor, dim=0)

        output = self.process_output(output, self.output_scales_tensor)
        return output
    
    def forward(self, in_vars: Dict[str, Tensor]) -> Dict[str, Tensor]:
        x = self.concat_input(
            in_vars,
            self.input_key_dict.keys(),
            detach_dict=self.detach_key_dict,
            dim=-1,
        )
        y = self._tensor_forward(x)
        return self.split_output(y, self.output_key_dict, dim=-1)

In [7]:
from modulus.sym.models.fully_connected import FullyConnectedArch
from modulus.sym.key import Key
import torch

arch_1 = FullyConnectedArch(
    [Key("x1", size=2)], [Key("b1", size=2)], layer_size=64, nr_layers=2
)
arch_2 = FullyConnectedArch(
    [Key("x2", size=2)], [Key("b2", size=2)], layer_size=64, nr_layers=2
)
arch_3 = FullyConnectedArch(
    [Key("x3", size=2)], [Key("b3", size=2)], layer_size=64, nr_layers=2
)

branch_net = MioBranchNet([arch_1, arch_2, arch_3], [Key("output", size=2)])


model = branch_net.make_node("branch_net")
input = {"x1": torch.randn(64, 2), "x2": torch.randn(64, 2), "x3": torch.randn(64, 2)}
output = model.evaluate(input)