# Graph Deep Fakes - Training Pipeline

This notebook trains a **Graph VAE + Latent Diffusion** model to generate synthetic FEA signals on a mesh.

## Architecture
1. **Graph Autoencoder**: Compresses 6523-dim field to ~64-dim latent using spectral projection
2. **Graph-Aware Diffusion (GAD)**: DDPM-style diffusion on the latent space with polynomial graph filters
3. **Generation**: Sample latent codes from diffusion → decode to mesh field

## 1. Setup and Imports

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.tri import Triangulation
from scipy.sparse import csr_matrix, diags, linalg as splinalg, save_npz, load_npz
from scipy.spatial import Delaunay
import time
import os
from itertools import product

# Check for PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

print(f"PyTorch {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# Dataset generation
DATASET_DIR = "dataset"
MESH_RESOLUTION = 0.005

# Parameter ranges
DIFFUSIVITY_VALUES = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]
SOURCE_VALUES = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 7.0]
# Cylinder fixed at y=0 (mesh geometry must match BC location)

# Domain
CYLINDER_RADIUS = 0.05
DOMAIN_RADIUS = 0.15
X_MIN, X_MAX = -DOMAIN_RADIUS, DOMAIN_RADIUS * 2.5
Y_MIN, Y_MAX = -DOMAIN_RADIUS, DOMAIN_RADIUS

# Model architecture
LATENT_DIM = 64
HIDDEN_DIM = 256
N_ENCODER_LAYERS = 4
N_DECODER_LAYERS = 4

# Diffusion
DIFFUSION_STEPS = 100
BETA_START = 1e-4
BETA_END = 0.02

# Training
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
AE_EPOCHS = 500
DIFF_EPOCHS = 800

OUTPUT_DIR = "training_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DATASET_DIR, exist_ok=True)

n_total_samples = len(DIFFUSIVITY_VALUES) * len(SOURCE_VALUES)
print(f"Will generate {n_total_samples} FEA solutions")

## 3. Generate Mesh

In [None]:
def generate_mesh(x_min, x_max, y_min, y_max, cx, cy, r, resolution):
    """Generate 2D triangular mesh for rectangular domain with circular obstacle."""
    nx = int((x_max - x_min) / resolution) + 1
    ny = int((y_max - y_min) / resolution) + 1

    x = np.linspace(x_min, x_max, nx)
    y = np.linspace(y_min, y_max, ny)
    xx, yy = np.meshgrid(x, y)
    points = np.column_stack([xx.ravel(), yy.ravel()])

    # Remove points inside cylinder
    dist_to_center = np.sqrt((points[:, 0] - cx)**2 + (points[:, 1] - cy)**2)
    outside_cylinder = dist_to_center > r * 1.1
    points = points[outside_cylinder]

    # Add rings around cylinder
    for ring_factor in [1.0, 1.15, 1.3]:
        n_circle = int(2 * np.pi * r * ring_factor / (resolution * 0.5))
        theta = np.linspace(0, 2*np.pi, n_circle, endpoint=False)
        circle_points = np.column_stack([
            cx + r * ring_factor * np.cos(theta),
            cy + r * ring_factor * np.sin(theta)
        ])
        points = np.vstack([points, circle_points])

    tri = Delaunay(points)
    triangles = tri.simplices

    # Remove triangles inside cylinder
    centroids = points[triangles].mean(axis=1)
    dist_centroids = np.sqrt((centroids[:, 0] - cx)**2 + (centroids[:, 1] - cy)**2)
    valid_triangles = triangles[dist_centroids > r]

    return points, valid_triangles

print("Generating mesh...")
start = time.time()
points, triangles = generate_mesh(
    X_MIN, X_MAX, Y_MIN, Y_MAX,
    0.0, 0.0, CYLINDER_RADIUS, MESH_RESOLUTION
)
n_nodes = len(points)
n_elements = len(triangles)
print(f"Mesh: {n_nodes} nodes, {n_elements} elements ({time.time()-start:.2f}s)")

