In [16]:
import torch
import torch.nn as nn
import sys
import os

sys.path.append('../')
from deeponet.modules import FullyConnected, DeepONet, ModifiedMlp, ResNet
from deeponet.utils import Key
import yaml
import ml_collections
import numpy as np
from deeponet.data.dataset import preprocess
import tree
from deeponet.data.loader import create_loader

from typing import Any

import numpy as np

In [17]:
def compute_rmse(label, pre):
    return torch.sqrt(torch.sum((label - pre) ** 2)/torch.sum(label ** 2))

In [36]:
DATA_PATH = "/root/projects/deeponet/data/g0.1-sigma_a3-sigma_t6_test_normalized.npz"
# DATA_PATH = "/root/projects/deeponet/data/g0.1-sigma_a3-sigma_t6_train_normalized.npz"
# DATA_PATH = "/root/projects/deeponet/data/test/bc1-g0.1/bc1-g0.1.npz"
result_path = '/root/projects/deeponet/deeponet/output/train/20230928-053257-inverter'

config_path = os.path.join(result_path, "config.yaml")
model_ckpt_path = os.path.join(result_path, "model_best.pth.tar")
with open(config_path, "r") as f:
    cfg = yaml.safe_load(f)
    cfg = ml_collections.ConfigDict(cfg)
np_data = dict(np.load(DATA_PATH, allow_pickle=True))
weights = np_data["weights"]

In [37]:
# tree.map_structure(lambda x: x.shape, np_data)
np_data = preprocess(cfg, np_data)
branch_dr, trunk_dr, label_dr, input_shape_dict = np_data

In [38]:
tree.map_structure(lambda x: x.shape if isinstance(x, torch.Tensor) else x, np_data)

({'sigma_a': torch.Size([200, 1600]),
  'sigma_t': torch.Size([200, 1600]),
  'boundary': torch.Size([200, 1920]),
  'scattering_kernel': torch.Size([200, 576])},
 {'phase_coords': torch.Size([38400, 4])},
 {'psi_label': torch.Size([200, 38400])},
 {'branch': {'sigma_a': 1600,
   'sigma_t': 1600,
   'boundary': 1920,
   'scattering_kernel': 576},
  'trunk': {'phase_coords': 4}})

In [39]:
latent_size = cfg.model.latent_size

def _get_activation():
    act = cfg.model.get("activation", "relu")
    if act == "tanh":
        return nn.Tanh
    elif act == "relu":
        return nn.ReLU

def create_model(model_cfg, shape_dict):
    input_name = model_cfg.get("input_key")
    model_type = model_cfg.get("type")
    if model_type == "mlp":
        net = FullyConnected(
            [Key(input_name, size=shape_dict[input_name])],
            [Key(model_cfg.get("output_key"), latent_size)],
            model_cfg.hidden_units,
            activation=_get_activation(),
        )
    elif model_type == "modified_mlp":
        net = ModifiedMlp(
            [Key(input_name, size=shape_dict[input_name])],
            [Key(model_cfg.get("output_key"), latent_size)],
            model_cfg.hidden_units,
            activation=_get_activation(),
        )
    elif model_type == "resnet":
        net = ResNet(
            [Key(input_name, size=shape_dict[input_name])],
            [Key(model_cfg.get("output_key"), latent_size)],
            model_cfg.hidden_units,
            activation=_get_activation(),
        )
    return net
branch_net_list = []
for k, d in cfg.model.items():
    if "branch" in k:
        branch_net_list.append(create_model(d, input_shape_dict["branch"]))
trunk_net = create_model(cfg.model.trunk_net, input_shape_dict["trunk"])

model = DeepONet(branch_net_list, trunk_net, output_keys=[Key("psi", 1)])

In [40]:
model.cuda()
ckpt = torch.load(model_ckpt_path, map_location="cuda")
model.load_state_dict(ckpt["state_dict"])
model.eval()

