In [1]:
import torch
import torch.nn as nn
import numpy as np
import sys
import os
import random
import matplotlib.pyplot as plt

src_path = os.path.abspath(os.path.join(os.getcwd(), 'src'))
if src_path not in sys.path:
    sys.path.append(src_path)
    
from utils import MIMONetDataset, DeepONetDataset, ChannelScaler
from mimonet_drop import MIMONet_Drop

In [2]:
# check if GPU is available and set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [3]:
# set working directory
working_dir = "/projects/bcnx/kazumak2/MIMONet/Subchannel/"
data_dir = os.path.join(working_dir, "data")

In [4]:
# trunk dataset
trunk_input = np.load(os.path.join(data_dir, "share/trunk_input.npz"))['trunk']

# training data
train_branch = np.load(os.path.join(data_dir, "training/train_branch_input.npz"))
train_branch_1 = train_branch['func_params']
train_branch_2 = train_branch['stat_params']

# [samples, channel, gridpoints]
train_target = np.load(os.path.join(data_dir, "training/train_target.npz"))['target']
# convert to [samples, gridpoints, channel]
train_target = np.moveaxis(train_target, 1, 2)

print("train_branch_1 shape:", train_branch_1.shape)
print("train_branch_2 shape:", train_branch_2.shape)
print("train_target shape:", train_target.shape)

# scaling the functional input data using predefined mean and std
f_mean = np.load(os.path.join(data_dir, "share/func_mean_std_params.npz"))['mean']
f_std = np.load(os.path.join(data_dir, "share/func_mean_std_params.npz"))['std']

train_branch_1 = (train_branch_1 - f_mean) / f_std

# scaling the static input data using predefined mean and std
s_mean = np.load(os.path.join(data_dir, "share/stat_mean_std_params.npz"))['mean']
s_std = np.load(os.path.join(data_dir, "share/stat_mean_std_params.npz"))['std']

for i in range(s_mean.shape[0]):
    train_branch_2[:, i] = (train_branch_2[:, i] - s_mean[i]) / s_std[i]

train_branch_1 shape: (4000, 100)
train_branch_2 shape: (4000, 2)
train_target shape: (4000, 1733, 3)


In [5]:
test_branch = np.load(os.path.join(data_dir, "test/test_branch_input.npz"))
test_branch_1 = test_branch['func_params']
test_branch_2 = test_branch['stat_params']

test_target = np.load(os.path.join(data_dir, "test/test_target.npz"))['target']
test_target = np.moveaxis(test_target, 1, 2)

print("test_branch_1 shape:", test_branch_1.shape)
print("test_branch_2 shape:", test_branch_2.shape)
print("test_target shape:", test_target.shape)

# scaling the functional input data using predefined mean and std
test_branch_1 = (test_branch_1 - f_mean) / f_std
# scaling the static input data using predefined mean and std
for i in range(s_mean.shape[0]):
    test_branch_2[:, i] = (test_branch_2[:, i] - s_mean[i]) / s_std[i]

test_branch_1 shape: (1000, 100)
test_branch_2 shape: (1000, 2)
test_target shape: (1000, 1733, 3)


In [6]:
# scaling the target data
'''  
note: reverse the scaling for the target data
train_target = scaler.inverse_transform(train_target_scaled)
test_target = scaler.inverse_transform(test_target_scaled)
'''
scaler = ChannelScaler(method='minmax', feature_range=(-1, 1))
scaler.fit(train_target)
train_target_scaled = scaler.transform(train_target)
test_target_scaled = scaler.transform(test_target)

In [7]:
# test dataset and dataloader
test_dataset = MIMONetDataset(
    [test_branch_1, test_branch_2],  # branch_data_list
    trunk_input,                     # trunk_data
    test_target_scaled               # target_data
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,  # set to 1 for testing
    shuffle=False,
    num_workers=0
)

In [8]:
train_dataset = MIMONetDataset(
    [train_branch_1, train_branch_2],  # branch_data_list
    trunk_input,                       # trunk_data
    train_target_scaled                # target_data
)

In [9]:
# Architecture parameters
dim = 256
branch_input_dim1 = 100
branch_input_dim2 = 2
trunk_input_dim = 2

