# Preamble

In [None]:
from flair.models import SequenceTagger
from flair.data import Sentence, Token
import stanfordnlp
import json
import tqdm
import numpy as np
from itertools import repeat, takewhile


In [None]:
PATH_CLUEWEB_EXTRACTION = "/mnt/ceph/storage/data-in-progress/data-research/web-search/health-question-answering/causenet-data/causality-graphs/extraction/"
PATH_CLUEWEB_EXTRACTION += "clueweb12/clueweb12-extraction.tsv"
PATH_FLAIR_FOLDER = "/mnt/ceph/storage/data-in-progress/data-research/web-search/health-question-answering/causenet-data/flair-models/texts/"

PATH_STANFORD_RESOURCES = "/mnt/ceph/storage/data-in-progress/data-research/web-search/health-question-answering/causenet-data/external/stanfordnlp/"

PATH_OUTPUT_GRAPH = "/mnt/ceph/storage/data-in-progress/data-research/web-search/health-question-answering/causenet-data/causality-graphs/spotting/"
PATH_OUTPUT_GRAPH += "clueweb12/clueweb-graph.json"

In [None]:
def count_lines(file, verbose=False):
    bufgen = takewhile(lambda x: x, (file.read(1024 * 1024) for _ in repeat(None)))
    count = 0
    for buf in bufgen:
        count += buf.count(b"\n")
        if verbose:
            print(count, end="\r")
    if verbose:
        print()
    return count

In [None]:
stanfordnlp.download('en', PATH_STANFORD_RESOURCES)
stanford_nlp = stanfordnlp.Pipeline(processors='tokenize,pos',
                                    tokenize_pretokenized=True,
                                    models_dir=PATH_STANFORD_RESOURCES,
                                    treebank='en_ewt',
                                    use_gpu=True)

# Loading sentence data

In [None]:
sentence_set = []

with open(PATH_CLUEWEB_EXTRACTION, "rb") as file:
    num_lines = count_lines(file)

with open(PATH_CLUEWEB_EXTRACTION, encoding="utf-8") as file:
    for line in tqdm.tqdm(file, total=num_lines):
        parts = line.strip().split('\t')
        if parts[0] != 'clueweb12_sentence':
            continue
        assert len(parts) == 8
        
        for match in json.loads(parts[7]):
            sentence_data = {
                "causal_relation": match,
                "sources": [{
                    "type": "clueweb12_sentence",
                    "payload": {
                        "clueweb12_page_id": parts[1],
                        "clueweb12_page_reference": parts[2],
                        "clueweb12_page_timestamp": parts[3],
                        "sentence": {
                            "surface": json.loads(parts[4]),
                            "tokens": json.loads(parts[5]),
                            "dependencies": json.loads(parts[6])
                            },
                        "path_pattern": match['Pattern']
                        }
                    }
                ]
            }
            sentence_set.append(sentence_data)

# POS-Tagging

In [None]:
def get_offset_of_tags_in_sentence(sentence, tags):
    # go from left to right and determine tag offsets
    offsets = []
    total_offset = 0
    for tag in tags:
        label = tag[0]
        local_offset = sentence.find(label)
        offset = total_offset + local_offset
        offsets.append(offset)

        # prepare for next iteration
        sentence = sentence[local_offset + len(label):]
        total_offset = offset + len(label)
    return offsets


def get_pos_tags_of_sentence(sentence):
    tags = []
    for token in sentence.tokens:
        for word in token.words:
            tags.append((word.text, word.pos))
    return tags


def calculate_pos_tags_for_string(doc):
    tags = []
    for sentence in doc.sentences:
        sentence_pos = []
        for pos in get_pos_tags_of_sentence(sentence):
            sentence_pos.append(pos)
        tags.append(sentence_pos)
    return tags

In [None]:
def pos_tagging(sentences_to_predict):
    batch_of_tokens = [sample['sources'][0]['payload']['sentence']['tokens']
                       for sample in sentences_to_predict]
    strings = [' '.join(tokens) for tokens in batch_of_tokens]

    # batch processing is faster
    batch = '\n\n'.join(strings)
    doc = stanford_nlp(batch)
    tags = calculate_pos_tags_for_string(doc)

    assert len(tags) == len(sentences_to_predict)

    for i in range(len(sentences_to_predict)):
        sample = sentences_to_predict[i]
        sentence = sample['sources'][0]['payload']['sentence']['surface']
        offsets = get_offset_of_tags_in_sentence(sentence, tags[i])
        sample['sources'][0]['payload']['sentence']['POS'] = [
            (tags[i][x][0], tags[i][x][1], str(offsets[x]))
            for x in range(len(tags[i]))]

# Text-Spotter: Prediction

In [None]:
def prepare(batch):
    sentences = []

    for sample in batch:
        sentence = Sentence(use_tokenizer=False)

        tokens = sample['sources'][0]['payload']['sentence']['tokens']
        POS_tags = sample['sources'][0]['payload']['sentence']['POS']
        if len(tokens) > 200:
            # skipping sentences with too many tokens
            # due to GPU memory limitation
            continue

        for pos in POS_tags:
            token = Token(pos[0])
            token.add_tag('POS', pos[1])
            token.add_tag('idx', pos[2])
            sentence.add_token(token)

        sentences.append(sentence)
    return sentences

In [None]:
def predict(sentences, mini_batches):
    prediction = []
    classifier.predict(sentences, mini_batches)

    for i in range(len(sentences)):
        sentence = sentences[i]
        indices = [[token.idx-1 for token in chunk.tokens]
                   for chunk in sentence.get_spans('chunk_BIO')]

        extraction = []
        for index_list in indices:
            result = [sentence.tokens[index].text
                      for index in index_list]
            extraction.append(' '.join(result))

        prediction.append([extraction, indices])

    return prediction

