In [6]:
import numpy as np
from PhysicsRegression import PhyReg

In [2]:
phyreg = PhyReg(
    path = "./model.pt"
)

## First try an toy example with Divide-and-Conquer Strategy

基本的符号回归流程，使用Oracle分治策略

数据准备了随机的x，用确定的函数目标函数 $y = 2.2 * x_0 * x_1$ 生成了 y，然后尝试拟合x和y的关系。


In [5]:
x = np.random.random((100, 2)) * 3
y = 2.2 * x[:, 0] * x[:, 1]

phyreg.fit(x, y, use_Divide=True, use_MCTS=False, use_GP=False,
           save_oracle_model=False, oracle_name="demo")

Training oracle Newral Network...
Generating formula through End-to-End...
Finished forward in 1.6989436149597168 secs
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Back aggregating formulas...
idx : 0
expr: (2.2 * (x_0 * x_1))
mse : 4.269093351926522e-31



  完整执行流程:                                                                                                                                                                        
                                                                                                                                                                                       
  ┌─────────────────────────────────────────┐                                                                                                                                          
  │ 阶段1: 数据预处理                        │                                                                                                                                         
  ├─────────────────────────────────────────┤                                                                                                                                          
  │ • 转换格式: x=(100,2), y=(100,1)        │                                                                                                                                          
  │ • 编码提示: units/complexity/unarys     │                                                                                                                                          
  │ • 数据验证: 检查维度和数值有效性         │                                                                                                                                         
  └─────────────────────────────────────────┘                                                                                                                                          
      ↓                                                                                                                                                                                
  ┌─────────────────────────────────────────┐                                                                                                                                          
  │ 阶段2: Oracle 分治策略                   │                                                                                                                                         
  │ (use_Divide=True)                       │                                                                                                                                          
  ├─────────────────────────────────────────┤                                                                                                                                          
  │ self.oracle.oracle_fit()                │                                                                                                                                          
  │ [Oracle/oracle.py:686-903]              │                                                                                                                                          
  │                                         │                                                                                                                                          
  │ 1. 训练SimpleNet逼近目标函数             │                                                                                                                                         
  │    SimpleNet架构:                       │                                                                                                                                          
  │    ├─ Linear(2→128) + tanh             │                                                                                                                                           
  │    ├─ Linear(128→128) + tanh           │                                                                                                                                           
  │    ├─ Linear(128→64) + tanh            │                                                                                                                                           
  │    ├─ Linear(64→64) + tanh             │                                                                                                                                           
  │    └─ Linear(64→1) [无激活]            │                                                                                                                                           
  │                                         │                                                                                                                                          
  │ 2. oracle_seperate() 分解尝试           │                                                                                                                                          
  │    [oracle.py:483-531]                  │                                                                                                                                          
  │    ├─ 计算导数矩阵(0/1/2阶)             │                                                                                                                                          
  │    ├─ 尝试分离策略:                     │                                                                                                                                          
  │    │  • "id,add": z=φ(x), y=z+ψ(x)     │                                                                                                                                           
  │    │  • "inv,mul": z=1/φ(x), y=z*ψ(x)  │                                                                                                                                           
  │    │  • "arcsin,add": z=arcsin(φ(x))   │                                                                                                                                           
  │    │  • "sqrt,mul": z=√φ(x)            │                                                                                                                                           
  │    │  └─ "arccos,add": z=arccos(φ(x))  │                                                                                                                                           
  │    └─ 返回: 主问题 + 残差子问题         │                                                                                                                                          
  └─────────────────────────────────────────┘                                                                                                                                          
      ↓                                                                                                                                                                                
  ┌─────────────────────────────────────────┐                                                                                                                                          
  │ 阶段3: 端到端 Transformer 预测          │                                                                                                                                          
  ├─────────────────────────────────────────┤                                                                                                                                          
  │ self.dstr.fit(total_x, total_y, hints)  │                                                                                                                                          
  │                                         │                                                                                                                                          
  │ 3.1 LinearPointEmbedder                 │                                                                                                                                          
  │     [embedders.py:41+]                  │                                                                                                                                          
  │     ├─ 浮点数编码: [符号,指数,尾数]     │                                                                                                                                          
  │     ├─ 投影: Linear((2+1)*10→512)      │                                                                                                                                           
  │     └─ 输出: (100, 512) 嵌入向量       │                                                                                                                                           
  │                                         │                                                                                                                                          
  │ 3.2 TransformerEncoder                  │                                                                                                                                          
  │     [transformer.py:245+]               │                                                                                                                                          
  │     ├─ 位置编码: Sinusoidal            │                                                                                                                                           
  │     ├─ 2层Transformer:                 │                                                                                                                                           
  │     │  ├─ MultiHeadAttention (16头)   │                                                                                                                                            
  │     │  ├─ LayerNorm                   │                                                                                                                                            
  │     │  ├─ FFN: 512→2048→512          │                                                                                                                                             
  │     │  └─ LayerNorm                   │                                                                                                                                            
  │     └─ 输出: (100, 512) 上下文表示    │                                                                                                                                            
  │                                         │                                                                                                                                          
  │ 3.3 TransformerDecoder + BeamSearch    │                                                                                                                                           
  │     [transformer.py:500+]               │                                                                                                                                          
  │     ├─ 自回归生成(因果掩码)             │                                                                                                                                          
  │     ├─ 束搜索(beam_size=10):           │                                                                                                                                           
  │     │  ├─ 保持top-10最优路径           │                                                                                                                                           
  │     │  ├─ 每步扩展10个候选token        │                                                                                                                                           
  │     │  └─ 长度惩罚(length_penalty)     │                                                                                                                                           
  │     └─ 输出: 10个公式候选及评分        │                                                                                                                                           
  └─────────────────────────────────────────┘                                                                                                                                          
      ↓                                                                                                                                                                                
  ┌─────────────────────────────────────────┐                                                                                                                                          
  │ 阶段4: 公式聚合                          │                                                                                                                                         
  ├─────────────────────────────────────────┤                                                                                                                                          
  │ self.oracle.reverse()                   │                                                                                                                                          
  │ • 将子问题解合并为完整公式               │                                                                                                                                         
  │ • 应用分解策略逆变换                    │                                                                                                                                          
  └─────────────────────────────────────────┘  