# Visualize mesh
fig, ax = plt.subplots(figsize=(12, 4))
ax.triplot(points[:, 0], points[:, 1], triangles, linewidth=0.2, color='blue')
circle = plt.Circle((0, 0), CYLINDER_RADIUS, fill=True, color='cyan', ec='black')
ax.add_patch(circle)
ax.set_aspect('equal')
ax.set_title(f'Mesh: {n_nodes} nodes, {n_elements} triangles')
plt.show()

## 4. Build Graph Laplacian

In [None]:
def build_cotangent_laplacian(points, triangles):
    """Build cotangent-weighted Laplacian matrix."""
    n = len(points)
    rows, cols, weights = [], [], []

    for tri in triangles:
        p = points[tri]
        for i in range(3):
            j = (i + 1) % 3
            k = (i + 2) % 3
            vi, vj, vk = tri[i], tri[j], tri[k]

            e1 = p[i] - p[k]
            e2 = p[j] - p[k]

            cos_angle = np.dot(e1, e2)
            e1_3d = np.array([e1[0], e1[1], 0])
            e2_3d = np.array([e2[0], e2[1], 0])
            sin_angle = np.abs(np.linalg.norm(np.cross(e1_3d, e2_3d)))
            cot_weight = cos_angle / (sin_angle + 1e-10) * 0.5
            cot_weight = max(cot_weight, 0)

            rows.extend([vi, vj])
            cols.extend([vj, vi])
            weights.extend([cot_weight, cot_weight])

    W = csr_matrix((weights, (rows, cols)), shape=(n, n))
    W = W.tocsr()
    D = diags(np.array(W.sum(axis=1)).flatten())
    L = D - W
    return L, W

print("Building Laplacian...")
start = time.time()
L, W = build_cotangent_laplacian(points, triangles)
print(f"Laplacian: {L.shape}, nnz={L.nnz} ({time.time()-start:.2f}s)")

# Compute eigendecomposition
print("Computing eigenvectors...")
start = time.time()
n_eigs = 50
L_reg = L + 1e-8 * diags(np.ones(n_nodes))
eigenvalues, eigenvectors = splinalg.eigsh(
    L_reg, k=n_eigs, which='LM', sigma=1e-6, tol=1e-4, maxiter=5000
)
eigenvalues = np.real(eigenvalues)
eigenvectors = np.real(eigenvectors)
idx = np.argsort(eigenvalues)
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]
print(f"Computed {n_eigs} eigenvectors ({time.time()-start:.2f}s)")

