In [3]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util
import xml.etree.ElementTree as ET
import numpy as np
import time
import os
import random

## Load model and metadata from disk

In [4]:
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
meta_data_df = pd.read_excel('data/category_configuration_09-08-2022_08-08-01.xlsx', sheet_name = 'article_names_matching')
title2category = dict(zip(meta_data_df["Article Title"], meta_data_df["Category 2"]))

## Code for calculate similarity between two documents

In [5]:
def elem2sent(article, break_sentence = True):
    '''
    Break article into sentences. Break sentences with "."
    article: xml element
    '''
    #if article is a list of sentences
    sentences = [_ for _ in article.itertext() if not _.isspace() and len(_.split()) > 5]
    
    #break sentences with "."
    if break_sentence:
        sentences = [sent for sub in map(lambda x: x.split('.'), sentences) for sent in sub if len(sent.split()) > 5]
    
    #remove empty sentences
    return list(filter(None, sentences))


def get_article_alignment(all_articles1, all_articles2, title2category, sanity_check = True):
    '''
    Get article alignment between two documents
    article_body: xml element
    title2category: dictionary of title to category
    sanity_check: use similarity score to check if the alignment is correct
    return a dictionary of article alignment
    '''
    alignment_match = {}

    #if both articles attribute includes title
    if 'title' in all_articles1[0].attrib and 'title' in all_articles2[0].attrib:
        article_title1 = [(article.get('title').lower(), article.get('num')) for article in all_articles1 if article.get('title')]
        article_title2 = [(article.get('title').lower(), article.get('num')) for article in all_articles2 if article.get('title')]
        for index1, (title1, num1) in enumerate(article_title1):
            for index2, (title2, num2) in enumerate(article_title2):
                if title2category.get(title1) == title2category.get(title2):
                    #to make sure that they have a high similarity
                    if sanity_check and max(util.cos_sim(
                        model.encode(''.join(elem2sent(all_articles1[index1]))),
                        model.encode(''.join(elem2sent(all_articles2[index2])))
                    )).item() < 0.7:
                        continue
                    alignment_match[num1] = num2
                    break
                
    #title not included in article attri, use sentence similarity instead
    else:
        for article1 in all_articles1:
            text1_embd = model.encode(''.join(elem2sent(article1)))
            text2_list_embd = model.encode([''.join(elem2sent(article2)) for article2 in all_articles2])
            scores = util.cos_sim(text1_embd, text2_list_embd)
            if max(scores[0]).item() > 0.7:
                index = np.argmax(scores[0])
                alignment_match[article1.get('num')] = all_articles2[index].get('num')

    return alignment_match
    

def extract_similar_sentences_from_article(article1, article2):
    '''
    article1, article2: xml element
    return: list of similar sentences: (sentence1, sentence2, similarity)
    '''
    article1_sents, article2_sents = elem2sent(article1, break_sentence = False), elem2sent(article2, break_sentence = False)
    #article1_sents, article2_sents = elem2sent(article1), elem2sent(article2)

    #Embed article1 and article2
    article1_embd, article2_embd = model.encode(article1_sents), model.encode(article2_sents)

    #Get similarity between article1 and article2
    scores = util.cos_sim(article1_embd, article2_embd)

    visited = set() #to make sure that we don't add the same sentence twice

    #filter out the sentence with similarity greater than 0.98, this means they are perfect match and no need to compare
    identical = (scores > 0.98).to(torch.int64)
    for i, j in identical.nonzero().tolist():
        visited.add('row' + str(i))
        visited.add('col' + str(j))

    #filter out the sentences with similarity between 0.5 and 0.98
    mask = (scores > 0.5) & (scores < 0.98)
    scores *= mask.to(torch.int64) 

    #get the index of the sentences with similarity between 0.5 and 0.98
    sim_pairs = [(scores[i][j], i, j) for i, j in mask.nonzero().tolist()]
    sim_pairs.sort(key = lambda x: x[0]) #sort by similarity score
    
    ret = []
    while sim_pairs:
        score, i, j = sim_pairs.pop()
        if 'row' + str(i) not in visited and 'col' + str(j) not in visited:
            ret.append((article1_sents[i], article2_sents[j], scores[i][j].item()))
            visited.add('row' + str(i))
            visited.add('col' + str(j))

    return ret

