## Methodology

The purpose of the code in this notebook is to take a Pandas Dataframe containing documents and be able to match a before-unseen block of text to our domain-specific documents found in the dataframe.

To perform our matching task, we will use a BERT model which has been previously fine-tuned on Sentence-BERT tasks (outlined in https://arxiv.org/abs/1908.10084 "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks"). This model is designed to take in two document embeddings (less than 512 BERT tokens) and compare them using dot-product or cosine similarity.

One weakness of this approach is that off-the-shelf models do not transfer well to unique domains, and so in order to further fine-tune a BERT model for our purposes, we will use the transformer model FLAN T-5 to generate summaries for each of our documents, and then fine-tune the BERT model on a summary-source match task. Inspiration for this approach comes from BEIR's paper outlining its choice to perform synthetic text generation in order to append documents with semantically similar text to improve BM-25 search results (https://arxiv.org/abs/2104.08663 "BEIR: A Heterogenous Benchmark for Zero-shot Evaluation of Information Retrieval Models")

Once we have fine-tuned our BERT model on our specific domain, we can now embed the documents using the model and compare new blocks of text with our document store. 

We will be using FAISS to index our documents and perform our matching. For more information on FAISS indices, please visit their github (https://github.com/facebookresearch/faiss)


In [1]:
from sentence_transformers import SentenceTransformer, util, losses, models, datasets, InputExample, CrossEncoder
from torch import nn
import os
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5ForConditionalGeneration
import pandas as pd
import torch
import joblib
import numpy as np
from accelerate import Accelerator
from matplotlib import pyplot as plt
import faiss
from tqdm import tqdm
from cross_enc_search import cross_search, cross_fetch_article_info

data = pd.read_csv('swcs_text_data.csv')
# data = joblib.load('fulldataset.joblib')

In [2]:
#This creates a new dataframe column called 'body'
data['body'] = data.text.str.strip()

In [3]:
#This makes sure all of the body text is in the same format before sending it to be encoded
data['body'] = [''.join(map(str, l)) for l in data['body']]

In [4]:
'''Helper function which fetches article info for query-article match. Right now it is formatted to pull the body 
column from the data dataframe, but can be adjusted'''
def fetch_article_info(dataframe_idx):
    info = data.iloc[dataframe_idx]
    meta_dict = dict()
#     meta_dict['Title'] = info['Title']
    meta_dict['Body'] = info['body']
    return meta_dict
    
'''Helper function which encodes the query using the BERT model and then performs a search to try to match the query vector
to the top k most similar articles'''
def search(query, top_k, index, model):
    t=time.time()
    query_vector = model.encode([query])
    distances, top_k = index.search(query_vector, top_k)
    print('>>>> Time to return results: {}'.format(time.time()-t))
    top_k_ids = top_k.tolist()[0]
    top_k_ids = list(np.unique(top_k_ids))
    results = zip([fetch_article_info(idx) for idx in top_k_ids], distances[0])
    return results

In [5]:
from pprint import pprint
import time
index = faiss.read_index('body_paragraphs.index')
model = SentenceTransformer('search/search-model-t5-large-queries')

query="Who is the WTU"
results=search(query, top_k=5, index=index, model=model)

cross_results_all = cross_search(query, model=model, data=data)

print("\n")
for result in cross_results_all:
    print('\t','\n',result)

  return torch._C._cuda_getDeviceCount() > 0


>>>> Time to return results: 0.02078390121459961
>>>> Time to return results: 0.016375303268432617
[[3042, 6.948469], [72, 6.943064], [3202, 6.7963085], [3793, 6.7963085], [3766, 6.795002], [3175, 6.7950015], [666, 6.709031], [636, 6.677759], [380, 4.4303737], [3039, 3.8847961], [3555, 3.8847961], [554, 3.0714765], [2724, 3.0152762], [3414, 3.0152762], [2827, 2.0439653]]


	 
 [[6.948469, 3042    Analyst Note:  The WTU is the USSP branch curr...
6       Open Source Intelligence (OSINT)/Human Intelli...
Name: body, dtype: object], [6.943064, 72    Analyst Note: The WTU is the USSP branch curre...
6     Open Source Intelligence (OSINT)/Human Intelli...
Name: body, dtype: object], [6.7963085, 3202    The WTU is promoting a dedicated, nonviolent r...
6       Open Source Intelligence (OSINT)/Human Intelli...
Name: body, dtype: object], [6.7963085, 3793    The WTU is promoting a dedicated, nonviolent r...
6       Open Source Intelligence (OSINT)/Human Intelli...
Name: body, dtype: object]]


In [None]:
# model_name = 'google/flan-t5-large' #Time to complete 3795 paragraphs w/ num_queries = 5, batch_size = 64, max_query_length = 64: 1:45:06
model_name = 'google/flan-t5-base' #Time to complete 3795 paragraphs w/ num_queries = 5, batch_size = 64, max_query_length = 64: 58:04

device = 'cuda'

tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)

def _removeNonAscii(s): 
    return "".join(i for i in s if ord(i) < 128)

# Parameters for generation
batch_size = 64 #Batch size
num_queries = 5 #Number of queries to generate for every paragraph
max_length_query = 64   #Max length for output query

corpus = data.body
k = 0

'''Now we create a new tsv which will store our summary-source pairs'''

with open('generated_queries_t5-base(2).tsv', 'w') as fOut:
    for para in tqdm(corpus):
        para = str(corpus.values[k])
        pre_para= 'summarize:'+para
        index = corpus.index[k]
        input_ids = tokenizer.encode(pre_para, truncation=True, return_tensors='pt').to(device)
        outputs = model.generate(
            input_ids=input_ids,
            max_length=max_length_query,
            do_sample=True,
            top_p=0.95,
            num_return_sequences=num_queries)

        for i in range(len(outputs)):
            query = tokenizer.decode(outputs[i], skip_special_tokens=True)
            query = _removeNonAscii(query)
            para = _removeNonAscii(para)
            fOut.write("{}\t{}\n".format(query,para))
        k += 1

In [None]:
'''create the training dataset using the tsv we just created.'''

from sentence_transformers import SentenceTransformer, InputExample, losses, models, datasets
from torch import nn
import os

log = []
train_examples = [] 
with open('generated_queries_t5-base(2).tsv') as fIn:
    for line in fIn:
        try:
            query, paragraph = line.strip().split('\t', maxsplit=1)
            train_examples.append(InputExample(texts=[query, paragraph]))
        except:
            log.append("error")
            print(line)
            pass
    print("The following number of examples could not be appended into your training examples: {} out of {}".format(len(log), len(train_examples)))

In [None]:
'''Now we fine tune our model using the dataset we created from the tsv'''

model = SentenceTransformer("sentence-transformers/all-distilroberta-v1")
train_dataloader = datasets.NoDuplicatesDataLoader(train_examples, batch_size=8)
train_loss = losses.MultipleNegativesRankingLoss(model)
accelerator = Accelerator()

num_epochs = 3
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)
model.fit(train_objectives=[(train_dataloader, train_loss)], 
          epochs=num_epochs, 
          warmup_steps=warmup_steps, 
          show_progress_bar=True,
         accelerator=accelerator)

