In [None]:
def fix_layout(width:int=95):
    from IPython.core.display import display, HTML
    display(HTML('<style>.container { width:' + str(width) + '% !important; }</style>'))
    
fix_layout()

In [None]:
import sys

import numpy as np
import matplotlib.pyplot as plt
#from logs import enable_logging, logging 
from importlib import reload
import nnpde.functions.iterative_methods as im
from nnpde.functions import geometries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

net = nn.Sequential(
    nn.Conv2d(1, 1, 3, padding=1, bias=False),
    #nn.Conv2d(1, 1, 3, padding=1, bias=False),
    #nn.Conv2d(1, 1, 3, padding=1, bias=False),
)


# Set the optimizer, you have to play with lr: if too big nan
optim = torch.optim.SGD(net.parameters(), lr = 1e-6)
##optim = torch.optim.Adadelta(net.parameters())
#optim = torch.optim.Adam(net.parameters(), lr=1e-6)
#optim = torch.optim.ASGD(net.parameters())
# SGD seems much faster

for name, param in net.named_parameters():
    print(name, param)

In [None]:
class dirichlet_problem:
    """A class for setting a problem instance"""
    
    def __init__(self, B_idx=None, B=None, f=torch.zeros(1, 1, N, N), k=100, initial_u=torch.zeros(1, 1, N, N), domain_type="Square", N=16):
        
        if B_idx is None:
            self.B_idx, self.B = geometries.square_geometry(N)
        else:
            self.B_idx = B_idx
            self.B = B
        
        self.initial_u = initial_u
        self.f = f
        self.k = k 
        
        

In [None]:
A = dirichlet_problem()
B = dirichlet_problem()

In [None]:
C = list([A, B])


In [None]:
l = [None]*10
l[0] = A
print(l[0].B)

New one based only on convolutions and pointwise tensor operations, see iterativeMethods.py

Define a set of problem instances

In [None]:
def compute_loss(net, k, B_idx, B, f, initial_u, ground_truth, nb_problem_instances):
    
    loss = torch.zeros(1, requires_grad=False)
    u = torch.zeros(1, 1, N, N, nb_problem_instances)
    
    for i in range(nb_problem_instances):
        u[:, :, :, :, i] = im.H_method(net, B_idx[:, :, :, :, i], B[:, :, :, :, i], f, initial_u[:, :, :, :, i], k[i])
        loss = loss + F.mse_loss(ground_truth[:, :, :, :, i], u[:, :, :, :, i])
        
    return loss

In [None]:
nb_problem_instances = 50

problem_instances_list = [None]*nb_problem_instances

In [None]:
# Define train dimension
N = 16

# Define batch_size problem instances
nb_problem_instances = 50

# Sample k
##k = np.random.randint(1, 20)
k = np.random.randint(1, 20, nb_problem_instances)

# Initialize f: we use a zero forcing term for training
f = torch.zeros(1, 1, N, N)

# Initialize boundary conditions
B = torch.zeros(1, 1, N, N, nb_problem_instances)
B_idx = torch.zeros(1, 1, N, N, nb_problem_instances)

# Initialize initial solution
initial_u = torch.randn(1, 1, N, N, nb_problem_instances, requires_grad = True)

# Initialize ground truth solution
ground_truth = torch.zeros(1, 1, N, N, nb_problem_instances)

# Define problems
for i in range(nb_problem_instances):

    B[:, :, :, :, i], B_idx[:, :, :, :, i] = geometries.square_geometry(N)    

    # Compute ustar = ground_truth solution torch 
    ground_truth[:, :, :, :, i] = im.jacobi_method(B_idx[:, :, :, :, i], B[:, :, :, :, i], f, initial_u = None, k = 1000)


In [None]:
losses = []
prev_total_loss = compute_loss(net, k, B_idx, B, f, initial_u, ground_truth, nb_problem_instances)
print(prev_total_loss)

In [None]:
# Solve the same problem, at each iteration the only thing changing are the weights, which are optimized
# TODO why though? wouldn't it make much more sense to train it more times on different problems? isn't this the same as oversampling each training sample?

for _ in range(1000):
    net.zero_grad()
    loss = torch.zeros(1)

    batch_size = 10
    u = torch.zeros(1, 1, N, N, batch_size)


    problem_idx = np.random.choice(np.arange(nb_problem_instances), batch_size, replace = 0)

    for i in range(batch_size):
        # Compute the solution with the updated weights
        idx = problem_idx[i]
        u[:, :, :, :, i] = im.H_method(net, B_idx[:, :, :, :, idx], B[:, :, :, :, idx], f, initial_u[:, :, :, :, idx], k[idx])

        # Define the loss, CHECK if it is correct wrt paper
        loss = loss + F.mse_loss(ground_truth[:, :, :, :, idx], u[:, :, :, :, i])
    
   
    """ TODO 
    spectral_radius = TODO
    regularization = 1e10
    if spectral_radius > 1
       loss += regularization
    """

    # Backpropagation
    loss.backward(retain_graph =  False)

    # SGD step
    optim.step()
    
    total_loss = compute_loss(net, k, B_idx, B, f, initial_u, ground_truth, nb_problem_instances)
    
    # Exit optimization 
    tol = 1e-2
    if total_loss.item() <= tol or np.abs(total_loss.item() - prev_total_loss.item()) < tol:
        break


    # Store lossses for visualization
    losses.append(total_loss.item())
    prev_loss = total_loss.item()

