In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util
import xml.etree.ElementTree as ET
import numpy as np
import time
import os
import random
import xlwt
import string

from utils import *

## Load model and metadata from disk

In [2]:
model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
#model = SentenceTransformer('nlpaueb/bert-base-uncased-contracts')
meta_data_df = pd.read_excel('data/category_configuration_09-08-2022_08-08-01.xlsx', sheet_name = 'article_names_matching')
title2category = dict(zip(meta_data_df["Article Title"], meta_data_df["Category 2"]))

keyword_df = pd.read_excel('data/category_configuration_09-08-2022_08-08-01.xlsx', sheet_name = 'keywords_category_2_mapping')
#keyword2category2 = dict(zip(keyword_df["Keyword"], keyword_df["Category 2"]))
#keyword2category3 = dict(zip(keyword_df["Keyword"], keyword_df["Category 3"]))
keyword_df.dropna(subset = ["Keyword"], inplace=True)
keyword2category = {}
for i in range(len(keyword_df)):
    cat2 = keyword_df.iloc[i]['Category 2']
    cat3 = keyword_df.iloc[i]['Category 3']
    keyword = keyword_df.iloc[i]['Keyword']
    if keyword not in keyword2category:
        keyword2category[keyword] = {}
    if cat2 not in keyword2category[keyword]:
        keyword2category[keyword][cat2] = set()
    keyword2category[keyword][cat2].add(cat3)


No sentence-transformers model found with name C:\Users\Xiang/.cache\torch\sentence_transformers\nlpaueb_legal-bert-base-uncased. Creating a new one with MEAN pooling.
Some weights of the model checkpoint at C:\Users\Xiang/.cache\torch\sentence_transformers\nlpaueb_legal-bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactl

In [3]:
keyword_df[keyword_df['Keyword'] == 'formal requirement']

Unnamed: 0,Keyword,Category 4,Category 3,Category 2,Category 1,Additional Document Sections,Additional 2º level categories
474,formal requirement,Investor Information,Investor Information,Investor Information and Confidentiality,Non-Protection Provisions,,


## Code for calculate similarity between two documents

In [50]:
def get_article_alignment(all_articles1, all_articles2, title2category, sanity_check = True):
    '''
    Get article alignment between two documents
    article_body: xml element
    title2category: dictionary of title to category
    sanity_check: use similarity score to check if the alignment is correct
    return a dictionary of article alignment
    '''
    alignment_match = {}

    #filter out the arrticles that don't have title
    all_articles1 = [article for article in all_articles1 if article.get("title") is not None]
    all_articles2 = [article for article in all_articles2 if article.get("title") is not None]

    #if both articles attribute includes title
    #if 'title' in all_articles1[0].attrib and 'title' in all_articles2[0].attrib:
    article_title1 = [(article.get('title').lower(), article.get('num')) for article in all_articles1 if article.get('title')]
    article_title2 = [(article.get('title').lower(), article.get('num')) for article in all_articles2 if article.get('title')]
    for index1, (title1, num1) in enumerate(article_title1):
        for index2, (title2, num2) in enumerate(article_title2):
            if title2category.get(title1) == title2category.get(title2):
                #to make sure that they have a high similarity
                if sanity_check and max(util.cos_sim(
                    model.encode(''.join(elem2sent(all_articles1[index1]))),
                    model.encode(''.join(elem2sent(all_articles2[index2])))
                )).item() < 0.7:
                    continue
                alignment_match[num1] = num2
                break
                
    #title not included in article attri, use sentence similarity instead
    # else:
    #     for article1 in all_articles1:
    #         text1_embd = model.encode(''.join(elem2sent(article1)))
    #         text2_list_embd = model.encode([''.join(elem2sent(article2)) for article2 in all_articles2])
    #         scores = util.cos_sim(text1_embd, text2_list_embd)
    #         if max(scores[0]).item() > 0.7:
    #             index = np.argmax(scores[0])
    #             alignment_match[article1.get('num')] = all_articles2[index].get('num')

    return alignment_match
    