In [None]:
def find_match(relation, indices):
    cause_index = int(relation['Cause'][1])
    effect_index = int(relation['Effect'][1])

    cause_match = None
    effect_match = None

    for index_range in indices:
        if cause_index in index_range:
            cause_match = indices.index(index_range)
        if effect_index in index_range:
            effect_match = indices.index(index_range)

        if (cause_match is not None
                and effect_match is not None
                and cause_match != effect_match):
            return [cause_match, effect_match]
    return []


def get_relations(batch, prediction):
    relations = []
    skipped_elements = 0

    for i in range(len(batch)):
        sample = batch[i]
        tokens = sample['sources'][0]['payload']['sentence']['tokens']
        POS_tags = sample['sources'][0]['payload']['sentence']['POS']
        if len(tokens) > 200:
            # skipping sentences with too many tokens
            # due to GPU memory limitation
            # see method prepare(batch)
            skipped_elements += 1
            continue

        path_pattern_extraction = sample['causal_relation']
        spotting_extraction, indices = prediction[i - skipped_elements]

        match = find_match(path_pattern_extraction, indices)
        if len(match) < 2:
            # In cases the tagger failed,
            # we disregarded the causal concepts
            continue
        cause_match, effect_match = match

        cause = spotting_extraction[cause_match]
        effect = spotting_extraction[effect_match]

        # concept POS (save for later post-processing)
        cause_pos_raw = [POS_tags[j] for j in indices[cause_match]]
        offset = get_offset_of_tags_in_sentence(cause, cause_pos_raw)
        cause_pos = [(cause_pos_raw[x][0],
                      cause_pos_raw[x][1],
                      str(offset[x]))
                     for x in range(len(cause_pos_raw))]

        effect_pos_raw = [POS_tags[j] for j in indices[effect_match]]
        offset = get_offset_of_tags_in_sentence(effect, effect_pos_raw)
        effect_pos = [(effect_pos_raw[x][0],
                       effect_pos_raw[x][1],
                       str(offset[x]))
                      for x in range(len(effect_pos_raw))]

        causal_relation = {'causal_relation': {
            'cause': {'concept': cause, 'POS': cause_pos, 'idcs': indices[cause_match]},
            'effect': {'concept': effect, 'POS': effect_pos, 'idcs': indices[effect_match]},
        }, 'sources': sample['sources']}
        relations.append(causal_relation)
    return relations

In [None]:
classifier = SequenceTagger.load(PATH_FLAIR_FOLDER + 'final-model.pt')

In [None]:
text_graph = []

batch_size = 512 if len(sentence_set) > 512 else 32
batches = np.array_split(sentence_set, len(sentence_set)/batch_size)

for i in tqdm.tqdm(range(len(batches))):
    batch = batches[i]
    pos_tagging(batch)
    prepared_sentences = prepare(batch)
    prediction = predict(prepared_sentences, mini_batches=batch_size)
    batch_relations = get_relations(batch, prediction)

    text_graph.extend(batch_relations)

# Postprocessing

In [None]:
def get_token_idcs(tokens, sentence):
    idcs = []
    idx = 0
    for token in tokens:
        token_idx = sentence[idx:].find(token)
        idcs.append((token_idx + idx, token_idx + len(token) + idx))
        idx = idcs[-1][1]
    return idcs

def post_process(pos_tags, tokens, indices):
    left = 0
    right = len(pos_tags)

    punctuation = ['.', ',', ';', '(', ')', '``', "''"]
    cutoff = ['CC', 'DT', 'PRP', 'PRP$'] + punctuation

    for tag in pos_tags:
        if tag[1] in cutoff:
            left += 1
        else:
            break

    for tag in reversed(pos_tags):
        if tag[1] in cutoff:
            right -= 1
        else:
            break

    indices = indices[left:right]
    value = " ".join(tokens[idx] for idx in indices)

    return value, indices

In [None]:
for relation in tqdm.tqdm(text_graph):
    tokens = relation['sources'][0]['payload']['sentence']['tokens']
    sentence = relation['sources'][0]['payload']['sentence']['surface']
    token_idcs = get_token_idcs(tokens, sentence)
    cause_token_idcs = relation['causal_relation']['cause']['idcs']
    effect_token_idcs = relation['causal_relation']['effect']['idcs']

    cause_concept = relation['causal_relation']['cause']['concept']
    cause_pos = relation['causal_relation']['cause']['POS']
    cause, cause_token_idcs = post_process(cause_pos, tokens, cause_token_idcs)

    effect_concept = relation['causal_relation']['effect']['concept']
    effect_pos = relation['causal_relation']['effect']['POS']
    effect, effect_token_idcs = post_process(effect_pos, tokens, effect_token_idcs)

    cause_idcs = [token_idcs[cause_token_idcs[0]][0], token_idcs[cause_token_idcs[-1]][1]]
    effect_idcs = [token_idcs[effect_token_idcs[0]][0], token_idcs[effect_token_idcs[-1]][1]]

    causal_relation = {
       'cause': {'concept': cause},
        'effect': {'concept': effect}
    }
    relation['causal_relation'] = causal_relation
    relation['sources'][0]['payload']['idcs'] = {
        'surface': {'cause': cause_idcs, 'effect': effect_idcs},
        'tokens': {'cause': cause_token_idcs, 'effect': effect_token_idcs}
    }

    # further cleanup
    del relation['sources'][0]['payload']['sentence']['POS']

# Save Text-graph

In [None]:
jsonarray = json.dumps(text_graph)
with open(PATH_OUTPUT_GRAPH, "w+") as file:
    file.write(jsonarray)