In [1]:
from controller import LearnableTree
from tree import BinaryTree
import torch
import torch.nn as nn
import function as func
from possion.pde import possion_eq, cal_l2_relative_error
from possion.pde import possion_true_solution as true_solution
from possion.pde import possion_sample_bc_x as sample_bc_x


In [2]:
unary = func.unary_functions
binary = func.binary_functions

In [3]:
    # build a computation tree
    #     X
    #    / \
    #   X   X
    #  / \  |
    # X  X  X  
tree = BinaryTree(node_operator=None, is_unary=False)
tree.leftChild = BinaryTree(node_operator=None, is_unary=False)
tree.rightChild = BinaryTree(node_operator=None, is_unary=True)
tree.leftChild.leftChild = BinaryTree(node_operator=None, is_unary=True)
tree.leftChild.rightChild = BinaryTree(node_operator=None, is_unary=True)
tree.rightChild.leftChild = BinaryTree(node_operator=None, is_unary=True)
operator_idxs = torch.tensor([5, 1, 8, 0, 8, 3])

In [4]:
left = -1
right = 1
# 获取boundary points和pde points
boundary_points = torch.rand(1000, 2) * (right - left) + left
true_boundary = true_solution(boundary_points)
pde_points = (torch.rand(1000, 2) * (right - left) + left).requires_grad_(True)

In [5]:
learnable_tree = LearnableTree(tree, dim=2)
params = learnable_tree.get_parameters()

In [6]:
def train_learnable_tree(tree, operator_idxs, boundary_points, true_boundary, pde_points, epochs=1000, lr=0.01):
    # 初始化LearnableTree
    learnable_tree = LearnableTree(tree, dim=2)
    # 定义优化器
    optimizer = torch.optim.Adam(learnable_tree.get_parameters(), lr=lr)
    # print('learnable tree parameters: ', learnable_tree.get_parameters())
    
    # 开始训练
    for epoch in range(epochs):
        # 前向传播
        pred_boundary = learnable_tree(boundary_points, operator_idxs)
        pred_pde = learnable_tree(pde_points, operator_idxs)
        # 计算pde loss
        pde_loss = (possion_eq(pred_pde, pde_points, 2)**2).mean()
        # 计算boundary loss
        boundary_loss = ((pred_boundary - true_boundary)**2).mean()
        # 计算总loss
        loss = pde_loss + 100 * boundary_loss
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 打印loss
        if epoch % 100 == 0:
            x_test = torch.rand(20000, 2) * (right - left) + left
            true = true_solution(x_test)
            pred = learnable_tree(x_test, operator_idxs)
            l2_err = cal_l2_relative_error(pred, true)
            print('epoch: {}, pde loss: {}, boundary loss: {}, l2 error: {}'.format(epoch, pde_loss, boundary_loss, l2_err))

            


In [7]:
train_learnable_tree(tree, operator_idxs, boundary_points, true_boundary, pde_points, epochs=1000, lr=0.01)

epoch: 0, pde loss: 10.427647590637207, boundary loss: 0.1359492391347885, l2 error: 0.8953984379768372
epoch: 100, pde loss: 0.559153139591217, boundary loss: 0.0070700980722904205, l2 error: 0.2090953290462494
epoch: 200, pde loss: 0.10199389606714249, boundary loss: 0.00040378657286055386, l2 error: 0.04920344054698944
epoch: 300, pde loss: 0.0311147291213274, boundary loss: 5.459683598019183e-05, l2 error: 0.018310105428099632
epoch: 400, pde loss: 0.011383498087525368, boundary loss: 5.135366791364504e-06, l2 error: 0.00612399447709322
epoch: 500, pde loss: 0.0065131415612995625, boundary loss: 6.482395292550791e-06, l2 error: 0.006411783397197723
epoch: 600, pde loss: 0.004578797146677971, boundary loss: 1.0634038517309818e-05, l2 error: 0.008128189481794834
epoch: 700, pde loss: 0.0035551704932004213, boundary loss: 1.377131775370799e-05, l2 error: 0.009251370094716549
epoch: 800, pde loss: 0.00298957247287035, boundary loss: 1.6115474863909185e-05, l2 error: 0.00997280515730381