In [1]:
import numpy as np
import pyamg
import scipy.sparse.linalg as spla
from numpy import linalg as LA

import inspect
import time

import plotly.graph_objects as go
import html

from IPython.core.display import display, HTML

display(HTML("<style>.container { width:100% !important; }</style>"))

  from IPython.core.display import display, HTML


In [2]:
args, varargs, varkw, defaults = inspect.getargspec(spla.cg)
args

  args, varargs, varkw, defaults = inspect.getargspec(spla.cg)


['A', 'b', 'x0', 'tol', 'maxiter', 'M', 'callback', 'atol']

In [3]:
N = 2048

iterat = 0
A = pyamg.gallery.poisson((N - 2, N - 2), format="csr")

In [4]:
def iteration(_):
    global iterat
    iterat = iterat + 1
    pass


def timer(func):
    def _wrapper(*args, **kwargs):
        global iterat
        iterat = 0
        start = time.perf_counter()
        x, exit_code = func(*args, callback=iteration, **kwargs)
        runtime = time.perf_counter() - start
        print(
            f"RunTime: {runtime:.4f} sec \t Norm Ax-b : {LA.norm(A.dot(x)-b):.3e} \t Norm x_true-x_pred : {LA.norm(x-y)/LA.norm(y):.3e} \t #  of iter: {iterat} \t ExitCode: {exit_code} "
        )
        return x, exit_code

    return _wrapper


def timer_several(func):
    def _wrapper(*args, **kwargs):
        global iterat
        number_of_iterations = []
        runtimes = []
        for _ in range(10):
            iterat = 0
            start = time.perf_counter()
            x, exit_code = func(*args, callback=iteration, **kwargs)
            runtime = time.perf_counter() - start

            number_of_iterations.append(iterat)
            runtimes.append(runtime)

        print(
            f"AvgRunTime: {sum(runtimes)/len(runtimes):.4f} sec \t Norm Ax-b : {LA.norm(A.dot(x)-b):.3e} \t Norm x_true-x_pred : {LA.norm(x-y)/LA.norm(y):.3e} \t Avg#OfIter: {sum(number_of_iterations)/len(number_of_iterations)}"
        )

        return x, exit_code

    return _wrapper


@timer
def cg_solution(**kwargs):
    return spla.cg(**kwargs)


@timer_several
def cg_solution_several(**kwargs):
    return spla.cg(**kwargs)

In [5]:
def plot_3d(data, name):
    fig = go.Figure(
        go.Surface(
            z=data,
            x=np.linspace(0, 1, data.shape[0]),
            y=np.linspace(0, 1, data.shape[0]),
        )
    )

    title_font_size = 30
    font_size = 18

    fig.update_layout(
        margin=dict(l=0, r=0, t=20, b=0),
        font=dict(family="Times New Roman,italic", size=font_size),
        scene={
            "aspectratio": {"x": 1, "y": 1, "z": 0.6},
        },
    )

    fig.update_layout(
        scene=dict(
            xaxis_title="<i>x</i>",
            yaxis_title="<i>y</i>",
            zaxis_title="<i>z</i>",
            xaxis_title_font_size=title_font_size,
            yaxis_title_font_size=title_font_size,
            zaxis_title_font_size=title_font_size,
            zaxis=dict(
                gridcolor="white",
                showbackground=True,
                nticks=5,
                tickwidth=6,
                zerolinecolor="white",
                ticklen=10,
                ticks="outside",
            ),
            yaxis=dict(
                tickvals=np.linspace(0.1, 0.9, 5),
                ticklen=5,
                tickwidth=6,
                ticks="outside",
                backgroundcolor="rgb(230, 200,230)",
                gridcolor="white",
                showbackground=True,
                zerolinecolor="white",
            ),
            xaxis=dict(
                tickvals=np.linspace(0.1, 0.9, 5),
                ticklen=5,
                tickwidth=6,
                ticks="outside",
                backgroundcolor="rgb(200, 200, 230)",
                gridcolor="white",
                showbackground=True,
                zerolinecolor="white",
            ),
        )
    )

    fig.update_traces(
        cauto=False,
        selector=dict(type="surface"),
        colorbar=dict(
            lenmode="fraction",
            len=0.4,
            thickness=15,
            tickfont=dict(size=font_size),
        ),
    )

    fig.update_traces(colorbar_outlinecolor="black", selector=dict(type="surface"))
    fig.update_traces(colorbar_outlinewidth=1, selector=dict(type="surface"))
    fig.update_traces(colorbar_x=0.83, selector=dict(type="surface"))
    fig.update_scenes(zaxis_tickangle=0, xaxis_tickangle=0, yaxis_tickangle=0)

    # fig.show()
    fig.write_image(f"{name}.png", scale=2)

