# Passive Scalar Transport in pygSQuiG

This notebook demonstrates:
1. Adding passive scalars to gSQG simulations
2. Configuring scalar diffusivity
3. Analyzing scalar mixing and transport
4. Multiple scalar species
5. Scalar variance spectra

## 1. Setup and Imports

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib import cm

# pygSQuiG imports
from pygsquig.core.grid import make_grid, ifft2
from pygsquig.core.solver_with_scalars import gSQGSolverWithScalars
from pygsquig.scalars.diagnostics import (
    compute_scalar_variance,
    compute_scalar_variance_spectrum,
    compute_scalar_dissipation,
)
from pygsquig.utils.diagnostics import compute_total_energy

print("pygSQuiG passive scalar demonstration")

## 2. Physical Setup

Passive scalars are advected by the flow but don't affect it (one-way coupling).
The equation for each scalar $c_i$ is:

$$\frac{\partial c_i}{\partial t} + \mathbf{u} \cdot \nabla c_i = \kappa_i \nabla^2 c_i + S_i$$

where:
- $\mathbf{u}$ is the velocity from the gSQG dynamics
- $\kappa_i$ is the molecular diffusivity
- $S_i$ is an optional source term

In [None]:
# Grid parameters
N = 128
L = 2 * np.pi
grid = make_grid(N, L)

# gSQG parameters
alpha = 1.0      # SQG case
nu_p = 1e-16     # Hyperviscosity for θ
p = 8            # Hyperviscosity order

# Passive scalar configuration
# We'll add a 'dye' scalar with low diffusivity
passive_scalars = {
    'dye': {
        'kappa': 1e-4,  # Much smaller than velocity dissipation
        'source': None   # No source for now
    }
}

print(f"Configuration:")
print(f"  Grid: {N}×{N}, L={L:.2f}")
print(f"  gSQG: α={alpha}, ν_{p}={nu_p:.1e}")
print(f"  Scalar 'dye': κ={passive_scalars['dye']['kappa']:.1e}")

In [None]:
# Create solver that includes passive scalars
solver = gSQGSolverWithScalars(
    grid=grid,
    alpha=alpha,
    nu_p=nu_p,
    p=p,
    passive_scalars=passive_scalars
)

print("Solver created with passive scalar capability!")
print(f"Number of scalars: {len(passive_scalars)}")

## 4. Initialize with Scalar Field

We'll initialize:
- Random turbulent flow for θ
- A Gaussian blob for the scalar (to watch it get mixed)

In [None]:
# Initialize θ with random field
state = solver.initialize(seed=42)

# Create initial scalar field - Gaussian blob
x, y = grid.x, grid.y
x0, y0 = L/2, L/2  # Center of domain
width = L/8        # Blob width

# Gaussian blob
r2 = (x - x0)**2 + (y - y0)**2
dye_init = jnp.exp(-r2 / (2 * width**2))

# Add scalar to state
scalar_init = {'dye': dye_init}
state = solver.initialize(seed=42, scalar_init=scalar_init)

print("Initial state created with scalar field")
print(f"Initial dye variance: {compute_scalar_variance(state.scalar_state.scalars['dye'], grid):.6f}")

In [None]:
# Initialize θ with random field
state = solver.initialize(seed=42)

# Create initial scalar field - Gaussian blob
x, y = grid.x, grid.y
x0, y0 = L/2, L/2  # Center of domain
width = L/8        # Blob width

# Gaussian blob
r2 = (x - x0)**2 + (y - y0)**2
dye_init = jnp.exp(-r2 / (2 * width**2))

# Add scalar to state
scalar_init = {'dye': dye_init}
state = solver.initialize(seed=42, scalar_init=scalar_init)

print("Initial state created with scalar field")
print(f"Initial dye variance: {float(compute_scalar_variance(state['scalar_state'].scalars['dye'])):.6f}")

# Visualize initial condition
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# θ field
theta = ifft2(state['theta_hat']).real
im1 = ax1.imshow(theta, cmap='RdBu_r', origin='lower', extent=[0, L, 0, L])
ax1.set_title('Initial θ (buoyancy)')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
plt.colorbar(im1, ax=ax1)