def extract_similar_sentences_from_article(article1, article2):
    '''
    article1, article2: xml element
    return: list of similar sentences and their category: (sentence1, sentence2, similarity), category
    '''
    if article1.get('title') and article2.get('title'):
        if (title2category.get(article1.get('title').lower()) != title2category.get(article2.get('title').lower())):
            return [], 'TITLE_MISMATCH'

    article1_sents, article2_sents = elem2sent(article1, break_sentence = False), elem2sent(article2, break_sentence = False)
    #tokenize the sentences
    #article1_sents = [model.tokenize(sent) for sent in article1_sents]


    #article1_sents, article2_sents = elem2sent(article1), elem2sent(article2)

    #Embed article1 and article2
    article1_embd, article2_embd = model.encode(article1_sents), model.encode(article2_sents)

    #Get similarity between article1 and article2
    scores = util.cos_sim(article1_embd, article2_embd)

    visited = set() #to make sure that we don't add the same sentence twice

    #filter out the sentence with similarity greater than 0.98, this means they are perfect match and no need to compare
    identical = (scores > 0.999).to(torch.int64)
    for i, j in identical.nonzero().tolist():
        visited.add('row' + str(i))
        visited.add('col' + str(j))

    #filter out the sentences with similarity between 0.5 and 0.98
    mask = (scores >= 0.91) & (scores < 1)
    scores *= mask.to(torch.int64) 

    #get the index of the sentences with similarity between 0.5 and 0.98
    sim_pairs = [(scores[i][j], i, j) for i, j in mask.nonzero().tolist()]
    sim_pairs.sort(key = lambda x: x[0]) #sort by similarity score
    
    ret = []
    while sim_pairs:
        score, i, j = sim_pairs.pop()
        if 'row' + str(i) not in visited and 'col' + str(j) not in visited:
            ret.append((article1_sents[i], article2_sents[j], scores[i][j].item()))
            visited.add('row' + str(i))
            visited.add('col' + str(j))

    return ret, title2category.get(article1.get('title').lower())

def get_cat3(sentence, keyword2category, cat2):
    '''
    sentence: string
    keyword2category: dictionary of keyword to category
    cat2: string
    return: category 3
    '''
    kws = set()
    res = set()
    sentence = sentence.lower().translate(str.maketrans('', '', string.punctuation))
    for keyword in keyword2category.keys():
        if keyword in sentence and cat2 in keyword2category[keyword]:
            kws.add(keyword)
            res |= keyword2category[keyword][cat2]
            #return keyword, keyword2category[keyword].get(cat2, set())
    return kws, res

def extract_similar_from_doc(doc1_path, doc2_path, title2category, target_category = ['Fair and Equitable Treatment'], min_length = 5):
    '''
    doc_path: path to first document
    title2category: dictionary of title to category
    '''

    try:   
        doc1, doc2 = ET.parse(doc1_path), ET.parse(doc2_path)
        doc_root1, doc_root2 = doc1.getroot(), doc2.getroot()
        doc_body1, doc_body2 = doc_root1[1][2], doc_root2[1][2]
        
    except Exception as e:
        print(e); return []
    
    #get article alignment between two documents
    all_articles1, all_articles2 = doc_body1.findall(".//div[@type='article']"), doc_body2.findall(".//div[@type='article']")
    alignment_match = get_article_alignment(all_articles1, all_articles2, title2category)

    ret = []
    for page1, page2 in alignment_match.items():
        article1, article2 = doc_body1.find(".//div[@num='" + page1 + "']"), doc_body2.find(".//div[@num='" + page2 + "']") 
        similar_sents, category2 = extract_similar_sentences_from_article(article1, article2)

        if category2 == 'TITLE_MISMATCH':
            continue

        if not target_category or category2 in target_category:
            for (sentence1, sentence2, score) in similar_sents:
                
                k1, s1_cate3 = get_cat3(sentence1, keyword2category, category2)
                k2, s2_cate3 = get_cat3(sentence2, keyword2category, category2)


                if (s1_cate3 and s2_cate3 and s1_cate3 != s2_cate3 and len(sentence1.split()) >= min_length and len(sentence2.split()) >= min_length and abs(len(sentence1) - len(sentence2)) <= 4 * min_length):
                    print('s1_cate3: ', s1_cate3)
                    print('s2_cate3: ', s2_cate3)
                    print('sentence1: ', sentence1)
                    print('sentence2: ', sentence2)
                    print('category2: ', category2)
                    print('k1: ', k1)
                    print('k2: ', k2)
                    print('score: ', score)
                    print('-----------------------------')

                if s1_cate3 == s2_cate3:
                    # ret.append((sentence1, sentence2, score, list(set(s1_cate3).intersection(s2_cate3))[0]))
                    ret.append((sentence1, sentence2, score, list(set(s1_cate3).intersection(s2_cate3)), 'Stylystic Change'))
                # else:
                #     ret.append((sentence1, sentence2, score, list(set(s1_cate3).intersection(s2_cate3)), 'Content Change'))

    #filter out the pairs in which both sentences are longer than min_length words and length difference is less than 4 * min_length
    ret = [x for x in ret if len(x[0].split()) > min_length and len(x[1].split()) > min_length and abs(len(x[0].split()) - len(x[1].split())) < 4 * min_length and 0.95 < x[2] < 0.99]
    #ret = [x for x in ret if len(x[0].split()) > min_length and len(x[1].split()) > min_length]
    #sort by similarity score
    ret.sort(key = lambda x: x[2], reverse = True)

    return ret

