In [1]:
from CG_Surrogate import cg_surrogate
import torch
from poisson_fem import PoissonFEM
import numpy as np
import ROM
import Data

### Load data

In [2]:
Ntrain = 8
data = Data.DarcyData(supervised_samples=set(range(Ntrain)))
data.set_mesh()
data.load()

### Set up boundary conditions

In [3]:
a = torch.tensor([1, 1, 0])               # Boundary condition function coefficients
lin_dim_rom = 4


# Define mesh and boundary conditions
mesh = PoissonFEM.RectangularMesh(torch.ones(lin_dim_rom)/lin_dim_rom)
# mesh.plot()

def origin(x):
    return torch.abs(x[0]) < torch.finfo(torch.float32).eps and torch.abs(x[1]) < torch.finfo(torch.float32).eps

def ess_boundary_fun(x):
    return 0.0
mesh.set_essential_boundary(origin, ess_boundary_fun)

def domain_boundary(x):
    # unit square
    return torch.abs(x[0]) < torch.finfo(torch.float32).eps or torch.abs(x[1]) < torch.finfo(torch.float32).eps or \
            torch.abs(x[0]) > 1.0 - torch.finfo(torch.float32).eps or torch.abs(x[1]) > 1.0 - torch.finfo(torch.float32).eps
mesh.set_natural_boundary(domain_boundary)

def flux(x):
    q = np.array([a[0] + a[2]*x[1], a[1] + a[2]*x[0]])
    return q

### Specify right hand side and stiffness matrix

In [4]:
rhs = PoissonFEM.RightHandSide(mesh, batched_version=True)
rhs.set_natural_rhs(mesh, flux)
K = PoissonFEM.StiffnessMatrix(mesh)
rhs.set_rhs_stencil(mesh, K)
rhs.expand_rhs_stencil_torch(Ntrain)

### Define ROM

In [5]:
rom = ROM.ROM(mesh, K, rhs, lin_dim_rom**2)
# Change for non unit square domains!!
xx, yy = torch.meshgrid((torch.linspace(0, 1, lin_dim_rom), torch.linspace(0, 1, lin_dim_rom)))
X = torch.cat((xx.flatten().unsqueeze(1), yy.flatten().unsqueeze(1)), 1)
rom.mesh.get_interpolation_matrix(X)

In [6]:
model = cg_surrogate.CG_Surrogate(rom)

In [None]:
designMatrix = cg_surrogate.DesignMatrix()
designMatrix.get_masks(data.mesh, rom.mesh)
designMatrix.assemble(data.input, rom.mesh.n_cells)

In [None]:
model.train(designMatrix.matrix, data.output)

In [None]:
rhs.natural_rhs_torch.shape