#### CODE ANALYSIS

函数中各参数的意义  
  
|**param**|**type**|**meaning**|
|:---:|:---:|:---:|
|**input_batch**|list | 原始文本中单词在词表中的索引 |
|**input_length**    |list  | 原始文本中单词序列的长度 |
|**target_batch**    |list  | 输出公式中的单词在词表中的索引 |
|**target_length**    |list  | 输出公式中单词序列的长度 |
|**nums_stack_batch** |list  | 原始文本中的重复数字在number_values list中的索引 |
|**num_size_batch**   |list | 原始文本中的数字个数 |
|**generate_nums**    |list  |  原始数据集中所包含的常数(1, 3.14) |
|**num_pos**        |list  |  原始文本中的数字在原始文本中的位置 |
|**batch_graph**     |numpy  | 建立好的Quantity Cell Graph 和 Quantity Comparison Graph |

In [None]:
def train_tree():
   ...
   # input_var:    [seq_len, batch_size]
   # input_length: [batch_size]
   # batch_graph:  [batch_size, 5, seq_len, seq_len]
   encoder_outputs, problem_output = encoder(input_seqs=input_var,
                                             input_lengths=input_length,
                                             batch_graph=batch_graph)
   # encoder_outputs: [seq_len, batch_size, hidden_size]
   # problem_output:  [         batch_size, hidden_size]

**encoder\_outputs**: graph representation $Z_{g}$  
**problem\_output**: &nbsp;&nbsp;root node $n_{0}$ goal vector $q_{0}$

In [None]:
def train_tree():
    ...
    # encoder_outputs: [batch_size, seq_len, hidden_size]
    all_nums_encoder_outputs = get_all_number_encoder_outputs(encoder_outputs=encoder_outputs,
                                                              num_pos=num_pos,
                                                              batch_size=batch_size,
                                                              num_size=num_size,
                                                              hidden_size=encoder.hidden_size)
    # all_nums_encoder_outputs: [batch_size, num_size, hidden_size]

**number  embedding** &nbsp;&nbsp;取自encoder\_outputs  
**constant embedding** 取自nn.Embedding(Prediction)，且每次生成新的节点时都会更新  
**operator embedding** 取自nn.Embedding(GenerateNode)，且每次生成新的节点时都会更新  

In [None]:
def train_tree():
    ...
    node_stacks       = [[TreeNode(embedding=_, left_flag=False)] for _ in problem_output.split(1, dim=0)]  # [1, hidden_size]
    embeddings_stacks = [[] for _ in range(batch_size)]
    left_childs       = [None for _ in range(batch_size)]

**node\_stacks**: TreeNode(embedding, left\_flag), 记录节点node中的 goal vector $q$  
初始根节点 $n_{0}$ 为 goal vector $q_{0}$ = problem_output  
  
**embedding\_stacks**: TreeEmbedding, 记录节点node之前的节点的subtree embedding $t$(list)  
如果为操作符(非叶子节点),此时embedding\_stacks添加operator的token embedding $e(y|P)$，并设置terminal=False  
如果为操作数(左孩子节点),此时embedding\_stacks添加operator的token embedding $e(y|P)$，并设置terminal=True  
如果为操作数(右孩子节点),此时  
&emsp;&emsp;初始化右孩子节点的subtree embedding $t_{r}$ 为token embedding $e(y|P)$  
&emsp;&emsp;弹出左孩子节点(terminal=True)的subtree embedding $t_{l}$和根节点的subtree embedding $t$(parent t)  
&emsp;&emsp;循环完成merge操作, 得到右孩子节点的最终subtree embedding $t_{r}$，并设置terminal=True  

**left\_childs**: 记录节点node中当前节点的subtree embedding $t$  
如果为操作符(非叶子节点),此时left\_childs输出为None  
如果为操作数(左孩子节点),此时left\_childs输出为左孩子节点的subtree embedding $t_l$  
如果为操作数(右孩子节点),此时left\_childs输出为右孩子节点的subtree embedding $t_r$    

#### 2. **prediction**

In [None]:
def train_tree():
    ...
    # encoder_outputs:          node representation(words)
    # all_nums_encoder_outputs: node representation(numbers)

    # encoder_outputs:          [seq_len,  batch_size, hidden_size]
    # all_nums_encoder_outputs: [batch_size, num_size, hidden_size]
    # padding_hidden:           [1,       hidden_size]
    # seq_mask:                 [batch_size, seq_len]
    # num_mask:                 [batch_size, num_size + constant_size]
    num_score, op, current_embeddings, current_context, current_nums_embeddings = predict(
        node_stacks=node_stacks,
        left_childs=left_childs,
        encoder_outputs=encoder_outputs,
        num_pades=all_nums_encoder_outputs,
        padding_hidden=padding_hidden,
        seq_mask=seq_mask,
        mask_nums=num_mask)
    outputs = torch.cat((op, num_score), dim=1)

