In [1]:
import sys
sys.path.append('../')

import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset

from MyPlot import *
from utils import force, fd_solve_nlinear
from ConstCofFVM import UniformFVM, BlockCofProblem
from BaseTester import BaseTester


In [3]:

class C2TestDs(Dataset):
	def __init__(self, F, M, dtype, device):
		self.force = torch.from_numpy(F).to(dtype).to(device)
		self.mus = torch.from_numpy(M).to(dtype).to(device)
			
	def __len__(self):
		return self.mus.shape[0]

	def __getitem__(self, index):
		f = self.force[index]
		mu = self.mus[index]
		data = torch.stack([f, mu])[None, ...]
		return data, f, mu
	
class nonlinearTester(BaseTester):
	def __init__(self, **kwargs):
		super().__init__(**kwargs)
		self.mesh()

	def init_test_ds(self):
		F, M = [], []
		for i in range(self.DataN):
			F.append(force(self.xx, self.yy, self.centers[i]))
			M.append(np.ones((self.GridSize, self.GridSize)) * self.mus[i])
		F, M = np.stack(F), np.stack(M)
		self.ds = C2TestDs(F, M, self.dtype, self.device)

	def hard_encode(self, x, gd):
		y = F.pad(x, (1, 1, 1, 1), 'constant', value=gd)
		return y
	
	def test(self, centers, mus, exp_name, DataN, best_or_last):
		if centers is None or mus is None:
			self.centers = np.random.uniform(0.05, 0.95, DataN)
			self.mus = np.random.uniform(0.1, 1, DataN)
			self.DataN = DataN
		else:
			self.centers = centers
			self.mus = mus
			self.DataN = len(mus)

		self.load_kwargs(exp_name)
		self.init_test_ds()
		self.load_ckpt(best_or_last, exp_name)

		df = {
			'id': [],
			'l2': []
		}

		with torch.no_grad():
			for i, (data, force, mu) in enumerate(self.ds):
				pre = self.net(data)
				pre = self.hard_encode(pre, gd=0)
				
				mu = self.mus[i]
				c = self.centers[i]
				ans = fd_solve_nlinear(self.GridSize, self.area, mu, c, Picard_maxiter=2000)
				self.save_img(f"{self.img_save_path}/{exp_name}/TestCase-{i}", pre, ans, force, mu)
		
		df = pd.DataFrame(df)
		df.to_csv(f"{self.img_save_path}/{exp_name}/l2.csv", index=False)
	
	def save_img(self, path, pre, ans, force, mu):
		p = Path(path)
		if not p.is_dir():
				p.mkdir(parents=True)

		pre = pre.cpu().numpy().reshape(self.GridSize, self.GridSize)
		force = force.cpu().numpy().reshape(self.GridSize, self.GridSize)
		ans = ans.reshape(self.GridSize, self.GridSize)
		
		save_surf(path, pre, self.xx, self.yy, 'surf_pre')
		save_surf(path, ans, self.xx, self.yy, 'surf_ref')
		save_ctf(path, pre, ans, self.xx, self.yy)
		save_contour(path, pre, ans, self.xx, self.yy, levels=None)
		save_img_force(path, force, 'force')

In [4]:
GridSize = 256
centers = [(0.5, 0.5), (0.5, 0.5), (0.5, 0.5)]
mus=[0.1, 0.5, 1.0]
DataN = 10

nonlinear_tester = nonlinearTester(
    GridSize=GridSize,
    area=((0, 0), (1, 1)),
    ckpt_save_path=f'model_save',
    hyper_parameters_save_path = f'hyper_parameters', 
    img_save_path = f'./images', 
    device='cuda',
    dtype=torch.float,
	)

exp_name = 'JuMu-ResBottleNeck-2#2#4#6#6-2#2#4#6#6-2-layer-max-replicate-GridSize:256-maxiter:5-trainN:10000-bs:5'
best_or_last = 'best_train'
nonlinear_tester.test(centers, mus, exp_name, DataN, best_or_last)


Itr: 0	 Delta: 1.217e+02	 Error: 1.577e-01	
Itr: 1	 Delta: 9.903e+01	 Error: 5.867e-01	
Itr: 2	 Delta: 5.041e+01	 Error: 1.004e+00	
Itr: 3	 Delta: 4.245e+01	 Error: 1.033e+00	
Itr: 4	 Delta: 2.027e+01	 Error: 2.456e+00	
Itr: 5	 Delta: 1.943e+01	 Error: 1.660e+00	
Itr: 6	 Delta: 9.815e+00	 Error: 3.715e+00	
Itr: 7	 Delta: 6.382e+00	 Error: 1.606e+00	
Itr: 8	 Delta: 1.024e+00	 Error: 4.211e-01	
Itr: 9	 Delta: 3.518e-01	 Error: 5.762e-03	
Itr: 10	 Delta: 1.076e-01	 Error: 1.745e-03	
Itr: 11	 Delta: 3.082e-02	 Error: 5.829e-04	
Itr: 12	 Delta: 8.084e-03	 Error: 1.701e-04	
Itr: 13	 Delta: 1.981e-03	 Error: 4.616e-05	
Itr: 14	 Delta: 4.545e-04	 Error: 1.158e-05	
Itr: 15	 Delta: 9.815e-05	 Error: 2.713e-06	
Itr: 16	 Delta: 2.003e-05	 Error: 5.966e-07	
Itr: 17	 Delta: 3.876e-06	 Error: 1.237e-07	
Itr: 18	 Delta: 7.135e-07	 Error: 2.427e-08	
Itr: 0	 Delta: 1.503e+01	 Error: 1.837e-01	
Itr: 1	 Delta: 9.991e+00	 Error: 3.862e-02	
Itr: 2	 Delta: 4.918e+00	 Error: 3.341e-02	
Itr: 3	 Delta: 2.299e+0