In [29]:
import jax.numpy as jnp
import scipy.integrate as spi

# Define a harmonic function
def phi(x, y):
    return jnp.sin(-x**2) #jnp.log(jnp.sqrt((x)**2+(y-1)**2))
Delta_x = 1e-2
L = 10; N = int(2*L/Delta_x) + 1
x_values = jnp.linspace(-L, L, N)
phi_samples = jnp.array([phi(x, 0) for x in x_values])
LIMIT = 2000

jj = int(N/2) - 20
x0 = x_values[jj]

# Compute vertical derivative (partial derivative w.r.t y) at (x0,0)
#x0 = 10.0  # Test point
#eps = 0.1
#numerical_phi_z = (phi(x0, eps) - phi(x0, -eps))/(2*eps)

# Define the integral function (principal value integral)
#def integrand(x):
#    return (phi(x0, 0.0) - phi(x, 0.0)) / (x - x0) ** 2

# Integrating near the origin
#Dx = eps/2
#f_evals = jnp.array([phi(x, 0) for x in [x0-2*Dx, x0-Dx, x0, x0+Dx, x0+2*Dx]])
#int_near_x0 = jnp.dot(f_evals, jnp.array([-1.0, -32.0, 66.0, -32.0, -1.0]))/(18*Dx)

# Use SciPy's to integrate. 
#LIMIT = 2000
#integral_value = spi.quad(integrand, -jnp.inf, x0-eps, limit=LIMIT)[0] + spi.quad(integrand, x0+eps, jnp.inf, limit=LIMIT)[0] + int_near_x0
#integral_uncer = spi.quad(integrand, -jnp.inf, x0-eps, limit=LIMIT)[1] + spi.quad(integrand, x0+eps, jnp.inf, limit=LIMIT)[1] 
#approx_phi_z = integral_value / jnp.pi

# Compare results
#print(f"Numerical phi_z: {numerical_phi_z}")
#print(f"Integral near x0: {int_near_x0} ")
#print(f"Integral approximation: {approx_phi_z} +- {integral_uncer/jnp.pi}")

In [30]:
def DtN_generator(Delta_x = 1/jnp.float32(100), N = None):
    '''
    This script will generate the matrix so that Aphi is an approximation of dphi/dz
    '''
    N = int(1/Delta_x) if N is None else N

    # Create the main diagonal with 66's
    DtN1 = jnp.diag(jnp.full(N, 66))
    
    # Fill the first sub- and super-diagonals with -32's
    if N > 1:
        DtN1 += jnp.diag(jnp.full(N-1, -32), k=1)
        DtN1 += jnp.diag(jnp.full(N-1, -32), k=-1)
        
    # Fill the second sub- and super-diagonals with -1's
    if N > 2:
        DtN1 += jnp.diag(jnp.full(N-2, -1), k=2)
        DtN1 += jnp.diag(jnp.full(N-2, -1), k=-2)
        
    DtN1 = DtN1 / 18.0 # This is the integral around the origin
    DtN2 = jnp.diag(jnp.full(N, 1.0)) # First integral away of the origin. 
    
    # Now second integral away from the origin
    coefficients = [0 for _ in range(N+1)]
    coef = lambda n, d: -jnp.float32(n)/(n+d) + (2*n - d)/2 * jnp.log((n+1)/(n-1)) - 1.0
    for jj in range(1, int(N/2)):
        n = 2 * jj + 1
        coefficients[n-1] += coef(n, -1.0)
        coefficients[n+1] += coef(n, +1.0)
        coefficients[n]   += -2*coef(n, 0.0)

    coefficients = jnp.array(coefficients)  
    print(coefficients[:5])
    #i = jnp.arange(N)
    #j = jnp.arange(N)
    #I, J = jnp.meshgrid(i, j, indexing='ij')
    #diff = jnp.abs(J - I)

    #DtN3 = jnp.zeros((N, N))
    #for ii in range(N):
    #    for jj in range(N):
    #        if ii >= jj:
    #            DtN3 = DtN3.at[(ii, jj)].set(coefficients[ii-jj])
    #        else:
    #            DtN3 = DtN3.at[(ii, jj)].set(coefficients[jj-ii])
    i, j = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing='ij')
    DtN3 = coefficients[jnp.abs(i - j)]

    # Use jnp.where to apply the function elementwise
    #DtN3 = jnp.where(diff >= 0, 
    #              jnp.take(coefficients, diff),  
    #              -jnp.take(coefficients, -diff)) 
    #DtN3 = jnp.take(coefficients, diff)
    
    # Apply the function f(i, j) to all pairs (i, j)
    DtN = DtN1 + DtN2 + DtN3  # Broadcasting will handle the rest

    return DtN/(jnp.pi * Delta_x), DtN1/(jnp.pi * Delta_x), DtN2/(jnp.pi * Delta_x), DtN3/(jnp.pi * Delta_x)

