In [1]:
from controller import LearnableTree
from tree import BinaryTree
import torch
import torch.nn as nn
import function as func
from pde import possion_eq, true_solution, generate_boundary_points, cal_l2_relative_error

In [2]:
unary = func.unary_functions
binary = func.binary_functions

In [3]:
    # build a computation tree
    #     X
    #    / \
    #   X   X
    #  / \  |
    # X  X  X  
tree = BinaryTree(node_operator=None, is_unary=False)
tree.leftChild = BinaryTree(node_operator=None, is_unary=False)
tree.rightChild = BinaryTree(node_operator=None, is_unary=True)
tree.leftChild.leftChild = BinaryTree(node_operator=None, is_unary=True)
tree.leftChild.rightChild = BinaryTree(node_operator=None, is_unary=True)
tree.rightChild.leftChild = BinaryTree(node_operator=None, is_unary=True)
operator_idxs = torch.tensor([5, 1, 8, 0, 8, 3])

In [4]:
left = -1
right = 1
# 获取boundary points和pde points
boundary_points = torch.rand(1000, 2) * (right - left) + left
true_boundary = true_solution(boundary_points)
pde_points = (torch.rand(1000, 2) * (right - left) + left).requires_grad_(True)

In [5]:
learnable_tree = LearnableTree(tree, dim=2)
params = learnable_tree.get_parameters()

In [6]:
def train_learnable_tree(tree, operator_idxs, boundary_points, true_boundary, pde_points, epochs=1000, lr=0.01):
    # 初始化LearnableTree
    learnable_tree = LearnableTree(tree, dim=2)
    # 定义优化器
    optimizer = torch.optim.Adam(learnable_tree.get_parameters(), lr=lr)
    # print('learnable tree parameters: ', learnable_tree.get_parameters())
    
    # 开始训练
    for epoch in range(epochs):
        # 前向传播
        pred_boundary = learnable_tree(boundary_points, operator_idxs)
        pred_pde = learnable_tree(pde_points, operator_idxs)
        # 计算pde loss
        pde_loss = (possion_eq(pred_pde, pde_points, 2)**2).mean()
        # 计算boundary loss
        boundary_loss = ((pred_boundary - true_boundary)**2).mean()
        # 计算总loss
        loss = pde_loss + 100 * boundary_loss
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 打印loss
        if epoch % 100 == 0:
            print('epoch: {}, loss: {}'.format(epoch, loss.item()))
            params = learnable_tree.get_parameters()
            grads = [p.grad for p in params]
            


In [7]:
train_learnable_tree(tree, operator_idxs, boundary_points, true_boundary, pde_points, epochs=1000, lr=0.01)

epoch: 0, loss: 543.5640869140625
epoch: 100, loss: 8.505796432495117
epoch: 200, loss: 4.204105377197266
epoch: 300, loss: 3.4457178115844727
epoch: 400, loss: 3.142218589782715
epoch: 500, loss: 2.8415656089782715
epoch: 600, loss: 2.545720338821411
epoch: 700, loss: 2.2618842124938965
epoch: 800, loss: 1.9943264722824097
epoch: 900, loss: 1.7451786994934082
