In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
from torch.optim.lr_scheduler import MultiStepLR

import numpy as np
from numpy import random as npr
from math import gamma
from math import factorial
from sobol_seq import sobol_seq
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp
import scipy.integrate as integrate
from tqdm import tqdm

from config import DataConfig, device
from Wfamily import*
from Model import*

In [None]:
# Initialize the configuration
config = DataConfig()

# Generate all training points
points = config.generate_training_points()

# Access the points as needed
len_collocation, x_collocation, y_collocation = points['collocation']
x_bc, y_bc_lower, y_bc_upper, y_bc, x_bc_left, x_bc_right = points['boundary']
u_x_validation, u_y_validation, v_x_validation, v_y_validation = points['validation']
n_test, xtest, ytest, x_test, y_test = points['test']

family = wavelet_family().to(device)
fam = family.cpu()
print("family_len: ", len(family))

Wfamily = torch.stack([gaussian(x_collocation,y_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

DWx = torch.stack([D1xgaussian(x_collocation,y_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DWy = torch.stack([D1ygaussian(x_collocation,y_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DW2x = torch.stack([D2xgaussian(x_collocation,y_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DW2y = torch.stack([D2ygaussian(x_collocation,y_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

Wbc_x_left = torch.stack([gaussian(x_bc_left,y_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
Wbc_x_right = torch.stack([gaussian(x_bc_right,y_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
Wbc_y_lower = torch.stack([gaussian(x_bc,y_bc_lower,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
Wbc_y_upper = torch.stack([gaussian(x_bc,y_bc_upper,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

WValidation_u = torch.stack([gaussian(u_x_validation,u_y_validation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
WValidation_v = torch.stack([gaussian(v_x_validation,v_y_validation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
WTest = torch.stack([gaussian(x_test,y_test,fam[i,0],fam[i,1],fam[i,2],fam[i,3]) for i in range(len(fam))]).T


u_validation_100 = torch.tensor([0.0, -0.03717, -0.04192, -0.04775, -0.06434, -0.10150, -0.15662, -0.21090, -0.20581, -0.13641, 0.00332, 0.23151, 0.68717, 0.73722, 0.78871, 0.84123, 1.0]).to(device)
u_validation_400 = torch.tensor([0.0, -0.08186, -0.09266, -0.10338, -0.14612, -0.24299, -0.32726, -0.17119, -0.11477, 0.02135, 0.16256, 0.29093, 0.55892, 0.61756, 0.68439, 0.75837, 1.0]).to(device)

v_validation_100 = torch.tensor([0.0, 0.09233, 0.10091, 0.10890, 0.12317, 0.16077, 0.17507, 0.17527, 0.05454, -0.24533, -0.22445, -0.16914, -0.10313, -0.08864, -0.07391, -0.05906, 0.0]).to(device)
v_validation_400 = torch.tensor([0.0, 0.18360, 0.19713, 0.20920, 0.22965, 0.28124, 0.30203, 0.30174, 0.05186, -0.38598, -0.44993, -0.23827, -0.22847, -0.19254, -0.15663, -0.12146, 0.0]).to(device)


In [None]:
U_model_WPINN = WPINN(len_collocation, len(family)).to(device)
V_model_WPINN = WPINN(len_collocation, len(family)).to(device)
P_model_WPINN = WPINN(len_collocation, len(family)).to(device)

U_optimizer1 = optim.Adam(U_model_WPINN.parameters(), lr=1e-6)#, weight_decay=1e-4)
V_optimizer1 = optim.Adam(V_model_WPINN.parameters(), lr=1e-6)#, weight_decay=1e-4)
P_optimizer1 = optim.Adam(P_model_WPINN.parameters(), lr=1e-6)#, weight_decay=1e-4)

# milestones = [1000, 2000, 4000] 
# scheduler = MultiStepLR(optimizer1, milestones=milestones, gamma=0.1)

u_c, u_b, u = U_model_WPINN(x_collocation, y_collocation, [Wfamily, DWx, DWy, DW2x, DW2y])
len(u[0])

In [None]:
x_interior = x_collocation.clone()
y_interior = y_collocation.clone()

global Re
Re = 100

def wpinn_loss(model1, model2, model3):   
    global u_c, u_b, v_c, v_b, p_c, p_b
    u_c, u_b, u = model1(x_interior, y_interior, [Wfamily, DWx, DWy, DW2x, DW2y])
    v_c, v_b, v = model2(x_interior, y_interior, [Wfamily, DWx, DWy, DW2x, DW2y])
    p_c, p_b, p = model3(x_interior, y_interior, [Wfamily, DWx, DWy, DW2x, DW2y])

    u_pred_bc_x_left = torch.mv(Wbc_x_left, u_c) + u_b
    u_pred_bc_x_right = torch.mv(Wbc_x_right, u_c) + u_b
    u_pred_bc_y_lower = torch.mv(Wbc_y_lower, u_c) + u_b
    u_pred_bc_y_upper = torch.mv(Wbc_y_upper, u_c) + u_b

    v_pred_bc_x_left = torch.mv(Wbc_x_left, v_c) + v_b
    v_pred_bc_x_right = torch.mv(Wbc_x_right, v_c) + v_b
    v_pred_bc_y_lower = torch.mv(Wbc_y_lower, v_c) + v_b
    v_pred_bc_y_upper = torch.mv(Wbc_y_upper, v_c) + v_b

    
    pde_loss = torch.mean((u[1] + v[2]) ** 2) +\
               torch.mean((u[0]*u[1] + v[0]*u[2] + p[1] - (u[3] + u[4])/Re)**2) +\
               torch.mean((u[0]*v[1] + v[0]*v[2] + p[2] - (v[3] + v[4])/Re)**2)

    bc_loss = torch.mean(u_pred_bc_x_left** 2) + torch.mean(u_pred_bc_x_right** 2) + torch.mean(u_pred_bc_y_lower** 2) + torch.mean((u_pred_bc_y_upper - 1)** 2) +\
              torch.mean(v_pred_bc_x_left** 2) + torch.mean(v_pred_bc_x_right** 2) + torch.mean(v_pred_bc_y_lower** 2) + torch.mean(v_pred_bc_y_upper**2)
    
    total_loss = pde_loss + bc_loss
    
    return total_loss, pde_loss, bc_loss

def train_wpinn(model1, model2, model3, optimizer1, optimizer2, optimizer3, num_prints):
    # Training loop
    pde_losses = []
    bc_losses = []
    for epoch in range(num_epochs):
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        optimizer3.zero_grad()

        total_loss, pde_loss, bc_loss = wpinn_loss(model1, model2, model3)
        
        total_loss.backward()
        optimizer1.step()
        optimizer2.step()
        optimizer3.step()
        # scheduler.step()

        # current_lr = optimizer.param_groups[0]['lr']
    
        
        if epoch % ((num_epochs-1)/num_prints) == 0:
            u_numerical = torch.mv(WValidation_u, u_c) + u_b
            v_numerical = torch.mv(WValidation_v, v_c) + v_b
            
            u_errL2 = (torch.sum(torch.abs(u_validation_100-u_numerical)**2))**0.5 / (torch.sum(torch.abs(u_validation_100)**2))**0.5
            u_errMax = torch.max(torch.abs(u_validation_100-u_numerical))

            v_errL2 = (torch.sum(torch.abs(v_validation_100-v_numerical)**2))**0.5 / (torch.sum(torch.abs(v_validation_100)**2))**0.5
            v_errMax = torch.max(torch.abs(v_validation_100-v_numerical))
            
            print(f'Epoch [{epoch}/{num_epochs-1}], '
                  f'Total Loss: {total_loss.item():.6f}, '
                  f'PDE Loss: {pde_loss.item():.6f}, '
                  f'BC Loss: {bc_loss.item():.6f}, \n'
                  f'\t\t  u:  RelativeL2: {u_errL2}, '
                  f'Max: {u_errMax} \n'
                  f'\t\t  v:  RelativeL2: {v_errL2}, '
                  f'Max: {v_errMax}\n' )
    
    return [pde_losses, bc_losses]

In [None]:
num_epochs = 3*10**4+1
l = train_wpinn(U_model_WPINN, V_model_WPINN, P_model_WPINN, U_optimizer1, V_optimizer1, P_optimizer1, num_prints=50)

In [None]:
#coefficient refinement network
U_model_refined = CoefficientRefinementNetwork(initial_coefficients=u_c, initial_bias = u_b, family_size=len(family)).to(device)
V_model_refined = CoefficientRefinementNetwork(initial_coefficients=v_c, initial_bias = v_b, family_size=len(family)).to(device)
P_model_refined = CoefficientRefinementNetwork(initial_coefficients=p_c, initial_bias = p_b, family_size=len(family)).to(device)

U_optimizer2 = optim.Adam(U_model_refined.parameters(), lr=1e-7)  # Lower learning rate
V_optimizer2 = optim.Adam(V_model_refined.parameters(), lr=1e-7) 
P_optimizer2 = optim.Adam(P_model_refined.parameters(), lr=1e-7) 

num_epochs = 10**4+1
l = train_wpinn(U_model_refined, V_model_refined, P_model_refined, U_optimizer2, V_optimizer2, P_optimizer2, num_prints=50)