In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
from torch.optim.lr_scheduler import MultiStepLR

import numpy as np
from numpy import random as npr
from math import gamma
from math import factorial
from sobol_seq import sobol_seq
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp
import scipy.integrate as integrate
from tqdm import tqdm

from config import DataConfig, device
from Wfamily import*
from maxwell import*
from Model import*

In [None]:
# Initialize the configuration
config = DataConfig()

# Generate all training points
points = config.generate_training_points()

len_collocation, x_collocation, t_collocation = points['collocation']
x_ic, t_ic = points['initial']
t_bc, x_bc_left, x_bc_right = points['boundary']
x_validation, t_validation = points['validation']
n_test, x_grid, t_grid, x_test, t_test = points['test']

family = wavelet_family().to(device)
fam = family.cpu()
print("family_len: ", len(family))

Wfamily = torch.stack([gaussian(x_collocation,t_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

DWx = torch.stack([D1xgaussian(x_collocation,t_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DWt = torch.stack([D1tgaussian(x_collocation,t_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

Wic = torch.stack([gaussian(x_ic,t_ic,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
Wbc_left = torch.stack([gaussian(x_bc_left,t_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
Wbc_right = torch.stack([gaussian(x_bc_right,t_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DWbc_left = torch.stack([D1xgaussian(x_bc_left,t_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DWbc_right = torch.stack([D1xgaussian(x_bc_right,t_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

WValidation = torch.stack([gaussian(x_validation,t_validation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
WTest = torch.stack([gaussian(x_test,t_test,fam[i,0],fam[i,1],fam[i,2],fam[i,3]) for i in range(len(fam))]).T


E_validation, H_validation = analytical(x_validation, t_validation)
E_ic, H_ic = analytical(x_ic, t_ic)
E_exact, H_exact = analytical(x_test, t_test)

In [None]:
model_WPINN = WPINN(len_collocation, len(family)).to(device)
optimizer1 = optim.Adam(model_WPINN.parameters(), lr=1e-5)#, weight_decay=1e-4)

scheduler = MultiStepLR(optimizer1, milestones=[10000], gamma=0.1)

c, b, u = model_WPINN(x_collocation, t_collocation, Wfamily)
u[0].shape

In [None]:
x_interior = x_collocation.clone()
t_interior = t_collocation.clone()

def wpinn_loss(model):   
    global c, b
    c, b, u = model(x_interior, t_interior, Wfamily)

    E_pred_ic = torch.mv(Wic, c[0]) + b[0]
    H_pred_ic = torch.mv(Wic, c[1]) + b[1]
    
    E_pred_bc_left = torch.mv(Wbc_left, c[0]) + b[0]
    E_pred_bc_right = torch.mv(Wbc_right, c[0]) + b[0]
    D_H_pred_bc_left = torch.mv(DWbc_left, c[1])
    D_H_pred_bc_right = torch.mv(DWbc_right, c[1])


    E_x = torch.mv(DWx, c[0])
    E_t = torch.mv(DWt, c[0])
    H_x = torch.mv(DWx, c[1])
    H_t = torch.mv(DWt, c[1])

    
    pde_loss = torch.mean((E_x + H_t) ** 2) + torch.mean((H_x + E_t) ** 2)

    ic_loss = torch.mean((E_pred_ic - E_ic)**2) + torch.mean((H_pred_ic - H_ic)**2)

    bc_loss = torch.mean(E_pred_bc_left**2) + torch.mean(E_pred_bc_right**2) +\
              torch.mean(D_H_pred_bc_left**2) + torch.mean(D_H_pred_bc_right**2)
    
    total_loss = pde_loss + ic_loss + bc_loss
    
    return total_loss, pde_loss, ic_loss, bc_loss

def train_wpinn(model, optimizer, num_prints):
    # Training loop
    pde_losses = []
    bc_losses = []
    for epoch in tqdm(range(num_epochs)):
        optimizer.zero_grad()

        total_loss, pde_loss, ic_loss, bc_loss = wpinn_loss(model)
        
        total_loss.backward()
        optimizer.step()
        scheduler.step()
    
        
        if epoch % ((num_epochs-1)/num_prints) == 0:
            E_numerical = (torch.mv(WValidation, c[0]) + b[0])
            H_numerical = (torch.mv(WValidation, c[1]) + b[1])
            
            E_errL2 = (torch.sum(torch.abs(E_validation-E_numerical)**2))**0.5 / (torch.sum(torch.abs(E_validation)**2))**0.5
            E_errMax = torch.max(torch.abs(E_validation-E_numerical))
            H_errL2 = (torch.sum(torch.abs(H_validation-H_numerical)**2))**0.5 / (torch.sum(torch.abs(H_validation)**2))**0.5
            H_errMax = torch.max(torch.abs(H_validation-H_numerical))
            
            print(f'Epoch [{epoch}/{num_epochs-1}], '
                  f'Total Loss: {total_loss.item():.6f}, '
                  f'PDE Loss: {pde_loss.item():.6f}, '
                  f'IC Loss: {ic_loss.item():.6f}, '
                  f'BC Loss: {bc_loss.item():.6f}, \n'
                  f'\t\t  E:  RelativeL2: {E_errL2}, '
                  f'Max: {E_errMax} \n'
                  f'\t\t  B:  RelativeL2: {H_errL2}, '
                  f'Max: {H_errMax}\n' )
    
    return [pde_losses, ic_loss, bc_losses]

In [None]:
num_epochs = 10**4+1
l = train_wpinn(model_WPINN, optimizer1, num_prints=20)

In [None]:
#coefficient refinement network

model_refined = CoefficientRefinementNetwork(initial_coefficients=c, initial_bias = b, family_size=len(family)).to(device)
optimizer2 = optim.Adam(model_refined.parameters(), lr=1e-6)  # Lower learning rate

num_epochs = 10**4+1
l = train_wpinn(model_refined, optimizer2, num_prints=20)

In [None]:
#Testing

with torch.no_grad():
    E_pred = torch.mv(WTest, c[0].cpu()) + b[0].cpu()
    H_pred = torch.mv(WTest, c[1].cpu()) + b[1].cpu()

exact_sol = analytical(x_test, t_test)
E_exact = exact_sol[0]
H_exact = exact_sol[1]

E_errL2 = (torch.sum(torch.abs(E_exact-E_pred)**2))**0.5 / (torch.sum(torch.abs(E_exact)**2))**0.5
E_errMax = torch.max(torch.abs(E_exact-E_pred))

H_errL2 = (torch.sum(torch.abs(H_exact-H_pred)**2))**0.5 / (torch.sum(torch.abs(H_exact)**2))**0.5
H_errMax = torch.max(torch.abs(H_exact-H_pred))

print(f'E:  RelativeL2: {E_errL2}, '
      f'Max: {E_errMax} \n\n'
      f'B:  RelativeL2: {H_errL2}, '
      f'Max: {H_errMax}')