In [1]:
import pandas as pd
import numpy as np
import torch
from transformers import BertTokenizer, BertModel
from transformers import BertForSequenceClassification, AdamW
import faiss
from sentence_transformers import SentenceTransformer

from tqdm.notebook import tqdm

import os
import sys
import json
import re
import string
import random
import time
import datetime
import copy
import pickle

from argparse import Namespace

# Preprocess Data 
Raw corpus must be a comma delimited CSV file with the following columns:
1. **name** : name of the document 
2. **main** : the main text of the document (what will be encoded and searched against)
3. **court** : OPTIONAL court that is associated with document 
4. **country** : OPTIONAL country of origin of document 

In [2]:
args = Namespace(
    data_path = './corpus.csv',
    model_path='./../../Notebooks/models/multi_parallel',
    faiss_index_path = './faiss_index_file.index'
)

In [3]:
df = pd.read_csv('./cleaned_corpus.csv')

In [4]:
df['idx'] = df.index

### For breaking down each document into subchunks to account for BERT 512 token size limit

In [5]:
def breakdown_documents(documents):
    keys = ['name','court', 'country', 'court','idx']

    new_documents = [] 

    for index, document in documents.iterrows():
        if type(document['main']) == type(1.5): continue
        if len(document['main']) > 1024:
            sections = breakdown_document(document['main'])

            for section in sections:
                obj = {} 
                obj['main'] = section
                for key in keys:
                    obj[key] = document[key]
                new_documents.append(obj)
        else:
            new_documents.append(document.to_dict())

    return pd.DataFrame(new_documents)
                


def breakdown_document(document, max_length=1024, stride = 128):
    def find_split_index(s, start):
        end = min(start + max_length, len(s))

        if end == len(s): return len(s)

        split_index = s.rfind(' ', start, end)
        return split_index if split_index != -1 else end

    sections = []
    start = 0
    while start < len(document):
        split_index = find_split_index(document, start)
        sections.append(document[start:split_index].strip())
        # start = split_index + 1 if split_index < len(paragraph) else len(paragraph)

        if start + stride >= len(document) or split_index >= len(document): break

        next_start = document.rfind(' ', start, start+stride)

        start = next_start + 1 if next_start != -1 else len(document)

    return sections

In [6]:
x = breakdown_documents(df)

In [7]:
len(x)

203506

In [101]:
x.head()

Unnamed: 0.1,Unnamed: 0,name,main,court,country,idx
0,26702.0,The Central Illinois Public Service Company et...,(No. 25992.\nThe Central Illinois Public Servi...,Illinois Supreme Court,USA,0
1,,Grant Johnson v. The People of the State of Il...,Grant Johnson v. The People of the State of Il...,Illinois Supreme Court,USA,1
2,,Grant Johnson v. The People of the State of Il...,acknowledgment of guilt. A confession is a vol...,Illinois Supreme Court,USA,1
3,,Grant Johnson v. The People of the State of Il...,in their nature.\n2. Same—when giving of instr...,Illinois Supreme Court,USA,1
4,,Grant Johnson v. The People of the State of Il...,free and voluntary confession of guilt is the ...,Illinois Supreme Court,USA,1


In [7]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [5]:
model = SentenceTransformer(args.model_path)

No sentence-transformers model found with name ./../../Notebooks/models/combined/mlm_combined. Creating a new one with MEAN pooling.


In [19]:
def generate_embeddings(text):
    embeddings = model.encode(text, convert_to_tensor=True)
    return embeddings

## Generate the embedding lists 

In [20]:
progress = tqdm(total=len(df), desc='Train Batches', leave=True)

embeddings_list = []
attention_list = []

for i,row in df.iterrows():
    embeddings = generate_embeddings(row['main'])
    embeddings_list.append(embeddings)

    progress.update(1)

embeddings_matrix = np.concatenate([embedding.cpu().numpy() for embedding in embeddings_list])

Train Batches:   0%|          | 0/10000 [00:00<?, ?it/s]

### Build FAISS Index 

In [25]:
embeddings_matrix = np.stack([embedding.cpu().numpy() for embedding in embeddings_list])
D = embeddings_matrix.shape[1]  # dimension of embeddings
index = faiss.IndexFlatL2(D)
index.add(embeddings_matrix)

In [None]:
# D = embeddings_matrix.shape# dimension of embeddings 
# index = faiss.IndexFlatL2(D)
# index.add(embeddings_matrix)

### Write index and attention data to static files 

In [26]:
faiss.write_index(index, "faiss_index_file.index")

## Search


In [6]:
loaded_index = faiss.read_index("faiss_index_file.index")
loaded_titles = df['name']

In [7]:
def search(query_embedding, k=100):
    # Perform search using cosine similarity
    D, I = loaded_index.search(query_embedding.numpy(), k)
    return D, I

def get_titles(indices):
    # Retrieve corresponding titles from metadata
    titles = [loaded_titles[i] for i in indices]
    return titles

In [9]:
query = 'license'

query_embedding = model.encode(query, convert_to_tensor=True)
query_embedding = query_embedding.unsqueeze(0)

distances, indices = search(query_embedding.cpu(), 10)
print(indices)

titles = get_titles(indices[0])

print("Distances:", distances)
print("--------------")
print("Nearest Neighbor Titles:")
for i, title in enumerate(titles[0:], start=1):  
    print(f"{i}. {title}")

[[7609  875 6797 8645 1306 8652 3736 9827 5714 3687]]
Distances: [[128.51797 136.04918 138.59277 141.84715 143.4627  144.2062  144.7921
  144.85117 145.04272 146.1094 ]]
--------------
Nearest Neighbor Titles:
1. The People of the State of Illinois ex rel. Henry Booth, v. Charles E. Lippincott, Auditor
2. The Santa Clara Female Academy v. Francis J. Sullivan et al.
3. The People of the State of Illinois, ex relatione The Merchants’ Savings, Loan and Trust Company of Chicago, v. The Auditor of Public Accounts
4. Charles Howard v. John W. Lakin
5. The Western Union Telegraph Company v. The Chicago and Paducah Railroad Company et al.
6. John Dolese et al. v. Daniel A. Pierce
7. The City of Olney v. J. N. Concur et al.
8. The W. Scheidel Coil Company, Appellant, vs. James A. Rose, Secretary of State, Appellee
9. Samuel Voris et al. v. William Renshaw, Jr.
10. The Illinois Starch Company v. The Ottawa Hydraulic Company


In [10]:
df.iloc[7609]['main']

"The People of the State of Illinois ex rel. Henry Booth, v. Charles E. Lippincott, Auditor.\nCircuit judges of Cook county—salaries of, urider the constitution of 1870. The circuit judges of Cook county elected under the constitution of 1870 were entitled to receive from the State, until the adjournment of the first session, of the general assembly after the adoption of such constitution, a salary of $1000 per annum, only, as provided h}' the constitution of 1848.\nMr. W. C. Goudy, for the relator."

: 