In [1]:
%matplotlib qt
import rom_inverse as ri
import torch
import matplotlib.pyplot as plt
# import sklearn.gaussian_process.kernels as kernels
from poisson_fem import PoissonFEM
import ROM
import numpy as np
import scipy as sp
import petsc4py
import sys
petsc4py.init(sys.argv)
from petsc4py import PETSc
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import os
import pyro.contrib.gp as gp
smoke_test = ('CI' in os.environ)  # ignore; used to check code integrity in the Pyro repo
assert pyro.__version__.startswith('1.2.1')
pyro.enable_validation(True)       # can help with debugging
import time
import Data
from CG_Surrogate import cg_surrogate

In [2]:
lin_dim_fom = 32                      # Linear number of rom elements

kernel = gp.kernels.RBF(input_dim=2, variance=torch.tensor(2.), lengthscale=torch.tensor([.3, .1]))
permeability_random_field = ri.DiscretizedRandomField(kernel=kernel, resolution=[lin_dim_fom],
                                                     nugget=1e-3)

In [3]:
permeability_random_field.plot_realizations()

In [4]:
a = torch.tensor([1, 1, 0])               # Boundary condition function coefficients


# Define mesh and boundary conditions
mesh = PoissonFEM.RectangularMesh(torch.ones(lin_dim_fom)/lin_dim_fom)
# mesh.plot()

def origin(x):
    return torch.abs(x[0]) < torch.finfo(torch.float32).eps and torch.abs(x[1]) < torch.finfo(torch.float32).eps

def ess_boundary_fun(x):
    return 0.0
mesh.set_essential_boundary(origin, ess_boundary_fun)

def domain_boundary(x):
    # unit square
    return torch.abs(x[0]) < torch.finfo(torch.float32).eps or torch.abs(x[1]) < torch.finfo(torch.float32).eps or \
            torch.abs(x[0]) > 1.0 - torch.finfo(torch.float32).eps or torch.abs(x[1]) > 1.0 - torch.finfo(torch.float32).eps
mesh.set_natural_boundary(domain_boundary)

def flux(x):
    q = np.array([a[0] + a[2]*x[1], a[1] + a[2]*x[0]])
    return q

In [5]:
#Specify right hand side and stiffness matrix
rhs = PoissonFEM.RightHandSide(mesh)
rhs.set_natural_rhs(mesh, flux)
K = PoissonFEM.StiffnessMatrix(mesh)
rhs.set_rhs_stencil(mesh, K)

In [6]:
# define fom
fom = ROM.ROM(mesh, K, rhs, lin_dim_fom**2)
# Change for non unit square domains!!
xx, yy = torch.meshgrid((torch.linspace(0, 1, lin_dim_fom), torch.linspace(0, 1, lin_dim_fom)))
X = torch.cat((xx.flatten().unsqueeze(1), yy.flatten().unsqueeze(1)), 1)
fom.mesh.get_interpolation_matrix(X)

In [7]:
fom_autograd = fom.get_autograd_fun()

In [8]:
lmbda = permeability_random_field.sample_permeability(n_samples=3)

In [9]:
img = plt.imshow(fom_autograd(lmbda[0, :]).view(lin_dim_fom, lin_dim_fom))
plt.colorbar(img)

<matplotlib.colorbar.Colorbar at 0x7fa6ca0b0310>

In [10]:
lmbda = torch.rand(lin_dim_fom**2, requires_grad=True)
x = fom_autograd(lmbda)
loss = torch.norm(x)
loss.backward()

In [11]:
def lossfun(x):
    return torch.norm(fom_autograd(x)) 

In [12]:
# from torch.autograd import gradcheck
# lmbda = torch.randn(lin_dim_fom**2, dtype=torch.double, requires_grad=True)
# test = gradcheck(lossfun, lmbda, eps=1e-3, atol=1e-4)
# print(test)

In [13]:
# define pyro model
beta = 1.0  # inverse temperature of observations
def uncertainty_propagation():
    lambda_f = permeability_random_field.sample()
    uf = fom_autograd(lambda_f)
    uf_observed = pyro.sample('uf_observed', dist.Normal(uf, torch.ones_like(uf)/beta))
    return uf_observed   

In [14]:
# nuts_kernel = NUTS(uncertainty_propagation)
# mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=100, num_chains=1)
# mcmc.run()
# mcmc.summary()

In [15]:
beta = 20
permeability_random_field.set_covariance_matrix()
scale_tril = permeability_random_field.log_permeability_scale_tril
mu_zero = torch.zeros(permeability_random_field.X.shape[0])
def joint_posterior():
    x = pyro.sample('x', dist.MultivariateNormal(mu_zero, scale_tril=scale_tril))
    lambdaf = torch.exp(x)
    uf = fom_autograd(lambdaf)
    uf_observed = pyro.sample('uf_observed', dist.Normal(uf, torch.ones_like(uf)/beta))
    return uf_observed

In [16]:
# nuts_kernel = NUTS(joint_posterior)
# mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=100, num_chains=1)
# mcmc.run()
# mcmc.summary()

In [17]:
print('x == ', x := permeability_random_field.sample_log_permeability())
print('uf_observed == ', uf_observed := fom_autograd(torch.exp(x)))

