In [None]:
import sys
sys.path.append('../')

import numpy as np
import torch
import matplotlib

from user_funn.field import D1Field
from user_funn.get_net import ForwardNetwork
from user_funn.ds import get_data_loader
from user_funn.solver import CloudPointSolver
from user_funn.pde import grad

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ForwardNetwork([1, 50, 50, 50, 1]).to(device)
loss_fn = torch.nn.MSELoss()

solver = CloudPointSolver(
    model = [1, 100, 100, 100, 1],
    optimizer = "adam"
    )

def pde_loss(model, data):
    x_in,y_real = data
    x_in.requires_grad=True
    U = model(x_in)
    u = U[:,[0]]
    dudx = grad(u, x_in)[0]
    loss = dudx - torch.pi * torch.cos(torch.pi * x_in)
    loss = loss_fn(loss, y_real)
    return loss

from user_funn.bc import data_loss_factory
data_loss = data_loss_factory(loss_fn,[0])

pde_epoch_size = 16
pde_batch_num = 1
pde_batch_size = pde_epoch_size//pde_batch_num

bc_epoch_size = 1
bc_batch_num = 1
bc_batch_size = bc_epoch_size//bc_batch_num

T_iter = 3
epoch_per_iter = 1000


solver_list = []
pde_input_list = []
pde_output_list = []
last_flame_data_list = []

for T_id in range(T_iter):
    t_span_start = 2*T_id
    t_span_end = 2*T_id+2 + 0.2
    pde_input_list.append( D1Field([t_span_start, t_span_end]).get_field_rand(pde_epoch_size) )
    pde_output_list.append( np.zeros([pde_epoch_size,1]) )

    bc_input = np.array([[t_span_start]])
    if T_id == 0:
        bc_output = np.zeros([bc_epoch_size ,1]) + 1
    else:
        bc_output = np.array(last_flame_data_list[T_id-1]).reshape([bc_epoch_size ,1])
    solver_list.append(CloudPointSolver(
        [[pde_input_list[T_id], pde_output_list[T_id]], [bc_input, bc_output]],
        [pde_loss,data_loss],
        model = [1, 100, 100, 100, 1],
        optimizer = "adam",
        batchsize = [pde_batch_size,bc_batch_size]))

    last_flame_data_list.append(solver.model_eval([[t_span_start + 2]]).item())

x_all = []
y_all = []

for T_id in range(T_iter):
    ## define a pde
    t_span_start = 2*T_id
    t_span_end = 2*T_id+2 +0.2


    ## define a bc
    bc_input = np.array([[t_span_start]])
    bc_output = np.zeros([bc_epoch_size ,1])
    if T_id != 0:
        bc_output[0] = last_flame_data
    
    for epoch_local_id in range(epoch_per_iter):
        epoch_id = T_id * epoch_per_iter + epoch_local_id
        solver.train_step()
        if epoch_local_id % 100 == 0:
            solver.test_step(print_flag=True)

    x_cpu = np.linspace(t_span_start, t_span_start + 2, 100).reshape(100,1)
    y_cpu = solver.model_eval(x_cpu)
    
    last_flame_data = solver.model_eval([[t_span_start + 2]]).item()
    print(last_flame_data)

    if T_id == 0:
        x_all = x_cpu
        y_all = y_cpu
    else:
        x_all = np.vstack([x_all,x_cpu])
        y_all = np.vstack([y_all,y_cpu])

