In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel, ConstantKernel as C
from sklearn.preprocessing import StandardScaler
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows
import pandas as pd


# Set random seeds for reproducibility
torch.manual_seed(1234)
np.random.seed(1234)

# Problem setup
L_x, L_y = 1.0, 1.0  # Domain size
T = 1.0  # Total time
alpha = 0.1  # Thermal diffusivity

# Boundary and initial conditions
T_top, T_bottom = 303.15, 303.15
T_left, T_right = 274.15, 274.15
T_init = 295.65

Pb_top, Pb_bottom = 1e7, 1e7
Pb_left, Pb_right = 1e6, 1e6
Pb_init = 1e7

CAb_top, CAb_bottom = 1.0, 1.0
CAb_left, CAb_right = 0.94, 0.94
CAb_init = 0.94

Af = 5.8446E5
Ac = 5.8264E8
Eaf = 75400
Eac = 1.038E5
aa = 0.27
R = 8.314  # Gas constant in J/(mol·K)

# Calculate scaling factors
T_min, T_max = min(T_left, T_right, T_bottom, T_top, T_init), max(T_left, T_right, T_bottom, T_top, T_init)
Pb_min, Pb_max = min(Pb_left, Pb_right, Pb_bottom, Pb_top, Pb_init), max(Pb_left, Pb_right, Pb_bottom, Pb_top, Pb_init)
CAb_min, CAb_max = min(CAb_left, CAb_right, CAb_bottom, CAb_top, CAb_init), max(CAb_left, CAb_right, CAb_bottom, CAb_top, CAb_init)

# Scaling functions
def scale_var(var, var_min, var_max):
    return (var - var_min) / (var_max - var_min)

def unscale_var(var_scaled, var_min, var_max):
    return var_scaled * (var_max - var_min) + var_min

# Scale the boundary and initial conditions
T_top_scaled, T_bottom_scaled = scale_var(T_top, T_min, T_max), scale_var(T_bottom, T_min, T_max)
T_left_scaled, T_right_scaled = scale_var(T_left, T_min, T_max), scale_var(T_right, T_min, T_max)
T_init_scaled = scale_var(T_init, T_min, T_max)

Pb_top_scaled, Pb_bottom_scaled = scale_var(Pb_top, Pb_min, Pb_max), scale_var(Pb_bottom, Pb_min, Pb_max)
Pb_left_scaled, Pb_right_scaled = scale_var(Pb_left, Pb_min, Pb_max), scale_var(Pb_right, Pb_min, Pb_max)
Pb_init_scaled = scale_var(Pb_init, Pb_min, Pb_max)

CAb_top_scaled, CAb_bottom_scaled = scale_var(CAb_top, CAb_min, CAb_max), scale_var(CAb_bottom, CAb_min, CAb_max)
CAb_left_scaled, CAb_right_scaled = scale_var(CAb_left, CAb_min, CAb_max), scale_var(CAb_right, CAb_min, CAb_max)
CAb_init_scaled = scale_var(CAb_init, CAb_min, CAb_max)

# PINN model
class CoupledPINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 10),
            nn.Tanh(),
            nn.Linear(10, 10),
            nn.Tanh(),
            nn.Linear(10, 10),
            nn.Tanh(),
            nn.Linear(10, 10),
            nn.Tanh(),
            nn.Linear(10, 3)
        )

    def forward(self, x, stage):
        output = self.net(x)
        if stage == 1:
            return output[:, 0:1]  # Only temperature
        elif stage == 2:
            return output[:, 0:2]  # Temperature and oxygen pressure
        else:
            return output  # All three variables

