In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW, LBFGS
from torch.autograd import Variable
from tqdm import tqdm
import numpy as np
import os
import pandas as pd
from itertools import cycle
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.ndimage import gaussian_filter
import os
import warnings

warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

import torch
import torch.nn as nn

class FFNN(nn.Module):
    def __init__(self):
        super(FFNN, self).__init__()
        self.input_dim_x = 1
        self.input_dim_t = 1
        self.hidden_dim = 256
        self.output_dim = 1
        
        # Define layers
        self.fc1 = nn.Linear(self.input_dim_x + self.input_dim_t, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.output_dim)
        self.activation = nn.Tanh()
        
    def forward(self, x, t):
        concatenated_inputs = torch.cat((x, t), dim=1)
        hidden_output = self.activation(self.fc1(concatenated_inputs))
        output = self.fc2(hidden_output)
        return output

# Create the FFNN model
ffnn_model = FFNN()

# Display model summary
print(ffnn_model)
num_samples = 10
x = torch.linspace(0,1,num_samples).unsqueeze(-1)
t = torch.linspace(0,1,num_samples).unsqueeze(-1)
model_output = ffnn_model(x,t)

print("Model output:")
print(model_output)


def grad(x, t):
    return torch.autograd.grad(x, t, grad_outputs=torch.ones_like(x), create_graph=True)[0]

def laplacian(field, x, t):
    field_x = grad(field, x)
    field_xx = grad(field_x, x)
    field_t = grad(field, t)
    field_tt = grad(field_t, t)
    return field_xx, field_tt

# Define the ODE system for the Coupled Higgs field equations
def coupled_higgs(u_real, u_imag, v, x, t):
    u_r_xx, u_r_tt = laplacian(u_real, x, t)
    u_i_xx, u_i_tt = laplacian(u_imag, x, t)
    v_xx, v_tt = laplacian(v, x, t)

    u_abs = torch.square(u_real) + torch.square(u_imag)
    u_abs_xx, u_abs_tt = laplacian(u_abs, x, t)

    # Calculate the field equations
    du_r = u_r_tt - u_r_xx + u_abs * u_real - 2 * u_real * v
    du_i = u_i_tt - u_i_xx + u_abs * u_imag - 2 * u_imag * v
    dv = v_tt + v_xx - u_abs_xx
    
    return du_r, du_i, dv

def real_u1(x, t, k, omega, r):
    complex_exp = torch.exp(1j * r * (omega * x + t))
    tanh_val = torch.tanh((r * (k + x + omega * t)) / torch.sqrt(torch.tensor(2.0)))
    result = torch.real(1j * r * complex_exp * torch.sqrt(torch.tensor(1) + omega**2) * tanh_val)
    return result

def imag_u1(x, t, k, omega, r):
    complex_exp = torch.exp(1j * r * (omega * x + t))
    tanh_val = torch.tanh((r * (k + x + omega * t)) / torch.sqrt(torch.tensor(2.0)))
    result = torch.imag(1j * r * complex_exp * torch.sqrt(torch.tensor(1) + omega**2) * tanh_val)
    return result

def real_v1(x, t, k, omega, r):
    result = (r * torch.tanh((r * (k + x + omega * t)) / torch.sqrt(torch.tensor(2.0))))**2
    return result

def compute_analytical_boundary_loss(model_ur, model_ui, model_v, x, t, mse_cost_function, k, omega, r):
    pred_u_r, pred_u_i, pred_v = model_ur(x, t), model_ui(x,t), model_v(x,t)

    real_u1_val = real_u1(x, t, k, omega, r)
    imag_u1_val = imag_u1(x, t, k, omega, r)
    real_v1_val = real_v1(x, t, k, omega, r)
 
    boundary_loss_ur = mse_cost_function(pred_u_r, real_u1_val)
    boundary_loss_ui = mse_cost_function(pred_u_i, imag_u1_val)
    boundary_loss_v = mse_cost_function(pred_v, real_v1_val)

    return boundary_loss_ur , boundary_loss_ui, boundary_loss_v

def compute_physics_loss(model_ur, model_ui, model_v, x, t, u_r0, u_i0, v0, y1, device, mse_cost_function, k, omega, r):
    x.requires_grad = True
    t.requires_grad = True
    
    pred_u_r, pred_u_i, pred_v = model_ur(x, t), model_ui(x,t), model_v(x,t)

    u_r = u_r0 + y1*pred_u_r 
    u_i = u_i0 + y1*pred_u_i
    v = v0 + y1*pred_v
    
    du_eq_r, du_eq_i, dv_eq = coupled_higgs(u_r, u_i, v, x, t)
    
    # Define target tensors of zeros with the same shape as the predictions
    zeros_r = torch.zeros_like(du_eq_r, device=device)
    zeros_i = torch.zeros_like(du_eq_i, device=device)
    zeros_v = torch.zeros_like(dv_eq, device=device)
    
    # Compute the MSE loss against zeros for each differential equation residual
    loss_r = mse_cost_function(du_eq_r, zeros_r)
    loss_i = mse_cost_function(du_eq_i, zeros_i)
    loss_v = mse_cost_function(dv_eq, zeros_v)
    
    # Return the scalar loss values for real part, imaginary part, and v
    return loss_r, loss_i, loss_v