# Dye field
dye = ifft2(state['scalar_state'].scalars['dye']).real
im2 = ax2.imshow(dye, cmap='viridis', origin='lower', extent=[0, L, 0, L])
ax2.set_title('Initial dye concentration')
ax2.set_xlabel('x')
ax2.set_ylabel('y')
plt.colorbar(im2, ax=ax2)

plt.tight_layout()
plt.show()

In [None]:
# Initialize θ with random field
state = solver.initialize(seed=42)

# Create initial scalar field - Gaussian blob
x, y = grid.x, grid.y
x0, y0 = L/2, L/2  # Center of domain
width = L/8        # Blob width

# Gaussian blob
r2 = (x - x0)**2 + (y - y0)**2
dye_init = jnp.exp(-r2 / (2 * width**2))

# Add scalar to state
scalar_init = {'dye': dye_init}
state = solver.initialize(seed=42, scalar_init=scalar_init)

print("Initial state created with scalar field")
print(f"Initial dye variance: {float(compute_scalar_variance(state.scalar_state.scalars['dye'])):.6f}")

In [None]:
# Initialize θ with random field
state = solver.initialize(seed=42)

# Create initial scalar field - Gaussian blob
x, y = grid.x, grid.y
x0, y0 = L/2, L/2  # Center of domain
width = L/8        # Blob width

# Gaussian blob
r2 = (x - x0)**2 + (y - y0)**2
dye_init = jnp.exp(-r2 / (2 * width**2))

# Add scalar to state
scalar_init = {'dye': dye_init}
state = solver.initialize(seed=42, scalar_init=scalar_init)

print("Initial state created with scalar field")
print(f"Initial dye variance: {float(compute_scalar_variance(state['scalar_state'].scalars['dye'])):.6f}")

## 6. Mixing Diagnostics

In [None]:
# Plot mixing diagnostics
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

# Scalar variance decay
ax1.plot(times, dye_variances, 'b-', linewidth=2)
ax1.set_ylabel('Scalar Variance ⟨c²⟩')
ax1.set_title('Scalar Variance Decay (Mixing)')
ax1.grid(True, alpha=0.3)
ax1.set_yscale('log')