'''save our fine-tuned model to disk'''

os.makedirs('search', exist_ok=True)
model.save('search/search-model-t5-base-queries')
#Time to fine tune "sentence-transformers/all-distilroberta-v1": 10 mins 1 sec with 3795 paragraphs w/ 5 queries each (18903)
#Time to fine tune "sentence-transformers/all-distilroberta-v1": 10 mins 1 sec with 3795 paragraphs w/ 5 queries each (18903)

In [None]:
'''load model checkpoint and encode documents again using fine-tuned model'''
model = SentenceTransformer('search/search-model-t5-large-queries')
#7 seconds to encode 3795 paragraphs using fine-tuned distilroberta model
encoded_data = model.encode(data.body.tolist(), show_progress_bar=True)
encoded_data = np.asarray(encoded_data.astype('float32'))

index = faiss.IndexIDMap(faiss.IndexFlatIP(768))
ids = np.array(range(0, len(data)), dtype='int64')
index.add_with_ids(encoded_data, ids)

# faiss.normalize_L2(encoded_data)
# index_cosine = faiss.IndexFlat(768, faiss.METRIC_INNER_PRODUCT)
# index_cosine.add(encoded_data)


faiss.write_index(index, 'body_paragraphs.index')
# faiss.write_index(index_cosine, 'body_paragraphs_cosine.index')

In [None]:
import torch

In [None]:
device = 'cuda'

In [None]:
torch.cuda.empty_cache()

In [None]:
from pprint import pprint
import time
index = faiss.read_index('body_paragraphs.index')
model = SentenceTransformer('search/search-model-t5-large-queries')

query="Who is the WTU"
results=search(query, top_k=5, index=index, model=model)

print("\n")
for result in results:
    print('\t','\n',result)

In [None]:
from pprint import pprint
import time

query = "President Canteth recently made the controversial move to make Mr. David Patton the new Governor of the Northern Pineland Province."
results=search(query, top_k=5, index=index, model=model)

print("\n")
for result in results:
    print('\t','\n',result)