In [None]:
import sys, os
ROOT = os.path.abspath("..")   # go up one directory from notebooks/
if ROOT not in sys.path:
    sys.path.append(ROOT)

print(ROOT)  


In [None]:
from src2.networks import get_all_networks
from src2.config import CONFIG

from src2.sampling import (
    sample_domain_points,
    sample_top_surface,
    sample_interface,
    sample_far_field
)

from src2.losses import total_loss

from src2.pde_residuals import (
    residual_layer_coupled,
    residual_halfspace
)

from src2.boundary_conditions import (
    top_surface_bc,
    interface_layer_halfspace,
    halfspace_far_field_bc
)


# — Torch + Device

In [None]:
import torch
import torch.optim as optim

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)


# Build-Models

In [None]:
model_layer, model_half = get_all_networks()

model_layer.to(DEVICE)
model_half.to(DEVICE)


# Load Config & Geometry

In [None]:
geom = CONFIG["GEOMETRY"]

params_layer = CONFIG["LAYER"]
params_half  = CONFIG["HALFSPACE"]
dispersion = []   # <-- DEFINE DISPERSION HERE


# Define trainable phase velocity (c)

In [None]:
c = torch.nn.Parameter(
    torch.tensor(
        (params_layer["mu44_0"] / params_layer["rho_0"])**0.5,
        device=DEVICE
    )
)


# Optimizer

In [None]:
optimizer = optim.Adam(
    list(model_layer.parameters()) +
    list(model_half.parameters()) +
    [c],
    lr=1e-2
)


# Training Loop (Simple + Transparent)

In [None]:
# Training Loop (Dispersion)
k_values = torch.linspace(
    CONFIG["GEOMETRY"]["k_min"],
    CONFIG["GEOMETRY"]["k_max"],
    CONFIG["GEOMETRY"]["num_k"]
)
k_values

In [None]:
# --------------------------------------------------
# Build models ONCE (outside k-loop)
# --------------------------------------------------
model_layer, model_half = get_all_networks()
model_layer.to(DEVICE)
model_half.to(DEVICE)

 # Define trainable phase velocity (c)
c = torch.nn.Parameter(
        torch.tensor(
            (params_layer["mu44_0"] / params_layer["rho_0"])**0.5,
            device=DEVICE
        )
    )
optimizer = optim.Adam(
    [
        {"params": model_layer.parameters(), "lr": 1e-3},
        {"params": model_half.parameters(), "lr": 1e-3},
        {"params": [c], "lr": 1e-3},
    ]
)

dispersion = []

# --------------------------------------------------
# k-loop (CONTINUATION)
# --------------------------------------------------
for k in k_values:

    print(f"\nTraining for k = {k.item():.3f}")

    for epoch in range(1, 30):

        z_layer, z_half = sample_domain_points(1000, geom)
        z_top = sample_top_surface(1000, geom)
        z_int = sample_interface(1000)
        z_far = sample_far_field(1000, geom)

        optimizer.zero_grad()

        loss, logs = total_loss(
            model_layer,
            model_half,
            z_layer,
            z_half,
            z_top,
            z_int,
            z_far,
            params_layer,
            params_half,
            k.item(),     
            c,
            w_pde=1.0,
            w_bc=50,
            w_int=100,
            w_far=100,
        )

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            c.clamp_(min=5)

        if epoch % 10 == 0:
          print(
        f"k={k.item():.2f} | "
        f"Epoch {epoch} | "
        f"Loss={loss.item():.3e} | "
        f"PDE={logs['pde']:.2e} | "
        f"BC={logs['bc_top']:.2e} | "
        f"INT={logs['interface']:.2e} | "
        f"FAR={logs['far']:.2e} | "
        f"c={c.item():.4f}    "
    )

    dispersion.append([k.item(), c.item()])


In [None]:
torch.save({
    "model_layer": model_layer.state_dict(),
    "model_half": model_half.state_dict(),
    "c": c.detach().cpu()
}, "dispersion_pinn.pth")