# Plot first few eigenvectors
triang = Triangulation(points[:, 0], points[:, 1], triangles)
fig, axes = plt.subplots(2, 4, figsize=(14, 6))
fig.suptitle('First 8 Laplacian Eigenvectors (Manifold Harmonics)', fontweight='bold')
for i, ax in enumerate(axes.flatten()):
    ax.tripcolor(triang, eigenvectors[:, i], cmap='RdBu', shading='gouraud')
    circle = plt.Circle((0, 0), CYLINDER_RADIUS, fill=True, color='gray', ec='black')
    ax.add_patch(circle)
    ax.set_aspect('equal')
    ax.set_title(f'Mode {i}: λ={eigenvalues[i]:.4f}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## 5. Generate FEA Solutions

In [None]:
def solve_heat_equation(points, L, diffusivity, source_strength):
    """Solve steady heat equation with parameter-dependent source and BCs."""
    n = len(points)
    x, y = points[:, 0], points[:, 1]
    cx, cy = 0.0, 0.0  # Cylinder fixed at origin (matches mesh geometry)
    r = CYLINDER_RADIUS
    
    dist = np.sqrt((x - cx)**2 + (y - cy)**2)
    
    # Source term
    source = np.zeros(n)
    source_region = dist < r * 3
    source[source_region] = source_strength * np.exp(-(dist[source_region] - r)**2 / (r**2))
    source += source_strength * 0.3 * np.sin(np.pi * y / Y_MAX) * (x > 0)
    
    # System matrix
    A = diffusivity * L + 1e-6 * diags(np.ones(n))
    b = source.copy()
    
    # Boundary conditions
    tol = MESH_RESOLUTION * 1.5
    left = x < X_MIN + tol
    right = x > X_MAX - tol
    top = y > Y_MAX - tol
    bottom = y < Y_MIN + tol
    on_cylinder = dist < r * 1.3
    
    bc_nodes = left | right | top | bottom | on_cylinder
    u_bc = np.zeros(n)
    u_bc[left] = 1.0
    u_bc[right] = 0.0
    u_bc[top] = 0.5 + 0.3 * source_strength / 5
    u_bc[bottom] = 0.5 - 0.3 * source_strength / 5
    u_bc[on_cylinder] = 0.2 + 0.6 * diffusivity
    
    A_mod = A.tolil()
    for i in np.where(bc_nodes)[0]:
        A_mod[i, :] = 0
        A_mod[i, i] = 1.0
        b[i] = u_bc[i]
    A_mod = A_mod.tocsr()
    
    u = splinalg.spsolve(A_mod, b)
    u = np.nan_to_num(u, nan=0.5, posinf=1.0, neginf=0.0)
    
    u_min, u_max = u.min(), u.max()
    if u_max > u_min:
        u = (u - u_min) / (u_max - u_min)
    
    return u

In [None]:
# Generate all solutions
param_combinations = list(product(DIFFUSIVITY_VALUES, SOURCE_VALUES))
n_samples = len(param_combinations)

print(f"Generating {n_samples} FEA solutions...")
solutions = np.zeros((n_samples, n_nodes))
parameters = np.zeros((n_samples, 2))  # [diffusivity, source]

start_time = time.time()
for idx, (diffusivity, source) in enumerate(param_combinations):
    u = solve_heat_equation(points, L, diffusivity, source)
    solutions[idx] = u
    parameters[idx] = [diffusivity, source]
    
    if (idx + 1) % 10 == 0 or idx == 0:
        elapsed = time.time() - start_time
        rate = (idx + 1) / elapsed
        remaining = (n_samples - idx - 1) / rate
        print(f"  [{idx+1:3d}/{n_samples}] {elapsed:.1f}s elapsed, ~{remaining:.1f}s remaining")

print(f"\nCompleted {n_samples} solutions in {time.time()-start_time:.1f}s")

In [None]:
# Visualize sample solutions
fig, axes = plt.subplots(3, 3, figsize=(14, 11))
fig.suptitle(f'Sample FEA Solutions ({n_samples} total)', fontsize=14, fontweight='bold')

sample_indices = [0, n_samples//4, n_samples//2, 3*n_samples//4, n_samples-1,
                  n_samples//8, 3*n_samples//8, 5*n_samples//8, 7*n_samples//8]

for ax, idx in zip(axes.flatten(), sample_indices):
    diff, src = parameters[idx]
    sol = solutions[idx]
    levels = np.linspace(0, 1, 30)
    tcf = ax.tricontourf(triang, sol, levels=levels, cmap='inferno', extend='both')
    ax.tricontour(triang, sol, levels=10, colors='white', linewidths=0.3, alpha=0.5)
    circle = plt.Circle((0, 0), CYLINDER_RADIUS, fill=True, color='cyan', ec='black', lw=1)
    ax.add_patch(circle)
    ax.set_aspect('equal')
    ax.set_title(f'k={diff:.1f}, Q={src:.1f}')
    ax.axis('off')

fig.subplots_adjust(right=0.92)
cbar_ax = fig.add_axes([0.94, 0.15, 0.02, 0.7])
sm = plt.cm.ScalarMappable(cmap='inferno', norm=plt.Normalize(0, 1))
fig.colorbar(sm, cax=cbar_ax, label='Temperature')
plt.show()

In [None]:
# Save dataset
np.savez(f"{DATASET_DIR}/mesh.npz", points=points, triangles=triangles,
         eigenvalues=eigenvalues, eigenvectors=eigenvectors)
save_npz(f"{DATASET_DIR}/laplacian.npz", L)
save_npz(f"{DATASET_DIR}/adjacency.npz", W)
np.savez(f"{DATASET_DIR}/solutions.npz", solutions=solutions, parameters=parameters,
         param_names=['diffusivity', 'source'])

print(f"Saved dataset to {DATASET_DIR}/")

## 6. Define Model Architectures

In [None]:
class SpectralGraphEncoder(nn.Module):
    def __init__(self, n_nodes, n_eigenvectors, latent_dim, hidden_dim, n_layers):
        super().__init__()
        self.spectral_proj = nn.Linear(n_eigenvectors, hidden_dim)
        self.spatial_proj = nn.Linear(n_nodes, hidden_dim)
        
        layers = []
        for i in range(n_layers):
            in_dim = hidden_dim * 2 if i == 0 else hidden_dim
            layers.extend([nn.Linear(in_dim, hidden_dim), nn.LayerNorm(hidden_dim),
                          nn.GELU(), nn.Dropout(0.1)])
        self.layers = nn.Sequential(*layers)
        self.to_latent_mu = nn.Linear(hidden_dim, latent_dim)
        self.to_latent_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x, eigenvectors):
        x_spectral = torch.matmul(x, eigenvectors)
        h_spectral = self.spectral_proj(x_spectral)
        h_spatial = self.spatial_proj(x)
        h = torch.cat([h_spectral, h_spatial], dim=-1)
        h = self.layers(h)
        return self.to_latent_mu(h), self.to_latent_logvar(h)


class SpectralGraphDecoder(nn.Module):
    def __init__(self, n_nodes, n_eigenvectors, latent_dim, hidden_dim, n_layers):
        super().__init__()
        self.from_latent = nn.Linear(latent_dim, hidden_dim)
        layers = []
        for i in range(n_layers):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim),
                          nn.GELU(), nn.Dropout(0.1)])
        self.layers = nn.Sequential(*layers)
        self.to_spectral = nn.Linear(hidden_dim, n_eigenvectors)
        self.to_spatial_residual = nn.Linear(hidden_dim, n_nodes)

    def forward(self, z, eigenvectors):
        h = self.from_latent(z)
        h = self.layers(h)
        spectral_coef = self.to_spectral(h)
        x_spectral = torch.matmul(spectral_coef, eigenvectors.T)
        x_residual = self.to_spatial_residual(h)
        return x_spectral + 0.1 * x_residual


