In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util
import xml.etree.ElementTree as ET
import numpy as np
import time
import os
import random
import xlwt
import string
from tqdm import tqdm

from utils import *

import warnings
warnings.filterwarnings('ignore')
from transformers import logging
logging.set_verbosity_error()

## Load model and metadata from disk

In [2]:
model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')
#model = SentenceTransformer('nlpaueb/bert-base-uncased-contracts')
meta_data_df = pd.read_excel('data/category_configuration_09-08-2022_08-08-01.xlsx', sheet_name = 'article_names_matching')
title2category = dict(zip(meta_data_df["Article Title"], meta_data_df["Category 2"]))

keyword_df = pd.read_excel('data/category_configuration_09-08-2022_08-08-01.xlsx', sheet_name = 'keywords_category_2_mapping')
#keyword2category2 = dict(zip(keyword_df["Keyword"], keyword_df["Category 2"]))
#keyword2category3 = dict(zip(keyword_df["Keyword"], keyword_df["Category 3"]))
keyword_df.dropna(subset = ["Keyword"], inplace=True)
keyword2category = {}
for i in range(len(keyword_df)):
    cat2 = keyword_df.iloc[i]['Category 2']
    cat3 = keyword_df.iloc[i]['Category 3']
    keyword = keyword_df.iloc[i]['Keyword']
    if keyword not in keyword2category:
        keyword2category[keyword] = {}
    if cat2 not in keyword2category[keyword]:
        keyword2category[keyword][cat2] = set()
    keyword2category[keyword][cat2].add(cat3)


No sentence-transformers model found with name C:\Users\Xiang/.cache\torch\sentence_transformers\nlpaueb_legal-bert-base-uncased. Creating a new one with MEAN pooling.


## Code for calculate similarity between two documents

In [3]:
def get_article_alignment(all_articles1, all_articles2, title2category, selected_cat2 = [], sanity_check = True):
    '''
    Get article alignment between two documents
    article_body: xml element
    title2category: dictionary of title to category
    sanity_check: use similarity score to check if the alignment is correct
    return a dictionary of article alignment
    '''
    alignment_match = {}

    #filter out the arrticles that don't have title
    all_articles1 = [article for article in all_articles1 if article.get("title") is not None]
    all_articles2 = [article for article in all_articles2 if article.get("title") is not None]

    #if both articles attribute includes title
    #if 'title' in all_articles1[0].attrib and 'title' in all_articles2[0].attrib:
    article_title1 = [(article.get('title').lower(), article.get('num')) for article in all_articles1 if article.get('title')]
    article_title2 = [(article.get('title').lower(), article.get('num')) for article in all_articles2 if article.get('title')]


    if selected_cat2:
        #remove the articles that don't have a target category
        article_title1 = [(title, page, title2category.get(title)) for title, page in article_title1 if title2category.get(title) in selected_cat2]
        article_title2 = [(title, page, title2category.get(title)) for title, page in article_title2 if title2category.get(title) in selected_cat2]
    else:
        article_title1 = [(title, page, title2category.get(title)) for title, page in article_title1]
        article_title2 = [(title, page, title2category.get(title)) for title, page in article_title2]

    for index1, (title1, num1, cat2_1) in enumerate(article_title1):
        for index2, (title2, num2, cat2_2) in enumerate(article_title2):
            if cat2_1 == cat2_2:
                #to make sure that they have a high similarity
                if sanity_check and max(util.cos_sim(
                    model.encode(''.join(elem2sent(all_articles1[index1]))),
                    model.encode(''.join(elem2sent(all_articles2[index2])))
                )).item() < 0.7:
                    continue
                alignment_match[num1] = num2
                break
            
    return alignment_match
    