print("Model saved.")


Choose time & y-location

In [None]:
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

H = geom["H"]   # layer thickness
L = geom["L"]   # half-space depth


Create a point in each layer

In [None]:
# z-grid for plotting / post-processing

z_layer = torch.linspace(-H, 0.0, 200).reshape(-1, 1).to(DEVICE)
z_half  = torch.linspace(0.0, L, 200).reshape(-1, 1).to(DEVICE)


Feed to model

In [None]:
with torch.no_grad():

    scale = 1e-2
# Layer (complex amplitude)   
    
V_layer = model_layer(z_layer)
V_R = V_layer[:, 0:1]
V_I =  V_layer[:, 1:2]
# Half-space (real amplitude)
V_half =  model_half(z_half)



In [None]:
import matplotlib.pyplot as plt
import torch

# dispersion: shape (N, 2) → [k, c]
dispersion = torch.tensor(dispersion)
# dispersion: shape (N, 2) → [k, c]
k_vals = dispersion[:, 0]
c_vals = dispersion[:, 1]


plt.figure(figsize=(6,5))
plt.plot(k_vals, c_vals, 'o-', linewidth=2, markersize=6)

plt.xlabel("Wave number $k$")
plt.ylabel("Phase velocity $c$")
plt.title("Dispersion Relation")
plt.grid(True)

plt.show()


Plot Disperion curves


In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np

# Your dispersion data
dispersion = torch.tensor(dispersion)
k_vals = dispersion[:, 0]
c_vals = dispersion[:, 1]

# Material properties for layer (from your config)
mu44_layer = 5.3e9  # Pa
rho_layer = 3400  # kg/m³
H = 1 # Layer thickness (non-dimensional)

# Calculate reference velocity: √(μ₆₆/ρ) for layer
c_ref_layer = np.sqrt(mu44_layer / rho_layer)# Reference shear speed

print(f"Layer properties:")
print(f"  μ₄₄ = {mu44_layer:.2e} Pa")
print(f"  ρ = {rho_layer:.0f} kg/m³")
print(f"  c_ref = √(μ₄₄/ρ) = {c_ref_layer:.1f} m/s")
print(f"  H = {H}")

# Non-dimensionalize
kH = k_vals * H                     # Dimensionless wavenumber
c_norm = c_vals / c_ref_layer       # Dimensionless phase velocity

# Plot
plt.figure(figsize=(6, 5))
plt.plot(kH, c_norm, 'o-', linewidth=2, markersize=6)

plt.xlabel("Dimensionless wavenumber $kH$")
plt.ylabel("Dimensionless phase velocity $c/\\sqrt{\\mu_{44}/\\rho}$")
plt.title("Dispersion Relation (Non-dimensional)")
plt.grid(True)

plt.show()

# Print the data
print("\nDispersion data (non-dimensional):")
print("   kH        c/c_ref")
for i in range(len(kH)):
    print(f"   {kH[i]:.3f}     {c_norm[i]:.3f}")

In [None]:
import matplotlib.pyplot as plt
import torch

# dispersion: shape (N, 2) → [k, c]
dispersion = torch.tensor(dispersion)
# dispersion: shape (N, 2) → [k, c]
k_vals = dispersion[:, 0]
c_vals = dispersion[:, 1]


plt.figure(figsize=(6,5))
plt.plot(k_vals, c_vals, 'o-', linewidth=2, markersize=6)

plt.xlabel("Wave number $k$")
plt.ylabel("Phase velocity $c$")
plt.title("Dispersion Relation")
plt.grid(True)

plt.show()


In [None]:
k_vals

In [None]:
c_vals

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline

# If not already done, convert tensors to numpy arrays
kH_np = kH.cpu().numpy() if hasattr(kH, 'cpu') else np.array(kH)
c_norm_np = c_norm.cpu().numpy() if hasattr(c_norm, 'cpu') else np.array(c_norm)

