In [None]:
!pip install transformers

In [None]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering

tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-wtq")

model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")

In [4]:
data = {
    "Date": ["2022-10-01", "2022-10-02", "2022-10-02", "2022-10-03"],
    "Country":["India", "Germany", "India", "France"],
    "Income": ["60", "40", "50", "120"]
}

queries = [
    "What was the total income on 2022-10-02?",
    "What was the income from Germany on 2022-10-02?",
    "What was the total income from India?",
    "What was the total income in October 2022?",
    "What was the total income in the first week of October 2022?",
    "What was the total income from Indonesia?",
    # Problematic:
    "What was the total income on 2022-10-02 and 2022-10-03?",
    "What was the total income from France and India in October 2022?",
    "How many are the planets in the Solar system?"
]

table = pd.DataFrame.from_dict(data)

inputs = tokenizer(
    table=table,
    queries=queries,
    padding="max_length",
    return_tensors="pt"
)

outputs = model(**inputs)

predicted_answer_coordinates, predicted_aggregation_indices = \
    tokenizer.convert_logits_to_predictions(
      inputs,
      outputs.logits.detach(),
      outputs.logits_aggregation.detach()
    )

# let's print out the results:
id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"}

aggregation_predictions_string = [
    id2aggregation[x] for x in predicted_aggregation_indices
]

answers = []
for coordinates in predicted_answer_coordinates:
    if len(coordinates) == 1:
        # only a single cell:
        answers.append(table.iat[coordinates[0]])
    else:
        # multiple cells
        cell_values = []

        for coordinate in coordinates:
            cell_values.append(table.iat[coordinate])

        answers.append(",".join(cell_values))

display(table)
print("")
for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string):
    print(query)

    if predicted_agg == "NONE":
        print("Predicted answer: " + answer)
        print("")
    else:
        print("Predicted answer: " + predicted_agg + " --> " + answer)
        print("")

Unnamed: 0,Date,Country,Income
0,2022-10-01,India,60
1,2022-10-02,Germany,40
2,2022-10-02,India,50
3,2022-10-03,France,120



What was the total income on 2022-10-02?
Predicted answer: SUM --> 40,50

What was the income from Germany on 2022-10-02?
Predicted answer: AVERAGE --> 40

What was the total income from India?
Predicted answer: SUM --> 60,50

What was the total income in October 2022?
Predicted answer: AVERAGE --> 60,40,50,120

What was the total income in the first week of October 2022?
Predicted answer: SUM --> 60,40,50

What was the total income from Indonesia?
Predicted answer: AVERAGE --> 

What was the total income on 2022-10-02 and 2022-10-03?
Predicted answer: SUM --> 120

What was the total income from France and India in October 2022?
Predicted answer: SUM --> 60,120

How many are the planets in the Solar system?
Predicted answer: COUNT --> 2022-10-01,2022-10-02,2022-10-02,2022-10-03

