### Imports

In [1]:
from mpi4py import MPI
from petsc4py import PETSc

import numpy as np
from scipy.sparse import csr_matrix

import ufl
from basix.ufl import element, mixed_element
from dolfinx import fem, la, plot
from dolfinx.fem import (
    Constant,
    Function,
    dirichletbc,
    extract_function_spaces,
    form,
    functionspace,
    locate_dofs_topological,
)
from dolfinx.fem.petsc import assemble_matrix_block, assemble_vector_block
from dolfinx.io import XDMFFile
from dolfinx.mesh import CellType, create_rectangle, locate_entities_boundary
from ufl import div, dx, grad, inner

import matplotlib.pyplot as plt

import tt
from tt import cross

import als_cross

### Parameters

In [None]:
rng = np.random.default_rng()

# number of sinkers
n_sinker = 2
# sinker exponential decay rate
delta = rng.uniform(10, 200, n_sinker)
# size of sinker center
omega = 0.1
# driving field strength
beta = 10
# center of sinkers
# c = np.random.rand(n_sinker,2)
c = np.array([[0.2, 0.3],[0.6,0.7]])
# dynamic ratio of viscosity
DR = 100
mu_min = 1 / np.sqrt(DR)
mu_max = np.sqrt(DR)


def chi(x, centers, deltas):
  res = np.ones(x.shape[1])
  for i in range(n_sinker):
    res *= 1 - np.exp(-deltas[i] * np.square(np.maximum(0, np.linalg.norm(x[:2] - centers[i].reshape(-1,1), axis = 0) - 0.5 * omega)))
  
  return res

def viscosity_expr(x, centers=c, deltas=delta):
  """Viscosity field"""
  return (mu_max - mu_min) * (1 - chi(x, centers, deltas)) + mu_min

def f_expr(x, centers=c, deltas=delta):
  """Driving field"""
  return np.vstack([np.zeros(x.shape[1]), beta * (chi(x, centers, deltas) - 1)])

### FEM setup

In [None]:
# Create mesh
msh = create_rectangle(
    MPI.COMM_WORLD, [np.array([0, 0]), np.array([1, 1])], [32, 32], CellType.triangle
)

# Function to mark x = 0, x = 1, y = 0 and y = 1 (all boundary) 
def noslip_boundary(x):
    # return np.logical_or(
    #     np.logical_or(np.isclose(x[0], 0.0), np.isclose(x[0], 1.0)), 
    #     np.isclose(x[1], 0.0), np.isclose(x[1], 2.0)
    # )
    return np.full(x.shape[1], True)

P2 = element("Lagrange", msh.basix_cell(), 2, shape=(msh.geometry.dim,))
P1 = element("Lagrange", msh.basix_cell(), 1)
V, Q = functionspace(msh, P2), functionspace(msh, P1)

# No-slip condition on boundaries 
noslip = np.zeros(msh.geometry.dim, dtype=PETSc.ScalarType)  # type: ignore
facets = locate_entities_boundary(msh, 1, noslip_boundary)
bc0 = dirichletbc(noslip, locate_dofs_topological(V, 1, facets), V)

# Collect Dirichlet boundary conditions
bcs = [bc0]

# Define variational problem
(u, p) = ufl.TrialFunction(V), ufl.TrialFunction(Q)
(v, q) = ufl.TestFunction(V), ufl.TestFunction(Q)
f = fem.Function(V)
f.interpolate(f_expr)

mu = fem.Function(Q)
mu.interpolate(viscosity_expr)

a = form([[mu * inner(grad(u), grad(v)) * dx, inner(p, div(v)) * dx], [inner(div(u), q) * dx, None]])
a00 = form(mu * inner(grad(u), grad(v)) * dx)
L = form([inner(f, v) * dx, inner(Constant(msh, PETSc.ScalarType(0)), q) * dx])  # type: ignore
L0 = form(inner(f, v) * dx)

# preconditioner
a_p11 = form(inner(p, q) * dx)
a_p = [[a[0][0], None], [None, a_p11]]

### Nested matrix direct (LU) solver

