Don't forget to install pytorch-scatter as it's a dependency for the TAPAS model

In [1]:
!pip install torch-scatter -f -q https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q transformers

Looking in links: -q
Collecting https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
  Downloading https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
[31m  ERROR: Cannot unpack file /tmp/pip-unpack-6kuof7q5/torch-1.8.0+cu101.html (downloaded from /tmp/pip-req-build-q_m6ygcp, content-type: text/html); cannot detect archive format[0m
[31mERROR: Cannot determine archive format of /tmp/pip-req-build-q_m6ygcp[0m
[K     |████████████████████████████████| 2.1MB 9.7MB/s 
[K     |████████████████████████████████| 3.3MB 11.8MB/s 
[K     |████████████████████████████████| 901kB 48.9MB/s 
[?25h

In [2]:
import os
import ast
import torch
import collections
import requests, zipfile, io

import numpy as np
import pandas as pd

from transformers import TapasTokenizer
from transformers import TapasForQuestionAnswering,AdamW

### Download SQA Dataset

In [3]:
def download_files(dir_name):
  if not os.path.exists(dir_name): 
    # 28 training examples from the SQA training set + table csv data
    urls = ["https://www.dropbox.com/s/2p6ez9xro357i63/sqa_train_set_28_examples.zip?dl=1",
            "https://www.dropbox.com/s/abhum8ssuow87h6/table_csv.zip?dl=1"
    ]
    for url in urls:
      r = requests.get(url)
      z = zipfile.ZipFile(io.BytesIO(r.content))
      z.extractall()

dir_name = "sqa_data"
download_files(dir_name)

### Data Preprocessing

In [4]:
data = pd.read_excel("sqa_train_set_28_examples.xlsx")
data.head()

Unnamed: 0,id,annotator,position,question,table_file,answer_coordinates,answer_text
0,nt-639,0,0,where are the players from?,table_csv/203_149.csv,"['(0, 4)', '(1, 4)', '(2, 4)', '(3, 4)', '(4, ...","['Louisiana State University', 'Valley HS (Las..."
1,nt-639,0,1,which player went to louisiana state university?,table_csv/203_149.csv,"['(0, 1)']",['Ben McDonald']
2,nt-639,1,0,who are the players?,table_csv/203_149.csv,"['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...","['Ben McDonald', 'Tyler Houston', 'Roger Salke..."
3,nt-639,1,1,which ones are in the top 26 picks?,table_csv/203_149.csv,"['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...","['Ben McDonald', 'Tyler Houston', 'Roger Salke..."
4,nt-639,1,2,"and of those, who is from louisiana state univ...",table_csv/203_149.csv,"['(0, 1)']",['Ben McDonald']



As you can see, each row corresponds to a question related to a table.

- The position column identifies whether the question is the first, second, ... in a sequence of questions related to a table.
= The table_file column identifies the name of the table file, which refers to a CSV file in the table_csv directory.
= The answer_coordinates and answer_text columns indicate the answer to the question. The answer_coordinates is a list of tuples, each tuple being a (row_index, column_index) pair. The answer_text column is a list of strings, indicating the cell values.

However, the answer_coordinates and answer_text columns are currently not recognized as real Python lists of Python tuples and strings respectively. Let's do that first using the .literal_eval()function of the ast module:

In [5]:
def _parse_answer_coordinates(answer_coordinate_str):
  """Parses the answer_coordinates of a question.
  Args:
    answer_coordinate_str: A string representation of a Python list of tuple
      strings.
      For example: "['(1, 4)','(1, 3)', ...]"
  """

  try:
    answer_coordinates = []
    # make a list of strings
    coords = ast.literal_eval(answer_coordinate_str)
    # parse each string as a tuple
    for row_index, column_index in sorted(
        ast.literal_eval(coord) for coord in coords):
      answer_coordinates.append((row_index, column_index))
  except SyntaxError:
    raise ValueError('Unable to evaluate %s' % answer_coordinate_str)
  
  return answer_coordinates


def _parse_answer_text(answer_text):
  """Populates the answer_texts field of `answer` by parsing `answer_text`.
  Args:
    answer_text: A string representation of a Python list of strings.
      For example: "[u'test', u'hello', ...]"
    answer: an Answer object.
  """
  try:
    answer = []
    for value in ast.literal_eval(answer_text):
      answer.append(value)
  except SyntaxError:
    raise ValueError('Unable to evaluate %s' % answer_text)

  return answer

data['answer_coordinates'] = data['answer_coordinates'].apply(lambda coords_str: _parse_answer_coordinates(coords_str))
data['answer_text'] = data['answer_text'].apply(lambda txt: _parse_answer_text(txt))

data.head(10)

Unnamed: 0,id,annotator,position,question,table_file,answer_coordinates,answer_text
0,nt-639,0,0,where are the players from?,table_csv/203_149.csv,"[(0, 4), (1, 4), (2, 4), (3, 4), (4, 4), (5, 4...","[Louisiana State University, Valley HS (Las Ve..."
1,nt-639,0,1,which player went to louisiana state university?,table_csv/203_149.csv,"[(0, 1)]",[Ben McDonald]
2,nt-639,1,0,who are the players?,table_csv/203_149.csv,"[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...","[Ben McDonald, Tyler Houston, Roger Salkeld, J..."
3,nt-639,1,1,which ones are in the top 26 picks?,table_csv/203_149.csv,"[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...","[Ben McDonald, Tyler Houston, Roger Salkeld, J..."
4,nt-639,1,2,"and of those, who is from louisiana state univ...",table_csv/203_149.csv,"[(0, 1)]",[Ben McDonald]
5,nt-639,2,0,who are the players in the top 26?,table_csv/203_149.csv,"[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...","[Ben McDonald, Tyler Houston, Roger Salkeld, J..."
6,nt-639,2,1,"of those, which one was from louisiana state u...",table_csv/203_149.csv,"[(0, 1)]",[Ben McDonald]
7,nt-11649,0,0,what are all the names of the teams?,table_csv/204_135.csv,"[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...","[Cordoba CF, CD Malaga, Granada CF, UD Las Pal..."
8,nt-11649,0,1,"of these, which teams had any losses?",table_csv/204_135.csv,"[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...","[Cordoba CF, CD Malaga, Granada CF, UD Las Pal..."
9,nt-11649,0,2,"of these teams, which had more than 21 losses?",table_csv/204_135.csv,"[(15, 1)]",[CD Villarrobledo]