# Mean time and num. of iters (mean for 10 runs)

In [7]:
shape = 2048

for item in range(10):
    pred = np.load(f"{shape}/pred_{item}_.npz")["arr_0"][
        0, 1 : N - 1, 1 : N - 1
    ].reshape(
        -1
    )  # you shoud save solutions from NN first
    b = np.load(f"{shape}/x_{item}_.npz")["arr_0"][0, 1 : N - 1, 1 : N - 1].reshape(
        -1
    )  # load from dataset
    y = np.load(f"{shape}/y_{item}_.npz")["arr_0"][0, 1 : N - 1, 1 : N - 1].reshape(
        -1
    )  # load from dataset

    print(f"imem # {item}: \n")
    x_0, exit_code = cg_solution(A=A, b=b, atol=1e-4, tol=0)
    x_pred, exit_code = cg_solution(A=A, b=b, atol=1e-4, tol=0, x0=pred)
    print(f"\n\n")

np.linalg.norm(pred-y)/np.linalg.norm(y))
imem # 0: 

RunTime: 884.4101 sec 	 Norm Ax-b : 9.992e-05 	 Norm x_true-x_pred : 1.795e+04 	 #  of iter: 6251 	 ExitCode: 0 
RunTime: 843.7687 sec 	 Norm Ax-b : 9.981e-05 	 Norm x_true-x_pred : 1.795e+04 	 #  of iter: 6253 	 ExitCode: 0 



imem # 1: 



In [111]:
for item in range(11):
    pred = np.load(f"1024/pred_{item}_.npz")["arr_0"][0, 1 : N - 1, 1 : N - 1].reshape(
        -1
    )
    b = np.load(f"1024/x_{item}_.npz")["arr_0"][0, 1 : N - 1, 1 : N - 1].reshape(-1)
    y = np.load(f"1024/y_{item}_.npz")["arr_0"][0, 1 : N - 1, 1 : N - 1].reshape(-1)

    print(f"imem # {item}: \n")
    x_0, exit_code = cg_solution_several(A=A, b=b, atol=1e-4, tol=0)
    x_pred, exit_code = cg_solution_several(A=A, b=b, atol=1e-4, tol=0, x0=pred)
    print(f"\n\n")

imem # 0: 

AvgRunTime: 67.9554 sec 	 Norm Ax-b : 9.860e-05 	 Norm x_true-x_pred : 8.387e-05 	 Avg#OfIter: 2203.0
AvgRunTime: 52.9094 sec 	 Norm Ax-b : 9.955e-05 	 Norm x_true-x_pred : 7.216e-05 	 Avg#OfIter: 1816.0



imem # 1: 

AvgRunTime: 64.4702 sec 	 Norm Ax-b : 9.874e-05 	 Norm x_true-x_pred : 9.252e-05 	 Avg#OfIter: 2137.0
AvgRunTime: 63.1811 sec 	 Norm Ax-b : 9.968e-05 	 Norm x_true-x_pred : 6.660e-05 	 Avg#OfIter: 1929.0



imem # 2: 

