Skip to content

Commit

Permalink
Created a sklearn wrapper for the QA Pipeline (#101)
Browse files Browse the repository at this point in the history
* Implemented QAPipeline object that do the whole process for question-answering

* Added option to attribute model: path (string) or joblib object

* corrected typo

* Created example of jupyter notebook for use of qa_pipeline

* Update notebook example

* Added description of QAPipeline class"

* Added descriptions to all methods of QAPipeline class"

* Corrected typo

* Added download of CPU version of model to download.py (#100)

* update example notebook and docstrings (#92, #90,  #79) (#102)

* update example notebook and docstrings (#92, #90,  #79)

* update docstrings #79

* continue #79

* add flake8 to pytest in CI

* start integrating rest api #35

* add info readme

* basic api #35

* update reqs

* add refs and badges #87 (#105)

* add refs and badges #87

* sync HF

* first version of paper

*  Add sklearn wrapper for retriever as well #95

* Add sklearn wrapper for retriever as well #95

* update readme and clean repo

* update evaluation section in README

* debug-minor-updates (#106)

* Add github badges #87

* Disable verbose during predictions #103

* fix typos and tests #95

* Rename variables and scripts #108

* adapt notebook to new retriever class (#109)

* adapt notebook to new retriever class

* remove samples dir

* clean up repo and rename #108

* Fix predict berqa (#113)

* Rename variables and scripts #108

* Rename variables and scripts #108

* BertQA().predict() should return only 1 final predictions object #110

* Implemented QAPipeline object that do the whole process for question-answering

* Added option to attribute model: path (string) or joblib object

* corrected typo

* Created example of jupyter notebook for use of qa_pipeline

* Update notebook example

* Added description of QAPipeline class"

* Added descriptions to all methods of QAPipeline class

* Corrected typo

* Changed code from qa_pipeline.py to cdqa_sklearn.py

* seperated kwargs for declaration of different classes within QAPipeline

* removed qa_pipeline.py

* Implemented predict() and retriever part of fit()

* Implemented reader training in fit() and completed documentation

* Modified documentation for predict() method

* Deleted useless tutorial

* Created notebook example for pipeline
  • Loading branch information
andrelmfarias authored and fmikaelian committed May 2, 2019
1 parent 7d7de7a commit 2fb5f06
Show file tree
Hide file tree
Showing 2 changed files with 488 additions and 0 deletions.
111 changes: 111 additions & 0 deletions cdqa/pipeline/cdqa_sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import joblib

import pandas as pd
import numpy as np

from sklearn.base import BaseEstimator

from cdqa.retriever.tfidf_sklearn import TfidfRetriever
from cdqa.utils.converter import filter_paragraphs, generate_squad_examples
from cdqa.reader.bertqa_sklearn import BertProcessor, BertQA


class QAPipeline(BaseEstimator):
"""
A scikit-learn implementation of the whole cdQA pipeline
Parameters
----------
metadata : pandas.DataFrame
dataframe containing your corpus of documents metadata
header should be of format: date, title, category, link, abstract, paragraphs, content.
model : str or .joblib object of a version of BERT model with sklearn wrapper, optional
bert_version : str
Bert pre-trained model selected in the list: bert-base-uncased,
bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased,
bert-base-multilingual-cased, bert-base-chinese.
Examples
--------
>>> from cdqa.pipeline.qa_pipeline import QAPipeline
>>> qa_pipe = QAPipeline(model='bert_qa_squad_vCPU-sklearn.joblib', metadata=df)
>>> qa_pipe.fit()
>>> prediction = qa_pipe.predict(X='When BNP Paribas was created?')
>>> from cdqa.pipeline.qa_pipeline import QAPipeline
>>> qa_pipe = QAPipeline(metadata=df)
>>> qa_pipe.fit('train-v1.1.json', fit_reader=True)
>>> qa_pipe.fit()
>>> prediction = qa_pipe.predict(X='When BNP Paribas was created?')
"""

def __init__(self, metadata, model=None, bert_version='bert-base-uncased', **kwargs):

# Separating kwargs
kwargs_bertqa = {key: value for key, value in kwargs.items()
if key in BertQA.__init__.__code__.co_varnames}

kwargs_processor = {key: value for key, value in kwargs.items()
if key in BertProcessor.__init__.__code__.co_varnames}

kwargs_retriever = {key: value for key, value in kwargs.items()
if key in TfidfRetriever.__init__.__code__.co_varnames}

if not model:
self.model = BertQA(self.bert_version, **kwargs_bertqa)
elif type(model) == str:
self.model = joblib.load(model)
else:
self.model = model

self.metadata = metadata
self.bert_version = bert_version

self.processor_train = BertProcessor(self.bert_version,
is_training=True,
**kwargs_processor)

self.processor_predict = BertProcessor(self.bert_version,
is_training=False,
**kwargs_processor)

self.retriever = TfidfRetriever(self.metadata, **kwargs_retriever)

def fit(self, X=None, y=None, fit_reader=False):
""" Fit the QAPipeline retriever to a list of documents in a dataframe if fit_reader is false,
fit the reader (QABert model) to a json file squad-like with questions and answers
Parameters
----------
X: dict or str
Dictionaire with questions and answers in SQUAD format or path to json file in SQUAD format
fit_reader: boolean, default false
Whether to fit reader (BertQA model) or retriever
"""
if not fit_reader:
self.retriever.fit(self.metadata['content'])
else:
if not X:
raise RuntimeError(
'fit_reader is True, please pass a json file in SQUAD format as input')
train_examples, train_features = self.processor_train.fit_transform(X)
self.model.fit(X=(train_examples, train_features))

return self

def predict(self, X):
""" Compute prediction of an answer to a question
"""

closest_docs_indices = self.retriever.predict(X)
squad_examples = generate_squad_examples(question=X,
closest_docs_indices=closest_docs_indices,
metadata=self.metadata)
examples, features = self.processor_predict.fit_transform(X=squad_examples)
prediction = self.model.predict((examples, features))

return prediction
Loading

0 comments on commit 2fb5f06

Please sign in to comment.