model_args = {
    'branch_arch_list': [
        [branch_input_dim1, 512, 512, 512, dim],
        [branch_input_dim2, 512, 512, 512, dim]
    ],
    'trunk_arch': [trunk_input_dim, 256, 256, 256, dim],
    'num_outputs': 3,
    'activation_fn': nn.ReLU,
    'merge_type': 'mul',
    'dropout_p': 0.1  # Dropout rate
}

model = MIMONet_Drop(**model_args)
model = model.to(device)

# Print parameter count
num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {num_params:,}")

Total number of parameters: 1,696,259


In [10]:
import copy, random

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

N = 20  # Number of ensemble members
ensemble = []

for _ in range(N):
    m = MIMONet_Drop(**model_args)
    m.load_state_dict(torch.load('Subchannel/checkpoints/best_model_dropout.pt'))
    m.to(device)
    m.train()  # Enable dropout during inference
    ensemble.append(m)

In [11]:
# Example: get predictions from all models in the ensemble
def ensemble_predict(ensemble, branch_batch, trunk_batch):
    preds = []
    with torch.no_grad():
        for m in ensemble:
            preds.append(m(branch_batch, trunk_batch).cpu().numpy())
    return np.stack(preds, axis=0)  # shape: (N, batch_size, ...)

def get_ensemble_predictions(ensemble, data_loader, device=device):

    all_targets = []
    all_preds = []

    # Set all models to evaluation mode
    
    for i, (branch_data, trunk_data, target_data) in enumerate(data_loader):
        branch_data = [bd.to(device).float() for bd in branch_data]
        trunk_data = trunk_data.to(device).float()
        target_data = target_data.to(device).float()

        # Predict from ensemble: returns shape (E, batch_size, ...)
        preds = ensemble_predict(ensemble, branch_data, trunk_data)
        all_preds.append(preds)  # Collect each batchâ€™s ensemble predictions
        all_targets.append(target_data.cpu().numpy())

    # Concatenate across batches
    all_preds = np.concatenate(all_preds, axis=1)  # [E, total_samples, ...]
    all_targets = np.concatenate(all_targets, axis=0)  # [total_samples, ...]

    print('Shape of all_preds:', all_preds.shape)

    return all_preds, all_targets

In [12]:
# Get ensemble predictions on test set
ensemble_preds, all_targets = get_ensemble_predictions(ensemble, test_loader, device=device)

Shape of all_preds: (20, 1000, 1733, 3)


In [13]:
# get mean and stddev across ensemble members
mean_preds = np.mean(ensemble_preds, axis=0)  # [total_samples, ...]
stddev_preds = np.std(ensemble_preds, axis=0)  # [total_samples, ...]

In [14]:
# reverse scaling the predictions
mean_preds_rescaled = scaler.inverse_transform(mean_preds)
stddev_preds_rescaled = scaler.inverse_transform(stddev_preds)
all_targets_rescaled = scaler.inverse_transform(all_targets)

In [15]:
# compute mean relative l2 error per output channel
mean_l2_0 = np.mean(np.linalg.norm(mean_preds_rescaled[..., 0] - all_targets_rescaled[..., 0], axis=1) / np.linalg.norm(all_targets_rescaled[..., 0], axis=1))
mean_l2_1 = np.mean(np.linalg.norm(mean_preds_rescaled[..., 1] - all_targets_rescaled[..., 1], axis=1) / np.linalg.norm(all_targets_rescaled[..., 1], axis=1))
mean_l2_2 = np.mean(np.linalg.norm(mean_preds_rescaled[..., 2] - all_targets_rescaled[..., 2], axis=1) / np.linalg.norm(all_targets_rescaled[..., 2], axis=1))
print(f"Mean relative L2 error (channel 0): {mean_l2_0:.4f}")
print(f"Mean relative L2 error (channel 1): {mean_l2_1:.4f}")
print(f"Mean relative L2 error (channel 2): {mean_l2_2:.4f}")

Mean relative L2 error (channel 0): 0.0846
Mean relative L2 error (channel 1): 0.0030
Mean relative L2 error (channel 2): 0.0486