def extract_similar_from_doc(doc1_path, doc2_path, title2category, min_length = 5):
    '''
    doc_path: path to first document
    title2category: dictionary of title to category
    '''

    try:   
        doc1, doc2 = ET.parse(doc1_path), ET.parse(doc2_path)
        doc_root1, doc_root2 = doc1.getroot(), doc2.getroot()
        doc_body1, doc_body2 = doc_root1[1][2], doc_root2[1][2]
        
    except Exception as e:
        print(e); return []
    
    #get article alignment between two documents
    all_articles1, all_articles2 = doc_body1.findall(".//div[@type='article']"), doc_body2.findall(".//div[@type='article']")
    #alignment_match = get_article_alignment(doc_body1, doc_body2, title2category)
    alignment_match = get_article_alignment(all_articles1, all_articles2, title2category)


    ret = []
    for page1, page2 in alignment_match.items():
        article1, article2 = doc_body1.find(".//div[@num='" + page1 + "']"), doc_body2.find(".//div[@num='" + page2 + "']")
        ret.extend(extract_similar_sentences_from_article(article1, article2))
    
    #filter out the pairs in which both sentences are longer than min_length words and length difference is less than 4 * min_length
    ret = [x for x in ret if len(x[0].split()) > min_length and len(x[1].split()) > min_length and abs(len(x[0].split()) - len(x[1].split())) < 4 * min_length]

    #sort by similarity score
    ret.sort(key = lambda x: x[2], reverse = True)

    return ret

In [6]:
start = time.time()
diff = extract_similar_from_doc(
    'data/full data/t1989-9-canada-russian-federation-bit-1989.xml', 
    'data/full data/t1990-14-canada-czech-republic-bit-1990.xml',
    title2category
)
print(time.time() - start)

3.446413278579712


In [12]:
_ = extract_similar_from_doc(
    'data/full data/t1995-139-hong-kong-china-sar-italy-bit-1995.xml',
    'data/full data/t1995-140-hong-kong-china-sar-new-zealand-bit-1995.xml',
    title2category
)

In [8]:
#random sample 20 canada documents
random.seed(42)
canada_docs = ['data/canada data/' + _ for _ in os.listdir('data/canada data') if 'canada' in _]
#canada_docs = random.sample(canada_docs, 20)

Sanity check to make sure the code is not reporting errors:

In [10]:
for i in range(5):
    #randomly select one of the canada documents
    doc1 = random.choice(canada_docs)
    doc2 = random.choice(canada_docs)

    l = []

    try:
        _ = extract_similar_from_doc(doc1, doc2, title2category)
        l.append(_)
    except Exception as e:
        print(e)
        print(doc1, doc2)
        continue

In [13]:
for elem in l:
    print(elem)
    print('-----------------------------------------------------')

[("5. If a Tribunal's order designates information as confidential and a Contracting Party's law on access to information requires public access to that information, the Contracting Party's law on access to information prevails. However, the Contracting Party should try to apply its law on access to information so as to protect information that the Tribunal's order has designated as confidential.", "5. If a Tribunal's order designates information as confidential and a Party's law on access to information requires public access to that information, the Party's law on access to information prevails. However, the Party should try to apply its law on access to information so as to protect information that the Tribunal's order has designated as confidential.", 0.9788036346435547), ('2. A Contracting Party may require that a majority of the board of directors, ora committee thereof, of an enterprise of that Contracting Party that is a covered investment, be of a particular nationality or a r