**num\_score**: &emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;\[batch\_size, num\_size + constant\_size\]  
**op(op_score)**: &emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;\[batch\_size, operator\_size\]  
**(GOAL VECTOR $q$)**  
**current_embeddings**: &emsp;&emsp;&emsp;&nbsp;\[batch\_size, 1, hidden\_size\]  
**(CONTEXT VECTOR $c$)**  
**current_context**: &emsp;&emsp;&emsp;&emsp;&emsp;&nbsp;&nbsp;\[batch\_size, 1, hidden\_size\]  
**(CURRENT NUMBER EMBEDDING MATRIX $M_{num}$)**  
**current_nums_embeddings**: \[batch\_size, num\_size + constant\_size, hidden\_size\]  
  
**(OUTPUT TOKEN LOGIT)**  
**outputs:** &emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&nbsp;\[batch\_size, num\_size + constant\_size + operator\_size\]  

In [None]:
def train_tree():
    ... 
    # 预测出每一个target的值
    target_t, generate_input = generate_tree_input(target=target[t].tolist(),
                                                   decoder_output=outputs,
                                                   nums_stack_batch=nums_stack_batch,
                                                   num_start=num_start,
                                                   unk=unk)
    # target_t:       [batch_size] = ground_truth
    # generate_input: [batch_size]

**注**: 上述代码只在预测token为操作符时有效  
**target\_t**: ground_truth token index  
&emsp;&emsp;如果有重复数字，则取概率score最大的token index作为num_pos  
&emsp;&emsp;如果有重复数字，则选择重复数字中概率最大的位置作为公式中重复数字的num_pos  
**generate\_input**: 操作符的token index

#### 3. **generate**

In [None]:
def train_tree():
    ...
    # current_embeddings: [batch_size, 1, hidden_size]
    # generate_input:     [batch_size]
    # current_context:    [batch_size, 1, hidden_size]
    left_child, right_child, node_label = generate(node_embedding=current_embeddings,
                                                   node_label=generate_input,
                                                   current_context=current_context)
    # left_child:  [batch_size,    hidden_size]
    # right_child: [batch_size,    hidden_size]
    # node_label:  [batch_size, embedding_size]

**node\_embeddings**:&nbsp;parent goal vector $q$  
**current\_context**:&emsp;&nbsp;&nbsp;parent context vector $c$  
**node_label**: &emsp;&emsp;&emsp;&nbsp;&nbsp;parent token_embedding $e(\hat{y}|P)$
  
$$o_{l} = \sigma(W_{ol}[q, c, e(\hat{y}|P)])$$  
$$C_{l} = tanh(W_{cl}[q, c, e(\hat{y}|P)])$$  
$$h_{l} = o_{l} \odot C_{l}$$  
**left\_child**:  当前node的left child的$h_l$  
  
$$o_{r} = \sigma(W_{or}[q, c, e(\hat{y}|P)])$$  
$$C_{r} = tanh(W_{cr}[q, c, e(\hat{y}|P)])$$  
$$h_{r} = o_{r} \odot C_{r}$$  
**right\_child**: 当前node的right child的$h_r$  
**node\_label**:  操作符的token\_embedding $e(\hat{y}|P)$  

#### 4. **merge**

In [None]:
def train_tree():
    ...
    # op.embedding:        [1, embedding_size]
    # sub_stree.embedding: [1,    hidden_size]
    # current_num:         [1,    hidden_size]

    current_num = merge(node_embedding=op.embedding,
                        sub_tree_1=sub_stree.embedding,
                        sub_tree_2=current_num)
    # current_num: [1, hidden_size]

当前预测的token为操作数(叶子节点)时，更新叶子节点的Tree embedding  
&emsp;&emsp;如果此时为右孩子节点，则通过左孩子节点的 subtree embedding $t_{l}$ 和 右孩子节点的subtree embedding $t_{r}$, 来更新根节点的subtree embedding $t$  
  
**op.embedding**:&emsp;&emsp;&emsp;&nbsp;&nbsp;&nbsp;parent node token embedding $e(\hat{y}|P)$  
**sub_stree.embedding**:&nbsp;left\_sub\_tree\_embedding $t_{l}$  
**current_num**:&emsp;&emsp;&emsp;&emsp;&nbsp;&nbsp;right\_sub\_tree\_embedding $t_{r}$  
  
$$g_{t} = \sigma(W_{gt} [t_{l}, t_{r}, e(\hat{y}|P)])$$  
$$C_{t} = tanh  (W_{ct} [t_{l}, t_{r}, e(\hat{y}|P)])$$  
$$comb(t_{l}, t_{r}, \hat{y}) = g_{t} \odot C_{t}$$  
$$t = comb(t_{l}, t_{r}, \hat{y})$$  