In [None]:
def nested_direct_solver(elliptic_system=False):
    """Solve the Stokes problem using nest matrices and an direct solver."""

    # Assemble nested matrix operators
    A = fem.petsc.assemble_matrix_nest(a, bcs=bcs)
    A.assemble()

    # Assemble right-hand side vector
    b = fem.petsc.assemble_vector_nest(L)

    # Modify ('lift') the RHS for Dirichlet boundary conditions
    fem.petsc.apply_lifting_nest(b, a, bcs=bcs)

    # Sum contributions for vector entries that are share across
    # parallel processes
    for b_sub in b.getNestSubVecs():
        b_sub.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)

    # Set Dirichlet boundary condition values in the RHS vector
    bcs0 = fem.bcs_by_block(extract_function_spaces(L), bcs)
    fem.petsc.set_bc_nest(b, bcs0)

    # The pressure field is determined only up to a constant. We supply
    # a vector that spans the nullspace to the solver, and any component
    # of the solution in this direction will be eliminated during the
    # solution process.
    null_vec = fem.petsc.create_vector_nest(L)

    # Set velocity part to zero and the pressure part to a non-zero
    # constant
    null_vecs = null_vec.getNestSubVecs()
    null_vecs[0].set(0.0), null_vecs[1].set(1.0)

    # Normalize the vector that spans the nullspace, create a nullspace
    # object, and attach it to the matrix
    null_vec.normalize()
    nsp = PETSc.NullSpace().create(vectors=[null_vec])
    assert nsp.test(A)
    A.setNullSpace(nsp)

    # Create a solver
    ksp = PETSc.KSP().create(msh.comm)
    ksp.setOperators(A)
    ksp.setType("preonly")
    
    # Set the solver type to MUMPS (LU solver) and configure MUMPS to
    # handle pressure nullspace
    pc = ksp.getPC()
    pc.setType("lu")
    pc.setFactorSolverType("mumps")
    try:
        pc.setFactorSetUpSolverType()
    except PETSc.Error as e:
        if e.ierr == 92:
            print("The required PETSc solver/preconditioner is not available. Exiting.")
            print(e)
            exit(0)
        else:
            raise e
    pc.getFactorMatrix().setMumpsIcntl(icntl=24, ival=1)  # For pressure nullspace
    pc.getFactorMatrix().setMumpsIcntl(icntl=25, ival=0)  # For pressure nullspace

    # Create finite element {py:class}`Function <dolfinx.fem.Function>`s
    # for the velocity (on the space `V`) and for the pressure (on the
    # space `Q`). The vectors for `u` and `p` are combined to form a
    # nested vector and the system is solved.
    u, p = Function(V), Function(Q)
    x = PETSc.Vec().createNest([la.create_petsc_vector_wrap(u.x), la.create_petsc_vector_wrap(p.x)])
    ksp.solve(b, x)

    if elliptic_system:
      A00 = A.getNestSubMatrix(0, 0)
      ai, aj, av = A00.getValuesCSR()
      Asp = csr_matrix((av, aj, ai))
      # print(Asp.shape)
      bnp = b.getNestSubVecs()[0].getArray()
      # print(bnp.shape)
      return u, p, Asp, bnp
    else:
      return u, p

### TT decomposition of coeff

In [None]:
cells, cell_types, coords = plot.vtk_mesh(Q)

Ny = np.array(n_sinker * [10])
Nx = coords.shape[0]

def mu_func(x):
  res = np.empty(x.shape[0])
  for i, xi in enumerate(x):
    d = (190 * xi[1:] / (Ny-1) + 10).reshape(-1,1)
    res[i] = viscosity_expr(coords[xi[0]].reshape(-1,1), deltas=d)[0]
  
  return res

# random init tensor
C_mu = tt.rand(np.array([Nx] + Ny.tolist()),r=10)
# compute TT approx using TT-cross
C_mu = tt.cross.rect_cross.cross(mu_func, C_mu, nswp=5, eps = 1e-5, kickrank=5)
C_mu = C_mu.round(1e-12)

print("Coefficient ranks: ", C_mu.r)
print("Coefficient dims: ", C_mu.n)

# helper tensor that is constant 1
cores = [np.ones((1,1,1))]
for i in range(len(Ny)):
  cores += [np.ones((1,Ny,1))]

C_const = tt.vector.from_list(cores)

In [None]:
class stokes_PDE_fun:
  def __init__(self):
    # get constant matrix part
    mu.x.array = 0
    f.x.array = 0

    A0 = assemble_matrix_block(a00, bcs=bcs)
    ai, aj, av = A0.getValuesCSR()
    self.A0 = csr_matrix((av, aj, ai))

    b0 = assemble_vector_block(L0, a00, bcs=bcs)
    self.b0 = b0.getArray()

  def linear_system(self, k, coeff):
    [num_coeff, nx, I] = np.shape(coeff)
    A = [None] * I
    F = [None] * I

    if k == 0:
      mu = coeff[0,:,:]

      for i in range(I):
        Fi = np.zeros((self.N+1,1))
        F[i]= Fi
        
    elif k == 1:
      for i in range(I):
        F[i] = self.b0
        A[i] = self.A0

    else:
      raise Exception(f"No component {k} implemented.")
    
    return A, F
  
  def solve(self, coeff):

    [num_coeff, nx, I] = np.shape(coeff[0])
    
    U = []

    A0, b0 = self.linear_system(0, coeff[0])
    A1, b1 = self.linear_system(1, coeff[1])
    A2, b2 = self.linear_system(2, coeff[2])

    for i in range(I):
      U.append(np.linalg.solve(A0[i] + A1[i] + A2[i], b0[i] + b1[i] + b2[i]).reshape((-1,1)))
    
    return U


### Run ALS cross

In [None]:
test = als_cross(
  [C_mu, C_const], 
  stokes_PDE_fun,
  1e-8,
  # random_init=5,
  kickrank=3
  )

test.iterate(3)
print(test.get_stats())

u = test.get_tensor()

print('Ranks', u.r)