AvgRunTime: 66.4469 sec 	 Norm Ax-b : 9.964e-05 	 Norm x_true-x_pred : 5.759e-05 	 Avg#OfIter: 2213.0
AvgRunTime: 57.8455 sec 	 Norm Ax-b : 9.887e-05 	 Norm x_true-x_pred : 4.105e-05 	 Avg#OfIter: 1796.0



imem # 3: 

AvgRunTime: 69.6126 sec 	 Norm Ax-b : 9.856e-05 	 Norm x_true-x_pred : 5.762e-05 	 Avg#OfIter: 2234.0
AvgRunTime: 51.1413 sec 	 Norm Ax-b : 9.885e-05 	 Norm x_true-x_pred : 8.444e-05 	 Avg#OfIter: 1751.0



imem # 4: 

AvgRunTime: 64.7653 sec 	 Norm Ax-b : 9.841e-05 	 Norm x_true-x_pred : 5.185e-05 	 Avg#OfIter: 2134.0
AvgRunTime

#  Plots: X0=pred + 10 iterations

In [46]:
for item in range(11):
    pred = np.load(f"1024/pred_{item}_.npz")["arr_0"][0, 1 : N - 1, 1 : N - 1].reshape(
        -1
    )
    b = np.load(f"1024/x_{item}_.npz")["arr_0"][0, 1 : N - 1, 1 : N - 1].reshape(-1)
    y = np.load(f"1024/y_{item}_.npz")["arr_0"][0, 1 : N - 1, 1 : N - 1].reshape(-1)

    print(f"imem # {item}: \n")
    x_0, exit_code = cg_solution(A=A, b=b, maxiter=10)
    x_pred, exit_code = cg_solution(A=A, b=b, maxiter=10, x0=pred)
    print(f"\n\n")

    x_new = x_pred.reshape((N - 2, N - 2))
    pred = pred.reshape((N - 2, N - 2))
    y_true = y.reshape((N - 2, N - 2))

    plot_3d(data=x_new, name=f"pic/{item}_after_10_iterations")
    plot_3d(data=pred, name=f"pic/{item}_predicted")
    plot_3d(data=y_true, name=f"pic/{item}_true")

imem # 0: 

RunTime: 0.3648 sec 	 Norm Ax-b : 8.863e+00 	 Norm x_true-x_pred : 9.772e-01 	 #  of iter: 10 	 ExitCode: 10 
RunTime: 0.3708 sec 	 Norm Ax-b : 1.633e+00 	 Norm x_true-x_pred : 5.600e-02 	 #  of iter: 10 	 ExitCode: 10 



imem # 1: 

RunTime: 0.3775 sec 	 Norm Ax-b : 9.358e+00 	 Norm x_true-x_pred : 9.749e-01 	 #  of iter: 10 	 ExitCode: 10 
RunTime: 0.3588 sec 	 Norm Ax-b : 2.242e+00 	 Norm x_true-x_pred : 1.044e-01 	 #  of iter: 10 	 ExitCode: 10 



imem # 2: 

RunTime: 0.4550 sec 	 Norm Ax-b : 9.476e+00 	 Norm x_true-x_pred : 9.795e-01 	 #  of iter: 10 	 ExitCode: 10 
RunTime: 0.4651 sec 	 Norm Ax-b : 1.984e+00 	 Norm x_true-x_pred : 6.660e-02 	 #  of iter: 10 	 ExitCode: 10 



imem # 3: 

RunTime: 0.4006 sec 	 Norm Ax-b : 9.494e+00 	 Norm x_true-x_pred : 9.771e-01 	 #  of iter: 10 	 ExitCode: 10 
RunTime: 0.2647 sec 	 Norm Ax-b : 1.906e+00 	 Norm x_true-x_pred : 6.948e-02 	 #  of iter: 10 	 ExitCode: 10 



imem # 4: 

RunTime: 0.3370 sec 	 Norm Ax-b : 9.831e+00 	 No