# Gradient growth (strain)
ax2.plot(times, dye_gradients, 'r-', linewidth=2)
ax2.set_xlabel('Time')
ax2.set_ylabel('⟨|∇c|²⟩^{1/2}')
ax2.set_title('Scalar Gradient (Strain Enhancement)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Mixing efficiency
initial_var = dye_variances[0]
final_var = dye_variances[-1]
mixing_efficiency = 1 - final_var/initial_var

print(f"Mixing analysis:")
print(f"  Initial variance: {initial_var:.6f}")
print(f"  Final variance: {final_var:.6f}")
print(f"  Mixing efficiency: {mixing_efficiency*100:.1f}%")
print(f"  Peak gradient enhancement: {max(dye_gradients)/dye_gradients[0]:.1f}x")

# Time stepping parameters
dt = 0.001
n_steps = 500
plot_interval = 100

# Storage for diagnostics
times = [state.time]
dye_variances = [float(compute_scalar_variance(state.scalar_state.scalars['dye']))]
dye_dissipation = [compute_scalar_dissipation(state.scalar_state.scalars['dye'], grid, passive_scalars['dye']['kappa'])]

# Evolution snapshots
snapshots = [(state.time, state.scalar_state.scalars['dye'].copy())]

print("Starting evolution...")
for step in range(n_steps):
    state = solver.step(state, dt)
    
    # Diagnostics
    if (step + 1) % 10 == 0:
        times.append(float(state.time))
        dye_variances.append(float(compute_scalar_variance(state.scalar_state.scalars['dye'])))
        dye_dissipation.append(float(compute_scalar_dissipation(state.scalar_state.scalars['dye'], grid, passive_scalars['dye']['kappa'])))
    
    # Save snapshots
    if (step + 1) % plot_interval == 0:
        snapshots.append((float(state.time), state.scalar_state.scalars['dye'].copy()))
        print(f"  Step {step+1}: t={state.time:.3f}")

print("Evolution complete!")

In [None]:
# Time stepping parameters
dt = 0.001
n_steps = 500
plot_interval = 100

# Storage for diagnostics
times = [float(state['time'])]
dye_variances = [float(compute_scalar_variance(state['scalar_state'].scalars['dye']))]
dye_dissipation = [float(compute_scalar_dissipation(state['scalar_state'].scalars['dye'], grid, passive_scalars['dye']['kappa']))]

# Evolution snapshots
snapshots = [(float(state['time']), state['scalar_state'].scalars['dye'].copy())]

print("Starting evolution...")
for step in range(n_steps):
    state = solver.step(state, dt)
    
    # Diagnostics
    if (step + 1) % 10 == 0:
        times.append(float(state['time']))
        dye_variances.append(float(compute_scalar_variance(state['scalar_state'].scalars['dye'])))
        dye_dissipation.append(float(compute_scalar_dissipation(state['scalar_state'].scalars['dye'], grid, passive_scalars['dye']['kappa'])))
    
    # Save snapshots
    if (step + 1) % plot_interval == 0:
        snapshots.append((float(state['time']), state['scalar_state'].scalars['dye'].copy()))
        print(f"  Step {step+1}: t={float(state['time']):.3f}")

print("Evolution complete!")

# Compute scalar spectrum
k_bins, C_k = compute_scalar_variance_spectrum(state['scalar_state'].scalars['dye'], grid)

# Also compute velocity spectrum for comparison
from pygsquig.utils.diagnostics import compute_energy_spectrum
k_e, E_k = compute_energy_spectrum(state['theta_hat'], grid, alpha)

# Plot spectra
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

# Normalize for comparison
ax.loglog(k_bins, C_k/C_k[1], 'b-', linewidth=2, label='Scalar spectrum C(k)')
ax.loglog(k_e, E_k/E_k[1], 'r--', linewidth=2, label='Energy spectrum E(k)')

# Reference slopes
k_ref = k_bins[k_bins > 10]
if len(k_ref) > 0:
    # Batchelor spectrum k^{-1} for high Schmidt number
    C_batch = (C_k/C_k[1])[k_bins == k_ref[0]][0] * (k_ref/k_ref[0])**(-1)
    ax.loglog(k_ref, C_batch, 'k:', alpha=0.7, label='k⁻¹ (Batchelor)')

ax.set_xlabel('Wavenumber k')
ax.set_ylabel('Normalized Spectrum')
ax.set_title('Scalar vs Energy Spectrum')
ax.grid(True, alpha=0.3, which='both')
ax.legend()
ax.set_xlim(1, N/2)

plt.show()

print("Spectrum analysis:")
print("  - Scalar spectrum extends to smaller scales than energy")
print("  - This is due to lower diffusivity (higher Schmidt number)")
print("  - At high k, expect Batchelor scaling C(k) ~ k⁻¹")

In [None]:
# Plot mixing diagnostics
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

# Scalar variance decay
ax1.plot(times, dye_variances, 'b-', linewidth=2)
ax1.set_ylabel('Scalar Variance ⟨c²⟩')
ax1.set_title('Scalar Variance Decay (Mixing)')
ax1.grid(True, alpha=0.3)
ax1.set_yscale('log')

# Dissipation rate (χ = κ⟨|∇c|²⟩)
ax2.plot(times, dye_dissipation, 'r-', linewidth=2)
ax2.set_xlabel('Time')
ax2.set_ylabel('Dissipation Rate χ')
ax2.set_title('Scalar Dissipation Rate')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Mixing efficiency
initial_var = dye_variances[0]
final_var = dye_variances[-1]
mixing_efficiency = 1 - final_var/initial_var

print(f"Mixing analysis:")
print(f"  Initial variance: {initial_var:.6f}")
print(f"  Final variance: {final_var:.6f}")
print(f"  Mixing efficiency: {mixing_efficiency*100:.1f}%")
print(f"  Peak dissipation rate: {max(dye_dissipation):.3e}")

In [None]:
# Visualize initial conditions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# θ field
theta = ifft2(state_multi.theta_hat).real
im = axes[0,0].imshow(theta, cmap='RdBu_r', origin='lower', extent=[0, L, 0, L])
axes[0,0].set_title('θ (buoyancy)')
plt.colorbar(im, ax=axes[0,0])

# Scalars
scalar_names = list(multi_scalars.keys())
cmaps = ['hot', 'viridis', 'plasma']

for i, (name, cmap) in enumerate(zip(scalar_names, cmaps)):
    ax_idx = (i+1) // 2, (i+1) % 2
    scalar = ifft2(state_multi.scalar_state.scalars[name]).real
    im = axes[ax_idx].imshow(scalar, cmap=cmap, origin='lower', extent=[0, L, 0, L])
    axes[ax_idx].set_title(f'{name} (κ={multi_scalars[name]["kappa"]:.1e})')
    plt.colorbar(im, ax=axes[ax_idx])

for ax in axes.flat:
    ax.set_xlabel('x')
    ax.set_ylabel('y')

plt.suptitle('Initial Conditions: Multiple Scalars', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Visualize initial conditions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# θ field
theta = ifft2(state_multi['theta_hat']).real
im = axes[0,0].imshow(theta, cmap='RdBu_r', origin='lower', extent=[0, L, 0, L])
axes[0,0].set_title('θ (buoyancy)')
plt.colorbar(im, ax=axes[0,0])

# Scalars
scalar_names = list(multi_scalars.keys())
cmaps = ['hot', 'viridis', 'plasma']

for i, (name, cmap) in enumerate(zip(scalar_names, cmaps)):
    ax_idx = (i+1) // 2, (i+1) % 2
    scalar = ifft2(state_multi['scalar_state'].scalars[name]).real
    im = axes[ax_idx].imshow(scalar, cmap=cmap, origin='lower', extent=[0, L, 0, L])
    axes[ax_idx].set_title(f'{name} (κ={multi_scalars[name]["kappa"]:.1e})')
    plt.colorbar(im, ax=axes[ax_idx])

for ax in axes.flat:
    ax.set_xlabel('x')
    ax.set_ylabel('y')

plt.suptitle('Initial Conditions: Multiple Scalars', fontsize=14)
plt.tight_layout()
plt.show()

# Evolve and compare mixing rates
n_steps_multi = 300
variances = {name: [float(compute_scalar_variance(state_multi['scalar_state'].scalars[name]))] 
             for name in scalar_names}

print("Evolving multiple scalars...")
for step in range(n_steps_multi):
    state_multi = solver_multi.step(state_multi, dt)
    
    if (step + 1) % 50 == 0:
        print(f"  Step {step+1}: t={float(state_multi['time']):.3f}")
        for name in scalar_names:
            var = compute_scalar_variance(state_multi['scalar_state'].scalars[name])
            variances[name].append(float(var))

# Plot comparative mixing
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

colors = ['red', 'green', 'purple']
for name, color in zip(scalar_names, colors):
    # Normalize by initial variance
    var_norm = np.array(variances[name]) / variances[name][0]
    times = np.arange(len(variances[name])) * 50 * dt
    ax.plot(times, var_norm, color=color, linewidth=2, 
            label=f'{name} (κ={multi_scalars[name]["kappa"]:.1e})')

ax.set_xlabel('Time')
ax.set_ylabel('Normalized Variance ⟨c²⟩/⟨c²⟩₀')
ax.set_title('Comparative Mixing Rates')
ax.grid(True, alpha=0.3)
ax.legend()
ax.set_yscale('log')

plt.show()

print("\nMixing rate analysis:")
print("  - Higher diffusivity → faster mixing")
print("  - Pollutant (lowest κ) retains structure longest")
print("  - Heat (highest κ) mixes most rapidly")

In [None]:
# Estimate typical velocity scale
from pygsquig.core.operators import compute_velocity_from_theta
u, v = compute_velocity_from_theta(state.theta_hat, grid, alpha)
U = float(jnp.sqrt(jnp.mean(u**2 + v**2)))

# Compute Péclet numbers
print("Péclet number analysis:")
print(f"Characteristic velocity U ≈ {U:.3f}")
print(f"Domain size L = {L:.2f}")
print("\nPéclet numbers (Pe = UL/κ):")

for name in multi_scalars:
    kappa = multi_scalars[name]['kappa']
    Pe = U * L / kappa
    print(f"  {name}: Pe = {Pe:.0f} (κ = {kappa:.1e})")

print("\nInterpretation:")
print("  Pe >> 1: Advection dominated (sharp gradients, slow mixing)")
print("  Pe ~ 1: Balance of advection and diffusion")
print("  Pe << 1: Diffusion dominated (smooth fields, fast mixing)")

# Estimate typical velocity scale
from pygsquig.core.operators import compute_velocity_from_theta
u, v = compute_velocity_from_theta(state['theta_hat'], grid, alpha)
U = float(jnp.sqrt(jnp.mean(u**2 + v**2)))

# Compute Péclet numbers
print("Péclet number analysis:")
print(f"Characteristic velocity U ≈ {U:.3f}")
print(f"Domain size L = {L:.2f}")
print("\nPéclet numbers (Pe = UL/κ):")

for name in multi_scalars:
    kappa = multi_scalars[name]['kappa']
    Pe = U * L / kappa
    print(f"  {name}: Pe = {Pe:.0f} (κ = {kappa:.1e})")

print("\nInterpretation:")
print("  Pe >> 1: Advection dominated (sharp gradients, slow mixing)")
print("  Pe ~ 1: Balance of advection and diffusion")
print("  Pe << 1: Diffusion dominated (smooth fields, fast mixing)")

In [None]:
# Example: Resolution check
def check_scalar_resolution(scalar_hat, grid, kappa, nu_p=1e-16):
    """Check if scalar is well-resolved."""
    k_bins, C_k = compute_scalar_spectrum(scalar_hat, grid)
    
    # Find where spectrum drops significantly
    C_k_norm = C_k / C_k[0]
    k_cutoff_idx = np.where(C_k_norm < 1e-8)[0]
    
    if len(k_cutoff_idx) > 0:
        k_cutoff = k_bins[k_cutoff_idx[0]]
        k_max = grid.N // 2
        
        resolution_factor = k_max / k_cutoff
        print(f"Resolution check:")
        print(f"  Spectrum cutoff at k ≈ {k_cutoff:.0f}")
        print(f"  Maximum resolved k = {k_max}")
        print(f"  Resolution factor = {resolution_factor:.1f}")
        
        if resolution_factor > 2:
            print("  ✓ Well resolved")
        else:
            print("  ⚠️ Marginally resolved - consider higher resolution")
    else:
        print("  ⚠️ Cannot determine resolution - spectrum may be under-resolved")

# Check our scalars
for name in ['pollutant']:  # Check the highest Pe scalar
    print(f"\nChecking {name}:")
    check_scalar_resolution(
        state_multi.scalar_state.scalars[name], 
        grid, 
        multi_scalars[name]['kappa']
    )

# Example: Resolution check
def check_scalar_resolution(scalar_hat, grid, kappa, nu_p=1e-16):
    """Check if scalar is well-resolved."""
    k_bins, C_k = compute_scalar_variance_spectrum(scalar_hat, grid)
    
    # Find where spectrum drops significantly
    C_k_norm = C_k / C_k[0]
    k_cutoff_idx = np.where(C_k_norm < 1e-8)[0]
    
    if len(k_cutoff_idx) > 0:
        k_cutoff = k_bins[k_cutoff_idx[0]]
        k_max = grid.N // 2
        
        resolution_factor = k_max / k_cutoff
        print(f"Resolution check:")
        print(f"  Spectrum cutoff at k ≈ {k_cutoff:.0f}")
        print(f"  Maximum resolved k = {k_max}")
        print(f"  Resolution factor = {resolution_factor:.1f}")
        
        if resolution_factor > 2:
            print("  ✓ Well resolved")
        else:
            print("  ⚠️ Marginally resolved - consider higher resolution")
    else:
        print("  ⚠️ Cannot determine resolution - spectrum may be under-resolved")

# Check our scalars
for name in ['pollutant']:  # Check the highest Pe scalar
    print(f"\nChecking {name}:")
    check_scalar_resolution(
        state_multi['scalar_state'].scalars[name], 
        grid, 
        multi_scalars[name]['kappa']
    )