x ==  tensor([[-2.0050, -2.0618, -1.9911,  ...,  1.5879,  1.3933,  1.1066]])
uf_observed ==  tensor([0.0000, 0.2268, 0.4609,  ..., 3.8524, 3.8547, 3.8617])


In [18]:
def conditioned_posterior(uf_observed):
    return pyro.condition(joint_posterior, data={"uf_observed": uf_observed})

In [19]:
nuts_kernel = NUTS(conditioned_posterior(uf_observed))
mcmc = MCMC(nuts_kernel, num_samples=40, warmup_steps=10, num_chains=1)
mcmc.run()
mcmc.summary()

Sample: 100%|██████████| 50/50 [01:16,  1.54s/it, step size=5.12e-05, acc. prob=0.855]


                mean       std    median      5.0%     95.0%     n_eff     r_hat
      x[0]     -2.71      1.41     -3.82     -3.85     -0.77      2.76      2.11
      x[1]      0.78      0.22      0.74      0.51      1.09      2.42      2.63
      x[2]     -0.79      0.15     -0.88     -0.93     -0.57      3.05      1.70
      x[3]      1.41      0.13      1.39      1.25      1.60      2.44      2.60
      x[4]     -0.21      0.01     -0.21     -0.23     -0.20      3.71      1.25
      x[5]     -1.48      0.05     -1.47     -1.55     -1.42      4.02      1.07
      x[6]      1.09      0.06      1.12      0.99      1.15      2.51      2.41
      x[7]     -0.06      0.01     -0.06     -0.07     -0.05      3.77      1.37
      x[8]     -1.59      0.06     -1.62     -1.64     -1.50      2.49      2.52
      x[9]     -1.55      0.02     -1.56     -1.58     -1.53      2.64      2.18
     x[10]     -1.78      0.02     -1.79     -1.80     -1.75      3.32      1.49
     x[11]      0.81      0




In [20]:
plt.plot(x[0])
plt.plot(torch.mean(mcmc.get_samples()['x'], 0))

[<matplotlib.lines.Line2D at 0x7fa6c90905e0>]

In [21]:
fig = plt.figure(figsize=(15, 7))
ax = plt.subplot(1, 2, 1)
im0 = plt.imshow(x[0].view(lin_dim_fom, lin_dim_fom))
plt.colorbar(im0)
ax = plt.subplot(1, 2, 2)
im1 = plt.imshow(torch.mean(mcmc.get_samples()['x'], 0).view(lin_dim_fom, lin_dim_fom))
plt.colorbar(im1)

<matplotlib.colorbar.Colorbar at 0x7fa6a47495b0>

In [22]:
lin_dim_fom = [32]                      # Linear number of rom elements

kernel = gp.kernels.RBF(input_dim=2, variance=torch.tensor(2.), lengthscale=torch.tensor([.3, .1]))
permeability_random_field = ri.DiscretizedRandomField(lin_dim_fom, kernel=kernel, nugget=1e-3)

In [23]:
dd = Data.DarcyData(supervised_samples=set(range(1024)))

In [24]:
dd.set_mesh()

In [25]:
dd.set_rhs_and_stiffness()

KeyboardInterrupt: 

In [None]:
dd.set_solver()

In [None]:
perm_smps = dd.permeability_random_field.sample_permeability(n_samples=3)

In [None]:
fig = plt.figure(figsize=(15, 7))
plt.imshow(perm_smps[0].view(128, 128))

In [None]:
u = dd.solver(perm_smps[0])

In [None]:
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
fig = plt.figure(figsize=(15, 7))
ax = fig.gca(projection='3d')
# plt.imshow(u.view(129, 129))
xx, yy = torch.meshgrid(torch.linspace(0, 1, 129), torch.linspace(0, 1, 129))
surf = ax.plot_surface(xx.detach().numpy(), yy.detach().numpy(), u.view(129, 129).detach().numpy(), cmap=cm.inferno, linewidth=0, antialiased=False)

In [None]:
Phi = cg_surrogate.DesignMatrix()

In [None]:
Phi

In [None]:
fine_mesh = PoissonFEM.RectangularMesh(torch.ones(64)/64)
coarse_mesh = PoissonFEM.RectangularMesh(torch.ones(4)/4)

In [None]:
Phi.get_masks(dd.mesh, coarse_mesh)

In [None]:
Phi.assemble(perm_smps, 16)

In [None]:
Phi.matrix[0]

In [None]:
d = {'theta_c': None}

In [None]:
x = torch.rand(3, 4, 5)

In [None]:
a, b = x.shape

In [None]:
x.sum(axis=1)

In [None]:
x

In [None]:
y = torch.ones(5)

In [None]:
z = x @ y

In [None]:
x = pyro.sample('x', dist.MultivariateNormal(mu_zero, scale_tril=scale_tril))

In [None]:
nuts_kernel.sample(x)

In [None]:
xl = dict(x)

In [None]:
print(nuts_kernel.initial_params)

In [None]:
nuts_kernel = NUTS(conditioned_posterior(uf_observed))
nuts_kernel.sample(None)