class GraphVAE(nn.Module):
    def __init__(self, n_nodes, n_eigenvectors, latent_dim, hidden_dim, n_layers):
        super().__init__()
        self.encoder = SpectralGraphEncoder(n_nodes, n_eigenvectors, latent_dim, hidden_dim, n_layers)
        self.decoder = SpectralGraphDecoder(n_nodes, n_eigenvectors, latent_dim, hidden_dim, n_layers)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + torch.randn_like(std) * std

    def forward(self, x, eigenvectors):
        mu, logvar = self.encoder(x, eigenvectors)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z, eigenvectors), mu, logvar, z

    def encode(self, x, eigenvectors):
        mu, _ = self.encoder(x, eigenvectors)
        return mu

    def decode(self, z, eigenvectors):
        return self.decoder(z, eigenvectors)

print("GraphVAE defined")

In [None]:
class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half_dim = self.dim // 2
        emb_scale = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb_scale)
        emb = t.float().unsqueeze(-1) * emb.unsqueeze(0)
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)


class GraphFilterTap(nn.Module):
    def __init__(self, in_channels, out_channels, filter_order=4):
        super().__init__()
        self.filter_order = filter_order
        self.theta = nn.Parameter(torch.randn(filter_order + 1, in_channels, out_channels) * 0.01)
        self.norm = nn.LayerNorm(out_channels)

    def forward(self, x, S_powers):
        out = torch.zeros(x.shape[0], x.shape[1], self.theta.shape[2], device=x.device)
        for k in range(self.filter_order + 1):
            Sk_x = torch.matmul(S_powers[k].unsqueeze(0), x)
            out = out + torch.einsum('bni,ioj->bno', Sk_x, self.theta[k:k+1].squeeze(0))
        return self.norm(F.silu(out))


