In [None]:
import sys
sys.path.append('../')

import copy

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
import matplotlib

from user_funn.field import D2Field
from user_funn.get_net import ForwardNetwork
from user_funn.ds import get_data_loader
from user_funn.solver import CloudPointSolver
from user_funn.geom import line_linspace
import user_funn.plot

np.random.seed(1)
torch.manual_seed(2)

# STEP1 data generate
batch_num = 1
nx = 21

field = D2Field([0,2],[0,2])
pde_batch_size = nx * nx
pde_input = field.get_field_mesh([nx,nx])
pde_output = np.zeros([pde_batch_size,3])

points_num_per_line = nx * 2 #每条边上用点数量
bc_left_input = line_linspace([0,0],[0,2], points_num_per_line)
bc_up_input = line_linspace([0,2],[2,2], points_num_per_line)
bc_right_input = line_linspace([2,2],[2,0], points_num_per_line)
bc_down_input = line_linspace([2,0],[0,0], points_num_per_line)

bc_uv_zero_input = np.vstack([bc_left_input, bc_down_input, bc_right_input])
bc_uv_up_input = bc_up_input 


user_funn.plot.scatter_2d_cloud_point_kind([pde_input, bc_uv_zero_input, \
    bc_uv_up_input])

from user_funn.geom import add_t
time_linspace = np.arange(0,1,0.1)
tc_input = add_t(pde_input,np.array([0]))
pde_input = add_t(pde_input,time_linspace)
bc_uv_zero_input = add_t(bc_uv_zero_input,time_linspace)
bc_uv_up_input = add_t(bc_uv_up_input,time_linspace)

tc_batchsize = tc_input.shape[0]
tc_output = np.zeros([tc_batchsize,3])

pde_batchsize = pde_input.shape[0]
pde_output = np.zeros([pde_batchsize,3])

bc_uv_zero_batchsize = bc_uv_zero_input.shape[0]
bc_uv_up_batchsize = bc_uv_up_input.shape[0]
bc_uv_zero_output = np.zeros([bc_uv_zero_batchsize, 2])

bc_uv_up_output = np.zeros([bc_uv_up_batchsize, 2])
bc_uv_up_output[:,0] = np.sin(np.pi*0.5*bc_uv_up_input[:,0])


In [None]:
from user_funn.pde import diff

MU = 0.1
loss_fn = torch.nn.MSELoss()

def pde_loss(model, data):
    x_in,y_real = data
    x_in.requires_grad=True

    t = x_in[:,[0]]
    x = x_in[:,[1]]
    y = x_in[:,[2]]
    x_use = torch.cat((t,x,y),dim = 1)
    U = model(x_use)
    p = U[:,[0]]
    u = U[:,[1]]
    v = U[:,[2]]

    dudt = diff(u,t)
    dvdt = diff(v,t)
    dudx = diff(u,x)
    dudy = diff(u,y)
    dvdx = diff(v,x)
    dvdy = diff(v,y)
    dpdx = diff(p,x)
    dpdy = diff(p,y)

    du2dx2 = diff(dudx,x)
    du2dy2 = diff(dudy,y)
    dv2dx2 = diff(dvdx,x)
    dv2dy2 = diff(dvdy,y)

    eq1 = u * dudx + v * dudy + dpdx - MU * (du2dx2 + du2dy2) - dudt
    eq2 = u * dvdx + v * dvdy + dpdy - MU * (dv2dx2 + dv2dy2) - dvdt
    eq3 = dudx + dvdy
    loss_val = loss_fn(eq1, y_real[:,[0]]) + loss_fn(eq2, y_real[:,[1]]) + \
        loss_fn(eq3, y_real[:,[2]])
    return loss_val


from user_funn.bc import data_loss_factory
bc_uv_zero_loss = data_loss_factory(loss_fn, [1,2])
bc_uv_up_loss = data_loss_factory(loss_fn, [1,2])
tc_loss = data_loss_factory(loss_fn, [0,1,2])

num_model = 2
overlap_length = 0.1
solver_list = []
for T_id in range(num_model):
    solver_list.append(
        CloudPointSolver(model = [2, 50, 50, 50, 1],optimizer = "adam")
    )

