# It trains MF-DeepONet and performs Reliability analysis on the MF 1D Poisson's data (time-independent problem).
### HF data size = 10

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from utils import *
import matplotlib.pyplot as plt

from timeit import default_timer

In [None]:
torch.manual_seed(10)
np.random.seed(10)

# **Deeponet**

In [None]:
class Deeponet(nn.Module):
    def __init__(self, branchnetdepth, trunknetdepth, width,insize, space_dim):
        super(Deeponet, self).__init__()

        # self.bdpth = branchnetdepth
        # self.tdpth = tdpth
        # self.width = width
        inp_dim = insize
        s_dim = space_dim
        tlayers = []
        blayers = [] 
        for i in range(branchnetdepth):
            blayers.append(nn.Linear(inp_dim,width)) 
            blayers.append(nn.ReLU(inplace=True))
            inp_dim = width
        for i in range(trunknetdepth):
            tlayers.append(nn.Linear(s_dim,width)) 
            tlayers.append(nn.ReLU(inplace=True))
            s_dim = width
        
        self.branchnet = nn.Sequential(*blayers)
        self.trunknet = nn.Sequential(*tlayers)
        self.bias = nn.ReLU(inplace=True)

    # Convolution
    def hadprodsum(self, branch, trunk):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("ij,ij->i", branch, trunk)

    def forward(self, xb,xt):
        #pdb.set_trace()
        x1 = self.branchnet(xb)
        x2 = self.trunknet(xt)
        x  = self.hadprodsum(x1,x2)
        x = x.view(-1,1)
        x  = self.bias(x)      
        return x

# Multifidelity

In [None]:
################################################################
#  configurations
################################################################
ntrain_m = 10
ntest_m = 2000
n_total = ntrain_m + ntest_m
last_m = 400
s = 100

batch_size = 500
learning_rate = 0.001

epochs = 500

In [None]:
PATH = 'data/possion_10pt_100pt__lscale_01.npz'
data = np.load(PATH)

x_data_h = data['f_stoch'] 
y_data_l = data['y_low_100'] 
y_data_h = data['yhi'] 
x_coords = data['xhi'].reshape((s,))

In [None]:
y_data_h.shape

In [None]:
x_mf = np.stack((x_data_h, y_data_l), axis=-1)
y_mf = y_data_h - y_data_l


In [None]:
print(x_mf.shape, y_mf.shape)

In [None]:
x_train_mf, y_train_mf = x_mf[:ntrain_m, ...], y_mf[:ntrain_m, ...]
x_test_mf, y_test_mf = x_mf[-ntest_m:, ...], y_mf[-ntest_m:, ...]

In [None]:
print(x_train_mf.shape, y_train_mf.shape, x_test_mf.shape, y_test_mf.shape)

In [None]:
x_coords.shape

In [None]:
# read data

# Data is of the shape (number of samples, grid size)
xb_train_mf  = np.repeat(x_train_mf.astype(np.float32).reshape((ntrain_m,-1)), x_coords.shape[0],axis=0)
xt_train_mf =  np.tile(x_coords.astype(np.float32),ntrain_m).reshape((-1,1))
y_train_mf  =  y_train_mf.astype(np.float32).reshape(-1,1)

xb_test_mf  = np.repeat(x_test_mf.astype(np.float32).reshape((ntest_m,-1)), x_coords.shape[0],axis=0)
xt_test_mf =  np.tile(x_coords.astype(np.float32),ntest_m).reshape((-1,1))
y_test_mf  =  y_test_mf.astype(np.float32).reshape(-1,1)

print(xb_train_mf.shape, xt_train_mf.shape, y_train_mf.shape)
print(xb_test_mf.shape, xt_test_mf.shape, y_test_mf.shape)

xb_train_mf = torch.from_numpy(xb_train_mf)
xt_train_mf = torch.from_numpy(xt_train_mf)
xb_test_mf = torch.from_numpy(xb_test_mf)
xt_test_mf = torch.from_numpy(xt_test_mf)
y_train_mf = torch.from_numpy(y_train_mf)
y_test_mf = torch.from_numpy(y_test_mf)