# PDE residual
def pde_residual(model, x, y, t, stage):
    xyts = torch.cat([x, y, t], dim=1)
    xyts.requires_grad_(True)
    
    output = model(xyts, stage)
    
    if stage == 1:
        T = output
        grads = torch.autograd.grad(T, xyts, grad_outputs=torch.ones_like(T), create_graph=True)[0]
        T_t, T_x, T_y = grads[:, 2:3], grads[:, 0:1], grads[:, 1:2]
        T_xx = torch.autograd.grad(T_x, xyts, grad_outputs=torch.ones_like(T_x), create_graph=True)[0][:, 0:1]
        T_yy = torch.autograd.grad(T_y, xyts, grad_outputs=torch.ones_like(T_y), create_graph=True)[0][:, 1:2]
        res_T = T_t - alpha * (T_xx + T_yy)
        return res_T
    
    elif stage == 2:
        T, Pb = output[:, 0:1], output[:, 1:2]
        grads = torch.autograd.grad(output, xyts, grad_outputs=torch.ones_like(output), create_graph=True)[0]
        T_t, Pb_t = grads[:, 2:3], grads[:, 2:3]
        T_x, Pb_x = grads[:, 0:1], grads[:, 0:1]
        T_y, Pb_y = grads[:, 1:2], grads[:, 1:2]
        T_xx = torch.autograd.grad(T_x, xyts, grad_outputs=torch.ones_like(T_x), create_graph=True)[0][:, 0:1]
        T_yy = torch.autograd.grad(T_y, xyts, grad_outputs=torch.ones_like(T_y), create_graph=True)[0][:, 1:2]
        Pb_xx = torch.autograd.grad(Pb_x, xyts, grad_outputs=torch.ones_like(Pb_x), create_graph=True)[0][:, 0:1]
        Pb_yy = torch.autograd.grad(Pb_y, xyts, grad_outputs=torch.ones_like(Pb_y), create_graph=True)[0][:, 1:2]
        res_T = T_t - alpha * (T_xx + T_yy)
        res_Pb = Pb_t - (1E-11 * (Pb_xx + Pb_yy))  # Simplified oxygen pressure equation
        return res_T, res_Pb
    
    else:
        T, Pb, CAb = output[:, 0:1], output[:, 1:2], output[:, 2:3]
        T = torch.clamp(T, min=1e-6)  # Avoid negative temperatures
        Pb = torch.clamp(Pb, min=1e-6)  # Avoid negative pressures
        grads = torch.autograd.grad(output, xyts, grad_outputs=torch.ones_like(output), create_graph=True)[0]
        T_t, Pb_t, CAb_t = grads[:, 2:3], grads[:, 2:3], grads[:, 2:3]
        T_x, Pb_x = grads[:, 0:1], grads[:, 0:1]
        T_y, Pb_y = grads[:, 1:2], grads[:, 1:2]
        T_xx = torch.autograd.grad(T_x, xyts, grad_outputs=torch.ones_like(T_x), create_graph=True)[0][:, 0:1]
        T_yy = torch.autograd.grad(T_y, xyts, grad_outputs=torch.ones_like(T_y), create_graph=True)[0][:, 1:2]
        Pb_xx = torch.autograd.grad(Pb_x, xyts, grad_outputs=torch.ones_like(Pb_x), create_graph=True)[0][:, 0:1]
        Pb_yy = torch.autograd.grad(Pb_y, xyts, grad_outputs=torch.ones_like(Pb_y), create_graph=True)[0][:, 1:2]
        
        # Calculate kf and kc
        kf = Af * (Pb/101325)**aa * torch.exp(-Eaf/(R*T)) * 1E7
        kc = Ac * (Pb/101325)**aa * torch.exp(-Eac/(R*T)) * 1E7
        
        res_T = T_t - alpha * (T_xx + T_yy)
        res_Pb = Pb_t - (1E-11 * (Pb_xx + Pb_yy)) + (((8.3 * T)) * CAb_t)  # Simplified coupled equations
        res_CAb = CAb_t - (0.05 * kf * torch.exp(-kf * t) + kc)  # Updated carbonyl area equation
        return res_T, res_Pb, res_CAb

