In [None]:
import os
import importlib
import logging
import time
importlib.reload(logging)
import framework
importlib.reload(framework)
import bert_qa
importlib.reload(bert_qa)
import infer_bert_qa
importlib.reload(infer_bert_qa)
import bert_utils
importlib.reload(bert_utils)
import pandas as pd
from framework import DataCuration, FeatureEngineering
from bert_qa import TaskQA, FeatureEngineeringQA, BERTQA
from bert_maskedLM import BERTMaskedLM

# Define some constants and configurations
logging.getLogger().setLevel(logging.INFO)
ACCESS_TOKEN = 'WUpGevbWC9lsnTW8quNUtmWRdAEM89'

Set up the task details. This notebook handles Question Answering for CARTA dataset.

Example

context = "New Zealand (MƒÅori: Aotearoa) is a sovereign island country in the southwestern Pacific Ocean. It has a total land area of 268,000 square kilometres (103,500 sq mi), and a population of 4.9 million. New Zealand's capital city is Wellington, and its most populous city is Auckland."

questions = "How many people live in New Zealand?", "What's the largest city?"


In [None]:
DATASET = 'carta' # supports w2 and resume
TASK_CONFIG = {
    'task': 'qa'
}

task = TaskQA(TASK_CONFIG)

Set paths for datasets and goldens (local or ib, both work).
Specify configurations

In [None]:
CARTA_DATA = [
   '/Users/ahsaasbajaj/Documents/Data/CARTA/Annotated Samples/out/s1_process_files'
]
CARTA_GOLDEN = [
   '/Users/ahsaasbajaj/Documents/Data/CARTA/Annotated Samples/golden/output.csv'
]

GOLDEN_CONFIG = {
    'path': CARTA_GOLDEN,
    'is_local': True,
    'index_field_name':'filename',
    'file_type': 'csv',
    'identifier': 'file'
}
DATASET_CONFIG = {
    'path': CARTA_DATA,
    'is_local': True, 
    'file_type': 'ibocr',
    'identifier': lambda path: os.path.basename(path).split('.ibocr')[0],
    'convert2txt': True
}

data = DataCuration(ACCESS_TOKEN, DATASET_CONFIG, GOLDEN_CONFIG)

In [None]:
data.golden

In [None]:
open_queries = [ 
                "Who is incorporating the company?",
                "How many shares are being created?",
                "What are the number of authorized shares?",
                "What are the Preferred stocks?",
                "What are the Non-cumulative dividends?",
                "What are the Common stocks?",
                "What is the Dividend rate per annum per preferred share type?",
                "What is the original issue price per share?",
                "What is the seniority of preferred share?",
                "What is the liquidation preference?",
                "What is the conversion price"
                ]

closed_queries = [ 
                "The company is incorporated by",
                "The number of shares being created are",
                "The common stocks are",
                "The Preferred stocks are",
                "The Non-cumulative dividends are",
                "The Dividend rate per annum per preferred share type are",
                "The number of authorized shares are",
                "The Original Issue Price per share is",
                "The Liquidation preference is"
                ]

In [None]:
NUM_FILES = len(data.dataset.keys())
stime = time.time()

DATA_ARGS = {
    'task': task,
    'dataset': data,
    'is_closed_query': False  # if False, then use BERTQA, otherwise use BERTMaskedLM 
}

queries = None

if DATA_ARGS['is_closed_query']:
    # Question Answering using Masked Language Model 
    queries = closed_queries
    queries = open_queries
    TRAINING_ARGS = {
    'model_file_or_path': "bert-large-uncased", # finetuned checkpoint available directly
    'gpu': False,
    'output_dir': '../outputs/bert_maskedLM'
    }

    model = BERTMaskedLM(DATA_ARGS, TRAINING_ARGS)
    output = model.predict(queries)
else:
    # Standard Question Answering Model
    queries = open_queries
    TRAINING_ARGS = {
    'model_file_or_path': "bert-large-uncased-whole-word-masking-finetuned-squad", # finetuned checkpoint available directly
    'gpu': False,
    'output_dir': '../outputs/bert_qa'
    }

    model = BERTQA(DATA_ARGS, TRAINING_ARGS)
    output = model.predict(queries)

etime = time.time()
logging.info('Total time for {} files and {} queries each is {} seconds'.format(NUM_FILES, len(queries), (etime - stime)))