# A BERT-based algorithm for edit distance between text trees

This notebook contains the implementation of a simple yet informative metric for text tree comparison. This way of text tree similarity measurement can be used, for example, to compare salient sentence-based mind maps generated by a neural network with reference maps. The way this algorithm works is by applying the Zhang-Shasha algorithm for tree edit distance to text trees, using semantic similarity as the cost of node updates. To measure semantic similarity of sentences in tree nodes, we use a BERT-like language model's embeddings of the sentences, given the context of all parent nodes if available, and compare these embeddings directly.

The Zhang-Shasha algorithm implementation is taken from the `zss` Python library developed by Tim Henderson and Steve Johnson (2013). _[here goes the description of what model we use and how we compare embeddings from it]_. 

In [10]:
import warnings
warnings.filterwarnings("ignore")

import zss
from sentence_transformers import SentenceTransformer

First, we will slightly extend the `Node` class from `zss` to include depth information in tree nodes. This will later be useful for similarity weight adjustment for nodes of different depth.

In [20]:
class Node(zss.Node):
    def __init__(self, label, children=None, depth=0):
        '''
        Updated version of Node constructor containing self.depth initialization.
        '''
        self.label = label
        self.children = children or list()
        self.depth = depth

    @staticmethod
    def get_depth(node):
        '''
        Method that returns the node's depth field.
        '''
        return node.depth

Here we provide a useful function to construct a `similarity_func` for our algorithm from a language model's encoder method and a given similarity function.  

In [16]:
def sentence_similarity(encoder, embedding_dist):
    '''
    A function that constructs the sentence similarity function used in Zhang-Shasha's algorithm.
    It uses a language model's encoder and a similarity function to estimate semantic similarity
    
    Arguments:
    encoder - a language model's callable encoder that takes string input returns embeddings from the model
    embedding_dist - an embedding similarity measure that takes two embeddings as input and returns a non-negative distance value. 

    Returns:
    similarity_func - function of two string arguments that calculates the between two sentences using the given embedder and distance measure.
    '''
    def similarity_func(sentence_a, sentence_b):
        a_embedding = encoder(sentence_a)
        b_embedding = encoder(sentence_b)

        return embedding_dist(a_embedding, b_embedding)

    return similarity_func

Also we will need a function to relabel text trees so that each node contains all of the parent nodes' labels a context.

In [29]:
def tree_with_context(input_node: Node):
    '''
    A function that creates a relabeled tree based on the input one by adding sentences from parent nodes as context to child nodes.

    Arguments:
    input_node - root node of the tree to be relabeled.

    Returns: root node of relabeled tree.
    '''
    def add_context(node: Node, context):
        new_node = Node(context + Node.get_label(node), depth=Node.get_depth(node))
        
        for child_node in Node.get_children(node):
            new_node.addkid(add_context(child_node, Node.get_label(new_node)))
        
        return new_node

    return add_context(input_node, "")

Finally, we use all the functions above to compute tree distance with `zss.distance`. Note the possible usage of a depth factor to adjust update costs according to the updated node's depth. Our approach can be modified to factor in node depth in a different way.

In [32]:
def text_tree_distance(tree_a: Node, tree_b: Node, similarity_func, depth_factor=1.0, use_context=True):
    '''
    The function that calculates tree edit distance between to trees given a similarity function for sentence pairs.
    
    Arguments:
    tree_a, tree_b - two trees of the zss type Node to be compared;
    similarity_func - a function of two string arguments which computes sentence distance;
    depth_factor - a hyperparameter that scales sentence similarity based on the node's depth;
    use_context - a flag indicating whether parents of the given node will be used as context for sentence comparison.
    '''
    if use_context:
        tree_a = tree_with_context(tree_a)
        tree_b = tree_with_context(tree_b)

    # Here we define the update_cost, insert_cost and delete_cost functions needed for Zhang-Shasha's algorithm using the provided similarity function.
    def update_cost(node_a: Node, node_b: Node):
        return similarity_func(Node.get_label(node_a), Node.get_label(node_b)) * depth_factor**Node.get_depth(node_a)

    def insert_cost(node: Node):
        # We define node insertion and deletion costs as similarity of the node's label to an empty string
        return similarity_func(Node.get_label(node), "")

    def remove_cost(node: Node):
        return similarity_func(Node.get_label(node), "")

    dist = zss.distance(tree_a, tree_b, Node.get_children, insert_cost, remove_cost, update_cost)
    return dist

