# Parse Dataset

In this notebook, we are going to parse the dataset obtained from the [PICO Dataset](https://ebm-nlp.herokuapp.com/annotations) ([download](https://github.com/bepnye/EBM-NLP)). Our goal is to turn it into a spaCy binary format. For now, we only need the `ebm_nlp_1_00` dataset that contains the first-phase annotations (P)articipants, (I)ntervention, and (O)utcome. 

In terms of the dataset, we'll only need the `annotations/aggregated/starting_spans` since we're only concerned with the P/I/O labels. The `hierarchical` tag is for finer annotations for each category, which we'll probably deal in the latter stages of this project. The documents themselves are already separated by tokens, but we can also access the text if we want to. 

Our strategy for parsing this dataset is that we do this from the annotations first, then match it with the document. Before doing that, I think I should
check if the PIO tags overlap. If they overlap, I should probably use the `Spans` or `SpanGroup` from spaCy.

In [130]:
from pathlib import Path
from glob import glob

TEXT_DIR = Path.cwd().parent / "assets" / "raw_data" / "documents"
ANNOTATIONS_DIR = Path.cwd().parent / "assets" / "raw_data" / "annotations" / "aggregated" / "starting_spans"
PARTICIPANTS_DIR = ANNOTATIONS_DIR / "participants"
INTERVENTIONS_DIR = ANNOTATIONS_DIR / "interventions"
OUTCOMES_DIR = ANNOTATIONS_DIR / "outcomes"

In [131]:
from typing import List
import numpy as np

# Just to be sure, I want to check if some labels overlap
# result = (A XOR B) | (B XOR C)
def xor_threes(a: List[int], b: List[int], c: List[int]):
    """Perform XOR on three vectors

    result = XOR(A, B, C) = (A XOR B) | (B XOR C)
    """
    x = np.logical_xor(a, b)
    y = np.logical_xor(b, c)
    return np.logical_or(x, y)

def get_annotations(filename: Path) -> List[int]:
    data = filename.read_text().split(",")
    # Typecast the labels right away
    data = [int(d) for d in data]
    return data

# Assuming all files are the same, I'll get all the filenames from the participants directory
train_files = [f.name for f in list((PARTICIPANTS_DIR / "train").glob("*_AGGREGATED.ann"))]
overlaps = 0
for file in train_files:
    p = get_annotations(PARTICIPANTS_DIR / "train" / file)
    i = get_annotations(INTERVENTIONS_DIR / "train" / file)
    o = get_annotations(OUTCOMES_DIR / "train" / file)

    res = xor_threes(p, i, o)
    if not all(res):
        print(f"{file}: overlapping labels")
        overlaps += 1
print(f"Total overlaps: {overlaps}/{len(train_files)}")



2015149_AGGREGATED.ann: overlapping labels
9430799_AGGREGATED.ann: overlapping labels
9735531_AGGREGATED.ann: overlapping labels
18410301_AGGREGATED.ann: overlapping labels
3704665_AGGREGATED.ann: overlapping labels
16806442_AGGREGATED.ann: overlapping labels
19264972_AGGREGATED.ann: overlapping labels
19515873_AGGREGATED.ann: overlapping labels
23104718_AGGREGATED.ann: overlapping labels
4154124_AGGREGATED.ann: overlapping labels
10928228_AGGREGATED.ann: overlapping labels
17302075_AGGREGATED.ann: overlapping labels
11167879_AGGREGATED.ann: overlapping labels
7708953_AGGREGATED.ann: overlapping labels
22721596_AGGREGATED.ann: overlapping labels
23811316_AGGREGATED.ann: overlapping labels
21148662_AGGREGATED.ann: overlapping labels
15764958_AGGREGATED.ann: overlapping labels
5322596_AGGREGATED.ann: overlapping labels
9648960_AGGREGATED.ann: overlapping labels
19512937_AGGREGATED.ann: overlapping labels
23873901_AGGREGATED.ann: overlapping labels
9230648_AGGREGATED.ann: overlapping labe

So there are overlapping entities...that's definitely expected. Our next step is to convert these entities into the spaCy serialized format. 

In [132]:
import re
from typing import Tuple

import numpy as np
import spacy
from spacy.tokens import Doc, Span
from spacy.util import get_words_and_spaces

nlp = spacy.blank("en")


def convert_to_doc(nlp, file_id: str) -> Tuple[Doc, str, List[str]]:
    """Convert raw text and tokens into a Doc object."""
    # Get the text and the tokens
    #text = re.sub('"', "", (TEXT_DIR / f"{file_id}.text").read_text())
    #text = re.sub("\n", "", text)
    text = (TEXT_DIR / f"{file_id}.tokens").read_text()
    tokens = text.split(" ")

    # Create a Doc object
    words, spaces = get_words_and_spaces(words=tokens, text=text)
    doc = Doc(nlp.vocab, words=words, spaces=spaces)
    return doc, text, tokens


def _get_contiguous_tokens(labels: List[int]) -> List:
    """Get contiguous tokens that can be used for creating Spans
    
    Remember that when you make the actual spans, you need to add +1 to 
    the end token. This is because you want to pas the index of the first
    token 'after' the span.

    [0, 1, 1, 0, 1, 1] -> [(1, 2), (4, 5)]
    """
    indices = np.asarray([i for i, x in enumerate(labels) if x])
    if np.all((indices == 0)):
        return []
    else:
        contig = np.split(indices, np.where(np.diff(indices) != 1)[0] + 1)
        span_indices = [(c[0],) if len(c) == 1 else (c[0], c[-1]) for c in contig]
        return span_indices


def attach_spans_to_doc(doc: Doc, file_id: str, span_key: str = "sc") -> Doc:
    """Attach spans to the spaCy Doc"""
    # Get annotations
    directories = [PARTICIPANTS_DIR, INTERVENTIONS_DIR, OUTCOMES_DIR]
    labels = ["PARTICIPANTS", "INTERVENTIONS", "OUTCOMES"]
    annotations = {l: get_annotations(list(d.glob(f"**/{file_id}_AGGREGATED.ann"))[0]) for d, l in zip(directories, labels)}
    
    # Sanity-check if all tokens are of the same length
    assert all([len(doc) == len(a) for a in annotations.values()]), f"Misaligned tokens in {file_id}"

    spans = []
    for annot, labels in annotations.items():
        # Get the token indices (where the value is 1)
        indices = _get_contiguous_tokens(labels)
        if indices:
            for idx in indices:
                start = idx[0]
                end = idx[0] if len(idx) == 1 else idx[-1]
                spans.append(Span(doc, start, end + 1, label=annot))

    doc.spans[span_key] = spans
    return doc


## Prepare training and test datasets

First, we need to get the IDs for each training and test dataset

In [133]:
def ids_only(s: str) -> str:
    return re.sub("\_AGGREGATED", "", s)

# Assumes that PARTICIPANTS_DIR is the same everywhere
train_file_ids = [ids_only(file.stem) for file in (PARTICIPANTS_DIR / "train").glob("*_AGGREGATED.ann")]
test_file_ids = [ids_only(file.stem) for file in (PARTICIPANTS_DIR / "test" / "gold").glob("*_AGGREGATED.ann")]


def to_spacy(file_id, nlp) -> Doc:
    doc, _, _ = convert_to_doc(nlp, file_id)
    doc = attach_spans_to_doc(doc, file_id)
    return doc

In [134]:
from spacy.tokens import DocBin

train_doc_bin = DocBin()
for idx, train_id in enumerate(train_file_ids):
    try:
        doc = to_spacy(train_id, nlp)
    except ValueError:
        print(f"Error in {train_id}")
    else:
        train_doc_bin.add(doc)

test_doc_bin = DocBin()
for idx, test_id in enumerate(test_file_ids):
    try:
        doc = to_spacy(test_id, nlp)
    except ValueError:
        print(f"Error in {test_id}")
    else:
        test_doc_bin.add(doc)


Processed 0 train files
Processed 500 train files
Processed 1000 train files
Processed 1500 train files
Processed 2000 train files
Processed 2500 train files
Processed 3000 train files
Processed 3500 train files
Processed 4000 train files
Processed 4500 train files
Processed 0 test files


In [141]:
train_doc_bin.to_disk(Path.cwd().parent / "assets" / "corpus" / "train.spacy")
test_doc_bin.to_disk(Path.cwd().parent / "assets" / "corpus" / "test.spacy")

In [None]:
# Things to do:
# 1. maybe put this in a script?
# 2. tomorrow, plan-out how you'll do the spancat training?