# Loss function
def curriculum_coupled_pinn_loss(model, x, y, t, x_bc, y_bc, t_bc, T_bc, Pb_bc, CAb_bc, x_ic, y_ic, t_ic, T_ic, Pb_ic, CAb_ic, stage):
    # Predict boundary and initial conditions
    bc_pred = model(torch.cat([x_bc, y_bc, t_bc], dim=1), stage)
    ic_pred = model(torch.cat([x_ic, y_ic, t_ic], dim=1), stage)
    
    mse_loss = nn.MSELoss()
    
    if stage == 1:
        T_bc_pred = bc_pred
        T_ic_pred = ic_pred
        loss_bc = mse_loss(T_bc_pred, T_bc)
        loss_ic = mse_loss(T_ic_pred, T_ic)
        res_T = pde_residual(model, x, y, t, stage)
        loss_pde = torch.mean(torch.square(res_T))
    elif stage == 2:
        T_bc_pred, Pb_bc_pred = bc_pred[:, 0:1], bc_pred[:, 1:2]
        T_ic_pred, Pb_ic_pred = ic_pred[:, 0:1], ic_pred[:, 1:2]
        loss_bc = mse_loss(T_bc_pred, T_bc) + mse_loss(Pb_bc_pred, Pb_bc)
        loss_ic = mse_loss(T_ic_pred, T_ic) + mse_loss(Pb_ic_pred, Pb_ic)
        res_T, res_Pb = pde_residual(model, x, y, t, stage)
        loss_pde = torch.mean(torch.square(res_T) + torch.square(res_Pb))
    else:
        T_bc_pred, Pb_bc_pred, CAb_bc_pred = bc_pred[:, 0:1], bc_pred[:, 1:2], bc_pred[:, 2:3]
        T_ic_pred, Pb_ic_pred, CAb_ic_pred = ic_pred[:, 0:1], ic_pred[:, 1:2], ic_pred[:, 2:3]
        loss_bc = mse_loss(T_bc_pred, T_bc) + mse_loss(Pb_bc_pred, Pb_bc) + mse_loss(CAb_bc_pred, CAb_bc)
        loss_ic = mse_loss(T_ic_pred, T_ic) + mse_loss(Pb_ic_pred, Pb_ic) + mse_loss(CAb_ic_pred, CAb_ic)
        res_T, res_Pb, res_CAb = pde_residual(model, x, y, t, stage)
        loss_pde = torch.mean(torch.square(res_T) + torch.square(res_Pb) + torch.square(res_CAb))
    
    total_loss = loss_bc + loss_ic + loss_pde
    return total_loss, {'bc': loss_bc.item(), 'ic': loss_ic.item(), 'pde': loss_pde.item()}

# Generate collocation points
def generate_collocation_points(n_points, L_x, L_y, T):
    x = torch.rand(n_points, 1) * L_x
    y = torch.rand(n_points, 1) * L_y
    t = torch.rand(n_points, 1) * T
    return x, y, t