def extract_similar_sentences_from_article(article1, article2):
    '''
    article1, article2: xml element
    return: list of similar sentences and their category: (sentence1, sentence2, similarity), category
    '''
    if article1.get('title') and article2.get('title'):
        if (title2category.get(article1.get('title').lower()) != title2category.get(article2.get('title').lower())):
            return [], 'TITLE_MISMATCH'

    article1_sents, article2_sents = elem2sent(article1, break_sentence = True), elem2sent(article2, break_sentence = False)
    #article1_sents, article2_sents = elem2sent(article1), elem2sent(article2)

    #Embed article1 and article2
    article1_embd, article2_embd = model.encode(article1_sents), model.encode(article2_sents)

    #Get similarity between article1 and article2
    scores = util.cos_sim(article1_embd, article2_embd)

    visited = set() #to make sure that we don't add the same sentence twice

    #filter out the sentence with similarity greater than 0.98, this means they are perfect match and no need to compare
    identical = (scores > 0.999).to(torch.int64)
    for i, j in identical.nonzero().tolist():
        visited.add('row' + str(i))
        visited.add('col' + str(j))

    #filter out the sentences with similarity between 0.5 and 0.98
    mask = (scores >= 0.8) & (scores < 1)
    scores *= mask.to(torch.int64) 

    #get the index of the sentences with similarity between 0.5 and 0.98
    sim_pairs = [(scores[i][j], i, j) for i, j in mask.nonzero().tolist()]
    sim_pairs.sort(key = lambda x: x[0]) #sort by similarity score
    
    ret = []
    while sim_pairs:
        score, i, j = sim_pairs.pop()
        if 'row' + str(i) not in visited and 'col' + str(j) not in visited:
            ret.append((article1_sents[i], article2_sents[j], scores[i][j].item()))
            visited.add('row' + str(i))
            visited.add('col' + str(j))

    return ret, title2category.get(article1.get('title').lower())

def get_cat3(sentence, keyword2category, cat2):
    '''
    sentence: string
    keyword2category: dictionary of keyword to category
    cat2: string
    return: category 3
    '''
    kws = set()
    res = set()
    sentence = sentence.lower().translate(str.maketrans('', '', string.punctuation))
    for keyword in keyword2category.keys():
        if keyword in sentence and cat2 in keyword2category[keyword]:
            kws.add(keyword)
            res |= keyword2category[keyword][cat2]
            #return keyword, keyword2category[keyword].get(cat2, set())
    return kws, res

def extract_similar_from_doc(doc1_path, doc2_path, title2category, target_category = [], min_length = 5):
    '''
    doc_path: path to first document
    title2category: dictionary of title to category
    '''

    try:   
        doc1, doc2 = ET.parse(doc1_path), ET.parse(doc2_path)
        doc_root1, doc_root2 = doc1.getroot(), doc2.getroot()
        doc_body1, doc_body2 = doc_root1[1][2], doc_root2[1][2]
        
    except Exception as e:
        print(e); return []
    
    #get article alignment between two documents
    all_articles1, all_articles2 = doc_body1.findall(".//div[@type='article']"), doc_body2.findall(".//div[@type='article']")
    alignment_match = get_article_alignment(all_articles1, all_articles2, title2category, selected_cat2 = target_category)

    #print(alignment_match)

    ret = []
    for page1, page2 in alignment_match.items():
        if not page1 or not page2: continue
        article1, article2 = doc_body1.find(".//div[@num='" + page1 + "']"), doc_body2.find(".//div[@num='" + page2 + "']") 
        similar_sents, category2 = extract_similar_sentences_from_article(article1, article2)
        
        #print(similar_sents)

        if category2 == 'TITLE_MISMATCH' or category2 not in target_category:
            continue
        #print(category2)

        for (sentence1, sentence2, score) in similar_sents:

            k1, s1_cate3 = get_cat3(sentence1, keyword2category, category2)
            k2, s2_cate3 = get_cat3(sentence2, keyword2category, category2)


            #filter out the pairs which has a category and which both sentences are longer than min_length words and length difference is less than 4 * min_length
            if (s1_cate3 and s2_cate3 and len(sentence1.split()) >= min_length and len(sentence2.split()) >= min_length and abs(len(sentence1) - len(sentence2)) <= 4 * min_length):
                # if s1_cate3 == s2_cate3 or s1_cate3.issubset(s2_cate3) or s2_cate3.issubset(s1_cate3):
                if s1_cate3 == s2_cate3:
                    if 0.965 < score < 0.98:
                        ret.append(('STYLYSTIC', sentence1, sentence2, score, list(s1_cate3)))
                elif s1_cate3 != s2_cate3 and iou(s1_cate3, s2_cate3) < 1/3 and score > 0.9:
                    ret.append(('RELEVANT', sentence1, sentence2, score, [list(s1_cate3), list(s2_cate3)]))
            elif 0.84 < score < 0.85 and len(sentence1.split()) >= min_length and len(sentence2.split()) >= min_length and abs(len(sentence1) - len(sentence2)) <= 4 * min_length:
                ret.append(('IRRELEVANT', sentence1, sentence2, score, []))

    #sort by similarity score
    ret.sort(key = lambda x: x[3], reverse = True)

    return ret