DtN, B_near, C_away_1, D_away_2 = DtN_generator(N=N, Delta_x=Delta_x)


[ 0.          0.         -0.07398486 -0.1588831  -0.03707409]


In [31]:
phi_away_1_dtn = C_away_1 @ phi_samples
# Define the integral function (principal value integral)
def integrand(x):
    return phi(x0, 0.0) / (x - x0) ** 2

integral_value = (spi.quad(integrand, -jnp.inf, x0-2*Delta_x, limit=LIMIT)[0] + spi.quad(integrand, x0+2*Delta_x, jnp.inf, limit=LIMIT)[0]) / jnp.pi

print(f"DtN approximation away 1:       {phi_away_1_dtn[jj]}")
print(f"Numerical approximation away 1: {integral_value}")


  the requested tolerance from being achieved.  The error may be 
  underestimated.
  integral_value = (spi.quad(integrand, -jnp.inf, x0-2*Delta_x, limit=LIMIT)[0] + spi.quad(integrand, x0+2*Delta_x, jnp.inf, limit=LIMIT)[0]) / jnp.pi


DtN approximation away 1:       -1.2728915214538574
Numerical approximation away 1: -1.2728918694869809


In [32]:
phi_away_2_dtn = D_away_2 @ phi_samples
# Define the integral function (principal value integral)
def integrand(x):
    return -phi(x, 0.0) / ((x - x0) ** 2)

integral_value = (spi.quad(integrand, -jnp.inf, x0-2*Delta_x, limit=LIMIT)[0] + spi.quad(integrand, x0+2*Delta_x, jnp.inf, limit=LIMIT)[0]) / jnp.pi

print(f"DtN approximation away 2:       {phi_away_2_dtn[jj]}")
print(f"Numerical approximation away 2: {integral_value}")


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  integral_value = (spi.quad(integrand, -jnp.inf, x0-2*Delta_x, limit=LIMIT)[0] + spi.quad(integrand, x0+2*Delta_x, jnp.inf, limit=LIMIT)[0]) / jnp.pi


DtN approximation away 2:       1.9926036596298218
Numerical approximation away 2: 1.9926232270339486


In [33]:
#numerical_phi_z = (phi(x0, Delta_x/100) - phi(x0, -Delta_x/100))/(2*Delta_x/100)

#from DtN import DtN_generator
#DtN = DtN_generator(N)
DtN_phi_z = DtN @ phi_samples

Dx = Delta_x
f_evals = jnp.array([phi(x, 0) for x in [x0-2*Dx, x0-Dx, x0, x0+Dx, x0+2*Dx]])
int_near_x0 = jnp.dot(f_evals, jnp.array([-1.0, -32.0, 66.0, -32.0, -1.0]))/(18*jnp.pi*Dx)

def integrand(x):
    return (phi(x0, 0.0)-phi(x, 0.0)) / ((x - x0) ** 2)

integral_value = (spi.quad(integrand, -jnp.inf, x0-2*Delta_x, limit=LIMIT)[0] + spi.quad(integrand, x0+2*Delta_x, jnp.inf, limit=LIMIT)[0]) / (jnp.pi)

print(f"DtN approximation:       {DtN_phi_z[jj]}")
print(f"Numerical integration: {integral_value + int_near_x0}")

  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  integral_value = (spi.quad(integrand, -jnp.inf, x0-2*Delta_x, limit=LIMIT)[0] + spi.quad(integrand, x0+2*Delta_x, jnp.inf, limit=LIMIT)[0]) / (jnp.pi)


DtN approximation:       0.7323896288871765
Numerical integration: 0.7324122190475464


In [76]:
#import numpy as jnp
from jax import numpy as jnp
from findiff import Diff, coefficients

x = jnp.linspace(0, 10, 100)
dx = x[1] - x[0]
f = jnp.sin(x)
g = jnp.cos(x)

d_dx = Diff(0, dx)

coefficients(deriv=1, acc=2)
d2_dx2 = Diff(0, float(dx), acc=2) ** 2
result = d2_dx2(f)

