<a href="https://colab.research.google.com/github/google-research/tapas/blob/master/notebooks/sqa_predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Running a Tapas fine-tuned checkpoint
---
This notebook shows how to load and make predictions with TAPAS model, which was introduced in the paper: [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349)

In [2]:
! pip install tapas-table-parsing

Collecting tapas-table-parsing
  Downloading tapas_table_parsing-0.0.1.dev0-py3-none-any.whl (195 kB)
[?25l[K     |█▊                              | 10 kB 27.0 MB/s eta 0:00:01[K     |███▍                            | 20 kB 7.6 MB/s eta 0:00:01[K     |█████                           | 30 kB 6.8 MB/s eta 0:00:01[K     |██████▊                         | 40 kB 6.5 MB/s eta 0:00:01[K     |████████▍                       | 51 kB 5.2 MB/s eta 0:00:01[K     |██████████                      | 61 kB 5.3 MB/s eta 0:00:01[K     |███████████▊                    | 71 kB 5.3 MB/s eta 0:00:01[K     |█████████████▍                  | 81 kB 5.9 MB/s eta 0:00:01[K     |███████████████                 | 92 kB 4.9 MB/s eta 0:00:01[K     |████████████████▊               | 102 kB 5.2 MB/s eta 0:00:01[K     |██████████████████▍             | 112 kB 5.2 MB/s eta 0:00:01[K     |████████████████████            | 122 kB 5.2 MB/s eta 0:00:01[K     |█████████████████████▉          | 133 

# Fetch models fom Google Storage

Next we can get pretrained checkpoint from Google Storage. For the sake of speed, this is base sized model trained on [SQA](https://www.microsoft.com/en-us/download/details.aspx?id=54253). Note that best results in the paper were obtained with a large model, with 24 layers instead of 12.

In [3]:

! gsutil cp gs://tapas_models/2020_04_21/tapas_sqa_base.zip . && unzip tapas_sqa_base.zip

Copying gs://tapas_models/2020_04_21/tapas_sqa_base.zip...
/ [1 files][  1.0 GiB/  1.0 GiB]   44.7 MiB/s                                   
Operation completed over 1 objects/1.0 GiB.                                      
Archive:  tapas_sqa_base.zip
   creating: tapas_sqa_base/
  inflating: tapas_sqa_base/model.ckpt.data-00000-of-00001  
  inflating: tapas_sqa_base/model.ckpt.index  
  inflating: tapas_sqa_base/README.txt  
  inflating: tapas_sqa_base/vocab.txt  
  inflating: tapas_sqa_base/bert_config.json  
  inflating: tapas_sqa_base/model.ckpt.meta  


# Imports

In [4]:
import tensorflow.compat.v1 as tf
import os 
import shutil
import csv
import pandas as pd
import IPython

tf.get_logger().setLevel('ERROR')

In [5]:
from tapas.utils import tf_example_utils
from tapas.protos import interaction_pb2
from tapas.utils import number_annotation_utils
from tapas.scripts import prediction_utils

# Load checkpoint for prediction

Here's the prediction code, which will create and `interaction_pb2.Interaction` protobuf object, which is the datastructure we use to store examples, and then call the prediction script.

In [6]:
import psycopg2

  """)


In [7]:
# access the postgresql server
conn = psycopg2.connect(
    host="codd04.research.northwestern.edu",
    port = "5433",
    database="postgres",
    user="cpdbstudent",
    password="DataSci4AI")
cursor = conn.cursor()

In [23]:
edges_query = '''
DROP TABLE IF EXISTS da_category_ids;
CREATE TEMP TABLE da_category_ids AS (
(SELECT data_officerallegation.id,data_allegationcategory.allegation_name,data_allegationcategory.category
FROM data_officerallegation
join data_allegationcategory on data_officerallegation.allegation_category_id = data_allegationcategory.id
WHERE data_allegationcategory.category = 'Drug / Alcohol Abuse' OR data_allegationcategory.category = 'Medical' or allegation_name LIKE 'Medical Roll%'
OR data_allegationcategory.category_code IN ('024', '003', '003A', '003B', '003C', '003D', '003E')));

SELECT gender,race,birth_year,first_name || ' ' || last_name officer_name,allegation_count,sustained_count,current_salary
FROM data_officer
join da_category_ids on data_officer.id=da_category_ids.id
where data_officer.current_salary IS NOT NULL and allegation_count>2 and sustained_count>1 order by officer_name desc;
'''

In [29]:
cursor.execute(edges_query)
colnames = [desc[0] for desc in cursor.description]
edges = cursor.fetchall()
res = [[str(x) for x in list(ele)] for ele in edges]
res.insert(0, colnames)
print(res)
b=[el[6] for el in res][1:]
print(b)

[['gender', 'race', 'birth_year', 'officer_name', 'allegation_count', 'sustained_count', 'current_salary'], ['M', 'Hispanic', '1952', 'Wilman Dones', '23', '6', '82008'], ['M', 'Black', '1963', 'Vernard Ross', '23', '3', '107988'], ['M', 'White', '1957', 'Thomas Motzny', '31', '2', '102978'], ['M', 'White', '1949', 'Thomas Biggane', '30', '2', '82878'], ['M', 'White', '1969', 'Steven Nowicki', '27', '3', '111474'], ['M', 'White', '1970', 'Steven Bechina', '43', '4', '107988'], ['M', 'White', '1948', 'Ronald Blake', '6', '2', '83604'], ['M', 'Black', '1948', 'Rollins Johnson', '6', '2', '88260'], ['M', 'Hispanic', '1985', 'Roger Farias', '14', '2', '84054'], ['M', 'Black', '1964', 'Ricky Bean', '11', '3', '100980'], ['M', 'White', '1961', 'Raymond Gadomski', '17', '3', '96060'], ['M', 'Black', '1953', 'Prentiss Jackson', '15', '5', '102978'], ['M', 'White', '1962', 'Philip Paluch', '56', '2', '111474'], ['F', 'Black', '1950', 'Paularie Draine', '7', '2', '78012'], ['F', 'Black', '1948',

In [10]:
os.makedirs('results/sqa/tf_examples', exist_ok=True)
os.makedirs('results/sqa/model', exist_ok=True)
with open('results/sqa/model/checkpoint', 'w') as f:
  f.write('model_checkpoint_path: "model.ckpt-0"')
for suffix in ['.data-00000-of-00001', '.index', '.meta']:
  shutil.copyfile(f'tapas_sqa_base/model.ckpt{suffix}', f'results/sqa/model/model.ckpt-0{suffix}')

In [19]:
max_seq_length = 512
vocab_file = "tapas_sqa_base/vocab.txt"
config = tf_example_utils.ClassifierConversionConfig(
    vocab_file=vocab_file,
    max_seq_length=max_seq_length,
    max_column_id=max_seq_length,
    max_row_id=max_seq_length,
    strip_column_names=False,
    add_aggregation_candidates=False,
)
converter = tf_example_utils.ToClassifierTensorflowExample(config)

def convert_interactions_to_examples(tables_and_queries):
  """Calls Tapas converter to convert interaction to example."""
  for idx, (table, queries) in enumerate(tables_and_queries):
    interaction = interaction_pb2.Interaction()
    for position, query in enumerate(queries):
      question = interaction.questions.add()
      question.original_text = query
      question.id = f"{idx}-0_{position}"
    for header in table[0]:
      interaction.table.columns.add().text = header
    for line in table[1:]:
      row = interaction.table.rows.add()
      for cell in line:
        row.cells.add().text = cell
    number_annotation_utils.add_numeric_values(interaction)
    for i in range(len(interaction.questions)):
      try:
        yield converter.convert(interaction, i)
      except ValueError as e:
        print(f"Can't convert interaction: {interaction.id} error: {e}")
        
def write_tf_example(filename, examples):
  with tf.io.TFRecordWriter(filename) as writer:
    for example in examples:
      writer.write(example.SerializeToString())

def predict(table_data, queries):
  answers=[]
  table = res
  examples = convert_interactions_to_examples([(table, queries)])
  write_tf_example("results/sqa/tf_examples/test.tfrecord", examples)
  write_tf_example("results/sqa/tf_examples/random-split-1-dev.tfrecord", [])
  
  ! python -m tapas.run_task_main \
    --task="SQA" \
    --output_dir="results" \
    --noloop_predict \
    --test_batch_size={len(queries)} \
    --tapas_verbosity="ERROR" \
    --compression_type= \
    --init_checkpoint="tapas_sqa_base/model.ckpt" \
    --bert_config_file="tapas_sqa_base/bert_config.json" \
    --mode="predict" 2> error


  results_path = "results/sqa/model/test_sequence.tsv"
  all_coordinates = []
  df = pd.DataFrame(table[1:], columns=table[0])
  with open(results_path) as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t')
    for row in reader:
      coordinates = prediction_utils.parse_coordinates(row["answer_coordinates"])
      all_coordinates.append(coordinates)
      answers.append(', '.join([table[row + 1][col] for row, col in coordinates]))
      position = int(row['position'])
      print(">", queries[position])
      print(answers)
  return answers

# Predict

In [20]:
# Example nu-1000-0
result = predict('', ["what is the current salary of Wilman Dones?","what is the current salary of Vernard Ross?","what is the current salary of Thomas Motzny?","what is the current salary of Thomas Biggane?","what is the current salary of Steven Nowicki?","what is the current salary of Steven Bechina?","what is the current salary of Ronald Blake?","what is the current salary of Rollins Johnson?","what is the current salary of Roger Farias?","what is the current salary of Ricky Bean?","what is the current salary of Raymond Gadomski?","what is the current salary of Prentiss Jackson?","what is the current salary of Philip Paluch?","what is the current salary of Paularie Draine?","what is the current salary of Patricia Ballentine?","what is the current salary of Orlando Fonseca?","what is the current salary of Nicholas Dimaggio?","what is the current salary of Michael Overstreet?","what is the current salary of Marvin Randolph?","what is the current salary of Marshall Pufundt?","what is the current salary of Marco Johnson?","what is the current salary of Luis Lopez?","what is the current salary of Lisa William-Handley?","what is the current salary of Linda Brumfield?","what is the current salary of Larry Dotson?","what is the current salary of Lakisa Anderson?","what is the current salary of Kimberly Hill?","what is the current salary of Kevin Keyes?","what is the current salary of Joseph Thompson?","what is the current salary of Joseph Battaglia?","what is the current salary of John Brownridge?","what is the current salary of Jesus Avila?","what is the current salary of Jeanetta Brown Cunningha?","what is the current salary of Jacklyn Mueller?","what is the current salary of Irwin Negron?","what is the current salary of Haki Akintunde?","what is the current salary of Frederick Anthony?","what is the current salary of Emil Bux?","what is the current salary of Donald Banks?","what is the current salary of Dennis Peca?","what is the current salary of Brian Burton?","what is the current salary of Anthony Singleton?","what is the current salary of Anita Ashton?","what is the current salary of Alvin Ward?","what is the current salary of Alvin Greenup?"])

is_built_with_cuda: True
is_gpu_available: True
GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Training or predicting ...
Evaluation finished after training step 0.
> what is the current salary of Wilman Dones?
['82008']
> what is the current salary of Vernard Ross?
['82008', '83706']
> what is the current salary of Thomas Motzny?
['82008', '83706', '82878']
> what is the current salary of Thomas Biggane?
['82008', '83706', '82878', '']
> what is the current salary of Steven Nowicki?
['82008', '83706', '82878', '', '107988']
> what is the current salary of Steven Bechina?
['82008', '83706', '82878', '', '107988', '']
> what is the current salary of Ronald Blake?
['82008', '83706', '82878', '', '107988', '', '83604']
> what is the current salary of Rollins Johnson?
['82008', '83706', '82878', '', '107988', '', '83604', '']
> what is the current salary of Roger Farias?
['82008', '83706', '82878', '', '107988', '', '83604', '', '84054']
> what is the current sala

In [13]:
print(result)

['82008', '83706', '82878', '', '107988', '', '83604', '', '84054', '100980', '96060, 107988', '102978', '', '78012', '79926', '68262', '87888', '78006', '', '', '', '', '', '68262', '90540', '93354', '86130, 100980', '', '83706', '', '', '', '99888', '87006', '', '70656', '92430', '93354', '92316', '', '93354', '', '71682', '93354', '93354']


In [30]:
b

['82008',
 '107988',
 '102978',
 '82878',
 '111474',
 '107988',
 '83604',
 '88260',
 '84054',
 '100980',
 '96060',
 '102978',
 '111474',
 '78012',
 '79926',
 '83616',
 '87888',
 '78006',
 '90024',
 '58884',
 '86130',
 '93354',
 '90618',
 '68262',
 '90540',
 '93354',
 '86130',
 '69270',
 '101442',
 '83706',
 '100980',
 '95106',
 '99888',
 '87006',
 '96060',
 '70656',
 '92430',
 '93354',
 '92316',
 '90456',
 '93354',
 '90024',
 '71682',
 '69264',
 '93354']

In [32]:
count=0
for x in result:
  for y in b:
    if x==y:
      print(x,y)
      count=count+1
print(count)

82008 82008
83706 83706
82878 82878
107988 107988
107988 107988
83604 83604
84054 84054
100980 100980
100980 100980
102978 102978
102978 102978
78012 78012
79926 79926
68262 68262
87888 87888
78006 78006
68262 68262
90540 90540
93354 93354
93354 93354
93354 93354
93354 93354
93354 93354
83706 83706
99888 99888
87006 87006
70656 70656
92430 92430
93354 93354
93354 93354
93354 93354
93354 93354
93354 93354
92316 92316
93354 93354
93354 93354
93354 93354
93354 93354
93354 93354
71682 71682
93354 93354
93354 93354
93354 93354
93354 93354
93354 93354
93354 93354
93354 93354
93354 93354
93354 93354
93354 93354
50