Sanaity Check

In [4]:
start = time.time()
diff = extract_similar_from_doc(
    'data/full data/t1989-9-canada-russian-federation-bit-1989.xml', 
    'data/full data/t1990-14-canada-czech-republic-bit-1990.xml',
    title2category,
    ['Definition', 'Promotion and Admission']
)
print(time.time() - start)
diff

3.33526349067688


[]

Filtering out the documents that contain the target category

In [5]:
docs = ['data/full data/' + _ for _ in os.listdir('data/full data') ]
print("total treaty is " + str(len(docs)))
target_category = []
for cat2 in keyword_df['Category 2'].unique():
    if 3 <= keyword_df['Category 2'].value_counts()[cat2] <= 50:
        target_category.append(cat2)
target_treaty = []


cat2_doc_map = {}
for cat2 in target_category:
    cat2_doc_map[cat2] = set()

for doc in docs:
    try:   
        doc_content = ET.parse(doc)
        doc_root = doc_content.getroot()
        doc_body = doc_root[1][2]
            
        for article in doc_body:
            try:
                cat2 = title2category.get(article.get("title").lower())
                if cat2 in target_category:
                    cat2_doc_map[cat2].add(doc)
            except:
                pass

    except Exception as e:
        pass
    

total treaty is 3309


In [6]:
#Write all sentences to a excel file
wb = xlwt.Workbook()
ws = wb.add_sheet('sheet1')

#add header
ws.write(0, 0, 'label')
ws.write(0, 1, 'sentence1')
ws.write(0, 2, 'sentence2')
ws.write(0, 3, 'similarity')
ws.write(0, 4, 'category3')
ws.write(0, 5, 'doc1')
ws.write(0, 6, 'doc2')


row = 1
res = []

counter = {}
for cat2 in cat2_doc_map:
    counter[cat2] = 0


for cat2 in tqdm(cat2_doc_map):
    docs = list(cat2_doc_map[cat2])

    if not docs:
        continue

    visited = set()

    for i in tqdm(range(20000), disable = 1):
        if len(docs) < 20 and len(visited) >= len(docs) * (len(docs) - 1) / 2:
            break
        doc1 = random.choice(docs)
        doc1name = doc1.replace('.','/').split('/')[2]

        doc2 = random.choice(docs)
        doc2name = doc2.replace('.','/').split('/')[2]

        if doc1 == doc2 or (doc1, doc2) in visited or (doc2, doc1) in visited:
            continue
        visited.add((doc1, doc2))
        visited.add((doc2, doc1))

        # try:
        diff = extract_similar_from_doc(doc1, doc2, title2category, target_category = [cat2])
        if any(diff):
            res.extend(diff)
            counter[cat2] += 1

            for i, (label, sentence1, sentence2, score, category3) in enumerate(diff):
                ws.write(row, 0, label)
                ws.write(row, 1, sentence1)
                ws.write(row, 2, sentence2)
                ws.write(row, 3, score)
                ws.write(row, 4, str(category3))
                ws.write(row, 5, doc1name)
                ws.write(row, 6, doc2name)
                row += 1

wb.save('generated_data/similar_sentences_large.xlsx')


100%|██████████| 47/47 [4:53:14<00:00, 374.35s/it]  


In [1]:
len(res)

NameError: name 'res' is not defined

In [8]:
sum([1 for _ in res if _[0] == 'STYLYSTIC'])

8430

In [9]:
sum([1 for _ in res if _[0] == 'RELEVANT'])

2069