def cyclic_iterator(items):
    return cycle(items)

def plot_predictions(epoch, model_ur, model_ui, model_v, model_save_path, device, k, omega, r, image_save_path):
    model_ur_state = torch.load(os.path.join(model_save_path, f'C_HIGGS_ur_first_training_epoch_{epoch}.pth'), map_location=device)
    model_ur = FFNN().to(device)  
    model_ur.load_state_dict(model_ur_state)
    model_ur.eval()  # Set the model to evaluation mode

    model_ui_state = torch.load(os.path.join(model_save_path, f'C_HIGGS_ui_first_training_epoch_{epoch}.pth'), map_location=device)
    model_ui = FFNN().to(device)  
    model_ui.load_state_dict(model_ui_state)
    model_ui.eval()  # Set the model to evaluation mode

    model_v_state = torch.load(os.path.join(model_save_path, f'C_HIGGS_v_first_training_epoch_{epoch}.pth'), map_location=device)
    model_v = FFNN().to(device)  
    model_v.load_state_dict(model_v_state)
    model_v.eval()  # Set the model to evaluation mode
    
    x = torch.linspace(0, 1, 300)
    t = torch.linspace(0, 1, 300)
    X, T = torch.meshgrid(x, t)  # Create a 2D grid of x and t
    X_flat = X.flatten().unsqueeze(-1).to(device)
    T_flat = T.flatten().unsqueeze(-1).to(device)
    x0 = torch.zeros_like(X_flat).to(device)
    x1 = torch.ones_like(X_flat).to(device)
    
    with torch.no_grad():
        pred_u_r, pred_u_i, pred_v = model_ur(X_flat, T_flat), model_ui(X_flat, T_flat), model_v(X_flat, T_flat)  
    
    pred_u_r = (1-X_flat)*real_u1(x0, T_flat, k, omega, r) + X_flat*real_u1(x1, T_flat, k, omega, r) + X_flat*(1-X_flat)*pred_u_r 
    pred_u_i = (1-X_flat)*imag_u1(x0, T_flat, k, omega, r) + X_flat*imag_u1(x1, T_flat, k, omega, r) + X_flat*(1-X_flat)*pred_u_i
    pred_v = (1-X_flat)*real_v1(x0, T_flat, k, omega, r) + X_flat*real_v1(x1, T_flat, k, omega, r) + X_flat*(1-X_flat)*pred_v
  
    pred_u_r = pred_u_r.cpu().reshape(X.shape).numpy()
    pred_u_i = pred_u_i.cpu().reshape(X.shape).numpy()
    pred_v = pred_v.cpu().reshape(X.shape).numpy()

    real_u1_analytical = real_u1(X_flat, T_flat, k, omega, r).cpu().reshape(X.shape).numpy()
    imag_u1_analytical = imag_u1(X_flat, T_flat, k, omega, r).cpu().reshape(X.shape).numpy()
    real_v1_analytical = real_v1(X_flat, T_flat, k, omega, r).cpu().reshape(X.shape).numpy()

    sigma = 10
    pred_v_smooth = gaussian_filter(pred_v, sigma=sigma)
    shrink = 0.3
    cmap = 'viridis'
    aspect = 50
    # Data for plotting
    data_to_plot = [
        (pred_u_r, 'Predicted Real Part of $u_1(x, t)$', 'Real part of $u_1$'),
        (pred_u_i, 'Predicted Imaginary Part of $u_1(x, t)$', 'Imag part of $u_1$'),
        (pred_v_smooth, 'Predicted Real Part of $v_1(x, t)$', 'Real part of $v_1$'),
        (real_u1_analytical, 'Analytical Real Part of $u_1(x, t)$', 'Real part of $u_1$'),
        (imag_u1_analytical, 'Analytical Imaginary Part of $u_1(x, t)$', 'Imag part of $u_1$'),
        (real_v1_analytical, 'Analytical Real Part of $v_1(x, t)$', 'Real part of $v_1$')
    ]
    fig = plt.figure(figsize=(24, 16))
    
    for idx, (data, title, zlabel) in enumerate(data_to_plot, start=1):
        ax = fig.add_subplot(2, 3, idx, projection='3d')
        surf = ax.plot_surface(X.numpy(), T.numpy(), data, cmap=cmap)
        fig.colorbar(surf, ax=ax, shrink=shrink, aspect=aspect)
        ax.set_title(title)
        ax.set_xlabel('x')
        ax.set_ylabel('t')
        ax.set_zlabel(zlabel)

    plt.tight_layout()
    plt.savefig(f'{image_save_path}/chiggs_model_comparison_epoch_{epoch}.png',dpi = 100)
    plt.close(fig)
      # Close the figure to free memory

