In [8]:
import pandas as pd
import numpy as np
import torch
from transformers import BertTokenizer, BertModel
import faiss

from tqdm.notebook import tqdm

import os
import sys
import json
import re
import string
import random
import time
import datetime
import copy
import pickle

from argparse import Namespace

In [9]:
args = Namespace(
    data_path = './corpus.csv',
    model_path='./../../Notebooks/models/parallel_combined'
)

In [10]:
df = pd.read_csv(args.data_path)

In [11]:
df = df.sample(frac=1, random_state=42)
df = df[0:10000]

In [12]:
df.head()

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,name,main,court,country,text,title
16735,16949,16949,"Jonathan Gibbons, administrator of Hiram Kimba...","Jonathan Gibbons, administrator of Hiram Kimba...",Illinois Supreme Court,USA,,
12346,12498,12498,Caroline C. Holden et al. v. The City of Chicago,Caroline C. Holden et al. v. The City of Chica...,Illinois Supreme Court,USA,,
12830,12992,12992,Ludwig Baker and Caroline Baker v. Augusta Young,Ludwig Baker and Caroline Baker v. Augusta You...,Illinois Supreme Court,USA,,
14453,14632,14632,The National Insurance Company v. Sidney T. We...,The National Insurance Company v. Sidney T. We...,Illinois Supreme Court,USA,,
584,592,592,James F. Stevenson v. Emma V. Stevenson,James F. Stevenson v. Emma V. Stevenson.\nOpin...,Illinois Supreme Court,USA,,


In [13]:
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('casehold/legalbert')
model = BertModel.from_pretrained(args.model_path, output_attentions=True)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [14]:
def generate_embeddings(text):
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    # embeddings = outputs.last_hidden_state[:, 0, :].numpy()
    embeddings = outputs.last_hidden_state.mean(dim=1)
    attention_weights = outputs.attentions 

    tokens = tokenizer.tokenize(text)

    # Determine top 10 document tokens based on attention weight 
    start_index_m = 1  # Assuming [SEP] token between query and document
    end_index_m = start_index_m + len(tokens)

    document_attention = attention_weights[-1][0, :, start_index_m:end_index_m, start_index_m:end_index_m].mean(dim=0)
    k = min(10, len(document_attention))
    
    top_attentions, top_indices = torch.topk(document_attention, k)  # Get top 10 attentions and their indices

    # Ensure top_indices and top_attentions are properly flattened
    top_indices_flat = top_indices.cpu().numpy().flatten()
    top_attentions_flat = top_attentions.cpu().numpy().flatten()


    top_tokens_with_weights = {tokens[idx]: float(attention) for idx, attention in zip(top_indices_flat, top_attentions_flat)}
    stringified_weights = json.dumps(top_tokens_with_weights)

    return embeddings, stringified_weights


## Generate the embedding lists 

In [15]:
# df = df.dropna(subset=['main'])
df['main'].isnull().any()
len(df)

10000

In [16]:
progress = tqdm(total=len(df), desc='Train Batches', leave=True)

embeddings_list = []
attention_list = []

for i,row in df.iterrows():
    embeddings, attention_data = generate_embeddings(row['main'])
    embeddings_list.append(embeddings)
    attention_list.append(attention_data)

    progress.update(1)

embeddings_matrix = np.concatenate(embeddings_list)
attention_array = np.array(attention_list)

Train Batches:   0%|          | 0/10000 [00:00<?, ?it/s]

### Build FAISS Index 

In [18]:
D = embeddings_matrix.shape[1] # dimension of embeddings 
index = faiss.IndexFlatL2(D)
index.add(embeddings_matrix)

### Write index and attention data to static files 

In [19]:
faiss.write_index(index, "faiss_index_file.index")
np.save("attention.npy", attention_array)

## Search


In [20]:
loaded_index = faiss.read_index("faiss_index_file.index")
loaded_metadata = np.load("attention.npy", allow_pickle=True)

In [21]:
def search(query_embedding, k=5):
    # Perform search using cosine similarity
    D, I = loaded_index.search(query_embedding, k)
    return D, I

def get_titles(indices):
    # Retrieve corresponding titles from metadata
    titles = [loaded_metadata[i] for i in indices]
    return titles

In [22]:
query = 'grand theft auto'
query_tokens = tokenizer(query, return_tensors='pt', padding=True, truncation=True, max_length=512)
query_tokens = {key: value.to('cpu') for key, value in query_tokens.items()}

with torch.no_grad():
    query_output = model(**query_tokens, output_attentions=True, output_hidden_states=True)
    query_embedding = query_output.last_hidden_state.mean(dim=1)

# Perform search
distances, indices = search(query_embedding, 10)

# Retrieve corresponding titles
titles = get_titles(indices[0])

print("Query Document:")
print("--------------")
print("Title: ", titles[0])  # Assuming the first returned document is the query document
print("Distances:", distances)
print("Nearest Neighbor Titles:")
for i, title in enumerate(titles[1:], start=1):  # Skip the first title as it's the query document
    print(f"{i}. {title}")

Query Document:
--------------
Title:  {".": 0.26979270577430725, ",": 0.17679546773433685, "tax": 0.0007339739240705967, "the": 0.0005150084034539759, "of": 0.0004657904792111367, "transferring": 0.008976730518043041, "proceeds": 0.00048569307546131313, "to": 0.008519104681909084, "auditor": 0.004310827702283859, "merchants": 0.004032318014651537, "state": 0.013412493281066418, "revenue": 0.014380929060280323, "illinois": 0.0005523016443476081, "chicago": 0.013388966210186481, "savings": 0.005269909743219614, "accounts": 0.00996419508010149, "loan": 0.002873433521017432, "and": 0.00892995111644268, "trust": 0.005016309674829245, "company": 0.005628583487123251, "act": 0.021552368998527527, "two": 0.0004801731847692281, "1861": 0.014034390449523926, "mill": 0.02923690527677536, "unconstitutional": 0.0055464175529778, "void": 0.008965053595602512}
Distances: [[ 97.23409   97.91188   98.28487  102.328606 105.81323  106.93414
  107.61917  107.87687  107.99062  108.01826 ]]
Nearest Neighbo