In [2]:
import numpy as np
import torch
import scipy.linalg as sla
import matplotlib.pyplot as plt
from models import *
from pl_modules.LAModel import LAModel
from utils.traditions import *
from utils.deltas import normal
from scipy import sparse
from pathlib import Path
from matplotlib import cm



In [25]:
def test_cases(cases, ckpt_path, imgs_save_path, a=1):
    for folder in ckpt_path.iterdir():

        # Get all parameters from the ckpt name
        backward_method, n, model_name, features, bs, data_type = folder.name.split('_')
        n = int(n)
        layers = int(model_name[-1])
        model_name = model_name[:-1]
        features = int(features)
        bs = int(bs[2:])
        boundary_type = data_type[-1]
        data_type = data_type[:-1]

        h = 2*a / (n -1)
        h2 = h ** 2
        x = np.linspace(-a, a, n)
        y = np.linspace(-a, a, n)
        xx, yy = np.meshgrid(x, y)
        
        # get network
        mat_path = f'./data/{n}/mat/'
        net = model_names[model_name](layers = layers, features = features, boundary_type = boundary_type)
        pl_module = LAModel(net, a, n, data_path = mat_path, backward_type=backward_method, boundary_type=boundary_type)

        # get matrix
        # A = pl_module.A
        lu, piv = sla.lu_factor(pl_module.A.to_dense().numpy())
        
        #get image save path and ckpy file path
        img_path = imgs_save_path/folder.name
        ckpt = folder/'version_0'/'checkpoints'/'last.ckpt'

        # Load ckpt
        ckpt = torch.load(ckpt)
        pl_module.load_state_dict(ckpt['state_dict'])
        pl_module.freeze()

        # Test all cases
        for i, case in enumerate(cases):
            case_img_save_path = img_path/f'case{i}'
            if not case_img_save_path.is_dir():
                case_img_save_path.mkdir(parents=True, exist_ok = True)
            f = np.zeros((n, n))
            for info in case:
                px, py, q = info
                f += q * normal(xx, yy, h, (px, py))
            
            # Get input tensor for networks
            input_tensor = np.stack([xx, yy, f], axis=0)
            input_tensor = torch.from_numpy(input_tensor).float()
        
            # Get b for linear equations, If the value of boundary changed should fix here
            b = f.reshape(n**2) * h2
            if boundary_type == 'D':
                b = apply_diri_bc(b, {'top':0, 'bottom':0, 'left':0, 'right':0})
            elif boundary_type == 'N':
                b = apply_diri_bc(b)
                b = apply_neumann_bc(b, h, f)

            # get predicted value and real ans
            ans = sla.lu_solve((lu, piv), b).reshape(n, n)

            pre = pl_module(input_tensor[None, ...])
            pre = pl_module.padder(pre).numpy().reshape(n, n)



            # Draw
            fig = plt.figure()
            fig.suptitle(f'{folder.name}_case{i}', fontsize=20)
            fig.set_figheight(10)
            fig.set_figwidth(30)

            ax1 = fig.add_subplot(1, 3, 1, projection='3d')
            ax2 = fig.add_subplot(1, 3, 2, projection='3d')
            ax3 = fig.add_subplot(1, 3, 3)


            ax1.set_title(f'$Pre U$', fontsize=20)
            surf_pre = ax1.plot_surface(xx, yy, pre, cmap=cm.coolwarm,)
            plt.colorbar(surf_pre, shrink=0.8, ax=ax1)

            ax2.set_title(f'$Ans U$', fontsize=20)
            surf_ans = ax2.plot_surface(xx, yy, ans, cmap=cm.coolwarm,)
            plt.colorbar(surf_ans, shrink=0.8, ax=ax2)

            ax3.set_title(f'Differnce', fontsize=20)
            im = ax3.imshow(np.abs(ans - pre))
            plt.colorbar(im, shrink=0.8, ax=ax3)
            fig.tight_layout()
            plt.close(fig)


            # Save images
            # Pre image
            fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
            fig.set_figheight(10)
            fig.set_figwidth(10)
            surf_pre = ax.plot_surface(xx, yy, pre, cmap=cm.coolwarm,)
            plt.colorbar(surf_pre, shrink=0.8, ax=ax)
            fig.savefig(f"{case_img_save_path/'pre.png'}", bbox_inches='tight')
            plt.close(fig)

            fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
            fig.set_figheight(10)
            fig.set_figwidth(10)
            surf_pre = ax.plot_surface(xx, yy, ans, cmap=cm.coolwarm,)
            plt.colorbar(surf_pre, shrink=0.8, ax=ax)
            fig.savefig(f"{case_img_save_path/'ans.png'}", bbox_inches='tight')
            plt.close(fig)

            fig, ax = plt.subplots()
            fig.set_figheight(10)
            fig.set_figwidth(10)
            im = ax.imshow(np.abs(ans - pre))
            plt.colorbar(im, shrink=0.8, ax=ax)
            fig.savefig(f"{case_img_save_path/'diff.png'}", bbox_inches='tight')
            plt.close(fig)




cases = [
    [(0, 0, 1)],
    [(-0.5, -0.5, 1), (0.5, 0.5, 1)],
]
ckpt_path = Path('./lightning_logs/')
imgs_save_path = Path('./images/')
test_cases(cases, ckpt_path, imgs_save_path)