In [4]:
import pandas as pd
import numpy as np
import torch
from transformers import BertTokenizer, BertModel
import faiss

from tqdm.notebook import tqdm

import os
import sys
import json
import re
import string
import random
import time
import datetime
import copy
import pickle

from argparse import Namespace

In [5]:
args = Namespace(
    data_path = './corpus.csv',
    model_path='./../../Notebooks/models/parallel_combined'
)

In [6]:
df = pd.read_csv(args.data_path)

In [11]:
df = df.sample(frac=1, random_state=42)
df = df[0:10000]

In [12]:
df.head()

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,name,main,court,country,text,title
16735,16949,16949,"Jonathan Gibbons, administrator of Hiram Kimba...","Jonathan Gibbons, administrator of Hiram Kimba...",Illinois Supreme Court,USA,,
12346,12498,12498,Caroline C. Holden et al. v. The City of Chicago,Caroline C. Holden et al. v. The City of Chica...,Illinois Supreme Court,USA,,
12830,12992,12992,Ludwig Baker and Caroline Baker v. Augusta Young,Ludwig Baker and Caroline Baker v. Augusta You...,Illinois Supreme Court,USA,,
14453,14632,14632,The National Insurance Company v. Sidney T. We...,The National Insurance Company v. Sidney T. We...,Illinois Supreme Court,USA,,
584,592,592,James F. Stevenson v. Emma V. Stevenson,James F. Stevenson v. Emma V. Stevenson.\nOpin...,Illinois Supreme Court,USA,,


In [10]:
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('casehold/legalbert')
model = BertModel.from_pretrained(args.model_path, output_attentions=True)
model.eval()

Downloading tokenizer_config.json: 100%|██████████| 300/300 [00:00<00:00, 82.8kB/s]
Downloading vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.51MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 48.6kB/s]
Downloading config.json: 100%|██████████| 740/740 [00:00<00:00, 465kB/s]


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [14]:
def generate_embeddings(text):
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    # embeddings = outputs.last_hidden_state[:, 0, :].numpy()
    embeddings = outputs.last_hidden_state.mean(dim=1)

    attention_weights = outputs.attentions 
    tokens = tokenizer.tokenize(text)

    # Determine top 10 document tokens based on attention weight 
    start_index_m = 1  # Assuming [SEP] token between query and document
    end_index_m = start_index_m + len(tokens)

    document_attention = attention_weights[-1][0, :, start_index_m:end_index_m, start_index_m:end_index_m].mean(dim=0)
    k = min(10, len(document_attention))
    
    top_attentions, top_indices = torch.topk(document_attention, k)  # Get top 10 attentions and their indices

    # Ensure top_indices and top_attentions are properly flattened
    top_indices_flat = top_indices.cpu().numpy().flatten()
    top_attentions_flat = top_attentions.cpu().numpy().flatten()


    top_tokens_with_weights = {tokens[idx]: float(attention) for idx, attention in zip(top_indices_flat, top_attentions_flat)}
    stringified_weights = json.dumps(top_tokens_with_weights)

    return embeddings, stringified_weights


## Generate the embedding lists 

In [15]:
# df = df.dropna(subset=['main'])
df['main'].isnull().any()
len(df)

10000

In [16]:
progress = tqdm(total=len(df), desc='Train Batches', leave=True)

embeddings_list = []
attention_list = []

for i,row in df.iterrows():
    embeddings, attention_data = generate_embeddings(row['main'])
    embeddings_list.append(embeddings)
    attention_list.append(attention_data)

    progress.update(1)

embeddings_matrix = np.concatenate(embeddings_list)
attention_array = np.array(attention_list)

Train Batches:   0%|          | 0/10000 [00:00<?, ?it/s]

### Build FAISS Index 

In [18]:
D = embeddings_matrix.shape[1] # dimension of embeddings 
index = faiss.IndexFlatL2(D)
index.add(embeddings_matrix)

### Write index and attention data to static files 

In [19]:
faiss.write_index(index, "faiss_index_file.index")
np.save("attention.npy", attention_array)

## Search


In [15]:
loaded_index = faiss.read_index("faiss_index_file.index")
loaded_attention = np.load("attention.npy", allow_pickle=True)
loaded_titles = df['name']

In [34]:
def search(query_embedding, k=5):
    # Perform search using cosine similarity
    D, I = loaded_index.search(query_embedding.numpy(), k)
    return D, I

def get_titles(indices):
    # Retrieve corresponding titles from metadata
    titles = [loaded_titles[i] for i in indices]
    return titles

In [31]:
def search1(query):
    query_tokens = tokenizer(query, return_tensors='pt', padding=True, truncation=True, max_length=512)
    query_tokens = {key: value.to('cpu') for key, value in query_tokens.items()}

    with torch.no_grad():
        query_output = model(**query_tokens, output_attentions=True, output_hidden_states=True)
        query_embedding = query_output.last_hidden_state.mean(dim=1)


        q_tokens = tokenizer.tokenize(query)

        query_attention = query_output.attentions
        last_layer_attention = query_attention[-1]
        query_attention_weights = last_layer_attention[0,1:1+len(q_tokens),1:1+len(q_tokens)].mean(dim=0).cpu().numpy()
        query_attention_matrix_serialized = json.dumps(query_attention_weights.tolist())

    # Perform search
    D, I = loaded_index.search(query_embedding, 10)

    # OR 

    D, I = loaded_index.search(query_embedding.reshape(1,-1), 10)

    
    

In [35]:
query = 'murder'
query_tokens = tokenizer(query, return_tensors='pt', padding=True, truncation=True, max_length=512)
query_tokens = {key: value.to('cpu') for key, value in query_tokens.items()}

with torch.no_grad():
    query_output = model(**query_tokens, output_attentions=True, output_hidden_states=True)
    query_embedding = query_output.last_hidden_state.mean(dim=1)


    q_tokens = tokenizer.tokenize(query)

    query_attention = query_output.attentions
    last_layer_attention = query_attention[-1]
    query_attention_weights = last_layer_attention[0,1:1+len(q_tokens),1:1+len(q_tokens)].mean(dim=0).cpu().numpy()
    query_attention_matrix_serialized = json.dumps(query_attention_weights.tolist())
    

# Perform search
distances, indices = search(query_embedding, 10)

print(indices)

# Retrieve corresponding titles
titles = get_titles(indices[0])

print("Distances:", distances)
print("--------------")
print("Nearest Neighbor Titles:")
for i, title in enumerate(titles[0:], start=1):  # Skip the first title as it's the query document
    print(f"{i}. {title}")

Query Document:
--------------
Title:  The City of Elmhurst, Appellee, vs. William J. Buettgen, Appellant
Distances: [[67.054665 68.089966 68.14056  68.42929  68.8322   69.070335 69.51305
  69.63351  69.759514 70.0788  ]]
Nearest Neighbor Titles:
1. The People of the State of Illinois, Defendant in Error, vs. Wayne Jeffers, Plaintiff in Error
2. Charles J. Meadowcroft et al. v. The People of the State of Illinois
3. Charles T. Schueler et al. v. Joseph Mueller
4. Chicago, Burlington and Quincy Railroad Co. v. Clara M. Harwood
5. Margarett Williams, Appellee, vs. Chalon Garvin et al., Appellants
6. The City of Dixon v. Eli B. Baker
7. Louis Glanz v. Charles S. Gloeckler
8. Beatrice Fitch et al. v. Joseph H. Gray et al.
9. Hiram H. Rosencrantz v. William W. Mason