#
# hyperparameter config
T_iter = 2
t_length_per_iter = 1

t_span_start_list, overlap_start_list, t_span_end_list = [], [], []
pde_input_list,pde_output_list = [],[]
tc_input_list,tc_output_list = [],[]
bc_input_list,bc_output_list = [],[]

pde_epoch_size = 128
bc_epoch_size = 64
tc_epoch_size = 256


overlap_length = 0.1
for T_id in range(T_iter):
    t_span_start = T_id * 0.5
    overlap_start = (T_id+1) * 0.5
    t_span_end = (T_id+1) * 0.5 + overlap_length 
    t_span_start_list.append(t_span_start)
    overlap_start_list.append(overlap_start)
    t_span_end_list.append(t_span_end)
    
    
    pde_field = D2Field([t_span_start,t_span_end])
    pde_input = pde_field.get_field_rand(pde_epoch_size)
    pde_output = np.zeros([pde_epoch_size,1])
    pde_input_list.append(pde_input)
    pde_output_list.append(pde_output)

    # define bc_input and bc_output 
    bc_input1 = line_linspace([t_span_start,0],[t_span_end,0],bc_epoch_size//2)
    bc_input2 = line_linspace([t_span_start,L],[t_span_end,L],bc_epoch_size//2)
    bc_input = np.vstack([bc_input1,bc_input2])
    bc_output = np.zeros([bc_epoch_size ,1])
    bc_input_list.append(bc_input)
    bc_output_list.append(bc_output)

    # define tc_input and tc_output
    if T_id == 0:
        tc_input = line_linspace([0,0],[0,L],tc_epoch_size)
    else:
        tc_field = D2Field([t_span_start, t_span_start+ overlap_length],[0,L])
        tc_input = tc_field.get_field_rand(tc_epoch_size)

    tc_input_list.append(tc_input)

    cloud_point_data = [
        [pde_input, pde_output],
        [bc_uv_zero_input, bc_uv_zero_output],
        [bc_uv_up_input, bc_uv_up_output],
        [tc_input, tc_output]
        ]


tc_input_init = line_linspace([0,0],[0,L],tc_epoch_size)

tc_t = tc_input_init[:,0]
tc_x = tc_input_init[:,1]

tc_output_init = np.sin(n * np.pi * tc_x / L).reshape(tc_epoch_size,1)

# TRAIN:BEGIN
epoch_num = 6000
cloud_point_list = [None for i in range(2)]
for epoch_id in range(epoch_num):
    for use_model_id in range(2):
        t_span_start = use_model_id
        overlap_start = use_model_id+1
        t_span_end = (use_model_id+1) + overlap_length 

        # 通信周期
        if epoch_id % 10 == 0:
            seg_start = np.array([t_span_start,0]).reshape(1,2)

            if use_model_id == 0:
                tc_output = tc_output_init
            else:
                tc_output =\
                    solver_list[use_model_id-1].model_eval(
                        tc_input_list[use_model_id] \
                        - seg_start + np.array([1,0]).reshape(1,2))
            
            cloud_point_list[use_model_id] = [
                [pde_input_list[use_model_id] - seg_start,
                    pde_output_list[use_model_id]],
                [tc_input_list[use_model_id] - seg_start,
                    tc_output],
                [bc_input_list[use_model_id] - seg_start, 
                    bc_output_list[use_model_id]]
            ]
        
        # 测试周期
        if epoch_id % 100 == 0:
            print(f'model{use_model_id}',end ='')
            solver_list[use_model_id].test_step(
                cloud_point_list = cloud_point_list[use_model_id],
                loss_list = [pde_loss, tc_loss, bc_loss],
                batchsize = [pde_epoch_size, tc_epoch_size, bc_epoch_size],
                loss_weight_list = [1,1,1],
                print_flag=True)
            

        solver_list[use_model_id].train_step(
            cloud_point_list = cloud_point_list[use_model_id],
            loss_list = [pde_loss, tc_loss, bc_loss],
            batchsize = [pde_epoch_size, tc_epoch_size, bc_epoch_size],
            loss_weight_list = [1,1,1])
    
# TRAIN:END