Below we provide an example use case of the functions above utilizing a model from `sentence_transformers`:

In [18]:
model = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v1')

def cos_dist(a_embedding, b_embedding):
    return float(1 - model.similarity(a_embedding, b_embedding))

similarity_func = sentence_similarity(model.encode, cos_dist)

In [35]:
A = (
    Node('We present a new metric for text tree comparison.', depth=0)
        .addkid(Node('It uses Zhang-Shasha\'s algorithm and a BERT-like model.', depth=1)
            .addkid(Node('Zhang-Shasha\'s algorithm is used to measure tree edit distance effectively.', depth=2))
            .addkid(Node('The BERT-like model is used to measure semantic similarity.', depth=2)))
        .addkid(Node('The algorithm is presented as an informative metric for text tree comparison.', depth=1)
            .addkid(Node('There hasn\'t yet been a metric that allows to compare tree-structured text data such as mind maps informatively.', depth=2))
            .addkid(Node('This metric can be used, for example, to evaluate automatic salient sentence-based mind map generation.', depth=2)))
)

B = (
    Node('A new metric for text tree comparison based on tree edit distance and semantic similarity.', depth=0)
        .addkid(Node('Zhang-Shasha\'s algorithm is used to compute tree edit distance.', depth=1))
        .addkid(Node('Semantic similarity is measured using a BERT-like language model.', depth=1)
            .addkid(Node('To measure it, the sentences with all parent nodes as context are passed to the language model.', depth=2))
            .addkid(Node('Semantic similarity is measured as the similarity of the model\'s embeddings of the sentences.', depth=2)))
        .addkid(Node('This metric can be used to compare text trees.', depth=1)
            .addkid(Node('For example, it can be utilized in automatic mind map generation evaluation against reference maps.', depth=2)))
)

C = (
    Node('A new algorithm for text tree edit distance based on Zhang-Shasha\'s algorithm and BERT-like model embedding similarity.', depth=0)
        .addkid(Node('The algorithm\'s novelty is in its similarity measure based on BERT-like model embeddings.', depth=1)
            .addkid(Node('Embedding distance is used as a measure of semantic similarity.', depth=2))
            .addkid(Node('The language model allows to capture semantic meaning of sentences and model their similarity.', depth=2)))
        .addkid(Node('Zhang-Shasha\'s algorithm is used to compute tree edit distance with new edit costs.', depth=1)
            .addkid(Node('Semantic similarity is used as the update cost in the algorithm.', depth=2))
            .addkid(Node('The costs of insertion and removal of nodes are defined as the similarity of the node and an empty sentence.', depth=2)))
        .addkid(Node('The proposed algorithm is presented as a more informative metric of similarity between text trees', depth=1)
            .addkid(Node('The current ways of comparing text trees overlook overlook their tree structure or the meaning of their labels.', depth=2))
            .addkid(Node('This new method can be used, for example, to compare mind maps or hierarchical summaries.', depth=2)))
)

A_B_dist = text_tree_distance(A, B, similarity_func)
A_C_dist = text_tree_distance(A, C, similarity_func)
C_B_dist = text_tree_distance(C, B, similarity_func)

In [36]:
print(A_B_dist)
print(A_C_dist)
print(C_B_dist)

3.2734222412109375
4.964913785457611
5.261121928691864


In [37]:
A_B_dist_0_7 = text_tree_distance(A, B, similarity_func, depth_factor=0.7)
A_B_dist_1_5 = text_tree_distance(A, B, similarity_func, depth_factor=1.5)

In [38]:
print(A_B_dist_0_7)
print(A_B_dist_1_5)

2.860270586013794
4.1851606369018555