# Sort values for interpolation (required for spline)
sort_idx = np.argsort(kH_np)
kH_sorted = kH_np[sort_idx]
c_norm_sorted = c_norm_np[sort_idx]

# Create smooth curve using cubic spline
kH_smooth = np.linspace(kH_sorted.min(), kH_sorted.max(), 300)
spline = make_interp_spline(kH_sorted, c_norm_sorted, k=2)
c_norm_smooth = spline(kH_smooth)

# Plot both discrete points and smooth curve
plt.figure(figsize=(6, 5))
plt.plot(kH_sorted, c_norm_sorted, 'o', label='Discrete points')
plt.plot(kH_smooth, c_norm_smooth, '-', label='Smooth curve', linewidth=2)
plt.xlabel("Dimensionless wavenumber $kH$")
plt.ylabel("Dimensionless phase velocity $c/\\sqrt{\\mu_{44}/\\rho}$")
plt.title("Dispersion Relation (Smooth)")
plt.grid(True)
plt.legend()
plt.show()

In [None]:
# Save PINN dispersion data
np.savetxt(
    "PINN_dispersion.csv",
    np.column_stack((kH_sorted, c_norm_sorted)),
    delimiter=",",
    header="kH,c_pinn",
    comments=""
)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- Parameters ---
mu44l = 4.35e9
mu66l = 5.0e9
mu44h = 5.3e9
mu66h = 6.47e9
rho1 = 9890
rho2 = 3400
g = 9.81

H = 1
beta1 = 0.5
beta2 = 0.5
P1 = 1e9
P2 = 1e9
phi = np.pi / 6
mu_eH02 = 0.3 * mu44l

l = 29 * H
v = 7 + 7j  # Fixed to match MATLAB
c_s = np.sqrt(mu44l / rho1)

# --- Shape functions ---
def chi(xi):
    return np.array([
        xi * (xi - 1) / 2,
        1 - xi**2,
        xi * (xi + 1) / 2
    ])

def dchi(xi):
    return np.array([
        (2 * xi - 1) / 2,
        -2 * xi,
        (2 * xi + 1) / 2
    ])

# --- Gauss integration ---
gp = np.array([-np.sqrt(3/5), 0.0, np.sqrt(3/5)])
gw = np.array([5/9, 8/9, 5/9])

