# Setup
### Imports

In [1]:
import sys
sys.path.append('../')
del sys

%reload_ext autoreload
%autoreload 2

from tesa.toolbox.parsers import standard_parser, add_annotations_arguments, add_task_arguments
from tesa.database_creation.annotation_task import AnnotationTask
from tesa.preprocess_annotations import filter_annotations
from tesa.toolbox.utils import load_task
from tesa.modeling.utils import format_context
from os.path import join as path_join
import pandas as pd
import os
import pickle

### Parameters

In [2]:
ap = standard_parser()
add_annotations_arguments(ap)
add_task_arguments(ap)
args = ap.parse_args(["--root", ".."])

### Load the annotations data (and first preprocessing step)

In [3]:
annotation_task = AnnotationTask(silent=args.silent,
                                     results_path=path_join(args.root, args.annotations_path),
                                     years=None,
                                     max_tuple_size=None,
                                     short=None,
                                     short_size=None,
                                     random=None,
                                     debug=None,
                                     random_seed=None,
                                     save=None,
                                     corpus_path=None)

annotation_task.process_task(exclude_pilot=args.exclude_pilot)

queries = annotation_task.queries
annotations = annotation_task.annotations

Processing the modeling task...
Computing the annotated queries...
Initial length of queries: 0.
Object loaded from ../results/annotation_task/annotations/v2_0/task/queries_short.pkl.
Object loaded from ../results/annotation_task/annotations/v2_1/task/queries.pkl.
Object loaded from ../results/annotation_task/annotations/v2_2/task/queries.pkl.
Final length of queries: 61056.
Done. Elapsed time: 1s.

Computing the annotations...
Initial length of annotations: 0.
batch_00 loaded from annotations/v2_0/results/batch_00_complete.csv
Correcting "n this article, Nevada and Ohio are discussed. The two American states..." to " The two American states..."
Correcting "In this article, California and Oregon are discussed. The two neighboring states..." to " The two neighboring states..."
Correcting "In this article, California and Oregon are discussed. The two West Coast states..." to " The two West Coast states..."
batch_01 loaded from annotations/v2_0/results/batch_01_complete.csv
Discarding "Th

In [4]:
global_data = []
for id_, annotation_list in annotations.items():
    data = dict()
    
    query = queries[id_]
    data["entities_type"] = query.entities_type_
    data["entities"] = query.entities
    data["summaries"] = query.summaries
    data["urls"] = query.urls
    data["title"] = query.title
    data["date"] = query.date
    data["context"] = query.context
    data["context_type"] = query.context_type_
    
    for i, annotation in enumerate(annotation_list):
        if annotation.answers:
            data[f"answer_{i}"] = annotation.answers
            
    global_data.append(data)
    
df = pd.DataFrame(global_data)[["entities_type","entities","answer_0","answer_1","answer_2","title","date","urls","summaries","context_type","context"]]

In [5]:
df.to_csv("../results/publication/dataset.csv", index=False, mode="w")
pickle.dump(global_data, open("../results/publication/dataset.pickle", "wb"))

### Load the modeling task

In [6]:
task = load_task(args)

Task loaded from ../results/modeling_task/context-dependent-same-type_50-25-25_rs24_bs4_cf-v0_tf-v0.pkl.



In [7]:
for split in ["train", "valid", "test"]:
    loader = getattr(task, f"{split}_loader")
    global_data = []
    
    for ranking_task in loader:
        data = dict()
        
        first_input = ranking_task[0][0]
        data["entities"] = first_input["entities"]
        data["entities_type"] = first_input["entities_type"]
        data["wiki_articles"] = first_input["wiki_articles"]
        assert len(first_input["nyt_titles"]) == 1
        assert len(first_input["nyt_contexts"]) == 1
        data["nyt_title"] = first_input["nyt_titles"][0]
        data["nyt_context"] = first_input["nyt_contexts"][0]
        
        candidates = []
        labels = []
        
        for batch_inputs, batch_outputs in ranking_task:
            candidates.extend(batch_inputs["choices"])
            labels.extend(list(batch_outputs.tolist()))
            
        data["candidates"] = candidates
        data["labels"] = labels
        
        global_data.append(data)

    df = pd.DataFrame(global_data)[["entities_type","entities","wiki_articles","nyt_title","nyt_context","candidates","labels"]]
    
    df.to_csv(f"../results/publication/ranking_task_{split}.csv", index=False, mode="w")
    pickle.dump(global_data, open(f"../results/publication/ranking_task_{split}.pickle", "wb"))