In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
from torch.optim.lr_scheduler import MultiStepLR

import numpy as np
from numpy import random as npr
from math import gamma
from math import factorial
from sobol_seq import sobol_seq
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp
import scipy.integrate as integrate
from tqdm import tqdm

from config import DataConfig, device
from Wfamily import*
from maxwell import*
from Model import*

In [None]:
global device, mu2, epsilon2

# Initialize the configuration
config = DataConfig()

# Generate all training points
points = config.generate_training_points()

len_collocation, x_collocation1, x_collocation2, t_collocation = points['collocation']
x_ic1, x_ic2, t_ic = points['initial']
t_bc, x_bc_left, x_bc_right = points['boundary']
x_interface, t_interface = points['interface']
x_validation1, x_validation2, t_validation = points['validation']
n_test, x_grid, t_grid, x_test1, x_test2, t_test = points['test']


# Wavelet family

family = wavelet_family().to(device)
fam = family.cpu()
print("family_len: ", len(family))

Wfamily1 = torch.stack([gaussian(x_collocation1,t_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
Wfamily2 = torch.stack([gaussian(x_collocation2,t_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

DWx1 = torch.stack([D1xgaussian(x_collocation1,t_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DWt1 = torch.stack([D1tgaussian(x_collocation1,t_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DWx2 = torch.stack([D1xgaussian(x_collocation2,t_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DWt2 = torch.stack([D1tgaussian(x_collocation2,t_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

Wic1 = torch.stack([gaussian(x_ic1,t_ic,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
Wic2 = torch.stack([gaussian(x_ic2,t_ic,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
Wbc_left = torch.stack([gaussian(x_bc_left,t_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
Wbc_right = torch.stack([gaussian(x_bc_right,t_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

Wint = torch.stack([gaussian(x_interface,t_interface,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

WValidation1 = torch.stack([gaussian(x_validation1,t_validation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
WValidation2 = torch.stack([gaussian(x_validation2,t_validation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

WTest1 = torch.stack([gaussian(x_test1,t_test,fam[i,0],fam[i,1],fam[i,2],fam[i,3]) for i in range(len(fam))]).T
WTest2 = torch.stack([gaussian(x_test2,t_test,fam[i,0],fam[i,1],fam[i,2],fam[i,3]) for i in range(len(fam))]).T

# ground truth
E_validation1, H_validation1 = analytical1(x_validation1, t_validation)
E_validation2, H_validation2 = analytical2(x_validation2, t_validation)
E_validation = torch.cat((E_validation1, E_validation2))
H_validation = torch.cat((H_validation1, H_validation2))

E_ic1, H_ic1 = analytical1(x_ic1, t_ic)
E_ic2, H_ic2 = analytical2(x_ic2, t_ic)
E_bc_left, H_bc_left = analytical1(x_bc_left, t_bc)
E_bc_right, H_bc_right = analytical2(x_bc_right, t_bc)

In [None]:
#neural-net models

model_domain1 = WPINN(len_collocation, len(family)).to(device)
optimizer1 = optim.Adam(model_domain1.parameters(), lr=1e-5)#, weight_decay=1e-4)
model_domain2 = WPINN(len_collocation, len(family)).to(device)
optimizer2 = optim.Adam(model_domain2.parameters(), lr=1e-5)#, weight_decay=1e-4)

scheduler1 = MultiStepLR(optimizer1, milestones=[10000], gamma=0.1)
scheduler2 = MultiStepLR(optimizer2, milestones=[10000], gamma=0.1)

c1, b1, u1 = model_domain1(x_collocation1, t_collocation, Wfamily1)
u1[0].shape

In [None]:
x_interior1 = x_collocation1.clone()
x_interior2 = x_collocation2.clone()
t_interior = t_collocation.clone()

def wpinn_loss(model1, model2):   
    global c1, b1, u1, c2, b2, u2
    c1, b1, u1 = model1(x_collocation1, t_collocation, Wfamily1)
    c2, b2, u2 = model2(x_collocation2, t_collocation, Wfamily2)

    E_pred_ic1 = torch.mv(Wic1, c1[0]) + b1[0]
    E_pred_ic2 = torch.mv(Wic2, c2[0]) + b2[0]
    H_pred_ic1 = torch.mv(Wic1, c1[1]) + b1[1]
    H_pred_ic2 = torch.mv(Wic2, c2[1]) + b2[1]
    
    E_pred_bc_left = torch.mv(Wbc_left, c1[0]) + b1[0]
    E_pred_bc_right = torch.mv(Wbc_right, c2[0]) + b2[0]
    H_pred_bc_left = torch.mv(Wbc_left, c1[1]) + b1[1]
    H_pred_bc_right = torch.mv(Wbc_right, c2[1]) + b2[1]


    E_x1 = torch.mv(DWx1, c1[0])
    E_t1 = torch.mv(DWt1, c1[0])
    H_x1 = torch.mv(DWx1, c1[1])
    H_t1 = torch.mv(DWt1, c1[1])
    
    E_x2 = torch.mv(DWx2, c2[0])
    E_t2 = torch.mv(DWt2, c2[0])
    H_x2 = torch.mv(DWx2, c2[1])
    H_t2 = torch.mv(DWt2, c2[1])
    
    E_int1 = torch.mv(Wint, c1[0]) + b1[0]
    E_int2 = torch.mv(Wint, c2[0]) + b2[0]
    H_int1 = torch.mv(Wint, c1[1]) + b1[1]
    H_int2 = torch.mv(Wint, c2[1]) + b2[1]
    
    pde_loss = torch.mean(torch.cat((E_x1+H_t1, E_x2+mu2*H_t2))**2) +\
    torch.mean(torch.cat((H_x1+E_t1, H_x2+epsilon2*E_t2))**2)

    ic_loss = torch.mean(torch.cat((E_pred_ic1-E_ic1, E_pred_ic2-E_ic2))**2) +\
    torch.mean(torch.cat((H_pred_ic1-H_ic1, H_pred_ic2-H_ic2))**2)

    bc_loss = torch.mean((E_pred_bc_left-E_bc_left)**2) + torch.mean((E_pred_bc_right-E_bc_right)**2) +\
              torch.mean((H_pred_bc_left-H_bc_left)**2) + torch.mean((H_pred_bc_right-H_bc_right)**2)

    int_loss = torch.mean((E_int1-E_int2)**2) + torch.mean((H_int1-H_int2)**2)
    
    total_loss = pde_loss + ic_loss + bc_loss + int_loss
    
    return total_loss, pde_loss, ic_loss, bc_loss

def train_wpinn(model1, model2, optimizer1, optimizer2, num_prints):
    # Training loop
    pde_losses = []
    bc_losses = []
    for epoch in tqdm(range(num_epochs)):
        optimizer1.zero_grad()
        optimizer2.zero_grad()

        total_loss, pde_loss, ic_loss, bc_loss = wpinn_loss(model1, model2)
        
        total_loss.backward()
        optimizer1.step()
        optimizer2.step()
        scheduler1.step()
        scheduler2.step()
    
        
        if epoch % ((num_epochs-1)/num_prints) == 0:
            E_numerical = torch.cat((torch.mv(WValidation1, c1[0]) + b1[0], torch.mv(WValidation2, c2[0]) + b2[0]))
            H_numerical = torch.cat((torch.mv(WValidation1, c1[1]) + b1[1], torch.mv(WValidation2, c2[1]) + b2[1]))
            
            E_errL2 = (torch.sum(torch.abs(E_validation-E_numerical)**2))**0.5 / (torch.sum(torch.abs(E_validation)**2))**0.5
            E_errMax = torch.max(torch.abs(E_validation-E_numerical))

            H_errL2 = (torch.sum(torch.abs(H_validation-H_numerical)**2))**0.5 / (torch.sum(torch.abs(H_validation)**2))**0.5
            H_errMax = torch.max(torch.abs(H_validation-H_numerical))
            
            print(f'Epoch [{epoch}/{num_epochs-1}], '
                  f'Total Loss: {total_loss.item():.6f}, '
                  f'PDE Loss: {pde_loss.item():.6f}, '
                  f'IC Loss: {ic_loss.item():.6f}, '
                  f'BC Loss: {bc_loss.item():.6f}, \n'
                  f'\t\t  E:  RelativeL2: {E_errL2}, '
                  f'Max: {E_errMax} \n'
                  f'\t\t  B:  RelativeL2: {H_errL2}, '
                  f'Max: {H_errMax}\n' )
    
    return [pde_losses, ic_loss, bc_losses]

In [None]:
num_epochs = 2*10**4+1
l = train_wpinn(model_domain1, model_domain2, optimizer1, optimizer2, num_prints=20)

In [None]:
# coefficient refinement network
model1_refined = CoefficientRefinementNetwork(initial_coefficients=c1, initial_bias = b1, family_size=len(family)).to(device)
model2_refined = CoefficientRefinementNetwork(initial_coefficients=c2, initial_bias = b2, family_size=len(family)).to(device)

refined_optimizer1 = optim.Adam(model1_refined.parameters(), lr=1e-6)  # Lower learning rate
refined_optimizer2 = optim.Adam(model2_refined.parameters(), lr=1e-6)  # Lower learning rate


num_epochs = 10**4+1
l = train_wpinn(model1_refined, model2_refined, refined_optimizer1, refined_optimizer2, num_prints=20)

In [None]:
#Testing
with torch.no_grad():
    E_pred = torch.cat((torch.mv(WTest1, c1[0].cpu()) + b1[0].cpu(), torch.mv(WTest2, c2[0].cpu()) + b2[0].cpu()))
    H_pred = torch.cat((torch.mv(WTest1, c1[1].cpu()) + b1[1].cpu(), torch.mv(WTest2, c2[1].cpu()) + b2[1].cpu()))

E_exact = torch.cat((analytical1(x_test1, t_test)[0], analytical2(x_test2, t_test)[0]))
H_exact = torch.cat((analytical1(x_test1, t_test)[1], analytical2(x_test2, t_test)[1]))

E_errL2 = (torch.sum(torch.abs(E_exact-E_pred)**2))**0.5 / (torch.sum(torch.abs(E_exact)**2))**0.5
E_errMax = torch.max(torch.abs(E_exact-E_pred))

H_errL2 = (torch.sum(torch.abs(H_exact-H_pred)**2))**0.5 / (torch.sum(torch.abs(H_exact)**2))**0.5
H_errMax = torch.max(torch.abs(H_exact-H_pred))

print(f'E:  RelativeL2: {E_errL2}, '
      f'Max: {E_errMax} \n\n'
      f'B:  RelativeL2: {H_errL2}, '
      f'Max: {H_errMax}')