Sanaity Check

In [30]:
start = time.time()
diff = extract_similar_from_doc(
    'data/full data/t1989-9-canada-russian-federation-bit-1989.xml', 
    'data/full data/t1990-14-canada-czech-republic-bit-1990.xml',
    title2category,
    []
)
print(time.time() - start)

0.9843745231628418


Filtering out the documents that contain the target category

In [44]:
docs = ['data/full data/' + _ for _ in os.listdir('data/full data') ]
print("total treaty is " + str(len(docs)))
target_category = ['Public Policy']
target_treaty = []

for treaty in docs:
    try:   
        doc1 = ET.parse(treaty)
        doc_root1 = doc1.getroot()
        doc_body1 = doc_root1[1][2]
        
    except Exception as e:
        continue

    for article in doc_body1:
        try:
            if title2category.get(article.get("title").lower()) in target_category:
                target_treaty.append(treaty)
                break
        except:
            continue

print(len(target_treaty))
    

total treaty is 3309
60


In [60]:
canada_docs = target_treaty

#Write all sentences to a excel file
wb = xlwt.Workbook()
ws = wb.add_sheet('sheet1')

#add header
ws.write(0, 0, 'sentence1')
ws.write(0, 1, 'sentence2')
ws.write(0, 2, 'similarity')
ws.write(0, 3, 'doc1')
ws.write(0, 4, 'doc2')
ws.write(0, 5, 'Subcategory')

row = 1
res = []

visited = set()

for i in range(100):
    doc1 = random.choice(canada_docs)
    doc1name = doc1.replace('.','/').split('/')[2]

    doc2 = random.choice(canada_docs)
    doc2name = doc2.replace('.','/').split('/')[2]

    if (doc1, doc2) in visited or (doc2, doc1) in visited:
        continue
    visited.add((doc1, doc2))
    visited.add((doc2, doc1))

    # try:
    diff = extract_similar_from_doc(doc1, doc2, title2category, target_category)
    if any(diff):
        res.extend(diff)
    # for j in range(len(_)):
    #     ws.write(row, 0, _[j][0]) #sentence1
    #     ws.write(row, 1, _[j][1]) #sentence2
    #     ws.write(row, 2, _[j][2]) #similarity
    #     ws.write(row, 3, doc1name) #doc1
    #     ws.write(row, 4, doc2name) #doc2
    #     ws.write(row, 5, _[j][3]) #Subcategory

    #     row += 1
    #     visited.add((doc1, doc2))

    # except Exception as e:
    #     print(e)
    #     print(doc1, doc2)
    #     # continue

#wb.save('generated_data/similar_sentences.xlsx')


s1_cate3:  {'Non-Derogation Obligations', 'Environment'}
s2_cate3:  {'Non-Derogation Obligations', 'Environment', 'Labor rights'}
sentence1:   the contracting parties recognise that it is inappropriate to encourage investment by relaxing domestic environmental legislation accordingly each contracting party shall strive to ensure that it does not waive or otherwise derogate from or offer to waive or otherwise derogate from such legislation as an encouragement fm the establishment maintenance or expansion in its territory of an investment
sentence2:   the contracting parties recognize that it is inappropriate to encourage investment by relaxing domestic labor or environmental measures accordingly a contracting party should not waive or otherwise derogate from or offer to waive or otherwise derogate from such measures as an encouragement for the establishment acquisition expansion or retention in its territory of an investment of an investor
category2:  Public Policy
k1:  {'environmental'