In [1]:
import json
from langchain.document_loaders import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
import pandas as pd

In [2]:
#load document
# file_path = 'data/doc2dial/doc2dial_qa_train.csv'
# loader = CSVLoader(file_path)
# document = loader.load()
# print(f'documents:{len(document)}')

In [3]:
#load qa pairs
file_path = 'data/doc2dial/doc2dial_qa_train.csv'
df = pd.read_csv(file_path)

In [4]:
doc1= df.loc[0]

In [5]:
doc1['domain']

'dmv'

In [6]:
file_path = 'data/doc2dial/doc2dial_doc.json'
with open(file_path, 'r') as f:
    doc2dial_doc = json.load(f)

In [7]:
doc1_text = doc2dial_doc['doc_data'][doc1['domain']][doc1['doc_id']]['doc_text']
print(doc1_text)

Many DMV customers make easily avoidable mistakes that cause them significant problems, including encounters with law enforcement and impounded vehicles. Because we see customers make these mistakes over and over again , we are issuing this list of the top five DMV mistakes and how to avoid them. 

1. Forgetting to Update Address 
By statute , you must report a change of address to DMV within ten days of moving. That is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ. It is not sufficient to only: write your new address on the back of your old license; tell the United States Postal Service; or inform the police officer writing you a ticket. If you fail to keep your address current , you will miss a suspension order and may be charged with operating an unregistered vehicle and/or aggravated unlicensed operation, both misdemeanors. This really happens , but the good news is this is a problem tha

# embedding

In [8]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import DataFrameLoader
from langchain.docstore.document import Document

In [9]:
def split(document):
    text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 500,
    chunk_overlap = 0
    )
    # remove empty documents
    split_documents = text_splitter.split_documents(document)
    print(f'documents:{len(split_documents)}')
    return split_documents

def embedding(documents) -> FAISS:
    #load embeddings
    embeddings = HuggingFaceEmbeddings(model_name = 'sentence-transformers/gtr-t5-base')
    db = FAISS.from_documents(documents, embeddings)
    return db

def save_db(db):
    db.save_local("data/faiss_index")

def load_db(embeddings):
    new_db = FAISS.load_local("data/faiss_index", embeddings)
    return new_db

In [10]:
document = Document(page_content=doc1_text, metadata={"source": doc1['doc_id']})

In [11]:
split_documents = split([document])

documents:17


In [12]:
split_documents