## Use the following code to try on your specified data.

- Modify the `oracle_name` parameter to point to another directory in `Oracle_model` for different symbolic regression problems

- Set `save_oracle_model` to be True if you wish to save to weight of OracleNN

In [None]:
# your specific data

# x = ...
# y = ...

# assert x.shape[0] == y.shape[0] and len(y.shape) == 1

# phyreg.fit(x, y, use_Divide=True, use_MCTS=False, use_GP=False,
#            save_oracle_model=False, oracle_name="demo1")

## You can then utilize MCTS and GP algorism to further refine the results

还是一样的数据，但是加了蒙特卡洛树搜索和遗传编程优化，

In [4]:
x = np.random.random((100, 2)) * 3
y = 2.2 * x[:, 0] * x[:, 1]

phyreg.fit(x, y, use_Divide=True, use_MCTS=True, use_GP=True, use_pysr_init=True, use_const_optimization=True,
           save_oracle_model=False, oracle_name="demo")

Training oracle Newral Network...
Generating formula through End-to-End...
Finished forward in 1.8899621963500977 secs
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 0/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Back aggregating formulas...
idx : 0
expr: (2.2 * (x_0 * x_1))
mse : 6.7097088146518035e-31

Refining formula through MCTS...
idx : 0
expr: (2.2 * (x_0 * x_1))
mse : 6.7097088146518035e-31

Refining formula through GP...
idx : 0
expr: (2.2 * (x_0 * x_1))
mse : 6.7097088146518035e-31

Refining constants...
idx : 0
expr: (2.2 * (x_0 * x_1))
mse : 6.7097088146518035e-31



### Physical priors includes physical units, complexity, candidate unarys, candidate constants can also be included within the searching process.

In [None]:
phyreg.fit(
        x, y, use_Divide=True, use_MCTS=False, use_GP=False, 
        units = ["kg1m1s0T0V0", "kg0m0s2T-1V-1", "kg1m1s2T-1V-2"],
        complexitys = 8,
        unarys = ["sin", "cos"],
        consts = [[2.1, "kg0m0s0T0V-1"]],
        save_oracle_model=False, oracle_name="demo"
    )

Training oracle Newral Network...
Generating formula through End-to-End...
Finished forward in 4.378088474273682 secs
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
Removed 1/2 skeleton duplicata
[{'refinement_type': 'NoRef', 'predicted_tree': (2.061 mul (x_0 mul (x_1 mul (1 add (sin(1))**2)))), 'time': 4.380510091781616, 'relabed_predicted_tree': (2.061 mul (x_0 mul (x_1 mul (1 add (sin(1))**2))))}, {'refinement_type': 'NoRef', 'predicted_tree': ((2.5 mul (x_0 mul (x_1 mul (sin(1) mul inv(sqrt((cos(1))**2)))))) add -0.006562999999999999), 'time': 4.394014358520508, 'relabed_predicted_tree': ((2.5 mul (x_0 mul (x_1 mul (sin(1) mul inv(sqrt((cos(1))**2)))))) add -0.006562999999999999)}, {'refinement_type': 'NoRef', 'predicted_tree': ((10.0 mul (x_0 mul (sin(1) mul inv((1 add cos(1)))))) ad