x, y, z = jnp.linspace(-2, 2, 5), jnp.linspace(-3, 3, 7), jnp.linspace(-1, 1, 7)
dx, dy, dz = float(x[1] - x[0]), float(y[1] - y[0]), float(z[1] - z[0])
X, Y= jnp.meshgrid(x, y, indexing='ij')
boundaries = (jnp.abs(X) == 2) | (jnp.abs(Y) == 3)
f = jnp.sin(X) * jnp.cos(Y) # * jnp.sin(Z)

linear_op =  Diff(0, dx)**2 + Diff(1, dy)**2
print(boundaries)
#print(linear_op.matrix((5, 5)))

[[ True  True  True  True  True  True  True]
 [ True False False False False False  True]
 [ True False False False False False  True]
 [ True False False False False False  True]
 [ True  True  True  True  True  True  True]]


In [81]:
A = linear_op.matrix((5, 7))
B = jnp.array(A.toarray().reshape(5, 7, 5, 7))
print(B[4, 6, :, :])

[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0. -1.]
 [ 0.  0.  0.  0.  0.  0.  4.]
 [ 0.  0.  0.  0.  0.  0. -5.]
 [ 0.  0.  0. -1.  4. -5.  4.]]


In [None]:
import jax.numpy as jnp

def solve_finite_diff(Operator, function_values, domain, boundaries, boundary_values):
    n = len(domain[0].shape) # dimension
    points = domain[0].shape

    matrix_form = jnp.array(A.toarray().reshape(points))


def solve_tensor_system(A, b):
    m = A.shape[0]
    x_shape = A.shape[1:]
    A_mat = A.reshape(m, -1)  # shape (m, N)
    b_vec = b.reshape(m)
    if m == b.shape[0]:
        x_flat = jnp.linalg.solve(A_mat, b_vec)
    else:
        x_flat, _, _, _ = jnp.linalg.lstsq(A_mat, b_vec, rcond=None)

    return x_flat.reshape(x_shape)

# Known solution
x_true = jnp.array([[1., -1.],
                    [2., 0.5]])

A_mat = jnp.array([
    [.1, 0., 0., 0.],
    [0., 2., 0., 0.1],
    [0., 0., 3., 0.],
    [1., 0., 0., 4.]
])
A = A_mat.reshape(4, 2, 2)


# Compute b from A and x_true
b = jnp.einsum('ijk,jk->i', A, x_true)  # shape (4,)

# Solve
x_computed = solve_tensor_system(A, b)

print("x_computed:\n", x_computed)
print("\nx_true:\n", x_true)


yes!
x_computed:
 [[ 0.99999976 -1.        ]
 [ 2.          0.50000006]]

x_true:
 [[ 1.  -1. ]
 [ 2.   0.5]]


In [142]:
import jax.numpy as jnp

def solve_tensor_with_known(A, b, x_known, known_mask):
    """
    Solves A * x = b where:
        - A: tensor of shape (p, q, p, q)
        - x_known: known entries of x (shape (p, q))
        - known_mask: boolean mask of shape (p, q) where True means 'known'
        - b: right-hand side tensor of shape (p, q)

    Returns:
        x_full: full solution of shape (p, q)
    """
    p, q = x_known.shape
    x_shape = (p, q)
    x_known_flat = x_known.reshape(-1)
    known_mask_flat = known_mask.reshape(-1)
    unknown_mask_flat = ~known_mask_flat

    # Build the full system matrix (reshaped for linear solve)
    A_mat_full = A.reshape(p * q, p * q)  # shape (eqns, vars)

    # Remove equations corresponding to known x[i,j]
    eq_mask_flat = ~known_mask_flat
    A_mat_reduced = A_mat_full[eq_mask_flat, :]       # shape (num_eqns, p*q)
    A_mat_unknown = A_mat_reduced[:, unknown_mask_flat]  # shape (num_eqns, num_unknowns)

    # Compute reduced RHS
    Ax_known = A_mat_reduced @ x_known_flat
    b_vec = b.reshape(-1)[eq_mask_flat]
    b_residual = b_vec - Ax_known

    # Solve for unknown entries
    if A_mat_unknown.shape[0] == A_mat_unknown.shape[1]:
        x_unknown = jnp.linalg.solve(A_mat_unknown, b_residual)
    else:
        x_unknown, *_ = jnp.linalg.lstsq(A_mat_unknown, b_residual, rcond=None)

    # Reconstruct full x
    x_full_flat = x_known_flat.at[unknown_mask_flat].set(x_unknown)
    return x_full_flat.reshape(x_shape)