DeepONet(
  (branch0): FullyConnected(
    (layers): ModuleList(
      (0): Linear(in_features=1600, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=64, bias=True)
    )
  )
  (branch1): FullyConnected(
    (layers): ModuleList(
      (0): Linear(in_features=1600, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=64, bias=True)
    )
  )
  (branch2): FullyConnected(
    (layers): ModuleList(
      (0): Linear(in_features=1920, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=64, bias=True)
    )
  )
  (branch3): FullyConnected(
    (layers): ModuleList(
      (0): Linear(in_features=576, out_features=32, bias=True)

In [41]:
param_count = sum([m.numel() for m in model.parameters()])
print('Model created, param count: %d' % param_count)

Model created, param count: 244960


In [42]:
eval_batch_size = 32

class MioValDataset(torch.utils.data.Dataset):
    def __init__(self, branch_dr, trunk_dr, label_dr):
        self.branch_dr = branch_dr
        self.trunk_dr = trunk_dr
        self.label_dr = label_dr
        self.size = len(self.label_dr[list(self.label_dr.keys())[0]])

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        branch_dr = {key: self.branch_dr[key][idx] for key in self.branch_dr}
        trunk_dr = {key: self.trunk_dr[key] for key in self.trunk_dr}
        label_dr = {key: self.label_dr[key][idx] for key in self.label_dr}

        return {**branch_dr, **trunk_dr, **label_dr}


In [43]:
val_dataset = MioValDataset(branch_dr, trunk_dr, label_dr)
val_loader = create_loader(val_dataset, eval_batch_size)

In [44]:
tree.map_structure(lambda x: x.shape, next(iter(val_loader)))

{'sigma_a': torch.Size([32, 1600]),
 'sigma_t': torch.Size([32, 1600]),
 'boundary': torch.Size([32, 1920]),
 'scattering_kernel': torch.Size([32, 576]),
 'phase_coords': torch.Size([32, 38400, 4]),
 'psi_label': torch.Size([32, 38400])}

In [45]:
from deeponet import utils
losses_m = utils.AverageMeter()
rmse_m = utils.AverageMeter()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_fn = nn.MSELoss().to(device)

with torch.no_grad():

    for batch_idx, input in enumerate(val_loader):
        input = {k: v.to(device) for k, v in input.items()}
        output = model(input)
        loss = loss_fn(output["psi"], input["psi_label"])
        print(output["psi"].shape, input["psi_label"].shape)
        rmse = compute_rmse(input["psi_label"], output["psi"])
        
        losses_m.update(loss.item(), input["psi_label"].size(0))
        rmse_m.update(rmse.item(), input["psi_label"].size(0))


        print(
            f"Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g})  "
            f"RMSE: {rmse_m.val:#.3g} ({rmse_m.avg:#.3g})  "
        )
        

torch.Size([32, 38400]) torch.Size([32, 38400])
Loss: 0.000199 (0.000199)  RMSE: 0.733 (0.733)  
torch.Size([32, 38400]) torch.Size([32, 38400])
Loss: 0.000187 (0.000193)  RMSE: 0.709 (0.721)  
torch.Size([32, 38400]) torch.Size([32, 38400])
Loss: 0.000187 (0.000191)  RMSE: 0.694 (0.712)  
torch.Size([32, 38400]) torch.Size([32, 38400])
Loss: 0.000204 (0.000194)  RMSE: 0.726 (0.715)  
torch.Size([32, 38400]) torch.Size([32, 38400])
Loss: 0.000210 (0.000197)  RMSE: 0.744 (0.721)  
torch.Size([32, 38400]) torch.Size([32, 38400])
Loss: 0.000202 (0.000198)  RMSE: 0.736 (0.724)  
torch.Size([8, 38400]) torch.Size([8, 38400])
Loss: 0.000207 (0.000198)  RMSE: 0.754 (0.725)  


In [46]:
rmse_m.avg

0.7248536920547486

In [47]:
losses_m.avg

0.00019843443355057388