In [1]:
import numpy as np
import torch
import tree
import yaml
import ml_collections
import sys
sys.path.append('../')
%reload_ext autoreload
%autoreload 2
from deeponet.data.loader import create_loader
from deeponet.data.dataset import MioDataset, preprocess
from deeponet.modules import DeepONet, FullyConnected
from deeponet.utils import Key

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_PATH = "/root/projects/deeponet/data/test/bc1-g0.1/bc1-g0.1.npz"
config_path = '/root/projects/deeponet/deeponet/config.yaml'
with open(config_path, "r") as f:
    cfg = yaml.safe_load(f)
    cfg = ml_collections.ConfigDict(cfg)

np_data = dict(np.load(DATA_PATH, allow_pickle=True))

In [3]:
tree.map_structure(lambda x: x.shape, np_data)

{'sigma': (100, 40, 40, 2),
 'psi_label': (100, 40, 40, 24),
 'scattering_kernel': (100, 24, 24),
 'boundary_scattering_kernel': (100, 160, 12, 24),
 'self_scattering_kernel': (100, 24, 24),
 'boundary': (100, 160, 12),
 'sigma_a': (100, 40, 40),
 'sigma_t': (100, 40, 40),
 'phase_coords': (40, 40, 24, 4),
 'weights': (24,)}

In [4]:
np_data = preprocess(cfg, np_data)

In [5]:
tree.map_structure(lambda x: x.shape if isinstance(x, torch.Tensor) else x, np_data)

({'sigma_a': torch.Size([100, 1600]),
  'sigma_t': torch.Size([100, 1600]),
  'boundary': torch.Size([100, 1920]),
  'scattering_kernel': torch.Size([100, 576])},
 {'phase_coords': torch.Size([38400, 4])},
 {'psi_label': torch.Size([100, 38400])},
 {'branch': {'sigma_a': 1600,
   'sigma_t': 1600,
   'boundary': 1920,
   'scattering_kernel': 576},
  'trunk': {'phase_coords': 4}})

In [37]:
dataset = MioDataset(*np_data[:3], collocation_size=20)

In [38]:
tree.map_structure(lambda x: x.shape if isinstance(x, torch.Tensor) else x, next(dataset))

{'sigma_a': torch.Size([1600]),
 'sigma_t': torch.Size([1600]),
 'boundary': torch.Size([1920]),
 'scattering_kernel': torch.Size([576]),
 'phase_coords': torch.Size([20, 4]),
 'psi_label': torch.Size([20])}

In [39]:
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=2,persistent_workers=True)
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=2, worker_init_fn=worker_init_fn)
dataloader = create_loader(dataset, batch_size=16, num_workers=2, is_training=True)

In [40]:
tree.map_structure(lambda x: x.shape, next(iter(dataloader)))

{'sigma_a': torch.Size([16, 1600]),
 'sigma_t': torch.Size([16, 1600]),
 'boundary': torch.Size([16, 1920]),
 'scattering_kernel': torch.Size([16, 576]),
 'phase_coords': torch.Size([16, 20, 4]),
 'psi_label': torch.Size([16, 20])}

In [10]:
d = iter(dataloader)
a = next(d)['sigma_a']
b = next(d)['sigma_a']
c = next(d)['sigma_a']

In [11]:
print(torch.allclose(a, b), torch.allclose(b, c), torch.allclose(a, c))

False False False


In [12]:
# print(next(d)[1], next(d)[1], next(d)[1])

In [13]:
# np_data = dict(np.load(cfg.data.train_data, allow_pickle=True))
branch_dr, trunk_dr, label_dr, input_shape_dict = np_data

latent_size = cfg.model.latent_size

def create_model(model_cfg, shape_dict):
    input_name = model_cfg.get("input_key")
    net = FullyConnected(
        [Key(input_name, size=shape_dict[input_name])],
        [Key(model_cfg.get("output_key"), latent_size)],
        model_cfg.hidden_units,
    )
    return net

trunk_net = create_model(cfg.model.trunk_net, input_shape_dict["trunk"])

branch_net_list = []
for k, d in cfg.model.items():
    if "branch" in k:
        branch_net_list.append(create_model(d, input_shape_dict["branch"]))

model = DeepONet(
    branch_net_list, trunk_net, output_keys=[Key("psi", 1)]
)

In [14]:
model

DeepONet(
  (branch0): FullyConnected(
    (layers): ModuleList(
      (0): Linear(in_features=1600, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=256, bias=True)
    )
  )
  (branch1): FullyConnected(
    (layers): ModuleList(
      (0): Linear(in_features=1600, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=256, bias=True)
    )
  )
  (branch2): FullyConnected(
    (layers): ModuleList(
      (0): Linear(in_features=1920, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=256, bias=True)
    )
  )
  (branch3): FullyConnected(
    (layers): ModuleList(
      (0): Linear(in_features=576, out_features=256, 

In [15]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", num_params)

Number of parameters:  2775040