# Generate boundary condition data
def generate_bc_data(n_points, L_x, L_y, T):
    # Top and bottom edges
    x_tb = torch.rand(n_points // 2, 1) * L_x
    y_top = torch.full_like(x_tb, L_y)
    y_bottom = torch.zeros_like(x_tb)
    t_tb = torch.rand(n_points // 2, 1) * T
    
    # Left and right edges
    y_lr = torch.rand(n_points // 2, 1) * L_y
    x_left = torch.zeros_like(y_lr)
    x_right = torch.full_like(y_lr, L_x)
    t_lr = torch.rand(n_points // 2, 1) * T
    
    x_bc = torch.cat([x_tb, x_tb, x_left, x_right], dim=0)
    y_bc = torch.cat([y_top, y_bottom, y_lr, y_lr], dim=0)
    t_bc = torch.cat([t_tb, t_tb, t_lr, t_lr], dim=0)
    
    T_bc = torch.cat([
        torch.full_like(x_tb, T_top_scaled),
        torch.full_like(x_tb, T_bottom_scaled),
        torch.full_like(y_lr, T_left_scaled),
        torch.full_like(y_lr, T_right_scaled)
    ], dim=0)
    
    Pb_bc = torch.cat([
        torch.full_like(x_tb, Pb_top_scaled),
        torch.full_like(x_tb, Pb_bottom_scaled),
        torch.full_like(y_lr, Pb_left_scaled),
        torch.full_like(y_lr, Pb_right_scaled)
    ], dim=0)
    
    CAb_bc = torch.cat([
        torch.full_like(x_tb, CAb_top_scaled),
        torch.full_like(x_tb, CAb_bottom_scaled),
        torch.full_like(y_lr, CAb_left_scaled),
        torch.full_like(y_lr, CAb_right_scaled)
    ], dim=0)
    
    return x_bc, y_bc, t_bc, T_bc, Pb_bc, CAb_bc

# Generate initial condition data
def generate_ic_data(n_points, L_x, L_y):
    x_ic = torch.rand(n_points, 1) * L_x
    y_ic = torch.rand(n_points, 1) * L_y
    t_ic = torch.zeros_like(x_ic)
    T_ic = torch.full_like(x_ic, T_init_scaled)
    Pb_ic = torch.full_like(x_ic, Pb_init_scaled)
    CAb_ic = torch.full_like(x_ic, CAb_init_scaled)
    
    return x_ic, y_ic, t_ic, T_ic, Pb_ic, CAb_ic

# Create heatmap data
def create_heatmap_data(model, t_values, stage, nx=100, ny=100):
    x = torch.linspace(0, L_x, nx).unsqueeze(1)
    y = torch.linspace(0, L_y, ny).unsqueeze(1)
    X, Y = torch.meshgrid(x.squeeze(), y.squeeze(), indexing='ij')
    
    xy = torch.column_stack((X.reshape(-1, 1), Y.reshape(-1, 1)))
    
    results = []
    for t in t_values:
        t_array = torch.full((nx*ny, 1), t)
        input_data = torch.cat((xy, t_array), dim=1)
        
        with torch.no_grad():
            output = model(input_data, stage)
            if stage == 1:
                T_scaled = output.reshape(nx, ny)
                T_pinn = unscale_var(T_scaled, T_min, T_max)
                results.append((T_pinn.numpy(), None, None))
            elif stage == 2:
                T_scaled, Pb_scaled = output[:, 0].reshape(nx, ny), output[:, 1].reshape(nx, ny)
                T_pinn = unscale_var(T_scaled, T_min, T_max)
                Pb_pinn = unscale_var(Pb_scaled, Pb_min, Pb_max)
                results.append((T_pinn.numpy(), Pb_pinn.numpy(), None))
            else:
                T_scaled, Pb_scaled, CAb_scaled = output[:, 0].reshape(nx, ny), output[:, 1].reshape(nx, ny), output[:, 2].reshape(nx, ny)
                T_pinn = unscale_var(T_scaled, T_min, T_max)
                Pb_pinn = unscale_var(Pb_scaled, Pb_min, Pb_max)
                CAb_pinn = unscale_var(CAb_scaled, CAb_min, CAb_max)
                results.append((T_pinn.numpy(), Pb_pinn.numpy(), CAb_pinn.numpy()))
    
    return X.numpy(), Y.numpy(), results

def plot_heatmaps(X, Y, results, stage, epoch):
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    variables = [
        ("Temperature (K)", 0),
        ("Oxygen Pressure (Pa)", 1),
        ("Carbonyl Area", 2)
    ]
    
    for i, (title, idx) in enumerate(variables):
        for j, (t, result) in enumerate(zip([0, 1], results)):
            var = result[idx]
            if var is not None:
                im = axes[j, i].imshow(var.T, extent=[0, L_x, 0, L_y], origin='lower', cmap='turbo', aspect='auto')
                axes[j, i].set_title(f'PINN {title} (t={t})')
                plt.colorbar(im, ax=axes[j, i], label=title)
            else:
                axes[j, i].text(0.5, 0.5, "Not Available", ha='center', va='center')
                axes[j, i].set_title(f'{title} (t={t}, Not Available)')
            axes[j, i].set_xlabel('x')
            axes[j, i].set_ylabel('y')
    
    plt.suptitle(f'Stage {stage}, Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'heatmap_stage_{stage}_epoch_{epoch}.png')
    plt.close()

# Implement curriculum learning
def train_coupled_pinn_curriculum(model, epochs_per_stage=100000, plot_every=2000):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    all_losses = {'pde': [], 'bc': [], 'ic': [], 'total': []}
    
    for stage in range(1, 4):
        print(f"Starting curriculum stage {stage}")
        
        # Adjust domain complexity based on the stage
        L_x_stage = L_x * stage / 3
        L_y_stage = L_y * stage / 3
        T_stage = T * stage / 3
        
        # Generate training data for the current stage
        n_points = 1000
        x_train, y_train, t_train = generate_collocation_points(n_points, L_x_stage, L_y_stage, T_stage)
        x_bc, y_bc, t_bc, T_bc, Pb_bc, CAb_bc = generate_bc_data(n_points, L_x_stage, L_y_stage, T_stage)
        x_ic, y_ic, t_ic, T_ic, Pb_ic, CAb_ic = generate_ic_data(n_points, L_x_stage, L_y_stage)
        
        for epoch in range(epochs_per_stage):
            optimizer.zero_grad()
            loss, component_losses = curriculum_coupled_pinn_loss(
                model, x_train, y_train, t_train, 
                x_bc, y_bc, t_bc, T_bc, Pb_bc, CAb_bc, 
                x_ic, y_ic, t_ic, T_ic, Pb_ic, CAb_ic,
                stage
            )
            loss.backward()
            optimizer.step()
            
            # Store loss values
            all_losses['pde'].append(component_losses['pde'])
            all_losses['bc'].append(component_losses['bc'])
            all_losses['ic'].append(component_losses['ic'])
            all_losses['total'].append(loss.item())
            
            if epoch % plot_every == 0:
                print(f"Stage {stage}, Epoch {epoch}, Total Loss: {loss.item():.4f}")
                print(f"Component Losses - BC: {component_losses['bc']:.4f}, IC: {component_losses['ic']:.4f}, PDE: {component_losses['pde']:.4f}")
                
                # Create and plot heatmaps for t=0 and t=T_stage
                X, Y, results = create_heatmap_data(model, [0, T_stage], stage)
                plot_heatmaps(X, Y, results, stage, epoch)
    
    return model, all_losses

# Add a new function to plot the losses
def plot_losses(all_losses):
    plt.figure(figsize=(12, 8))
    epochs = range(1, len(all_losses['total']) + 1)
    
    for loss_type in ['pde', 'bc', 'ic', 'total']:
        plt.plot(epochs, all_losses[loss_type], label=f'{loss_type.upper()} Loss')
    
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.legend()
    
    # Determine if log scale is better
    if max(all_losses['total']) / min(all_losses['total']) > 100:
        plt.yscale('log')
        plt.title('Training Losses (Log Scale)')
    
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()
    plt.savefig('training_losses.png', dpi=300)
    plt.close()

# 

def generate_line_predictions(model, L_x, L_y, T, nx=100, ny=100):
    # Generate points along horizontal and vertical lines
    x_horizontal = torch.linspace(0, L_x, nx).unsqueeze(1)
    y_horizontal = torch.full_like(x_horizontal, L_y / 2)
    
    y_vertical = torch.linspace(0, L_y, ny).unsqueeze(1)
    x_vertical = torch.full_like(y_vertical, L_x / 2)
    
    # Generate time points
    t_0 = torch.zeros_like(x_horizontal)
    t_1 = torch.full_like(x_horizontal, T)
    
    # Combine inputs for horizontal and vertical lines at t=0 and t=1
    inputs_horizontal_0 = torch.cat((x_horizontal, y_horizontal, t_0), dim=1)
    inputs_horizontal_1 = torch.cat((x_horizontal, y_horizontal, t_1), dim=1)
    inputs_vertical_0 = torch.cat((x_vertical, y_vertical, t_0), dim=1)
    inputs_vertical_1 = torch.cat((x_vertical, y_vertical, t_1), dim=1)
    
    # Make predictions
    with torch.no_grad():
        outputs_horizontal_0 = model(inputs_horizontal_0, stage=3)
        outputs_horizontal_1 = model(inputs_horizontal_1, stage=3)
        outputs_vertical_0 = model(inputs_vertical_0, stage=3)
        outputs_vertical_1 = model(inputs_vertical_1, stage=3)
    
    # Unscale the predictions
    T_horizontal_0 = unscale_var(outputs_horizontal_0[:, 0], T_min, T_max)
    Pb_horizontal_0 = unscale_var(outputs_horizontal_0[:, 1], Pb_min, Pb_max)
    CAb_horizontal_0 = unscale_var(outputs_horizontal_0[:, 2], CAb_min, CAb_max)
    
    T_horizontal_1 = unscale_var(outputs_horizontal_1[:, 0], T_min, T_max)
    Pb_horizontal_1 = unscale_var(outputs_horizontal_1[:, 1], Pb_min, Pb_max)
    CAb_horizontal_1 = unscale_var(outputs_horizontal_1[:, 2], CAb_min, CAb_max)
    
    T_vertical_0 = unscale_var(outputs_vertical_0[:, 0], T_min, T_max)
    Pb_vertical_0 = unscale_var(outputs_vertical_0[:, 1], Pb_min, Pb_max)
    CAb_vertical_0 = unscale_var(outputs_vertical_0[:, 2], CAb_min, CAb_max)
    
    T_vertical_1 = unscale_var(outputs_vertical_1[:, 0], T_min, T_max)
    Pb_vertical_1 = unscale_var(outputs_vertical_1[:, 1], Pb_min, Pb_max)
    CAb_vertical_1 = unscale_var(outputs_vertical_1[:, 2], CAb_min, CAb_max)
    
    return {
        'horizontal_0': (x_horizontal.numpy(), T_horizontal_0.numpy(), Pb_horizontal_0.numpy(), CAb_horizontal_0.numpy()),
        'horizontal_1': (x_horizontal.numpy(), T_horizontal_1.numpy(), Pb_horizontal_1.numpy(), CAb_horizontal_1.numpy()),
        'vertical_0': (y_vertical.numpy(), T_vertical_0.numpy(), Pb_vertical_0.numpy(), CAb_vertical_0.numpy()),
        'vertical_1': (y_vertical.numpy(), T_vertical_1.numpy(), Pb_vertical_1.numpy(), CAb_vertical_1.numpy())
    }

def create_excel_file(predictions, filename='predictions.xlsx'):
    wb = Workbook()
    
    # Create sheets
    sheets = {
        'Horizontal t=0': (predictions['horizontal_0'], 'x'),
        'Horizontal t=1': (predictions['horizontal_1'], 'x'),
        'Vertical t=0': (predictions['vertical_0'], 'y'),
        'Vertical t=1': (predictions['vertical_1'], 'y')
    }
    
    for sheet_name, (data, coordinate) in sheets.items():
        if sheet_name in wb.sheetnames:
            sheet = wb[sheet_name]
        else:
            sheet = wb.create_sheet(sheet_name)
        
        df = pd.DataFrame({
            coordinate: data[0].flatten(),
            'Temperature (K)': data[1].flatten(),
            'Oxygen Pressure (Pa)': data[2].flatten(),
            'Carbonyl Area': data[3].flatten()
        })
        
        for row in dataframe_to_rows(df, index=False, header=True):
            sheet.append(row)
    
    # Remove the default sheet created by openpyxl
    if 'Sheet' in wb.sheetnames:
        wb.remove(wb['Sheet'])
    
    wb.save(filename)
    print(f"Excel file '{filename}' has been created.")


#Main execution
if __name__ == "__main__":
    # Create coupled PINN model
    coupled_pinn_model = CoupledPINN()
    
    # Train the model using curriculum learning and get the losses
    trained_model, all_losses = train_coupled_pinn_curriculum(coupled_pinn_model, epochs_per_stage=100000, plot_every=2000)
    
    # Plot the losses
    plot_losses(all_losses)
    
    # Generate validation data
    n_points = 1000
    x_val, y_val, t_val = generate_collocation_points(n_points, L_x, L_y, T)
    X_val = torch.cat([x_val, y_val, t_val], dim=1)
    
    # PINN predictions on validation data
    with torch.no_grad():
        pinn_pred = trained_model(X_val, stage=3).numpy()
        T_pinn = unscale_var(pinn_pred[:, 0], T_min, T_max)
        Pb_pinn = unscale_var(pinn_pred[:, 1], Pb_min, Pb_max)
        CAb_pinn = unscale_var(pinn_pred[:, 2], CAb_min, CAb_max)
    
    # Calculate PDE residuals for GP training
    res_T, res_Pb, res_CAb = pde_residual(trained_model, x_val, y_val, t_val, stage=3)
    residuals = torch.cat([res_T, res_Pb, res_CAb], dim=1).detach().numpy()
    
    # Preprocess the input data and residuals
    scaler_X = StandardScaler()
    scaler_y = StandardScaler()

    X_val_scaled = scaler_X.fit_transform(X_val.numpy())
    residuals_scaled = scaler_y.fit_transform(residuals)

    kernel = 1.0 * RBF(length_scale=[1.0] * X_val_scaled.shape[1]) + WhiteKernel(noise_level=1e-5)
    
    # Train GP on PDE residuals with optimized settings
    gp = GaussianProcessRegressor(
        kernel=kernel,
        n_restarts_optimizer=5,
        normalize_y=False,  # We're manually scaling
        random_state=42,
        alpha=1e-10,
        optimizer='fmin_l_bfgs_b',
    )
    
    gp.fit(X_val_scaled, residuals_scaled)
    
    # Predict residuals using GP
    X_val_full_scaled = scaler_X.transform(X_val.numpy())
    residuals_gp_scaled, std_gp_scaled = gp.predict(X_val_full_scaled, return_std=True)
    
    # Unscale the predictions
    residuals_gp = scaler_y.inverse_transform(residuals_gp_scaled)
    std_gp = std_gp_scaled * scaler_y.scale_
    
    # Combine PINN and GP predictions
    T_combined = T_pinn - residuals_gp[:, 0]
    Pb_combined = Pb_pinn - residuals_gp[:, 1]
    CAb_combined = CAb_pinn - residuals_gp[:, 2]
    
    # Print GP kernel parameters
    print("\nOptimized GP Kernel:")
    print(gp.kernel_)
    
    # Calculate error metrics
    mse_T_pinn = np.mean((T_pinn - T_init)**2)
    mse_T_combined = np.mean((T_combined - T_init)**2)
    mse_Pb_pinn = np.mean((Pb_pinn - Pb_init)**2)
    mse_Pb_combined = np.mean((Pb_combined - Pb_init)**2)
    mse_CAb_pinn = np.mean((CAb_pinn - CAb_init)**2)
    mse_CAb_combined = np.mean((CAb_combined - CAb_init)**2)
    
    print(f"MSE Temperature PINN: {mse_T_pinn:.6f}")
    print(f"MSE Temperature PINN+GP: {mse_T_combined:.6f}")
    print(f"MSE Oxygen Pressure PINN: {mse_Pb_pinn:.6f}")
    print(f"MSE Oxygen Pressure PINN+GP: {mse_Pb_combined:.6f}")
    print(f"MSE Carbonyl Area PINN: {mse_CAb_pinn:.6f}")
    print(f"MSE Carbonyl Area PINN+GP: {mse_CAb_combined:.6f}")


    # Generate predictions along lines
    predictions = generate_line_predictions(trained_model, L_x, L_y, T)
    
    # Create Excel file with predictions
    create_excel_file(predictions)

    
    # Create and plot final heatmaps
    X, Y, results = create_heatmap_data(trained_model, [0, T], stage=3)
    plot_heatmaps(X, Y, results, stage=3, epoch="Final")

import numpy as np
import matplotlib.pyplot as plt

def create_and_plot_pinn_gp_heatmaps(trained_model, gp, scaler_X, scaler_y, t_values, stage, nx=100, ny=100):
    x = np.linspace(0, L_x, nx)
    y = np.linspace(0, L_y, ny)
    X, Y = np.meshgrid(x, y)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    for t_idx, t in enumerate(t_values):
        xy = np.column_stack((X.flatten(), Y.flatten()))
        t_array = np.full((nx*ny, 1), t)
        input_data = np.hstack((xy, t_array))
        
        # PINN predictions
        with torch.no_grad():
            pinn_pred = trained_model(torch.tensor(input_data).float(), stage).numpy()
        
        # Unscale PINN predictions
        T_pinn = unscale_var(pinn_pred[:, 0], T_min, T_max)
        Pb_pinn = unscale_var(pinn_pred[:, 1], Pb_min, Pb_max)
        CAb_pinn = unscale_var(pinn_pred[:, 2], CAb_min, CAb_max)
        
        # GP corrections
        input_data_scaled = scaler_X.transform(input_data)
        residuals_gp_scaled, _ = gp.predict(input_data_scaled, return_std=True)
        residuals_gp = scaler_y.inverse_transform(residuals_gp_scaled)
        
        # Combine PINN and GP predictions
        T_combined = T_pinn - residuals_gp[:, 0]
        Pb_combined = Pb_pinn - residuals_gp[:, 1]
        CAb_combined = CAb_pinn - residuals_gp[:, 2]
        
        variables = [
            ("Temperature (K)", T_combined),
            ("Oxygen Pressure (Pa)", Pb_combined),
            ("Carbonyl Area", CAb_combined)
        ]
        
        for i, (title, var) in enumerate(variables):
            im = axes[t_idx, i].imshow(var.reshape(nx, ny).T, extent=[0, L_x, 0, L_y], origin='lower', cmap='turbo', aspect='auto')
            axes[t_idx, i].set_title(f'PINN+GP {title} (t={t})')
            plt.colorbar(im, ax=axes[t_idx, i], label=title)
            axes[t_idx, i].set_xlabel('x')
            axes[t_idx, i].set_ylabel('y')
    
    plt.suptitle(f'PINN+GP Predictions, Stage {stage}')
    plt.tight_layout()
    plt.savefig(f'heatmap_pinn_gp_stage_{stage}.png')
    plt.close()

# Usage
create_and_plot_pinn_gp_heatmaps(trained_model, gp, scaler_X, scaler_y, [0, T], stage=3)