A_mat = jnp.array([
    [1., 0., 0., 0.],
    [0., 2., 0., 0.1],
    [0., 0., 3., 0.],
    [1., 0., 0., 4.]
])
import jax
seed = 100
key = jax.random.key(seed)
A = jax.random.uniform(key, (2, 2, 2, 2))
A = A_mat.reshape(2, 2, 2, 2)
x_true = jnp.array([[1., 2.], [3., 4.]])
b = jnp.einsum('ijkl,kl->ij', A, x_true)

# Suppose we know some entries
known_mask = jnp.array([[True, False], [False, True]])
x_known = jnp.array([[1., 0.], [0., 4.]])

x_solved = solve_tensor_with_known(A, b, x_known, known_mask)

print(x_true)
print(x_solved)
print("---")
print(jnp.einsum('ijkl,kl->ij', A, x_true))
print(jnp.einsum('ijkl,kl->ij', A, x_solved))


[[1. 2.]
 [3. 4.]]
[[1. 2.]
 [3. 4.]]
---
[[ 1.   4.4]
 [ 9.  17. ]]
[[ 1.   4.4]
 [ 9.  17. ]]


## Solving poisson problem with boundary conditions:

$$ \Delta \phi = f,  \quad (x, y) \in \Omega$$
$$\phi = 0, \quad (x, y) \in \partial\Omega$$

In [143]:

X_points = 20; Y_points = 20
x, y = jnp.linspace(-1, 1, X_points), jnp.linspace(-1, 1, Y_points)
dx, dy, dz = float(x[1] - x[0]), float(y[1] - y[0]), float(z[1] - z[0])
X, Y= jnp.meshgrid(x, y, indexing='ij')
boundaries = (jnp.abs(X) == 1) | (jnp.abs(Y) == 1)
phi = lambda x: jnp.sin(x[0]) * jnp.cos(x[1])
f_evaluator = lambda x: jnp.trace(jax.jacfwd(jax.grad(phi))(x))
zeroes = jnp.zeros(X.shape)

linear_op =  Diff(0, dx, acc=4)**2 + Diff(1, dy, acc=4)**2
#print(boundaries)
#print(linear_op.matrix((5, 5)))

In [144]:
#f_evaluator(jnp.array([1., 1.]))

batched_bc = jax.vmap(f_evaluator)
batched_phi = jax.vmap(phi)

input_values = jnp.stack([X.reshape(-1), Y.reshape(-1)], axis = 1)

f = batched_bc(input_values).reshape(X.shape)
phi_true = batched_phi(input_values).reshape(X.shape)
#print(f.shape)

In [154]:
A = linear_op.matrix((X_points, Y_points))
discretized_operator = jnp.array(A.toarray().reshape(X_points, Y_points, X_points, Y_points))

solution = solve_tensor_with_known(discretized_operator, f, phi_true, boundaries)

#print(solution)

#print(x_true)
#print(x_solved)
print("---")
#print(jnp.einsum('ijkl,kl->ij', discretized_operator, solution))
print(jnp.linalg.norm(solution - phi_true , ord=2))
print(jnp.einsum('ijkl,kl->ij', discretized_operator, solution) - f)
#print(jnp.einsum('ijkl,kl->ij', A, x_solved))

---
7.7877045
[[ 4.41074371e-06 -1.78740799e+02 -2.01148178e+02 -2.21337494e+02
  -2.39080048e+02 -2.54177399e+02 -2.66461273e+02 -2.75793152e+02
  -2.82072418e+02 -2.85229401e+02 -2.85229248e+02 -2.82071960e+02
  -2.75792603e+02 -2.66460999e+02 -2.54178604e+02 -2.39081573e+02
  -2.21337891e+02 -2.01144363e+02 -1.78725662e+02  4.41074371e-06]
 [-1.43080765e+02 -9.76196289e-01  2.07122135e+00 -1.20886230e+00
  -1.30566406e+00 -1.38818359e+00 -1.45532227e+00 -1.50622559e+00
  -1.54040527e+00 -1.55773926e+00 -1.55773926e+00 -1.54040527e+00
  -1.50610352e+00 -1.45520020e+00 -1.38830566e+00 -1.30578613e+00
  -1.20886230e+00  2.07122111e+00 -9.75952148e-01 -1.43061829e+02]
 [-1.30218063e+02  3.07132673e+00  6.34179306e+00  3.80382681e+00
   4.10872126e+00  4.36815739e+00  4.57927227e+00  4.73963356e+00
   4.84756136e+00  4.90180588e+00  4.90180540e+00  4.84756136e+00
   4.73966455e+00  4.57928753e+00  4.36815643e+00  4.10871983e+00
   3.80381155e+00  6.34179354e+00  3.07144880e+00 -1.3021055