def assemble_KM(kH, cbar):
    k = kH / H
    omega = (kH / H) * cbar * c_s  # w = k * c
    
    # Initialize global matrices
    K_global = np.zeros((5, 5), dtype=complex)
    M_global = np.zeros((5, 5), dtype=complex)
    
    # === Element 1 ===
    J1 = H / 2
    Ke1 = np.zeros((3, 3), dtype=complex)
    Me1 = np.zeros((3, 3), dtype=complex)
    
    for ig in range(3):
        xi = gp[ig]
        wg = gw[ig]
        
        N = chi(xi)
        dN = dchi(xi)
        wfac1 = np.exp(beta1 * H * (xi + 1) / 2)
        
        for i in range(3):
            for j in range(3):
                term1 = -k**2 * mu66l * N[i] * N[j]
                term2 = (1 / J1**2) * mu44l * dN[i] * dN[j]
                term3 = (P1 / 2) * k**2 * N[i] * N[j]
                term4 = mu_eH02 * (
                    -k**2 * (np.cos(phi))**2 * N[i] * N[j]
                    + 1j * k * np.sin(2 * phi)  * N[i] * dN[j]/ J1
                    + (np.sin(phi))**2 / J1**2 * dN[i] * dN[j]
                )
                
                Ke1[i, j] += (wfac1 * (term1 + term2 + term3+ term4)) * J1 * wg
                Me1[i, j] += rho1 * wfac1 * N[i] * N[j] * J1 * wg
    
    # === Element 2 (PML) ===
    J2 = l / 2
    Ke2 = np.zeros((3, 3), dtype=complex)
    Me2 = np.zeros((3, 3), dtype=complex)
    
    for ig in range(3):
        xi = gp[ig]
        wg = gw[ig]
        
        N = chi(xi)
        dN = dchi(xi)
        wfac2 = np.exp(beta2 * H * (xi + 1) / 2)
        x_val = H * (xi + 1) / 2  # z-coordinate
        
        for i in range(3):
            for j in range(3):
                # Basic terms
                term1 = -k**2 * mu66h * N[i] * N[j]
                term2 = (1 / (J2**2 * v**2)) * mu44h * dN[i] * dN[j]
                term3 = (P2 / 2) * k**2 * N[i] * N[j]
                
                # Gravity terms (terms 4-7 from MATLAB)
                term4a = -(rho2 * g / 2) * x_val * N[i] * N[j] / (J2**2 * v**2)

                term4b = (k**2 * rho2 * g / 2) * x_val * N[i] * N[j]

                term5 = -(g / 2) * rho2 * ((beta2 * H * (xi + 1)) / 2 + 1) \
                        * N[i] * dN[j] / (J2 * v)

                term6 = -(g / 2) * rho2 * ((beta2 * H * (xi + 1)) / 2 + 1) \
                        * dN[i] * N[j] / (J2 * v)

                term7 = -(g / 2) * rho2 * (
                        (beta2**2 * H * (xi + 1)) / 2 + 2 * beta2
                    ) * N[i] * N[j]
              
                # Combine all terms
                Ke2[i, j] += (wfac2 * (term1 + term2 + term3 + term4a + term4b + term5 + term6 + term7)) * v * J2 * wg
                Me2[i, j] += rho2 * wfac2 * N[i] * N[j] * v * J2 * wg
    
    # === Global Assembly ===
    # Element 1: nodes 1,2,3
    K_global[0:3, 0:3] += Ke1
    M_global[0:3, 0:3] += Me1
    
    # Element 2: nodes 3,4,5
    K_global[2:5, 2:5] += Ke2
    M_global[2:5, 2:5] += Me2
    
    # === Apply boundary condition (remove DOF 5) ===
    K_reduced = np.delete(K_global, 4, axis=0)
    K_reduced = np.delete(K_reduced, 4, axis=1)
    M_reduced = np.delete(M_global, 4, axis=0)
    M_reduced = np.delete(M_reduced, 4, axis=1)
    
    return K_reduced, M_reduced, omega

def det_A(kH, cbar):
    K, M, omega = assemble_KM(kH, cbar)
    A = K - omega**2 * M
    return np.real(np.linalg.det(A))

# === Root finding for dispersion curve ===
kH_range = np.linspace(0.5, 2.0, 100)
c_range = np.linspace(0.1, 3.0, 700)

kH_points = []
c_points = []

for kH in kH_range:
    det_vals = []
    for c in c_range:
        det_vals.append(det_A(kH, c))
    
    det_vals = np.array(det_vals)
    
    for i in range(len(c_range) - 1):

    # sign change = det(A)=0 in between
     if np.sign(det_vals[i]) != np.sign(det_vals[i + 1]):

        c_root = np.interp(
            0.0,
            [det_vals[i], det_vals[i + 1]],
            [c_range[i], c_range[i + 1]]
        )

        kH_points.append(kH)
        c_points.append(c_root)
        break   # first (physical) mode only


# === Plotting ===
plt.figure(figsize=(8, 6))
plt.plot(kH_points, c_points, 'k-', linewidth=2)
plt.xlabel('kH', fontsize=12)
plt.ylabel(r'$c/c_s$', fontsize=12)
plt.title('SAFE-PML Dispersion Relation', fontsize=14)
plt.grid(True, alpha=0.3)
plt.xlim([0.5, 2.0])
plt.ylim([0, 3.0])
plt.tight_layout()
plt.show()