class PolynomialGraphFilterDenoiser(nn.Module):
    def __init__(self, latent_dim, hidden_dim, filter_order=4, n_layers=3):
        super().__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.filter_order = filter_order

        self.time_embed = nn.Sequential(
            SinusoidalPositionalEmbedding(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        S_init = torch.zeros(latent_dim, latent_dim)
        for i in range(latent_dim - 1):
            S_init[i, i+1] = 0.5
            S_init[i+1, i] = 0.5
        S_init = S_init + 0.1 * torch.eye(latent_dim)
        self.S = nn.Parameter(S_init)

        hidden_channels = 32
        self.input_proj = nn.Linear(1 + hidden_dim // latent_dim + 1, hidden_channels)
        self.layers = nn.ModuleList([GraphFilterTap(hidden_channels, hidden_channels, filter_order)
                                     for _ in range(n_layers)])
        self.output_proj = nn.Sequential(nn.Linear(hidden_channels, hidden_channels),
                                         nn.SiLU(), nn.Linear(hidden_channels, 1))
        self.residual_weight = nn.Parameter(torch.tensor(0.1))

    def _compute_S_powers(self):
        S_powers = [torch.eye(self.latent_dim, device=self.S.device)]
        S_current = self.S
        for k in range(self.filter_order):
            S_powers.append(S_current.clone())
            S_current = torch.matmul(S_current, self.S)
        return S_powers

    def forward(self, z, t, n_steps):
        batch_size = z.shape[0]
        t_norm = t.float() / n_steps
        t_embed = self.time_embed(t_norm)
        t_per_node = t_embed.unsqueeze(1).expand(-1, self.latent_dim, -1)
        t_per_node = t_per_node[..., :self.hidden_dim // self.latent_dim + 1]
        
        z_expanded = z.unsqueeze(-1)
        x = torch.cat([z_expanded, t_per_node], dim=-1)
        x = self.input_proj(x)
        
        S_powers = self._compute_S_powers()
        for layer in self.layers:
            x = layer(x, S_powers) + x
        
        out = self.output_proj(x).squeeze(-1)
        return out + self.residual_weight * z


class GraphAwareDiffusion(nn.Module):
    def __init__(self, latent_dim, hidden_dim, n_steps, n_eigenvectors):
        super().__init__()
        self.latent_dim = latent_dim
        self.n_steps = n_steps

        eigenvalues = torch.linspace(0, 2, latent_dim) ** 2
        self.register_buffer('eigenvalues', eigenvalues)
        self._setup_gasde_schedule(n_steps)
        self.denoiser = PolynomialGraphFilterDenoiser(latent_dim, hidden_dim)

    def _setup_gasde_schedule(self, n_steps):
        c_min, alpha, k = 0.1, 2.0, 2.0
        t = torch.linspace(0, 1, n_steps)
        c_t = c_min + k * (t ** alpha)
        self.register_buffer('c_t', c_t)
        s_t = c_min * t + k * (t ** (alpha + 1)) / (alpha + 1)
        self.register_buffer('s_t', s_t)

        decay = torch.exp(-s_t.unsqueeze(-1) * self.eigenvalues.unsqueeze(0))
        self.register_buffer('decay', decay)

        eigenvalues_safe = self.eigenvalues.clamp(min=1e-6)
        marginal_var = (1 - torch.exp(-2 * s_t.unsqueeze(-1) * eigenvalues_safe.unsqueeze(0))) / (2 * eigenvalues_safe.unsqueeze(0))
        self.register_buffer('marginal_std', torch.sqrt(marginal_var.clamp(min=1e-8)))

        alphas_cumprod = (decay ** 2).mean(dim=-1).clamp(min=1e-6, max=1.0)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))

    def forward(self, z_noisy, t):
        return self.denoiser(z_noisy, t, self.n_steps)

    def add_noise(self, z, t, noise=None):
        if noise is None:
            noise = torch.randn_like(z)
        decay_t = self.decay[t]
        std_t = self.marginal_std[t]
        return decay_t * z + std_t * noise, noise

    @torch.no_grad()
    def sample(self, n_samples, device):
        z = torch.randn(n_samples, self.latent_dim, device=device)
        dt = 1.0 / self.n_steps

        for t in reversed(range(self.n_steps)):
            t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
            c = self.c_t[t]
            noise_pred = self.forward(z, t_batch)

            if t > 0:
                std_t = self.marginal_std[t].unsqueeze(0)
                score_estimate = -noise_pred / (std_t + 1e-6)
                drift = c * self.eigenvalues.unsqueeze(0) * z + 2 * c * score_estimate
                diffusion = torch.sqrt(2 * c)
                z = z - drift * dt + diffusion * np.sqrt(dt) * torch.randn_like(z)
            else:
                std_t = self.marginal_std[t].unsqueeze(0)
                score_estimate = -noise_pred / (std_t + 1e-6)
                z = z + 2 * c * score_estimate * dt

        return z

print("GraphAwareDiffusion defined")

## 7. Prepare Data for Training

In [None]:
n_eigenvectors = eigenvectors.shape[1]

# Convert to tensors
solutions_tensor = torch.FloatTensor(solutions).to(device)
parameters_tensor = torch.FloatTensor(parameters).to(device)
eigenvectors_tensor = torch.FloatTensor(eigenvectors).to(device)

# Create data loader
dataset = TensorDataset(solutions_tensor, parameters_tensor)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Solutions: {solutions_tensor.shape}")
print(f"Eigenvectors: {eigenvectors_tensor.shape}")
print(f"DataLoader: {len(dataloader)} batches")

## 8. Initialize Models

In [None]:
vae = GraphVAE(n_nodes, n_eigenvectors, LATENT_DIM, HIDDEN_DIM, N_ENCODER_LAYERS).to(device)
diffusion = GraphAwareDiffusion(LATENT_DIM, HIDDEN_DIM, DIFFUSION_STEPS, n_eigenvectors).to(device)

n_vae_params = sum(p.numel() for p in vae.parameters())
n_diff_params = sum(p.numel() for p in diffusion.parameters())

print(f"GraphVAE: {n_vae_params:,} parameters")
print(f"GraphAwareDiffusion: {n_diff_params:,} parameters")
print(f"Total: {n_vae_params + n_diff_params:,} parameters")

## 9. Train VAE

In [None]:
optimizer_vae = torch.optim.AdamW(vae.parameters(), lr=LEARNING_RATE)
scheduler_vae = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vae, AE_EPOCHS)

vae_losses = []
best_loss = float('inf')

print(f"Training VAE for {AE_EPOCHS} epochs...")
start_time = time.time()

for epoch in range(AE_EPOCHS):
    vae.train()
    epoch_loss = 0

    for batch_x, batch_params in dataloader:
        optimizer_vae.zero_grad()
        x_recon, mu, logvar, z = vae(batch_x, eigenvectors_tensor)
        recon_loss = F.mse_loss(x_recon, batch_x)
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        loss = recon_loss + 0.001 * kl_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
        optimizer_vae.step()
        epoch_loss += loss.item()

    scheduler_vae.step()
    avg_loss = epoch_loss / len(dataloader)
    vae_losses.append(avg_loss)

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(vae.state_dict(), f"{OUTPUT_DIR}/vae_best.pt")

    if (epoch + 1) % 50 == 0 or epoch == 0:
        elapsed = time.time() - start_time
        print(f"  Epoch {epoch+1:3d}/{AE_EPOCHS} | Loss: {avg_loss:.6f} | {elapsed:.1f}s")

print(f"VAE training complete. Best loss: {best_loss:.6f}")
vae.load_state_dict(torch.load(f"{OUTPUT_DIR}/vae_best.pt"))

In [None]:
# Plot VAE training curve
plt.figure(figsize=(10, 4))
plt.plot(vae_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('VAE Training Loss')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

## 10. Encode to Latent Space

In [None]:
vae.eval()
with torch.no_grad():
    all_latents = vae.encode(solutions_tensor, eigenvectors_tensor)

print(f"Latent codes: {all_latents.shape}")
print(f"Mean: {all_latents.mean().item():.4f}, Std: {all_latents.std().item():.4f}")

latent_dataset = TensorDataset(all_latents)
latent_dataloader = DataLoader(latent_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# Visualize latent space with PCA
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
latents_2d = pca.fit_transform(all_latents.cpu().numpy())

plt.figure(figsize=(8, 8))
scatter = plt.scatter(latents_2d[:, 0], latents_2d[:, 1], 
                      c=parameters[:, 0], cmap='viridis', alpha=0.7)
plt.colorbar(scatter, label='Diffusivity')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('Latent Space (colored by diffusivity)')
plt.grid(True, alpha=0.3)
plt.show()

## 11. Train Diffusion Model

In [None]:
optimizer_diff = torch.optim.AdamW(diffusion.parameters(), lr=LEARNING_RATE)
scheduler_diff = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_diff, DIFF_EPOCHS)

diff_losses = []
best_diff_loss = float('inf')

print(f"Training Diffusion for {DIFF_EPOCHS} epochs...")
start_time = time.time()

for epoch in range(DIFF_EPOCHS):
    diffusion.train()
    epoch_loss = 0

    for (batch_z,) in latent_dataloader:
        optimizer_diff.zero_grad()
        t = torch.randint(0, DIFFUSION_STEPS, (batch_z.size(0),), device=device)
        z_noisy, noise = diffusion.add_noise(batch_z, t)
        noise_pred = diffusion(z_noisy, t)
        loss = F.mse_loss(noise_pred, noise)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(diffusion.parameters(), 1.0)
        optimizer_diff.step()
        epoch_loss += loss.item()

    scheduler_diff.step()
    avg_loss = epoch_loss / len(latent_dataloader)
    diff_losses.append(avg_loss)

    if avg_loss < best_diff_loss:
        best_diff_loss = avg_loss
        torch.save(diffusion.state_dict(), f"{OUTPUT_DIR}/diffusion_best.pt")

    if (epoch + 1) % 50 == 0 or epoch == 0:
        elapsed = time.time() - start_time
        print(f"  Epoch {epoch+1:3d}/{DIFF_EPOCHS} | Loss: {avg_loss:.6f} | {elapsed:.1f}s")

print(f"Diffusion training complete. Best loss: {best_diff_loss:.6f}")
diffusion.load_state_dict(torch.load(f"{OUTPUT_DIR}/diffusion_best.pt"))

In [None]:
# Plot diffusion training curve
plt.figure(figsize=(10, 4))
plt.plot(diff_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Diffusion Training Loss')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

## 12. Generate Synthetic Samples

In [None]:
vae.eval()
diffusion.eval()

N_SYNTHETIC = 16

print(f"Generating {N_SYNTHETIC} synthetic samples...")
with torch.no_grad():
    synthetic_latents = diffusion.sample(N_SYNTHETIC, device)
    synthetic_fields = vae.decode(synthetic_latents, eigenvectors_tensor)
    synthetic_fields = synthetic_fields.cpu().numpy()

print(f"Generated fields shape: {synthetic_fields.shape}")
print(f"Value range: [{synthetic_fields.min():.3f}, {synthetic_fields.max():.3f}]")

In [None]:
# Visualize synthetic samples
fig, axes = plt.subplots(4, 4, figsize=(14, 12))
fig.suptitle('Synthetic FEA Fields (Generated)', fontsize=14, fontweight='bold')

for i, ax in enumerate(axes.flatten()):
    field = np.clip(synthetic_fields[i], 0, 1)
    levels = np.linspace(0, 1, 25)
    ax.tricontourf(triang, field, levels=levels, cmap='inferno', extend='both')
    ax.tricontour(triang, field, levels=8, colors='white', linewidths=0.3, alpha=0.5)
    circle = plt.Circle((0, 0), CYLINDER_RADIUS, fill=True, color='cyan', ec='black', lw=0.5)
    ax.add_patch(circle)
    ax.set_aspect('equal')
    ax.set_title(f'Sample {i+1}')
    ax.axis('off')

plt.tight_layout()
plt.show()

## 13. Compare Real vs Synthetic

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle('Real vs Synthetic Comparison', fontsize=14, fontweight='bold')

# Top row: real samples
for i in range(4):
    ax = axes[0, i]
    idx = np.random.randint(n_samples)
    field = solutions[idx]
    levels = np.linspace(0, 1, 25)
    ax.tricontourf(triang, field, levels=levels, cmap='inferno', extend='both')
    circle = plt.Circle((0, 0), CYLINDER_RADIUS, fill=True, color='cyan', ec='black', lw=0.5)
    ax.add_patch(circle)
    ax.set_aspect('equal')
    ax.set_title(f'Real {i+1}')
    ax.axis('off')

# Bottom row: synthetic samples
for i in range(4):
    ax = axes[1, i]
    field = np.clip(synthetic_fields[i], 0, 1)
    levels = np.linspace(0, 1, 25)
    ax.tricontourf(triang, field, levels=levels, cmap='inferno', extend='both')
    circle = plt.Circle((0, 0), CYLINDER_RADIUS, fill=True, color='cyan', ec='black', lw=0.5)
    ax.add_patch(circle)
    ax.set_aspect('equal')
    ax.set_title(f'Synthetic {i+1}')
    ax.axis('off')

axes[0, 0].set_ylabel('REAL', fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel('SYNTHETIC', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Latent space visualization with synthetic samples
synthetic_2d = pca.transform(synthetic_latents.cpu().numpy())

plt.figure(figsize=(8, 8))
plt.scatter(latents_2d[:, 0], latents_2d[:, 1], c='blue', alpha=0.5, s=20, label='Real')
plt.scatter(synthetic_2d[:, 0], synthetic_2d[:, 1], c='red', alpha=0.8, s=100, marker='*', label='Synthetic')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('Latent Space: Real vs Synthetic')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 14. Save Models and Results

In [None]:
# Save synthetic samples
np.savez(f"{OUTPUT_DIR}/synthetic_samples.npz",
         fields=synthetic_fields,
         latents=synthetic_latents.cpu().numpy())

# Save final models
torch.save({
    'vae_state_dict': vae.state_dict(),
    'diffusion_state_dict': diffusion.state_dict(),
    'config': {
        'latent_dim': LATENT_DIM,
        'hidden_dim': HIDDEN_DIM,
        'n_encoder_layers': N_ENCODER_LAYERS,
        'n_decoder_layers': N_DECODER_LAYERS,
        'diffusion_steps': DIFFUSION_STEPS,
        'n_nodes': n_nodes,
        'n_eigenvectors': n_eigenvectors
    }
}, f"{OUTPUT_DIR}/models_final.pt")

print(f"Saved to {OUTPUT_DIR}/")
print(f"  - synthetic_samples.npz")
print(f"  - models_final.pt")
print(f"  - vae_best.pt")
print(f"  - diffusion_best.pt")

## 15. Generate More Samples (Interactive)

In [None]:
def generate_and_plot(n_samples=4):
    """Generate new synthetic samples on demand."""
    with torch.no_grad():
        z = diffusion.sample(n_samples, device)
        fields = vae.decode(z, eigenvectors_tensor).cpu().numpy()
    
    cols = min(4, n_samples)
    rows = (n_samples + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3.5*rows))
    if n_samples == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for i in range(n_samples):
        ax = axes[i]
        field = np.clip(fields[i], 0, 1)
        ax.tricontourf(triang, field, levels=25, cmap='inferno')
        circle = plt.Circle((0, 0), CYLINDER_RADIUS, fill=True, color='cyan', ec='black')
        ax.add_patch(circle)
        ax.set_aspect('equal')
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    return fields

# Generate 4 new samples
new_samples = generate_and_plot(4)

In [None]:
print("Training complete!")
print(f"\nSummary:")
print(f"  - Dataset: {n_samples} FEA solutions, {n_nodes} nodes each")
print(f"  - VAE final loss: {vae_losses[-1]:.6f}")
print(f"  - Diffusion final loss: {diff_losses[-1]:.6f}")
print(f"  - Model parameters: {n_vae_params + n_diff_params:,}")