In [1]:
import torch

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

np.random.seed(1)
torch.manual_seed(1)

import sys
sys.path.append('./../../')

from user_fun.solver.cp_solver import CloudPointSolver
from user_fun.geom import line_linspace,generate_points_in_rectangle

density = 32
init_input = line_linspace([0,0],[1,0],density*2)
init_output = np.sin(2*np.pi *init_input[:,[0]])

left_input = line_linspace([0,0],[0,3],density*3)
left_output = np.sin(np.pi *left_input[:,[1]])
right_input = line_linspace([1,0],[1,3],density*3)
right_output = np.sin(np.pi *right_input[:,[1]])

field_input = generate_points_in_rectangle([0,0],[1,3],density*density*3)

%matplotlib widget
from visual import process_point_sets,visualize_point_sets_interactive

# 示例数据
org_dict = {
    "left bound": [left_input, left_output],
    "right bound": [right_input, right_output],
    "init condition": [init_input, init_output],
    "residual points": [field_input, 'pde']
}

plot_cp_dict = process_point_sets(org_dict)


import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image

import time
def plot_cloudpoints(selected_cloudpoints):
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for name in selected_cloudpoints:
        x, y, c = plot_cp_dict[name]
        ax.scatter(x, y, c, label=name)
    ax.set_title("Selected Point Sets")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Height")
    ax.set_xlim([-0.2,1.2])
    ax.set_ylim([0,3])
    ax.set_zlim([-1,1])
    ax.legend()

    time.sleep(0.1)

    plt.tight_layout()  # 添加此行
    buf = BytesIO()
    fig.savefig(buf, format="png")
    fig.savefig("my_figure.png", format="png")  # 将图像保存到当前文件夹下的 "my_figure.png" 文件

    buf.seek(0)
    plt.close(fig)

    with open("my_figure.png", "rb") as f:
        img = Image.open(f)
        img.load()

    return img

checkboxes = gr.CheckboxGroup(list(plot_cp_dict.keys()), label="Cloudpoints Set to Show")
iface = gr.Interface(
    fn=plot_cloudpoints,
    inputs=checkboxes,
    outputs=gr.Image(plot=True)
)
iface.launch()



Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


