In [1]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.optim import AdamW, LBFGS
from torch.autograd import Variable
from tqdm import tqdm
import numpy as np
import os
import pandas as pd
from itertools import cycle 

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
def fourier_features(x, B):
    x_transformed = torch.matmul(x, B)
    return torch.cat([torch.sin(x_transformed), torch.cos(x_transformed)], dim=-1)

def init_fixed_frequency_matrix(size, scale=1.0):
    num_elements = size[0] * size[1]
    lin_space = torch.linspace(-scale, scale, steps=num_elements)
    B = lin_space.view(size).float()
    return B

class FourierFeatureNN(nn.Module):
    def __init__(self, input_dim=1, shared_units=16, neuron_units=32, scale=1.0, 
                 activation=nn.Tanh, device='cpu'):
        super(FourierFeatureNN, self).__init__()
        self.Bx = init_fixed_frequency_matrix((input_dim, shared_units // 2), scale=scale).to(device)
        self.Bt = init_fixed_frequency_matrix((input_dim, shared_units // 2), scale=scale).to(device)

        # Define separate paths for x and t after Fourier transformation
        self.path_x = nn.Sequential( 
            nn.Linear(shared_units, neuron_units),  # Adjusted from shared_units // 2 to shared_units
            activation(),
            nn.Linear(neuron_units, neuron_units),
            activation(),
            nn.Linear(neuron_units, neuron_units),
            activation() )
        self.path_t = nn.Sequential( 
            nn.Linear(shared_units, neuron_units),  # Same adjustment
            activation(),
            nn.Linear(neuron_units, neuron_units),
            activation(),
            nn.Linear(neuron_units, neuron_units),
            activation() )

        # Define separate FFN for u and v directly after the paths
        self.ffn_u = nn.Sequential(
            nn.Linear(neuron_units, neuron_units), activation(),
            nn.Linear(neuron_units, neuron_units), activation(),
            nn.Linear(neuron_units, 2)  # Outputs for u (real and imaginary parts)
        )
        self.ffn_v = nn.Sequential(
            nn.Linear(neuron_units, neuron_units), activation(),
            nn.Linear(neuron_units, neuron_units), activation(),
            nn.Linear(neuron_units, 1)  # Output for v
        )

        self.apply(self.initialize_weights)

    def forward(self, x, t):
        # Apply Fourier feature transformations
        x_fourier = fourier_features(x, self.Bx)
        t_fourier = fourier_features(t, self.Bt)

        # Pass through separate paths
        x_path_output = self.path_x(x_fourier)
        t_path_output = self.path_t(t_fourier)

        # Pointwise multiplication of the separate path outputs
        combined_features = x_path_output * t_path_output

        # Directly pass through different FFNs for u and v
        final_output_u = self.ffn_u(combined_features)
        final_output_v = self.ffn_v(combined_features)

        # Splitting the output for u into real and complex parts
        output_1, output_2 = final_output_u.split(1, dim=-1)
        output_3 = final_output_v
        
        return output_1, output_2, output_3

    def initialize_weights(self, m):
        if isinstance(m, nn.Linear):
            init.xavier_uniform_(m.weight)
            if m.bias is not None:
                init.constant_(m.bias, 0)

In [3]:
def real_u1(x, t, k, omega, r):
    complex_exp = torch.exp(1j * r * (omega * x + t))
    tanh_val = torch.tanh((r * (k + x + omega * t)) / torch.sqrt(torch.tensor(2.0)))
    result = torch.real(1j * r * complex_exp * torch.sqrt(torch.tensor(1) + omega**2) * tanh_val)
    return result

def imag_u1(x, t, k, omega, r):
    complex_exp = torch.exp(1j * r * (omega * x + t))
    tanh_val = torch.tanh((r * (k + x + omega * t)) / torch.sqrt(torch.tensor(2.0)))
    result = torch.imag(1j * r * complex_exp * torch.sqrt(torch.tensor(1) + omega**2) * tanh_val)
    return result

def real_v1(x, t, k, omega, r):
    result = (r * torch.tanh((r * (k + x + omega * t)) / torch.sqrt(torch.tensor(2.0))))**2
    return result

def compute_analytical_boundary_loss(model, x, t, mse_cost_function, k, omega, r):
    pred_u_r, pred_u_i, pred_v = model(x, t)

    real_u1_val = real_u1(x, t, k, omega, r)
    imag_u1_val = imag_u1(x, t, k, omega, r)
    real_v1_val = real_v1(x, t, k, omega, r)
 
    boundary_loss_ur = mse_cost_function(pred_u_r, real_u1_val)
    boundary_loss_ui = mse_cost_function(pred_u_i, imag_u1_val)
    boundary_loss_v = mse_cost_function(pred_v, real_v1_val)
    
    return boundary_loss_ur, boundary_loss_ui, boundary_loss_v

def cyclic_iterator(items):
    return cycle(items)

In [4]:
def LBFGS_training(model, model_save_path, mse_cost_function, device, num_epochs, lr, num_samples, r, k, omega, gamma, beta, line_search_fn):
    print('Starting LBFGS Fine Tuning')
    optimizer = LBFGS(model.parameters(), lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn=line_search_fn)
    factor = -2

    x_n = (torch.rand(num_samples, 1)*4 + factor ).to(device)  # x in range [-5, -3]
    t_n = (torch.rand(num_samples, 1)).to(device)   
    x_dom = (torch.rand(num_samples, 1)*4 + factor ).to(device)
    t_dom = torch.rand(num_samples, 1).to(device) 
    x_bc_x0 = (torch.zeros(num_samples, 1)*4 + factor ).to(device)
    t_bc_x0 = torch.rand(num_samples, 1).to(device)  # Uniformly distributed random values between 0 and 1
    x_bc_x1 = (torch.zeros(num_samples, 1)*4 - factor ).to(device)
    t_bc_x1 = torch.rand(num_samples, 1).to(device)  # Uniformly distributed random values between 0 and 1
    x_bc_t0 = (torch.rand(num_samples, 1)*4 + factor ).to(device)  # Uniformly distributed random values between 0 and 1
    t_bc_t0 = torch.zeros(num_samples, 1).to(device)

    for epoch in tqdm(range(num_epochs),
                  desc='Progress:',  
                  leave=False,  
                  ncols=75,
                  mininterval=0.1,
                  bar_format='{l_bar} {bar} | {remaining}',  # Only show the bar without any counters
                  colour='blue'): 
        model.train()
        
        def closure():
            optimizer.zero_grad()
            x_dom = (torch.rand(num_samples, 1)*4 + factor ).to(device)
            t_dom = torch.rand(num_samples, 1).to(device) 
            x_dom.requires_grad_(True)
            t_dom.requires_grad_(True)

            #physics_loss_ur, physics_loss_ui, physics_loss_v = compute_physics_loss(model, x_n, t_n, device, mse_cost_function) 
            u_real, u_imag, v = model(x_dom, t_dom)
            u_abs = torch.square(u_real) + torch.square(u_imag)

            # First order derivatives with retain_graph=True to reuse computational graph
            u_real_x = torch.autograd.grad(u_real.sum(), x_dom, create_graph=True )[0]
            u_real_t = torch.autograd.grad(u_real.sum(), t_dom, create_graph=True )[0]
            u_imag_x = torch.autograd.grad(u_imag.sum(), x_dom, create_graph=True )[0]
            u_imag_t = torch.autograd.grad(u_imag.sum(), t_dom, create_graph=True )[0]
            v_x = torch.autograd.grad(v.sum(), x_dom, create_graph=True )[0]
            v_t = torch.autograd.grad(v.sum(), t_dom, create_graph=True )[0]
    
            # Second order derivatives
            u_real_xx = torch.autograd.grad(u_real_x.sum(), x_dom, create_graph=True )[0]
            u_real_tt = torch.autograd.grad(u_real_t.sum(), t_dom, create_graph=True )[0]
            u_imag_xx = torch.autograd.grad(u_imag_x.sum(), x_dom, create_graph=True )[0]
            u_imag_tt = torch.autograd.grad(u_imag_t.sum(), t_dom, create_graph=True )[0]
            v_xx = torch.autograd.grad(v_x.sum(), x_dom, create_graph=True )[0]
            v_tt = torch.autograd.grad(v_t.sum(), t_dom, create_graph=True )[0]

            # Compute u_abs_xx with retain_graph if further gradients need to be calculated
            u_abs_x = torch.autograd.grad(u_abs.sum(), x_dom, create_graph=True )[0]
            u_abs_xx = torch.autograd.grad(u_abs_x.sum(), x_dom, create_graph=True )[0]

            # Define du_r, du_i, dv according to given formulas
            du_r = u_real_tt - u_real_xx + u_abs * u_real - 2 * u_real * v
            du_i = u_imag_tt - u_imag_xx + u_abs * u_imag - 2 * u_imag * v
            dv = v_tt + v_xx - u_abs_xx

            zero_target = torch.zeros_like(du_r)  # Assuming du_r, du_i, dv have the same shape
            physics_loss_ur = mse_cost_function(du_r, zero_target)
            physics_loss_ui = mse_cost_function(du_i, zero_target)
            physics_loss_v = mse_cost_function(dv, zero_target)
            #print(physics_loss_ur)
            #print(physics_loss_ur)
            #print(physics_loss_v)

            boundary_loss_ur_x0, boundary_loss_ui_x0, boundary_loss_v_x0 = compute_analytical_boundary_loss(model, x_bc_x0, t_bc_x0, mse_cost_function, k, omega, r)
            boundary_loss_ur_x1, boundary_loss_ui_x1, boundary_loss_v_x1 = compute_analytical_boundary_loss(model, x_bc_x1, t_bc_x1, mse_cost_function, k, omega, r)
            boundary_loss_ur_t0, boundary_loss_ui_t0, boundary_loss_v_t0 = compute_analytical_boundary_loss(model, x_bc_t0, t_bc_t0, mse_cost_function, k, omega, r)
            # boundary_loss_ur_t1, boundary_loss_ui_t1, boundary_loss_v_t1 = compute_analytical_boundary_loss(model, x_bc_t1, t_bc_t1, mse_cost_function, k, omega, r)
            domain_loss_ur_t, domain_loss_ui_t, domain_loss_v_t = compute_analytical_boundary_loss(model, x_dom, t_dom, mse_cost_function, k, omega, r)
            
            # Total loss 
            loss_ur = gamma*(physics_loss_ur) + beta*( boundary_loss_ur_x0 + boundary_loss_ur_t0 + domain_loss_ur_t)
            loss_ui = gamma*(physics_loss_ui) + beta*( boundary_loss_ui_x0 + boundary_loss_ui_t0 + domain_loss_ui_t)
            loss_v = gamma*(physics_loss_v) + beta*( boundary_loss_v_x0 + boundary_loss_v_t0 + domain_loss_v_t )
            total_loss = loss_ur + loss_ui + loss_v
            total_loss.backward()

            return total_loss 
    
        optimizer.step(closure)
        if epoch % 10 == 0:
            current_loss = closure()  # Optionally recompute to print
            print(f' Epoch {epoch}, Loss: {current_loss.item()}') 
            model_filename = os.path.join(model_save_path, f'C_HIGGS_second_training_epoch_{epoch}.pth')
            torch.save(model.state_dict(), model_filename)
            plot_model_results(epoch, model, device, k, omega, r, sigma=1, cmap='viridis', image_save_path='results') 
            
    model_filename = os.path.join(model_save_path, f'C_HIGGS_second_training.pth')
    torch.save(model.state_dict(), model_filename)
    print('TRAINING COMPLETED')

In [5]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.ndimage import gaussian_filter

def plot_model_results(epoch, model, device, k, omega, r, sigma=1, cmap='viridis', image_save_path='results'):
    x = torch.linspace(-1.8, 1.8, 400)
    t = torch.linspace(0.2, 0.8, 400)
    X, T = torch.meshgrid(x, t)  # Create a 2D grid of x and t
    X_flat = X.flatten().unsqueeze(-1).to(device)
    T_flat = T.flatten().unsqueeze(-1).to(device)
    
    model_save_path = 'model_weights' 
    model_state = torch.load(os.path.join(model_save_path, f'C_HIGGS_second_training_epoch_{epoch}.pth'), map_location=device)
    model.load_state_dict(model_state)
    model.eval()

    # Get predictions from the trained models
    with torch.no_grad():
        pred_u_r, pred_u_i, pred_v = model(X_flat, T_flat) 

    pred_u_r = pred_u_r.cpu().reshape(X.shape).numpy()
    pred_u_i = pred_u_i.cpu().reshape(X.shape).numpy()
    pred_v = pred_v.cpu().reshape(X.shape).numpy()

    real_u1_analytical = real_u1(X_flat, T_flat, k, omega, r).cpu().reshape(X.shape).numpy()
    imag_u1_analytical = imag_u1(X_flat, T_flat, k, omega, r).cpu().reshape(X.shape).numpy()
    real_v1_analytical = real_v1(X_flat, T_flat, k, omega, r).cpu().reshape(X.shape).numpy()

    pred_v_smooth = gaussian_filter(pred_v, sigma=sigma)

    shrink = 0.3
    aspect = 50

    # Plotting predictions
    fig = plt.figure(figsize=(24, 16))

    ax1 = fig.add_subplot(231, projection='3d')
    ax1.plot_surface(X.numpy(), T.numpy(), pred_u_r, cmap=cmap)
    ax1.set_title('Predicted Real Part of $u_1(x, t)$')
    ax1.set_xlabel('x')
    ax1.set_ylabel('t')
    ax1.set_zlabel('Real part of $u_1$')

    ax2 = fig.add_subplot(232, projection='3d')
    ax2.plot_surface(X.numpy(), T.numpy(), pred_u_i, cmap=cmap)
    ax2.set_title('Predicted Imaginary Part of $u_1(x, t)$')
    ax2.set_xlabel('x')
    ax2.set_ylabel('t')
    ax2.set_zlabel('Imag part of $u_1$')

    ax3 = fig.add_subplot(233, projection='3d')
    ax3.plot_surface(X.numpy(), T.numpy(), pred_v_smooth, cmap=cmap)
    ax3.set_title('Predicted Real Part of $v_1(x, t)$')
    ax3.set_xlabel('x')
    ax3.set_ylabel('t')
    ax3.set_zlabel('Real part of $v_1$')

    ax4 = fig.add_subplot(234, projection='3d')
    ax4.plot_surface(X.numpy(), T.numpy(), real_u1_analytical, cmap=cmap)
    ax4.set_title('Analytical Real Part of $u_1(x, t)$')
    ax4.set_xlabel('x')
    ax4.set_ylabel('t')
    ax4.set_zlabel('Real part of $u_1$')

    ax5 = fig.add_subplot(235, projection='3d')
    ax5.plot_surface(X.numpy(), T.numpy(), imag_u1_analytical, cmap=cmap)
    ax5.set_title('Analytical Imaginary Part of $u_1(x, t)$')
    ax5.set_xlabel('x')
    ax5.set_ylabel('t')
    ax5.set_zlabel('Imag part of $u_1$')

    ax6 = fig.add_subplot(236, projection='3d')
    ax6.plot_surface(X.numpy(), T.numpy(), real_v1_analytical, cmap=cmap)
    ax6.set_title('Analytical Real Part of $v_1(x, t)$')
    ax6.set_xlabel('x')
    ax6.set_ylabel('t')
    ax6.set_zlabel('Real part of $v_1$')

    plt.tight_layout()
    plt.savefig(os.path.join(image_save_path, f'chiggs_model_comparison_3d_epoch_{epoch}.png'))
    plt.close(fig)  # Close the figure to free memory


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print("CUDA is available! Training on GPU.")
else:
    print("CUDA is not available. Training on CPU.")

model = FourierFeatureNN(device=device).to(device)

print(model)
num_epochs_lbfgs = 500  # Number of training epochs
num_samples_lbfgs = 1000*3 # Number of samples for training
num_epochs_sq = 36000
num_samples_sq = 1000
lr_sq = 1e-4 
lr_lbfgs = 1e-3
r = 1.1
omega = 5 
k = 0.5
gamma = 0
beta = 1
model_save_path = 'model_weights' 
mse_cost_function = torch.nn.MSELoss()
os.makedirs(model_save_path, exist_ok=True)
losses = []
line_search_fn = "strong_wolfe"

CUDA is available! Training on GPU.
FourierFeatureNN(
  (path_x): Sequential(
    (0): Linear(in_features=16, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_features=32, out_features=32, bias=True)
    (5): Tanh()
  )
  (path_t): Sequential(
    (0): Linear(in_features=16, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_features=32, out_features=32, bias=True)
    (5): Tanh()
  )
  (ffn_u): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_features=32, out_features=2, bias=True)
  )
  (ffn_v): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_fe

In [7]:
LBFGS_training(model, model_save_path, mse_cost_function, device, num_epochs_lbfgs, lr_lbfgs, num_samples_lbfgs, r, k, omega, gamma, beta, line_search_fn)

Starting LBFGS Fine Tuning


Progress::   0%| [34m                                                      [0m | ?[0m

 Epoch 0, Loss: 40.32649612426758


Progress::   2%| [34m█                                                 [0m | 15:13[0m

 Epoch 10, Loss: 11.141274452209473


Progress::   4%| [34m██                                                [0m | 16:05[0m

 Epoch 20, Loss: 0.36170750856399536


Progress::   6%| [34m███                                               [0m | 11:55[0m

 Epoch 30, Loss: 0.2836231589317322


Progress::   8%| [34m████                                              [0m | 13:01[0m

 Epoch 40, Loss: 0.12260878086090088


Progress::  10%| [34m█████                                             [0m | 11:34[0m

 Epoch 50, Loss: 0.11815682798624039


Progress::  12%| [34m██████                                            [0m | 09:44[0m

 Epoch 60, Loss: 0.1038997620344162


Progress::  14%| [34m███████                                           [0m | 10:38[0m

 Epoch 70, Loss: 0.09809969365596771


Progress::  16%| [34m████████                                          [0m | 10:10[0m

 Epoch 80, Loss: 0.07776270806789398


Progress::  18%| [34m█████████                                         [0m | 13:05[0m

 Epoch 90, Loss: 0.07059863954782486


Progress::  20%| [34m██████████                                        [0m | 16:16[0m

 Epoch 100, Loss: 0.05988970398902893


Progress::  22%| [34m███████████                                       [0m | 07:44[0m

 Epoch 110, Loss: 0.057288337498903275


Progress::  24%| [34m████████████                                      [0m | 08:42[0m

 Epoch 120, Loss: 0.05308811366558075


Progress::  26%| [34m█████████████                                     [0m | 12:48[0m

 Epoch 130, Loss: 0.05034566670656204


Progress::  28%| [34m██████████████                                    [0m | 09:53[0m

 Epoch 140, Loss: 0.04841305688023567


Progress::  30%| [34m███████████████                                   [0m | 05:49[0m

 Epoch 150, Loss: 0.045653000473976135


Progress::  32%| [34m████████████████                                  [0m | 06:57[0m

 Epoch 160, Loss: 0.04532000795006752


Progress::  34%| [34m█████████████████                                 [0m | 04:54[0m

 Epoch 170, Loss: 0.043359436094760895


Progress::  36%| [34m██████████████████                                [0m | 09:21[0m

 Epoch 180, Loss: 0.041150398552417755


Progress::  38%| [34m███████████████████                               [0m | 09:42[0m

 Epoch 190, Loss: 0.039673756808042526


Progress::  40%| [34m████████████████████                              [0m | 07:32[0m

 Epoch 200, Loss: 0.037515297532081604


Progress::  42%| [34m█████████████████████                             [0m | 06:02[0m

 Epoch 210, Loss: 0.037325166165828705


Progress::  44%| [34m██████████████████████                            [0m | 07:17[0m

 Epoch 220, Loss: 0.035366810858249664


Progress::  46%| [34m███████████████████████                           [0m | 08:18[0m

 Epoch 230, Loss: 0.03297027572989464


Progress::  48%| [34m████████████████████████                          [0m | 07:23[0m

 Epoch 240, Loss: 0.030817843973636627


Progress::  50%| [34m█████████████████████████                         [0m | 09:24[0m

 Epoch 250, Loss: 0.029109612107276917


Progress::  52%| [34m██████████████████████████                        [0m | 07:40[0m

 Epoch 260, Loss: 0.02840748056769371


Progress::  54%| [34m███████████████████████████                       [0m | 08:24[0m

 Epoch 270, Loss: 0.027436252683401108


Progress::  56%| [34m████████████████████████████                      [0m | 06:43[0m

 Epoch 280, Loss: 0.026334404945373535


Progress::  58%| [34m████████████████████████████▉                     [0m | 07:59[0m

 Epoch 290, Loss: 0.025803817436099052


Progress::  60%| [34m██████████████████████████████                    [0m | 04:25[0m

 Epoch 300, Loss: 0.024911588057875633


Progress::  62%| [34m███████████████████████████████                   [0m | 05:31[0m

 Epoch 310, Loss: 0.023505810648202896


Progress::  64%| [34m████████████████████████████████                  [0m | 06:49[0m

 Epoch 320, Loss: 0.023122970014810562


Progress::  66%| [34m█████████████████████████████████                 [0m | 05:07[0m

 Epoch 330, Loss: 0.021974485367536545


Progress::  68%| [34m██████████████████████████████████                [0m | 04:51[0m

 Epoch 340, Loss: 0.020499899983406067


Progress::  70%| [34m███████████████████████████████████               [0m | 03:39[0m

 Epoch 350, Loss: 0.020918721333146095


Progress::  72%| [34m████████████████████████████████████              [0m | 04:46[0m

 Epoch 360, Loss: 0.020186275243759155


Progress::  74%| [34m█████████████████████████████████████             [0m | 02:44[0m

 Epoch 370, Loss: 0.019499406218528748


Progress::  76%| [34m██████████████████████████████████████            [0m | 04:05[0m

 Epoch 380, Loss: 0.019053664058446884


Progress::  78%| [34m███████████████████████████████████████           [0m | 03:04[0m

 Epoch 390, Loss: 0.018598293885588646


Progress::  80%| [34m████████████████████████████████████████          [0m | 02:28[0m

 Epoch 400, Loss: 0.017988935112953186


Progress::  82%| [34m█████████████████████████████████████████         [0m | 03:05[0m

 Epoch 410, Loss: 0.01772921159863472


Progress::  84%| [34m██████████████████████████████████████████        [0m | 02:26[0m

 Epoch 420, Loss: 0.016262587159872055


Progress::  86%| [34m███████████████████████████████████████████       [0m | 01:51[0m

 Epoch 430, Loss: 0.01592869684100151


Progress::  88%| [34m████████████████████████████████████████████      [0m | 02:12[0m

 Epoch 440, Loss: 0.015365242958068848


Progress::  90%| [34m█████████████████████████████████████████████     [0m | 01:41[0m

 Epoch 450, Loss: 0.014659320935606956


Progress::  92%| [34m██████████████████████████████████████████████    [0m | 00:46[0m

 Epoch 460, Loss: 0.014546466991305351


Progress::  94%| [34m███████████████████████████████████████████████   [0m | 01:15[0m

 Epoch 470, Loss: 0.01357647217810154


Progress::  96%| [34m████████████████████████████████████████████████  [0m | 00:22[0m

 Epoch 480, Loss: 0.012881951406598091


Progress::  98%| [34m█████████████████████████████████████████████████ [0m | 00:13[0m

 Epoch 490, Loss: 0.01287191640585661


                                                                           m | 00:00[0m

TRAINING COMPLETED