def seq2seq_training(model_ur, model_ui, model_v, model_save_path, image_save_path, mse_cost_function, device, num_epochs, lr, num_samples, r, k, omega, gamma):
    optimizer_ur = AdamW(model_ur.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
    optimizer_ui = AdamW(model_ui.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
    optimizer_v = AdamW(model_v.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)

    items = [-2, -1, 0, 1]
    iterator = cyclic_iterator(items)
    print(' Starting Seq2Seq Training')
    factor = 0
    x = torch.linspace(0,1,num_samples).unsqueeze(-1).to(device)
    t = torch.linspace(0,1,num_samples).unsqueeze(-1).to(device)
    x0 = torch.zeros_like(x).to(device)
    x1 = torch.ones_like(x).to(device)
    y = 1-x 
    y1 = x*(1-x)
    u_r0 = y*real_u1(x0, t, k, omega, r) + x*real_u1(x1, t, k, omega, r)
    u_i0 = y*imag_u1(x0, t, k, omega, r) + x*imag_u1(x1, t, k, omega, r)
    v0 = y*real_v1(x0,t, k, omega, r) + x*real_v1(x1, t, k, omega, r)     
    x_dom = torch.rand(num_samples, 1).to(device)
    t_dom = torch.rand(num_samples, 1).to(device) 
    for epoch in tqdm(range(num_epochs),
                  desc='Progress:',  
                  leave=False,  
                  ncols=75,
                  mininterval=0.1,
                  bar_format='{l_bar} {bar} | {remaining}',  # Only show the bar without any counters
                  colour='blue'): 
        
        model_ur.train()
        model_ui.train()
        model_v.train()
        optimizer_ur.zero_grad()
        optimizer_ui.zero_grad()
        optimizer_v.zero_grad()
        
        physics_loss_ur, physics_loss_ui, physics_loss_v = compute_physics_loss(model_ur, model_ui, model_v, x, t, u_r0, u_i0, v0, y1, device, mse_cost_function, k, omega, r)
        domain_loss_ur_t, domain_loss_ui_t, domain_loss_v_t = compute_analytical_boundary_loss(model_ur, model_ui, model_v, x_dom, t_dom, mse_cost_function, k, omega, r)
        total_loss = lr*(physics_loss_ur + physics_loss_ui + physics_loss_v) + (1-lr)*(domain_loss_ur_t + domain_loss_ui_t + domain_loss_v_t)
        
        total_loss.backward()
        
        optimizer_ur.step()
        optimizer_ui.step()
        optimizer_v.step()
        
        if epoch % 1000 == 0:
            print(f' Epoch {epoch}, Factor {factor}, Total Loss {total_loss.item()}')
            model_ur_filename = os.path.join(model_save_path, f'C_HIGGS_ur_first_training_epoch_{epoch}.pth')
            torch.save(model_ur.state_dict(), model_ur_filename)
            model_ui_filename = os.path.join(model_save_path, f'C_HIGGS_ui_first_training_epoch_{epoch}.pth')
            torch.save(model_ur.state_dict(), model_ui_filename)
            model_v_filename = os.path.join(model_save_path, f'C_HIGGS_v_first_training_epoch_{epoch}.pth')
            torch.save(model_ur.state_dict(), model_v_filename)
            
            plot_predictions(epoch, model_ur, model_ui, model_v, model_save_path, device, k, omega, r, image_save_path)
    print('COMPLETED Seq2Seq Training')

def main():
    # Check if CUDA is available and set the default device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if torch.cuda.is_available():
        print("CUDA is available! Training on GPU.")
    else:
        print("CUDA is not available. Training on CPU.")

    model_ur = FFNN().to(device)
    model_ui = FFNN().to(device)
    model_v = FFNN().to(device)
    
    num_epochs_lbfgs = 50  # Number of training epochs
    num_samples_lbfgs = 1000*5 # Number of samples for training
    num_epochs_sq = 12000
    num_samples_sq = 3000
    lr = 1e-4
    r = 1.1
    omega = 3 
    k = 0.5
    gamma = 1e-3
    model_save_path = 'model_weights_test' 
    mse_cost_function = torch.nn.MSELoss()
    os.makedirs(model_save_path, exist_ok=True)
    image_save_path = 'results_test'
    os.makedirs(image_save_path, exist_ok=True)
    losses = []
    seq2seq_training(model_ur, model_ui, model_v, model_save_path, image_save_path, mse_cost_function, device, num_epochs_sq, lr, num_samples_sq, r, k, omega, gamma)
    #LBFGS_training(model, model_save_path, mse_cost_function, device, num_epochs_lbfgs, lr, num_samples_lbfgs, r, k, omega, gamma)
if __name__ == '__main__':
    main()