for name, param in net.named_parameters():
    print(name, param)

Plot the losses

In [None]:
np.abs(total_loss.item() - prev_total_loss.item())

In [None]:
color_map = plt.get_cmap('cubehelix')
colors = color_map(np.linspace(0.1, 1, 10))

losses_fig = plt.figure()
n_iter = np.arange(np.shape(losses)[0])
plt.plot(n_iter[2500:], losses[2500:], color = colors[0], linewidth = 1, linestyle = "-", marker = "",  label='Loss')

plt.legend(bbox_to_anchor=(0., -0.3), loc=3, borderaxespad=0.)
plt.xlabel('n iteration', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.title('Loss')
plt.grid(True, which = "both", linewidth = 0.5,  linestyle = "--")

print("final loss is {0}".format(losses[-1]))
#losses_fig.savefig('gridSearch.eps', bbox_inches='tight')

Test on a bigger grid

In [None]:
N = 50
nb_iters = 2000

B, B_idx = geometries.square_geometry(N)

# Set forcing term
f = torch.ones(1,1,N,N)*1.0

# Obtain solutions
gtt = im.jacobi_method(B_idx, B, f, torch.ones(1,1,N,N), k = 10000)
output = im.H_method(net, B_idx, B, f, torch.ones(1,1,N,N), k = nb_iters)
jacoby_pure = im.jacobi_method(B_idx, B, f, torch.ones(1,1,N,N), k = nb_iters)

In [None]:
loss_to_be_achieved = 1e-3

u_0 = torch.ones(1, 1, N, N)

In [None]:
%%timeit
# old method 
u_k_old = im.jacobi_method(B_idx, B, f, u_0, k = 1)
loss_of_old = F.mse_loss(gtt, u_k_old)
k_count_old = 1
while loss_of_old >= loss_to_be_achieved:
    u_k_old = im.jacobi_method(B_idx, B, f, u_k_old, k = 1)
    loss_of_old = F.mse_loss(gtt, u_k_old)
    k_count_old += 1

In [None]:
%%timeit
# new method
u_k_new = im.H_method(net, B_idx, B, f, u_0, k=1)

loss_new = F.mse_loss(gtt, u_k_new)
k_count_new = 1
while loss_new >= loss_to_be_achieved:
    u_k_new = im.H_method(net, B_idx, B, f, u_k_new, k=1)
    loss_new = F.mse_loss(gtt, u_k_new)
    k_count_new += 1

In [None]:
print("needed {0} iterations (compared to {1}), ratio: {2}".format(k_count_old, k_count_new, k_count_old/k_count_new))

In [None]:
print("the loss of the new method is {0}, compared to the pure-jacoby one: {1}. computed with {2} iterations".format(F.mse_loss(gtt, output), F.mse_loss(gtt, jacoby_pure), nb_iters))

Z_gtt = gtt.view(N,N).numpy() 
Z_output = output.detach().view(N, N).numpy()

fig, axes = plt.subplots(nrows = 1, ncols = 2)

fig.suptitle("Comparison")

im_gtt = axes[0].imshow(Z_gtt)
axes[0].set_title("Ground truth solution")

im_output = axes[1].imshow(Z_output)
axes[1].set_title("H method solution")

fig.colorbar(im_gtt)
fig.tight_layout()

plt.show()

In [None]:
np.mean(Z_gtt - Z_output)

In [None]:
np.mean(Z_gtt - Z_jacoby)

Test on L-shape domain

In [None]:
B, B_idx = geometries.l_shaped_geometry(N)

# Set forcing term
f = torch.ones(1,1,N,N)*1.0

# Obtain solutions
gtt = im.jacobi_method(B_idx, B, f, torch.ones(1,1,N,N), k = 10000)
output = im.H_method(net, B_idx, B, f, torch.ones(1,1,N,N), k = 2000)

In [None]:
print(F.mse_loss(gtt, output))

Z_gtt = gtt.view(N,N).numpy() 
Z_output = output.detach().view(N, N).numpy()

fig, axes = plt.subplots(nrows = 1, ncols = 2)

fig.suptitle("Comparison")

im_gtt = axes[0].imshow(Z_gtt)
axes[0].set_title("Ground truth solution")

im_output = axes[1].imshow(Z_output)
axes[1].set_title("H method solution")

fig.colorbar(im_gtt)
fig.tight_layout()

plt.show()