[Document(page_content='Many DMV customers make easily avoidable mistakes that cause them significant problems, including encounters with law enforcement and impounded vehicles. Because we see customers make these mistakes over and over again , we are issuing this list of the top five DMV mistakes and how to avoid them.', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='1. Forgetting to Update Address', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='By statute , you must report a change of address to DMV within ten days of moving. That is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ. It is not sufficient to only: write your new address on the back of your old license; tell the United States Postal Service; or inform the police officer writing you a ticket. If you fail to keep your address current , y

In [13]:
db = embedding(split_documents)

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
save_db(db)

In [85]:
query = doc1['question']
docs = db.similarity_search(query)
docs

[Document(page_content='1. Forgetting to Update Address', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='By statute , you must report a change of address to DMV within ten days of moving. That is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ. It is not sufficient to only: write your new address on the back of your old license; tell the United States Postal Service; or inform the police officer writing you a ticket. If you fail to keep your address current , you will miss a suspension order and may be', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='possible mail correspondence can reach you. Also , turning in your plates is important to avoid an insurance lapse.', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='receive their DRA assessment because we do 

In [16]:
doc1['references']

"[{'sp_id': '6', 'label': 'solution'}, {'sp_id': '7', 'label': 'solution'}]"

In [17]:
doc2dial_doc['doc_data'][doc1['domain']][doc1['doc_id']]['spans']['56']

{'id_sp': '56',
 'tag': 'u',
 'start_sp': 4496,
 'end_sp': 4527,
 'text_sp': 'Sign up or log into MyDMV [6 ] ',
 'title': '5. Not Bringing Proper Documentation to DMV Office',
 'parent_titles': [],
 'id_sec': '15',
 'start_sec': 4496,
 'text_sec': 'Sign up or log into MyDMV [6 ] ',
 'end_sec': 4527}

# answer agent

In [18]:
from langchain import VectorDBQA
from langchain.chains import qa_with_sources
from langchain.chains.question_answering import load_qa_chain
from langchain import HuggingFaceHub
import os


In [19]:
#read from txt file
with open('api.txt', 'r') as f:
    lines = f.readlines()
    lines = [line.strip().split(":")[1] for line in lines]
    openai_api = lines[0]
    hf_api = lines[1]

In [20]:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_api
os.environ["OPENAI_API_KEY"] = openai_api

In [98]:
qa_model = HuggingFaceHub(repo_id='google/flan-t5-large')
query = "Please answer the following question.\n"+doc1['question']
chain = load_qa_chain(llm=qa_model, chain_type="map_reduce")
chain.run(input_documents=docs, question=query, raw_response=True)

Token indices sequence length is longer than the specified maximum sequence length for this model (1517 > 1024). Running this sequence through the model will result in indexing errors


'By statute , you must report a change of address to DMV within ten days'

In [22]:
docs

[Document(page_content='About ten percent of customers visiting a DMV office do not bring what they need to complete their transaction, and have to come back a second time to finish their business. This can be as simple as not bringing sufficient funds to pay for a license renewal or not having the proof of auto insurance required to register a car. Better yet , don t visit a DMV office at all, and see if your transaction can be performed online, like an address change, registration renewal, license renewal, replacing', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='Many DMV customers make easily avoidable mistakes that cause them significant problems, including encounters with law enforcement and impounded vehicles. Because we see customers make these mistakes over and over again , we are issuing this list of the top five DMV mistakes and how to avoid them.', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(

In [23]:
query

'Please answer the following question.\nCan I do my DMV transactions online?'

In [24]:
docs

[Document(page_content='About ten percent of customers visiting a DMV office do not bring what they need to complete their transaction, and have to come back a second time to finish their business. This can be as simple as not bringing sufficient funds to pay for a license renewal or not having the proof of auto insurance required to register a car. Better yet , don t visit a DMV office at all, and see if your transaction can be performed online, like an address change, registration renewal, license renewal, replacing', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='Many DMV customers make easily avoidable mistakes that cause them significant problems, including encounters with law enforcement and impounded vehicles. Because we see customers make these mistakes over and over again , we are issuing this list of the top five DMV mistakes and how to avoid them.', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(

In [25]:
# from langchain.chains.qa_with_sources import load_qa_with_sources_chain

# chain = load_qa_with_sources_chain(qa_model, chain_type="map_reduce")
# chain({"input_documents": docs, "question": query}, return_only_outputs=True)

Token indices sequence length is longer than the specified maximum sequence length for this model (1673 > 1024). Running this sequence through the model will result in indexing errors


{'output_text': 'SOURCES: Top 5 DMV Mistakes and How to Avoid Them#3'}

In [26]:
# from langchain import OpenAI

# query = "Please answer the following question using the given documents only.\n"+doc1['question']
# openai_model = OpenAI(model_name="text-davinci-003", max_tokens=1024)
# chain = load_qa_chain(llm=qa_model, chain_type="stuff")
# chain.run(input_documents=docs, question=query, raw_response=True)

'b).'

In [27]:
query

'Please answer the following question using the given documents only.\nHello, I forgot o update my address, can you help me with that?'

In [28]:
docs

[Document(page_content='About ten percent of customers visiting a DMV office do not bring what they need to complete their transaction, and have to come back a second time to finish their business. This can be as simple as not bringing sufficient funds to pay for a license renewal or not having the proof of auto insurance required to register a car. Better yet , don t visit a DMV office at all, and see if your transaction can be performed online, like an address change, registration renewal, license renewal, replacing', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='Many DMV customers make easily avoidable mistakes that cause them significant problems, including encounters with law enforcement and impounded vehicles. Because we see customers make these mistakes over and over again , we are issuing this list of the top five DMV mistakes and how to avoid them.', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(

# evaluation

In [45]:
import re

In [44]:
doc1['references']

"[{'sp_id': '6', 'label': 'solution'}, {'sp_id': '7', 'label': 'solution'}]"

In [71]:
df.loc[1]['references']

question                   Can I do my DMV transactions online?
answer        Yes, you can sign up for MyDMV for all the onl...
domain                                                      dmv
doc_id             Top 5 DMV Mistakes and How to Avoid Them#3_0
references               [{'sp_id': '56', 'label': 'solution'}]
dial_id                        9f44c1539efe6f7e79b02eb1b413aa43
Name: 1, dtype: object

In [75]:
res = re.findall(r"\d+", doc1['references'])
[int(i) for i in res]

[6, 7]

In [80]:
#global dictionary
# doc2dial_doc

def get_ref(doc):
    refs_ID = re.findall(r"\d+", doc['references'])
    refs_ID = [int(i) for i in refs_ID]
    sp_list = [doc2dial_doc['doc_data'][doc['domain']][doc['doc_id']]['spans'][str(i)] for i in refs_ID]
    return sp_list

get_ref(doc1)


[{'id_sp': '6',
  'tag': 'u',
  'start_sp': 346,
  'end_sp': 416,
  'text_sp': 'you must report a change of address to DMV within ten days of moving. ',
  'title': '1. Forgetting to Update Address',
  'parent_titles': [],
  'id_sec': '2',
  'start_sec': 333,
  'text_sec': 'By statute , you must report a change of address to DMV within ten days of moving. That is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ. ',
  'end_sec': 567},
 {'id_sp': '7',
  'tag': 'u',
  'start_sp': 416,
  'end_sp': 567,
  'text_sp': 'That is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ. ',
  'title': '1. Forgetting to Update Address',
  'parent_titles': [],
  'id_sec': '2',
  'start_sec': 333,
  'text_sec': 'By statute , you must report a change of address to DMV within ten days of moving. That is the case for the address

In [90]:
get_ref(doc1)[0]['title']

'1. Forgetting to Update Address'

In [91]:
get_ref(doc1)[0]['text_sp']

'you must report a change of address to DMV within ten days of moving. '

In [86]:
docs

[Document(page_content='1. Forgetting to Update Address', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='By statute , you must report a change of address to DMV within ten days of moving. That is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ. It is not sufficient to only: write your new address on the back of your old license; tell the United States Postal Service; or inform the police officer writing you a ticket. If you fail to keep your address current , you will miss a suspension order and may be', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='possible mail correspondence can reach you. Also , turning in your plates is important to avoid an insurance lapse.', metadata={'source': 'Top 5 DMV Mistakes and How to Avoid Them#3_0'}),
 Document(page_content='receive their DRA assessment because we do 

In [106]:
docs[0].page_content

'1. Forgetting to Update Address'

In [None]:
em = 0
def eval_refs(refs, retrieved):
    #calculate exact match
    #calculate f1
    
    pass

In [105]:
doc1

question      Hello, I forgot o update my address, can you h...
answer        hi, you have to report any change of address t...
domain                                                      dmv
doc_id             Top 5 DMV Mistakes and How to Avoid Them#3_0
references    [{'sp_id': '6', 'label': 'solution'}, {'sp_id'...
dial_id                        9f44c1539efe6f7e79b02eb1b413aa43
Name: 0, dtype: object

In [99]:
short_answer = "By statute , you must report a change of address to DMV within ten days"

In [101]:
short_answer in doc1['answer']

False

In [114]:
example = {}
example['question'] = doc1['question']
example['answer'] = "By statute , you must report a change of address to DMV within ten days" #TODO should be model result?
example["passage"] = docs[0].page_content #TODO only first document. If use all documents, need to change 

In [109]:
from transformers import T5ForConditionalGeneration
from transformers import T5Tokenizer

PASSAGE_FORMAT = re.compile("« ([^»]*) » « ([^»]*) » (.*)")

# def format_passage_for_autoais(passage):
#   """Produce the NLI format for a passage.

#   Args:
#     passage: A passage from the Wikipedia scrape.

#   Returns:
#     a formatted string, e.g.

#       Luke Cage (season 2), Release. The second season of Luke Cage was released
#       on June 22, 2018, on the streaming service Netflix worldwide, in Ultra HD
#       4K and high dynamic range.
#   """
#   m = PASSAGE_FORMAT.match(passage)
#   if not m:
#     return passage

#   headings = m.group(2)
#   passage = m.group(3)
#   return f"{headings}. {passage}"

def format_example_for_autoais(example):
  return "premise: {} hypothesis: The answer to the question '{}' is '{}'".format(
      example["passage"], example["question"], example["answer"])

def infer_autoais(example, tokenizer, model):
  """Runs inference for assessing AIS between a premise and hypothesis.

  Args:
    example: Dict with the example data.
    tokenizer: A huggingface tokenizer object.
    model: A huggingface model object.

  Returns:
    A string representing the model prediction.
  """
  input_text = format_example_for_autoais(example)
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
  outputs = model.generate(input_ids)
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
  inference = "Y" if result == "1" else "N"
  example["autoais"] = inference
  return inference

def score_predictions(predictions, nq_answers):
  """Scores model predictions against AutoAIS and NQ answers.

  Args:
    predictions: A dict from questions to prediction rows.
    nq_answers: A dict from questions to lists of NQ reference answers.
    passages: A dict from identifiers from the attribution corpus to the
      corresponding paragraphs.

  Returns:
    a dict of metric values, keyed by metric names
  """
  AUTOAIS = "google/t5_xxl_true_nli_mixture"
  hf_tokenizer = T5Tokenizer.from_pretrained(AUTOAIS)
  hf_model = T5ForConditionalGeneration.from_pretrained(AUTOAIS)

  autoais = 0
  target_answers = []
  predicted_answers = []
  for question, answers in nq_answers.items():
    target_answers.append(answers)
    example = predictions.get(question, None)
    if example is None:
    #   logging.error("Did not find prediction for '%s'", question)
      predicted_answers.append("")
      continue
    predicted_answers.append(example["answer"])
    if not example["passage"]:
      continue
    inference = infer_autoais(example, hf_tokenizer, hf_model)
    autoais += inference == "Y"

  scores = {}
  scores["AutoAIS"] = autoais / len(target_answers)
#   for metric, score in squad(target_answers, predicted_answers).items():
#     scores[f"SQuAD ({metric})"] = score
  return scores

In [112]:
AUTOAIS = "google/t5_xxl_true_nli_mixture"
hf_tokenizer = T5Tokenizer.from_pretrained(AUTOAIS)
hf_model = T5ForConditionalGeneration.from_pretrained(AUTOAIS)

infer_autoais(example, hf_tokenizer, hf_model)


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

'N'

In [115]:
infer_autoais(example, hf_tokenizer, hf_model)

'N'