# Optional: Also plot contour for verification
if True:
    kH_vals = np.linspace(0.5, 2.0, 100)
    c_vals = np.linspace(0.1, 3.0, 100)
    KHH, CC = np.meshgrid(kH_vals, c_vals)
    
    Z = np.zeros_like(KHH)
    for i in range(KHH.shape[0]):
        for j in range(KHH.shape[1]):
            Z[i, j] = det_A(KHH[i, j], CC[i, j])
    
    plt.figure(figsize=(8, 6))
    plt.contour(KHH, CC, Z, levels=[0], colors='r', linewidths=2)
    plt.contourf(KHH, CC, np.log10(np.abs(Z) + 1), levels=50, cmap='RdBu_r', alpha=0.7)
    plt.colorbar(label='log10|det|')
    plt.xlabel('kH')
    plt.ylabel(r'$c/c_s$')
    plt.title('Dispersion determinant (contour at 0)')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    # === ADD THESE 3 LINES AT THE END OF YOUR FILE ===

# Save FEM dispersion data (after the plotting code)
np.savetxt(
    "FEM_dispersion.csv",
    np.column_stack((kH_points, c_points)),
    delimiter=",",
    header="kH,c_fem",
    comments=""
)

print(f"FEM data saved to 'FEM_dispersion.csv' with {len(kH_points)} points")

In [None]:
# Save PINN dispersion data
np.savetxt(
    "FEM_dispersion.csv",
    np.column_stack((kH_points, c_points)),
    delimiter=",",
    header="kH,c_fem",
    comments=""
)


In [None]:
np.savetxt(
    "PINN_dispersion.csv",
    np.column_stack((kH_np, c_norm_np)),
    delimiter=",",
    header="kH,c_pinn",
    comments=""
)


In [None]:
import numpy as np

fem = np.loadtxt("FEM_dispersion.csv", delimiter=",", skiprows=1)

print("FEM array shape:", fem.shape)
print(fem)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline

# ==============================
# Load data
# ==============================
pinn = np.loadtxt("PINN_dispersion.csv", delimiter=",", skiprows=1)
fem  = np.loadtxt("FEM_dispersion.csv",  delimiter=",", skiprows=1)

kH_pinn = pinn[:, 0]
c_pinn  = pinn[:, 1]

kH_fem = fem[:, 0]
c_fem  = fem[:, 1]

# ==============================
# Smooth PINNs curve
# ==============================
sort_idx = np.argsort(kH_pinn)
kH_pinn_sorted = kH_pinn[sort_idx]
c_pinn_sorted  = c_pinn[sort_idx]

kH_smooth = np.linspace(kH_pinn_sorted.min(),
                        kH_pinn_sorted.max(), 300)

spline = make_interp_spline(kH_pinn_sorted, c_pinn_sorted, k=2)
c_pinn_smooth = spline(kH_smooth)

# ==============================
# Plot comparison
# ==============================
plt.figure(figsize=(7, 5))

# PINNs (dashed + squares)
plt.plot(
    kH_smooth,
    c_pinn_smooth,
    '--',
    color='tab:blue',
    linewidth=2,
    label='Present PINNs'
)
plt.plot(
    kH_pinn,
    c_pinn,
    's',
    color='tab:blue',
    markersize=5
)

# SAFE / FEM (circles)
plt.plot(
    kH_fem,
    c_fem,
    'o',
    color='red',
    markersize=6,
    markerfacecolor='none',
    label='SAFE (Theoretical)'
)

# ==============================
# Styling (match paper figure)
# ==============================
plt.xlabel(r'Dimensionless wavenumber $kH$', fontsize=12)
plt.ylabel(r'Dimensionless phase velocity $c/c_s$', fontsize=12)

plt.grid(True, linestyle=':', alpha=0.6)
plt.legend(fontsize=11, frameon=False)

plt.xlim([0.1, 2.0])
plt.ylim([0.0, 3.0])

plt.tight_layout()
plt.show()