train_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(xb_train_mf, xt_train_mf, y_train_mf),
                                              batch_size=batch_size, shuffle=True)
test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(xb_test_mf, xt_test_mf, y_test_mf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
# model
bdepth = 3
tdepth = 3
width = 5
inputsizeb = xb_train_mf.shape[-1]
spdim = xt_train_mf.shape[-1]

model = Deeponet(bdepth,tdepth,width,inputsizeb,spdim).cuda()
print(count_params(model))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [None]:
model

In [None]:
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for xb, xt, y in train_loader_mf:
        xb, xt, y = xb.cuda(), xt.cuda(), y.cuda()
        
        optimizer.zero_grad()
        out = model(xb, xt)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        mse.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()

    model.eval()
    test_mse = 0
    with torch.no_grad():
        for xb, xt, y in test_loader_mf:
            xb, xt, y = xb.cuda(), xt.cuda(), y.cuda()

            out = model(xb, xt)
            tmse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
            test_mse += tmse.item()

    train_mse /= len(train_loader_mf)
    test_mse /= len(test_loader_mf)
    t2 = default_timer()
    print(f'epoch {ep}, time_taken: {t2-t1}, train_mse: {train_mse},test_mse: {test_mse}')


In [None]:
# Save the MF-DeepONet model

torch.save(model, 'model/deeponet_poissons_10')

In [None]:
pred_mf = []

with torch.no_grad():
    index = 0
    for xb, xt, y in test_loader_mf:
        xb, xt, y = xb.cuda(), xt.cuda(), y.cuda()
        tmse = 0
        
        out = model(xb, xt)
        tmse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        
        pred_mf.append( out.cpu() )
        print("Batch-{}, Test-loss-{:0.6f}".format( index, tmse ))
        index += 1

pred_mf = torch.cat(( pred_mf ))
print('Mean mse_mf-{}'.format(F.mse_loss(y_test_mf, pred_mf).item()))
     

In [None]:
pred_mf = pred_mf.reshape(ntest_m, x_coords.shape[0])
y_test_mf = y_test_mf.reshape(ntest_m, x_coords.shape[0])

print(pred_mf.shape, y_test_mf.device)

In [None]:
out_actual = pred_mf + torch.from_numpy(x_test_mf[:,:,1]) 
real_actual = y_test_mf + torch.from_numpy(x_test_mf[:,:,1]) 


In [None]:
print(real_actual.shape, out_actual.shape)

In [None]:
mse_pred = F.mse_loss(out_actual, real_actual).item()

print('MSE-Predicted solution-{:0.6f}'.format(mse_pred))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - DeepONet - High fidelity')

index = 0
for i in range(ntest_m):
    if i % 500 == 0:
        plt.plot(x_coords, out_actual[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, real_actual[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


# First passage failure

In [None]:
# %%
eh = 4.5
eh_mcs = np.zeros(ntest_m)
eh_deeponet_mf = np.zeros(ntest_m)

for i in range(ntest_m):
    if len( np.where( real_actual[i, ...] > eh )[0] ) == 0:
        eh_mcs[i] = 0
    else:
        eh_mcs[i] = 1
        
for i in range(ntest_m):
    if len( np.where( out_actual[i, ...] > eh )[0] ) == 0:
        eh_deeponet_mf[i] = 0
    else:
        eh_deeponet_mf[i] = 1

eh_deeponet_mf = len(np.where(eh_deeponet_mf!=0)[0])/ntest_m
pf_mcs = len(np.where(eh_mcs!=0)[0])/ntest_m
print('Prob. of failure, MFDeepONet-{}, MCS-{}'.format(eh_deeponet_mf, pf_mcs))


In [None]:
scipy.io.savemat('data/deeponet_poissons_n10.mat', mdict={'out_actual':out_actual.cpu().numpy(), 
                                                          'real_actual':real_actual.cpu().numpy(),
                                                          'x_coords':x_coords})
