In [None]:
import pandas as pd
pd.set_option("display.max_rows", None)

In [None]:
# Do one-time large imports

# For sentence tokenization
from nltk import tokenize

# For coreference resolution
from allennlp.predictors.predictor import Predictor
import allennlp_models.coref
coref_predictor = Predictor.from_path(
    "https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2021.03.10.tar.gz"
)

# For part-of-speech tagging
import nltk
nltk.download('averaged_perceptron_tagger')

# For dependency parsing
from allennlp.predictors.predictor import Predictor
import allennlp_models.structured_prediction
dependency_predictor = Predictor.from_path(
    "https://storage.googleapis.com/allennlp-public-models/biaffine-dependency-parser-ptb-2020.04.06.tar.gz"
)

# For constituency parsing
constituency_predictor = Predictor.from_path(
    "https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz"
)

In [None]:
import os.path

### Basic input flags for the notebook / pipeline

# Identify the term we are splitting on (the "anchor")
# This also serves as the name for this entire cluster of results / practical pipeline run name
search_word = 'BERT'

# Other common names for this term that we should also consider as anchors
# e.g. [search_word, 'GPT', 'GPT-2', 'GPT-3']
# e.g. [search_word, 'Transformers', 'Transformer', 'transfer learning', 'transfer']
anchor_synonyms = [search_word]

# Flags
flag_rerun_coreference = False or (not os.path.isfile(f'outputs/{search_word}/partial-coreference.hdf'))

In [None]:
import csv
import pandas as pd
from pathlib import Path

# Read in the dataframe containing entire paper abstracts (NOT pre-split into sentences)
df = pd.read_csv(f'data/nlp-align_{search_word}.csv')

# Create the outputs directory for this search word
Path(f"outputs/{search_word}").mkdir(parents=True, exist_ok=True)

# Split apart the 'Title' and 'Abstract' columns, add period to 'Title' if not present
def separate_title_abstract(group):
    row = group.loc[0]
    abs_text = tokenize.sent_tokenize(row['Abstract'])
    return pd.DataFrame({
        'URL': [row['URL']] * 2,
        'ID': [row['ID']] * 2,
        'Type': ['Title', 'Abstract'],
        'Text': [row['Title']+'.' if not row['Title'].endswith('.') else row['Title'], 
                 row['Abstract']]
    })

# Restructure the dataframe to be more usable...
df = df.groupby('ID', group_keys=False).apply(
    lambda group: separate_title_abstract(group)
).reset_index(drop=True)

df

In [None]:
# Split the full-abstract CSV into a CSV containing individual sentences instead
def sentence_tokenize(group):
    row = group.reset_index(drop=True).loc[0]
    sentences = tokenize.sent_tokenize(row['Text'])
    return pd.DataFrame({
        'URL': [row['URL']] * (len(sentences)),
        'ID': [row['ID']] * (len(sentences)),
        'Type': [row['Type']] * (len(sentences)),
        'Index': list(range(len(sentences))),
        'Text': sentences
    })

df_sentences = df.groupby(['ID', 'Type'], group_keys=False).apply(
    lambda group: sentence_tokenize(group)
).reset_index(drop=True)

df_sentences

In [None]:
# # Create a test dataframe so we can run models without taking impractically long
# # TODO: this is causing some type inconsistencies, fix those?

# temp_df = pd.DataFrame.from_dict(
#     {'URL': 'abc', 
#      'ID': '0', 
#      'Title': 'Paper Title',
#      'Abstract': 'The BERT language model (LM) (Devlin et al., 2019) is surprisingly good at answering cloze-style questions about relational facts. Petroni et al. (2019) take this as evidence that it memorizes factual knowledge during pre-training. We take issue with this interpretation and argue that the performance of BERT is partly due to reasoning about (the surface form of) entity names, e.g., guessing that a person with an Italian-sounding name speaks Italian.'
#     }
# )

# Splitting functions

Assume we have an input dataframe with some number of columns, at least one of which is titled `Text` and is the column containing each sentence of the abstract.

In [None]:
# split_0 is the text literally preceding the anchor
# split_1 is the text that the anchor consists of
# split_2 is the text literally following the anchor
# split_tokens is the list of tokens that split identifies
# split_anchor_span is a tuple (anchor_first_token_index, anchor_last_token_index) or null if there is no anchor
# split_anchor_indices is a tuple (anchor_start_char_index, anchor_end_char_index) or null if there is no anchor
splitting_headers = ['split_0','split_1','split_2', 
                     'split_tokens', 'split_anchor_span', 'split_anchor_indices', 
                     'within_anchor_index']
# Include ID, Type, Index in the split output to be able to join with df_sentences
join_headers = ['ID', 'Type', 'Index']
# The headers used for checking if rows should be eliminated as duplicate
duplicate_check_headers = splitting_headers[:3]+join_headers

In [None]:
import re

# Split on the search word, regardless of whitespace (if search word is A and we have word CAR, it slices it up)
# Splits on ALL instances of the search word
def split_term_literal(group, search_word, anchor_synonyms):
    row = group.iloc[0]
    output = []
    anchors = [re.search(f'({a})', row['Text'], flags=re.IGNORECASE) for a in anchor_synonyms]
    anchors = [a.group(1) for a in anchors if (a is not None)]
    for anchor in anchors:
        splits = row['Text'].split(anchor)
        for i in range(len(splits) - 1):
            output_i = [anchor.join(splits[:i+1]), anchor.strip(), anchor.join(splits[i+1:])]
            output_i = [i.strip() for i in output_i]
            # write tokens list and revise split to account for tokenization
            pre_split = nltk.word_tokenize(output_i[0])
            output_i[0] = ' '.join(pre_split)
            mid_split = nltk.word_tokenize(anchor)
            output_i[1] = ' '.join(mid_split)
            post_split = nltk.word_tokenize(output_i[2])
            output_i[2] = ' '.join(post_split)
            # split_tokens
            output_i.append(pre_split + mid_split + post_split)
            # split_anchor_span
            output_i.append((len(pre_split), len(pre_split)+len(mid_split)))
            # split_anchor_indices
            output_i.append((len(output_i[0]), len(output_i[0])+len(output_i[1])))
            # within_anchor_index
            # for split_term_literal, the split term is ALWAYS the entire anchor
            output_i.append(0)
            output_i += list(row[join_headers])
            output.append(output_i)
    if output == []:
        output = [[row['Text'].strip(),'','',
                   row['Text'].strip().split(' '),None,None,None]
                  +list(row[join_headers])]
    # Transpose the output format so we can use it in zip for dataframe generation
    output_t = [list(t) for t in list(zip(*output))]
    return pd.DataFrame(
        dict(zip(splitting_headers+join_headers,output_t))
    ).drop_duplicates(duplicate_check_headers)

literal_output = df_sentences.groupby(df_sentences.index, group_keys=False).apply(
    lambda group: split_term_literal(group, search_word, anchor_synonyms)).reset_index(drop=True)

literal_output

In [None]:
import re

# Split on any token (whitespace delineated, or token-delineated) that contains an instance of the search word
# Splits on ALL instances of the search word
# a little bit misleadingly named, sorry. the intention is for it to be a simple split
def split_term_whitespace(group, search_word, anchor_synonyms):
    row = group.iloc[0]
    output = []
    anchors = [re.search(rf'(^|\W)(\w*{a}\w*)($|\W)', row['Text'], flags=re.IGNORECASE) for a in anchor_synonyms]
    anchors = [a.group(2) for a in anchors if (a is not None)]
    for anchor in anchors:
        splits = row['Text'].split(anchor)
        for i in range(len(splits) - 1):
            output_i = [anchor.join(splits[:i+1]), anchor.strip(), anchor.join(splits[i+1:])]
            output_i = [i.strip() for i in output_i]
            # write tokens list and revise split to account for tokenization
            pre_split = nltk.word_tokenize(output_i[0])
            output_i[0] = ' '.join(pre_split)
            mid_split = nltk.word_tokenize(anchor)
            output_i[1] = ' '.join(mid_split)
            post_split = nltk.word_tokenize(output_i[2])
            output_i[2] = ' '.join(post_split)
            # split_tokens
            output_i.append(pre_split + mid_split + post_split)
            # split_anchor_span
            output_i.append((len(pre_split), len(pre_split)+len(mid_split)))
            # split_anchor_indices
            output_i.append((len(output_i[0]), len(output_i[0])+len(output_i[1])))
            # within_anchor_index
            anchorsearch = [re.search(rf'{a}', anchor, flags=re.IGNORECASE) for a in anchor_synonyms]
            output_i.append([a.start() for a in anchorsearch if (a is not None)][0])
            output_i += list(row[join_headers])
            output.append(output_i)
    if output == []:
        output = [[row['Text'].strip(),'','',
                   row['Text'].strip().split(' '),None,None,None]
                  +list(row[join_headers])]
    # Transpose the output format so we can use it in zip for dataframe generation
    output_t = [list(t) for t in list(zip(*output))]
    return pd.DataFrame(
        dict(zip(splitting_headers+join_headers,output_t))
    ).drop_duplicates(duplicate_check_headers)

whitespace_output = df_sentences.groupby(df_sentences.index, group_keys=False).apply(
    lambda group: split_term_whitespace(group, search_word, anchor_synonyms)).reset_index(drop=True)

whitespace_output

In [None]:
# Run coreference resolution over the entire abstract, not individual sentences
if flag_rerun_coreference:
    output = df.apply(
        lambda row: coref_predictor.predict(row['Text']), axis=1, result_type='expand')
    df_merged = df.join(output)

In [None]:
import re

# transform the output of coreference resolution into something that is more easily manipulated
# split it across multiple sentences so each indiv sentence row can still work
def reinterpret_coref_clusters(row, search_word, anchor_synonyms, sentences):
    # Create dicts to map full-document to indiv sentence data
    src = sentences.loc[sentences['ID'] == row['ID']].loc[sentences['Type'] == row['Type']]['Text']
    curr_sentence = 0
    consumed = 0
    sent_mapping = {}
    sent_content = {}
    last_sent_end = 0
    doct_mapping = {}
    doct_split = []
    for i in range(len(row['document'])):
        if row['document'][i].strip() != '':
            if row['document'][i] not in src.iloc[curr_sentence][consumed:]:
                doct_split.append(row['document'][last_sent_end:i])
                last_sent_end = i
                curr_sentence += 1
                consumed = 0
            offset = src.iloc[curr_sentence][consumed:].index(row['document'][i])
            sent_mapping[i] = curr_sentence
            if curr_sentence not in sent_content:
                sent_content[curr_sentence] = []
            sent_content[curr_sentence].append(i)
            doct_mapping[i] = i - last_sent_end
            consumed += offset + len(row['document'][i])
        else:
            sent_mapping[i] = curr_sentence
            sent_content[curr_sentence].append(i)
            doct_mapping[i] = i - last_sent_end
        doct_split.append(row['document'][last_sent_end:])
    # Select the clusters that contain search words
    selcluster_idxs = set()
    for i in range(len(row['clusters'])):
        currcluster_ct = 0
        for c in row['clusters'][i]:
            for anchor in anchor_synonyms:
                # TODO this does overcounting if an anchor synonym is contained within another
                currcluster_ct += len(
                    re.findall(f'{anchor}', ''.join(row['document'][c[0]:c[1]+1]), flags=re.IGNORECASE)
                )
        if currcluster_ct > 0:
            selcluster_idxs.add(i)
    # Build the output row
    output = [sent_mapping, sent_content, doct_mapping, list(selcluster_idxs)]
    return dict(zip(['sent_mapping', 'sent_content', 'doct_mapping', 'selcluster_idxs'],output))

if flag_rerun_coreference:
    output = df_merged.apply(
        lambda row: reinterpret_coref_clusters(row, search_word, anchor_synonyms, df_sentences), 
        axis=1, result_type='expand')
    df_merged = df_merged.join(output)
    
    df_merged.to_pickle(f'outputs/{search_word}/partial-coreference.pkl')

In [None]:
df_merged = pd.read_pickle(f'outputs/{search_word}/partial-coreference.pkl')
df_merged

In [None]:
# Split based on co-references to any phrase containing search term, using allennlp coreference resolution
# This does NOT preserve the original sentence spacing
# REQUIRES THAT WE ALREADY RAN THE COREFERENCE PREDICTOR - this func does NOT do all of the work!
# Splits on ALL instances of references to the search word,
# including sub-references (e.g. "the accuracy of RoBERTa")
def split_term_coreference(group, search_word, anchor_synonyms, lookup, fallback):
    row = group.iloc[0]
    # there's probably a cleaner way to do this...
    lookup_row = lookup.loc[lookup['ID']==row['ID']].loc[lookup['Type']==row['Type']].to_dict(orient='records')[0]
    if len(lookup_row['selcluster_idxs']) == 0:
        # if we didn't identify any clusters that match the search term, use our fallback method
        return fallback(group, search_word, anchor_synonyms)
    output = []
    for cluster_id in lookup_row['selcluster_idxs']:
        split_clusters = lookup_row['clusters'][cluster_id]
        for i in range(len(split_clusters)):
            c = split_clusters[i]
            if lookup_row['sent_mapping'][c[0]] == row['Index']:
                sentence_start = lookup_row['sent_content'][row['Index']][0]
                sentence_end = lookup_row['sent_content'][row['Index']][-1]
                pre_split = lookup_row['document'][sentence_start:c[0]]
                anchor = lookup_row['document'][c[0]:c[1]+1]
                post_split = lookup_row['document'][c[1]+1:sentence_end+1]
                output_i=[' '.join(pre_split),
                        ' '.join(anchor),
                        ' '.join(post_split)]
                # split_tokens
                output_i.append(lookup_row['document'][sentence_start:sentence_end+1])
                # split_anchor_span
                output_i.append((len(pre_split), len(pre_split)+len(anchor)))
                # split_anchor_indices
                output_i.append((len(output_i[0]), len(output_i[0])+len(output_i[1])))
                # within_anchor_index
                anchorsearch = [re.search(rf'{a}', ' '.join(anchor), flags=re.IGNORECASE) for a in anchor_synonyms]
                anchorsearch = [a.start() for a in anchorsearch if (a is not None)]
                output_i.append(anchorsearch[0] if len(anchorsearch) > 0 else -1)
                output_i += list(row[join_headers])
                output.append(output_i)
    if output == []:
        # if there wasn't any reference in the sentence found, use our fallback method
        return fallback(group, search_word, anchor_synonyms)
    # Transpose the output format so we can use it in zip for dataframe generation
    output_t = [list(t) for t in list(zip(*output))]
    return pd.DataFrame(dict(zip(splitting_headers+join_headers,output_t)))

coreference_output = df_sentences.groupby(df_sentences.index, group_keys=False).apply(
    lambda group: split_term_coreference(group, search_word, anchor_synonyms, df_merged, split_term_whitespace)
).reset_index(drop=True)

coreference_output

# Grouping functions

Assume we have an input dataframe with column headers `['split_0','split_1','split_2', 'split_tokens', 'split_anchor_span']`

`'split_1'` is the column that contains our search term / anchor point

`'split_0'` and `'split_2'` are the columns that contain text before and after the search terms respectively

In [None]:
# group is the text uniquely identifying a group
grouping_headers = [
    'group', 'group2', 'group3', 'group4', 'group5', 
    'group6', 'group7', 'group8', 'group9', 'group10', 
    'group11', 'group12', 'group13', 'group14', 'group15', 
    'group16', 'group17']

In [None]:
sample_input = df_sentences.merge(
    coreference_output,
    how='outer',
    left_on=join_headers,
    right_on=join_headers)

In [None]:
# Group on the literal first word that comes after the anchor point
# skips things that are punctuation
def group_first_word(row):
    output = None
    if row['split_anchor_span'] is not None:
        index = row['split_anchor_span'][1]
        while (index < len(row['split_tokens'])) and (output is None):
            next_token = row['split_tokens'][index]
            next_token_r = re.search(rf'^[.,():-]*(\w+(.+\w+)*)[.,():-]*$', next_token, flags=re.IGNORECASE)
            if next_token_r is not None:
                output = [next_token_r.group(1)]
            else:
                index += 1
    if output is None:
        output = ['']
    return dict(zip(grouping_headers, output))

output = sample_input.apply(
    lambda row: group_first_word(row), 
    axis=1, result_type='expand').sort_values(by=['group'])

output = sample_input.join(output)
output

In [None]:
# Group on the first verb that comes after the anchor point, using NLTK part-of-speech tagging
def group_first_verb(row):
    tokens = [nltk.word_tokenize(row['split_0']), 
              nltk.word_tokenize(row['split_1']),
              nltk.word_tokenize(row['split_2'])]
    tokens_pos = nltk.pos_tag([item for sublist in tokens for item in sublist])
    verb = ''
    for i in range(len(tokens[0])+len(tokens[1]), len(tokens_pos)):
        if tokens_pos[i][1].startswith('V'):
            verb = tokens_pos[i][0]
            break
    output = [verb]
    return dict(zip(grouping_headers, output))

output = sample_input.apply(
    lambda row: group_first_verb(row), 
    axis=1, result_type='expand').sort_values(by=['group'])

output = sample_input.join(output)
output

In [None]:
# Guide on dependency parses:
# https://web.stanford.edu/~jurafsky/slp3/15.pdf

# POS tag info:
# https://universaldependencies.org/u/pos/
# https://cs.nyu.edu/grishman/jet/guide/PennPOS.html

# Do dependency parsing once for the entire sample_input to save processing time
# for groupings that require dependency parsing later
def parse_dependency(row):
    p = dependency_predictor.predict(
        sentence=' '.join(row['split_tokens']).strip()
    )
    return dict(zip(['dependency_parse'], [p]))

sample_input_dep = sample_input.apply(
    lambda row: parse_dependency(row), 
    axis=1, result_type='expand')

sample_input_dep.iloc[0][0]

In [None]:
import copy

# Guide on constituency parses:
# https://web.stanford.edu/~jurafsky/slp3/13.pdf

# POS tag info:
# https://universaldependencies.org/u/pos/
# https://cs.nyu.edu/grishman/jet/guide/PennPOS.html
# http://www.surdeanu.info/mihai/teaching/ista555-fall13/readings/PennTreebankConstituents.html

# Side effect: modifies the given node to add a 'spans' field
def add_constituency_span_recursive(node):
    curr_span_start = node['spans'][0]['start']
    curr_word = node['word']
    for c in (node['children'] if ('children' in node) else []):
        c_start = curr_span_start + curr_word.index(c['word'])
        c_end = c_start + len(c['word'])
        c['spans'] = [{
            'start': c_start,
            'end': c_end
        }]
        add_constituency_span_recursive(c)
        curr_span_start += len(c['word'])+1
        curr_word = curr_word[len(c['word'])+1:]
    return

# Edit the constituency parse dict to contain a 'spans' field in the hierplane_tree
def add_constituency_span(parse):
    parse = copy.deepcopy(parse)
    parse['hierplane_tree']['root']['spans'] = [{
        'start': 0, 
        'end': len(parse['hierplane_tree']['root']['word'])
    }]
    add_constituency_span_recursive(parse['hierplane_tree']['root'])
    return parse

# Do constituency parsing once for the entire sample_input to save processing time
# for groupings that require constituency parsing later
# Also add a span field to the parse dict so that it's easier to process later too
def parse_constituency(row):
    p = constituency_predictor.predict(
        sentence=' '.join(row['split_tokens']).strip()
    )
    p = add_constituency_span(p)
    return dict(zip(['constituency_parse'], [p]))

sample_input_con = sample_input.apply(
    lambda row: parse_constituency(row), 
    axis=1, result_type='expand')

sample_input_con.iloc[0][0]

In [None]:
# Group on sentence-level details, using allennlp constituency parsing
def group_type_sentence_con(row):
    p = row['constituency_parse']
    sentence = p['hierplane_tree']['root']
    tl_fillers = [] # [',', '.', ':', 'HYPH']
    tl_breakdown = [child['nodeType'] for child in sentence['children'] 
                    if child['nodeType'] not in tl_fillers]
    output = [sentence['nodeType'], tl_breakdown]
    return dict(zip(grouping_headers, output))

output = sample_input.join(sample_input_con).apply(
    lambda row: group_type_sentence_con(row), 
    axis=1, result_type='expand').sort_values(by=['group'])

output = sample_input.join(output)
output

In [None]:
# Returns n if the span is within split_n
def split_span_descriptor(span, split):
    # span = (whole_sentence_begin_index, whole_sentence_end_index)
    # split = [str(split_0), str(split_1), str(split_2), maybe more ...]
    length_sum = 0
    spillover = span[0]
    for i in range(len(split)):
        length_sum += len(split[i])
        if span[0] < length_sum:
            return i, (spillover, spillover + span[1] - span[0])
        spillover -= len(split[i]) + (1 if len(split[i]) > 0 else 0)
    return len(split), (spillover, spillover + span[1] - span[0])

print(split_span_descriptor((0, 3), ['', 'abc', 'def ghi']), 
      split_span_descriptor((4, 7), ['', 'abc', 'def ghi']), 
      split_span_descriptor((8, 11), ['', 'abc', 'def ghi']))
print(split_span_descriptor((0, 3), ['abc', '', 'def ghi']), 
      split_span_descriptor((4, 7), ['abc', '', 'def ghi']), 
      split_span_descriptor((8, 11), ['abc', '', 'def ghi']))
print(split_span_descriptor((0, 3), ['abc', 'def', 'ghi']), 
      split_span_descriptor((4, 7), ['abc', 'def', 'ghi']), 
      split_span_descriptor((8, 11), ['abc', 'def', 'ghi']))

In [None]:
# Group on the main verb in the sentence, using allennlp dependency parsing
def group_main_verb_dep(row):
    p = row['dependency_parse']
    main_verb = p['hierplane_tree']['root']
    main_verb_span = (main_verb['spans'][0]['start'], main_verb['spans'][0]['end'])
    split_loc, cspan = split_span_descriptor(main_verb_span, [row['split_0'], row['split_1'], row['split_2']])
    output = [main_verb['word'], split_loc, 
              main_verb_span[0], main_verb_span[1], cspan[0], cspan[1]]
    return dict(zip(grouping_headers, output))

output = sample_input.join(sample_input_dep).apply(
    lambda row: group_main_verb_dep(row), 
    axis=1, result_type='expand').sort_values(by=['group'])

output = sample_input.join(output)
output

In [None]:
# Overlaps two strings, filling in any whitespace characters in s1 with the non-whitespace char in s2
# If len(s2) > len(s1), the extra is added to the end
def string_union(s1, s2):
    output = list(s1)
    for i in range(min(len(s1), len(s2))):
        if s1[i].isspace():
            output[i] = s2[i]
    # add on the extra if s2 is longer
    if len(s2) > len(s1):
        output += list(s2[len(s1):])
    return ''.join(output)

# Overlaps two spans, producing their union (if there is a gap in between, it fills the gaps)
def span_union(s1, s2):
    return (min(s1[0], s2[0]), max(s1[1], s2[1]))

In [None]:
# Helper function
# Returns the entire phrase that comprises a dependency tree node and its children
def unroll_dependency_node(node, allowed_links=None, allowed_types=None):
    node_str = ' '*node['spans'][0]['start'] + node['word']
    node_span = (node['spans'][0]['start'], node['spans'][0]['end'])
    if 'children' in node:
        for i in range(len(node['children'])):
            if (allowed_links is None) or (node['children'][i]['link'] in allowed_links):
                if (allowed_types is None) or (node['children'][i]['attributes'][0] in allowed_types):
                    child_str, child_span = unroll_dependency_node(
                        node['children'][i],
                        allowed_links=allowed_links,
                        allowed_types=allowed_types)
                    node_str = string_union(node_str, child_str)
                    node_span = span_union(node_span, child_span)
    return node_str, node_span

# print(unroll_dependency_node(sample_input_dep.iloc[0][0]['hierplane_tree']['root']))

In [None]:
def get_node_from_index(p, i):
    pathway = []
    while p['predicted_heads'][i] != 0:
        # count what index of its parent this node is
        parent_idx = 0
        for test_i in range(i):
            if p['predicted_heads'][i] == p['predicted_heads'][test_i]:
                parent_idx += 1
        pathway.insert(0, parent_idx)
        i = p['predicted_heads'][i] - 1
    curr_node = p['hierplane_tree']['root']
    for child_i in pathway:
        curr_node = curr_node['children'][child_i]
    return curr_node

In [None]:
def get_node_children(node, types):
    child_nodes = node['children'] if ('children' in node) else []
    matches = []
    for i in range(len(child_nodes)):
        if child_nodes[i]['nodeType'] in types:
            matches.append(child_nodes[i])
    return matches

In [None]:
# Group on the main verb in the sentence, using allennlp dependency parsing
# Expanded to include subject, object, expand upon verb form
def group_main_verb_expanded_dep(row):
    p = row['dependency_parse']
    root_node = p['hierplane_tree']['root']
    root_phrase, root_span = unroll_dependency_node(
        root_node,
        allowed_links=['aux', 'auxpass', 'cop'],
        allowed_types=['VERB', 'AUX', 'PART']
    )
    root_phrase = ' '.join(root_phrase.strip().split())
    split_loc, cspan = split_span_descriptor(root_span, [row['split_0'], row['split_1'], row['split_2']])
    subjects = get_node_children(root_node, ['subj', 'nsubj', 'nsubjpass', 'csubj', 'csubjpass'])
    subjects = [' '.join(unroll_dependency_node(n)[0].strip().split()) for n in subjects]
    objects = get_node_children(root_node, ['obj', 'dobj', 'iobj', 'pobj'])
    objects = [' '.join(unroll_dependency_node(n)[0].strip().split()) for n in objects]
    # Pick out dependencies that come before/after the verb...
    # this targets dependencies that weren't clearly identified as subj/obj by the parser
    unsure_relation = get_node_children(root_node, ['dep'])
    unsure_relation = [(n['spans'][0]['start'],
                        ' '.join(unroll_dependency_node(n)[0].strip().split()))
                       for n in unsure_relation]
    # TODO improve this possibly
    # For now, assume a dependent phrase before the averb is more likely to be subject-like
    # and a dependent phrase after the averb is more likely to be object-like
    subjects += [n[1] for n in unsure_relation if n[0] < root_node['spans'][0]['start']]
    objects += [n[1] for n in unsure_relation if n[0] > root_node['spans'][0]['start']]
    output = [root_phrase, subjects, objects, 
              split_loc, root_span[0], root_span[1], cspan[0], cspan[1]]
    return dict(zip(grouping_headers, output))

output = sample_input.join(sample_input_dep).apply(
    lambda row: group_main_verb_expanded_dep(row), 
    axis=1, result_type='expand').sort_values(by=['group'])

output = sample_input.join(output)
output

In [None]:
# TODO implement some kind of clustering for level-1 sentence structure
# (order of NP, VP, PP, commas or periods etc in the sentence)
# can start by making a distance metric between different structure lists?

In [None]:
# Helper function
# Returns the entire phrase that comprises a constituency tree node and its children
# If there's link or type restrictions, then enforces those.
def unroll_constituency_node(node, allowed_links=None, allowed_types=None):
    node_str = []
    node_span = []
    if 'children' in node:
        for i in range(len(node['children'])):
            if (allowed_links is None) or (node['children'][i]['link'] in allowed_links):
                if (allowed_types is None) or (node['children'][i]['attributes'][0] in allowed_types):
                    child_str, child_span = unroll_constituency_node(
                        node['children'][i],
                        allowed_links=allowed_links,
                        allowed_types=allowed_types)
                    node_str += child_str
                    node_span += child_span
    else:
        node_str = [node['word']]
        node_span = [(node['spans'][0]['start'], node['spans'][0]['end'])]
    return node_str, node_span

# print(unroll_dependency_node(sample_input_dep.iloc[0][0]['hierplane_tree']['root']))
for num in range(100):
    print(sample_input_con.iloc[num][0]['hierplane_tree']['root']['word'])
    print('=====')
    print(unroll_constituency_node(
        sample_input_con.iloc[num][0]['hierplane_tree']['root'],
        allowed_links=['DT','JJ','JJR','JJS','NP','PRP','PRPS','NN','NNS','NNP','NNPS','POS']))
    print('=====')
    print(unroll_constituency_node(
        sample_input_con.iloc[num][0]['hierplane_tree']['root'],
        allowed_links=['S','DT','JJ','JJR','JJS','NP','PRP','PRPS','NN','NNS','NNP','NNPS']))
    print('=====')
    print(unroll_constituency_node(
        sample_input_con.iloc[num][0]['hierplane_tree']['root'],
        allowed_links=['VP','VB','VBD','VBG','VBN','VBP','VBZ','MD','CC']))
    print('=====')
    print(unroll_constituency_node(
        sample_input_con.iloc[num][0]['hierplane_tree']['root'],
        allowed_links=['S','VP','VB','VBD','VBG','VBN','VBP','VBZ','MD']))
    print('===============')

In [None]:
import functools

# Group on the main subject and verb in the sentence, using allennlp constituency parsing
# Also include a misc collection of subject-like words and verb-like words
def group_main_verb_con(row):
    p = row['constituency_parse']
    sentence = p['hierplane_tree']['root']
    tl_subj, tl_subj_span = unroll_constituency_node(
        sentence,
        allowed_links=['DT','JJ','JJR','JJS','NP','PRP','PRPS','NN','NNS','NNP','NNPS','POS'])
    if len(tl_subj_span) > 0:
        tl_subj_span_unified = functools.reduce(span_union, tl_subj_span)
        tl_subj_split_loc, tl_subj_cspan = split_span_descriptor(
            tl_subj_span_unified, [row['split_0'], row['split_1'], row['split_2']])
    else:
        tl_subj_span_unified = ('', '')
        tl_subj_split_loc = ''
        tl_subj_cspan = ('', '')
    all_subj, all_subj_span = unroll_constituency_node(
        sentence,
        allowed_links=['S','DT','JJ','JJR','JJS','NP','PRP','PRPS','NN','NNS','NNP','NNPS'])
    tl_verb, tl_verb_span = unroll_constituency_node(
        sentence,
        allowed_links=['VP','VB','VBD','VBG','VBN','VBP','VBZ','MD','CC'])
    if len(tl_verb_span) > 0:
        tl_verb_span_unified = functools.reduce(span_union, tl_verb_span)
        tl_verb_split_loc, tl_verb_cspan = split_span_descriptor(
            tl_verb_span_unified, [row['split_0'], row['split_1'], row['split_2']])
    else:
        tl_verb_span_unified = ('', '')
        tl_verb_split_loc = ''
        tl_verb_cspan = ('', '')
    all_verb, all_verb_span = unroll_constituency_node(
        sentence,
        allowed_links=['S','VP','VB','VBD','VBG','VBN','VBP','VBZ','MD'])
    output = [' '.join(tl_subj), tl_subj_split_loc, 
              tl_subj_span_unified[0], tl_subj_span_unified[1], 
              tl_subj_cspan[0], tl_subj_cspan[1],
              tl_subj_span,
              ' '.join(tl_verb), tl_verb_split_loc, 
              tl_verb_span_unified[0], tl_verb_span_unified[1], 
              tl_verb_cspan[0], tl_verb_cspan[1],
              tl_verb_span,
              all_subj, all_verb]
    return dict(zip(grouping_headers, output))

output = sample_input.join(sample_input_con).apply(
    lambda row: group_main_verb_con(row), 
    axis=1, result_type='expand').sort_values(by=['group'])

output = sample_input.join(output)
output

In [None]:
# Helper function for phrase POS
# Return the match bounds of the sequence of elements of given sizes starting at list1[i1] and list2[i2] 
# that match
# If no given size is returned, returns max matching sequence length
# (ratio of element matches must be 1:some or some:1 between l1 and l2)
# Returns [(l1 bounds), (l2 bounds)] or None if they do not match
def list_elements_match(list1, list2, i1, i2, size1=None, size2=None):
    matchlen = 0
    if size1 is not None and size2 is not None:
        # check for exact text match
        matchlen = len(''.join(list1[i1:i1+size1]))
        if ''.join(list1[i1:i1+size1]) != ''.join(list2[i2:i2+size2]):
            return None
    elif size1 is not None:
        # and size2 is none
        matchlen = len(''.join(list1[i1:i1+size1]))
        if ''.join(list1[i1:i1+size1]) != ''.join(list2[i2:])[:matchlen]:
            return None
    elif size2 is not None:
        # and size1 is none
        matchlen = len(''.join(list2[i2:i2+size2]))
        if ''.join(list2[i2:i2+size2]) != ''.join(list1[i1:])[:matchlen]:
            return None
    else:
        # both are none; just calculate the match length
        matchlen = 0
        while l1concat[matching] == l2concat[matching]:
            matchlen += 1
    matchphrase = ''.join(list1[i1:])[:matchlen]
    # get the exact bounds for list1
    bound1 = 1
    for i in range(len(list1)-i1+1):
        if ''.join(list1[i1:i1+i]) == matchphrase:
            bound1 = i
            break
    # get the exact bounds for list2
    bound2 = 1
    for i in range(len(list2)-i2+1):
        if ''.join(list2[i2:i2+i]) == matchphrase:
            bound2 = i
            break
    return [(i1, i1+bound1), (i2, i2+bound2)]

In [None]:
import math

# Group on the verb closest to the anchor point, using allennlp dependency parsing
# Include SVO information and expanded verbs
def group_anchor_verb_dep(row):
    if row['split_anchor_span'] is None:
        return dict(zip(grouping_headers, ['']))
    p = row['dependency_parse']
    # build out a more usable version of the dependency tree with information about tree level!
    # tree_array[n] = (parent, level) so the root node is (0, 0)
    tree_array = [(h, -1) for h in p['predicted_heads']]
    need_connection = [0]
    level = 0
    while len(need_connection) > 0:
        need_connection_update = []
        for i in range(len(tree_array)):
            if tree_array[i][0] in need_connection:
                tree_array[i] = (tree_array[i][0], level)
                need_connection_update.append(i+1)
        need_connection = need_connection_update
        level += 1
    # Figure out what indexes our anchor fits into
    # Assume that the anchor is contiguous text
    # TODO: is this always true?
    for i in range(len(p['words'])):
        if len(''.join(p['words'][:i])) >= len(''.join(row['split_tokens'][:row['split_anchor_span'][0]])):
            match = list_elements_match(
                p['words'], row['split_tokens'], i, row['split_anchor_span'][0], 
                size2=row['split_anchor_span'][1]-row['split_anchor_span'][0])
            if match is not None:
                break
    # Find the smallest containing dependency node
    matching_nodes = [(i, tree_array[i][0], tree_array[i][1]) for i in range(match[0][0], match[0][1])]
    matching_nodes = list(set(matching_nodes))
    # Remove punctuation from the set of nodes that consist the anchor
    # This avoids odd dependency structures
    matching_nodes = [node for node in matching_nodes if p['pos'][node[0]] != 'PUNCT']
    # check if there's no non-punctuation anchor ...?
    if len(matching_nodes) == 0:
        return dict(zip(grouping_headers, ['']))
    while len(matching_nodes) > 1:
        matching_nodes.sort(key=lambda x: x[2])
        parent = matching_nodes.pop()[1]
        matching_nodes.append((parent-1, tree_array[parent-1][0], tree_array[parent-1][1]))
        matching_nodes = list(set(matching_nodes))
    # this is the anchor node
    node = matching_nodes[0]
    anchor_node = matching_nodes[0]
    # Find the closest parent that is a verb
    parent = node[1]
    while p['pos'][node[0]] not in ['VERB', 'AUX'] and parent != 0:
        node = (parent-1, tree_array[parent-1][0], tree_array[parent-1][1])
        parent = node[1]
    if p['pos'][node[0]] not in ['VERB', 'AUX']:
        # we've landed in a root node that isn't a verb, wheeeeee
        return dict(zip(grouping_headers, ['']))
    # node is definitely in the anchor verb (type 'VERB' or 'AUX') now
    immediate_verb = p['words'][node[0]]
    relevant_verbs = ['']*len(p['words'])
    relevant_verbs[node[0]] = p['words'][node[0]]
    # climb to the top of the anchor verb blob
    if p['pos'][node[0]] in ['AUX']:
        parent = node[1]
        while parent != 0 and p['pos'][node[0]] in ['VERB']:
            node = (parent-1, tree_array[parent-1][0], tree_array[parent-1][1])
            if p['pos'][node[0]] in ['VERB']:
                relevant_verbs.append(p['words'][node[0]])
            parent = node[1]
    # node should now be the top anchorverb
    averb_node = node
    node = get_node_from_index(row['dependency_parse'], node[0])
    averb_string, averb_span = unroll_dependency_node(
        node,
        allowed_links=['aux', 'auxpass', 'cop'],
        allowed_types=['VERB', 'AUX', 'PART']
    )
    averb_string = ' '.join(averb_string.strip().split())
    split_loc, cspan = split_span_descriptor(averb_span, [row['split_0'], row['split_1'], row['split_2']])
    subjects = get_node_children(node, ['subj', 'nsubj', 'nsubjpass', 'csubj', 'csubjpass'])
    subjects = [' '.join(unroll_dependency_node(n)[0].strip().split()) for n in subjects]
    objects = get_node_children(node, ['obj', 'dobj', 'iobj', 'pobj'])
    objects = [' '.join(unroll_dependency_node(n)[0].strip().split()) for n in objects]
    # Pick out dependencies that come before/after the verb...
    # this targets dependencies that weren't clearly identified as subj/obj by the parser
    unsure_relation = get_node_children(node, ['dep'])
    unsure_relation = [(n['spans'][0]['start'],
                        ' '.join(unroll_dependency_node(n)[0].strip().split()))
                       for n in unsure_relation]
    # TODO improve this possibly
    # For now, assume a dependent phrase before the averb is more likely to be subject-like
    # and a dependent phrase after the averb is more likely to be object-like
    subjects += [n[1] for n in unsure_relation if n[0] < node['spans'][0]['start']]
    objects += [n[1] for n in unsure_relation if n[0] > node['spans'][0]['start']]
    # determine if the anchor is in the subj or obj part
    # TODO this could be improved by adding linguistics knowledge to algo
    # For now, just mark if the anchor comes before or after the top-averb
    # 1 if anchor is subj (before), 0 if anchor=verb (same index), -1 if anchor is obj (after)
    # TODO change this to use the anchor span/index and the averb span, NOT the node number
    relation = averb_node[0] - anchor_node[0]
    relation = int(math.copysign(1, relation)) if relation != 0 else 0
    output = [averb_string, subjects, objects, relation, 
              split_loc, averb_span[0], averb_span[1], cspan[0], cspan[1]]
    return dict(zip(grouping_headers, output))

output = sample_input.join(sample_input_dep).apply(
    lambda row: group_anchor_verb_dep(row), 
    axis=1, result_type='expand')

output = sample_input.join(output)
output

In [None]:
# Group on the POS of the anchor point, using allennlp dependency parsing
# I'm defining the "POS of a phrase" as the POS of the lowest node that contains the entire phrase
def group_anchor_pos_dep(row, context=1):
    if row['split_anchor_span'] is None:
        return dict(zip(grouping_headers, ['']))
    p = row['dependency_parse']
    # build out a more usable version of the dependency tree with information about tree level!
    # tree_array[n] = (parent, level) so the root node is (0, 0)
    tree_array = [(h, -1) for h in p['predicted_heads']]
    need_connection = [0]
    level = 0
    while len(need_connection) > 0:
        need_connection_update = []
        for i in range(len(tree_array)):
            if tree_array[i][0] in need_connection:
                tree_array[i] = (tree_array[i][0], level)
                need_connection_update.append(i+1)
        need_connection = need_connection_update
        level += 1
    # Figure out what indexes our anchor fits into
    # Assume that the anchor is contiguous text
    # TODO: is this always true?
    for i in range(len(p['words'])):
        if len(''.join(p['words'][:i])) >= len(''.join(row['split_tokens'][:row['split_anchor_span'][0]])):
            match = list_elements_match(
                p['words'], row['split_tokens'], i, row['split_anchor_span'][0], 
                size2=row['split_anchor_span'][1]-row['split_anchor_span'][0])
            if match is not None:
                break
    # Find the smallest containing dependency node
    matching_nodes = [(i, tree_array[i][0], tree_array[i][1]) for i in range(match[0][0], match[0][1])]
    matching_nodes = list(set(matching_nodes))
    # Remove punctuation from the set of nodes that consist the anchor
    # This avoids odd dependency structures
    matching_nodes = [node for node in matching_nodes if p['pos'][node[0]] != 'PUNCT']
    # check if there's no non-punctuation anchor ...?
    if len(matching_nodes) == 0:
        return dict(zip(grouping_headers, [[], []]))
    while len(matching_nodes) > 1:
        matching_nodes.sort(key=lambda x: x[2])
        parent = matching_nodes.pop()[1]
        matching_nodes.append((parent-1, tree_array[parent-1][0], tree_array[parent-1][1]))
        matching_nodes = list(set(matching_nodes))
    # this is the anchor node
    node = matching_nodes[0]
    # And get the POS and words corresponding to that node and its {context} parents
    labeltiers = []
    labelwords = []
    while len(labeltiers) < context:
        labeltiers.append(p['pos'][node[0]])
        labelwords.append(p['words'][node[0]])
        parent = node[1]
        if parent == 0:
            break
        node = (parent-1, tree_array[parent-1][0], tree_array[parent-1][1])
    return dict(zip(grouping_headers, [labeltiers, labelwords]))

output = sample_input.join(sample_input_dep).apply(
    lambda row: group_anchor_pos_dep(row, context=3), 
    axis=1, result_type='expand')

output = sample_input.join(output)
output

# Export CSVs

This section should be able to be run as a sequence by itself, assuming all functions are defined and the large imports have been performed already

In [None]:
import pandas as pd

def generate_outputs(search_word, anchor_type, try_rerun, split_data, df_sentences):
    splitted_sentences = df_sentences.merge(
        split_data,
        how='outer',
        left_on=join_headers,
        right_on=join_headers)
    
    # Retrieve dependency parse
    if try_rerun or (not os.path.isfile(f'outputs/{search_word}/partial-dependency-{anchor_type}.pkl')):
        splitted_sentences_dep = splitted_sentences.apply(
            lambda row: parse_dependency(row), 
            axis=1, result_type='expand')
        splitted_sentences_dep.to_pickle(f'outputs/{search_word}/partial-dependency-{anchor_type}.pkl')
    splitted_sentences_dep = pd.read_pickle(f'outputs/{search_word}/partial-dependency-{anchor_type}.pkl')
    
    # Retrieve constituency parse
    if try_rerun or (not os.path.isfile(f'outputs/{search_word}/partial-constituency-{anchor_type}.pkl')):
        splitted_sentences_con = splitted_sentences.apply(
            lambda row: parse_constituency(row), 
            axis=1, result_type='expand')
        splitted_sentences_con.to_pickle(f'outputs/{search_word}/partial-constituency-{anchor_type}.pkl')
    splitted_sentences_con = pd.read_pickle(f'outputs/{search_word}/partial-constituency-{anchor_type}.pkl')
    
    # Generate individual outputs
    output = splitted_sentences.apply(
        lambda row: group_first_word(row), 
        axis=1, result_type='expand')
    splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_firstword.csv')
    
    output = splitted_sentences.apply(
        lambda row: group_first_verb(row), 
        axis=1, result_type='expand')
    splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_firstverb.csv')

    output = splitted_sentences.join(splitted_sentences_con).apply(
        lambda row: group_type_sentence_con(row), 
        axis=1, result_type='expand')
    splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_sentstruct_con.csv')

    output = splitted_sentences.join(splitted_sentences_con).apply(
        lambda row: group_main_verb_con(row), 
        axis=1, result_type='expand')
    splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_mainverb_svo_con.csv')

    output = splitted_sentences.join(splitted_sentences_dep).apply(
        lambda row: group_main_verb_dep(row), 
        axis=1, result_type='expand')
    splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_mainverb_dep.csv')

    output = splitted_sentences.join(splitted_sentences_dep).apply(
        lambda row: group_main_verb_expanded_dep(row), 
        axis=1, result_type='expand')
    splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_mainverb_svo_dep.csv')

    output = splitted_sentences.join(splitted_sentences_dep).apply(
        lambda row: group_anchor_pos_dep(row, context=3), 
        axis=1, result_type='expand')
    splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_anchorpos_dep.csv')

    output = splitted_sentences.join(splitted_sentences_dep).apply(
        lambda row: group_anchor_verb_dep(row), 
        axis=1, result_type='expand')
    splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_anchorverb_dep.csv')
    
    # Restructure the outputs into something that can be neatly imported into a spreadsheet to play with
    group_anchorverb_dep = pd.read_csv(f'outputs/{search_word}/{anchor_type}_anchorverb_dep.csv')
    group_anchorverb_dep = group_anchorverb_dep.rename(
        columns={
            "group": "d_averb", "group2": "d_averb_s", "group3": "d_averb_o", "group4": "d_averb_relation", 
            "group5": "d_averb_split", "group6": "d_averb_span0", "group7": "d_averb_span1", 
            "group8": "d_averb_cspan0", "group9": "d_averb_cspan1"})
    group_mainverb_dep = pd.read_csv(f'outputs/{search_word}/{anchor_type}_mainverb_dep.csv')
    group_mainverb_dep = group_mainverb_dep[
        ['group']]
    group_mainverb_dep = group_mainverb_dep.rename(
        columns={
            "group": "d_root"})
    group_mainverb_svo_dep = pd.read_csv(f'outputs/{search_word}/{anchor_type}_mainverb_svo_dep.csv')
    group_mainverb_svo_dep = group_mainverb_svo_dep[
        ['group', 'group2', 'group3', 'group4', 'group5', 
         'group6', 'group7', 'group8']]
    group_mainverb_svo_dep = group_mainverb_svo_dep.rename(
        columns={
            "group": "d_root_full", "group2": "d_root_s", "group3": "d_root_o", 
            "group4": "d_root_split", "group5": "d_root_span0", "group6": "d_root_span1", 
            "group7": "d_root_cspan0", "group8": "d_root_cspan1"})
    group_firstverb = pd.read_csv(f'outputs/{search_word}/{anchor_type}_firstverb.csv')
    group_firstverb = group_firstverb[
        ['group']]
    group_firstverb = group_firstverb.rename(
        columns={
            "group": "fverb"})
    group_anchorpos_dep = pd.read_csv(f'outputs/{search_word}/{anchor_type}_anchorpos_dep.csv')
    group_anchorpos_dep = group_anchorpos_dep[
        ['group', 'group2']]
    group_anchorpos_dep = group_anchorpos_dep.rename(
        columns={
            "group": "d_apos", "group2": "d_apos_w"})
    group_firstword = pd.read_csv(f'outputs/{search_word}/{anchor_type}_firstword.csv')
    group_firstword = group_firstword[
        ['group']]
    group_firstword = group_firstword.rename(
        columns={
            "group": "fword"})
    group_sentstruct_con = pd.read_csv(f'outputs/{search_word}/{anchor_type}_sentstruct_con.csv')
    group_sentstruct_con = group_sentstruct_con[
        ['group', 'group2']]
    group_sentstruct_con = group_sentstruct_con.rename(
        columns={
            "group": "c_senttype", "group2": "c_sentparts"})
    group_mainverb_svo_con = pd.read_csv(f'outputs/{search_word}/{anchor_type}_mainverb_svo_con.csv')
    group_mainverb_svo_con = group_mainverb_svo_con[
        ['group', 'group2', 'group3', 'group4', 'group5', 
         'group6', 'group7', 'group8', 'group9', 'group10', 
         'group11', 'group12', 'group13', 'group14', 'group15', 
         'group16']]
    group_mainverb_svo_con = group_mainverb_svo_con.rename(
        columns={
            "group": "c_subj_full", "group2": "c_subj_split", "group3": "c_subj_span0", 
            "group4": "c_subj_span1", "group5": "c_subj_cspan0", "group6": "c_subj_cspan1", 
            "group7": "c_subj_allspans", "group8": "c_verb_full", "group9": "c_verb_split", 
            "group10": "c_verb_span0", "group11": "c_verb_span1", "group12": "c_verb_cspan0", 
            "group13": "c_verb_cspan1", "group14": "c_verb_allspans", "group15": "c_subj_list", 
            "group16": "c_verb_list"})

    outputs_merged = group_anchorverb_dep.join(
        group_mainverb_dep
    ).join(
        group_mainverb_svo_dep
    ).join(
        group_firstverb
    ).join(
        group_anchorpos_dep
    ).join(
        group_firstword
    ).join(
        group_sentstruct_con
    ).join(
        group_mainverb_svo_con
    )
    outputs_merged = outputs_merged[[
        'split_0', 'split_1', 'split_2', 
        'c_senttype', 'c_sentparts',
        'c_subj_full', 'c_subj_split', 
        'c_subj_span0', 'c_subj_span1', 'c_subj_cspan0', 'c_subj_cspan1', 'c_subj_allspans',
        'c_verb_full', 'c_verb_split', 
        'c_verb_span0', 'c_verb_span1', 'c_verb_cspan0', 'c_verb_cspan1', 'c_verb_allspans',
        'c_subj_list', 'c_verb_list',
        'd_averb', 'd_averb_s', 'd_averb_o', 'd_averb_relation', 
            'd_averb_split', 'd_averb_span0', 'd_averb_span1', 'd_averb_cspan0', 'd_averb_cspan1',
        'd_root', 'd_root_full', 'd_root_s', 'd_root_o', 
            'd_root_split', 'd_root_span0', 'd_root_span1', 'd_root_cspan0', 'd_root_cspan1',
        'fverb', 'fword', 'd_apos', 'd_apos_w',
        'URL', 'ID', 'Type', 'Index', 'Text', 
        'split_tokens', 'split_anchor_span', 'split_anchor_indices', 
        'within_anchor_index'
    ]]

    outputs_merged.to_csv(f'outputs/{search_word}/{anchor_type}.csv')
    return outputs_merged

In [None]:
import os.path
import pickle

def pipeline(search_word, 
             anchor_synonyms, 
             try_rerun=False):
    """
    Run the entire analysis pipeline for a given search phrase and set of synonyms.

    Parameters
    ----------
    search_word : str
        Identify the term we are splitting on (the "anchor").
        This also serves as the name for this entire cluster of results / practical pipeline run name.
    anchor_synonyms : [str]
        Other common names for this term that we should also consider as anchors.
    try_rerun : Boolean
        Forcibly recreate intermediate files such as coreference or dependency parses.

    """
    
    # Perform basic df loading
    
    # Read in the dataframe containing entire paper abstracts (NOT pre-split into sentences)
    df = pd.read_csv(f'data/nlp-align_{search_word}.csv')

    # Create the outputs directory for this search word
    Path(f"outputs/{search_word}").mkdir(parents=True, exist_ok=True)

    # Restructure the dataframe to be more usable
    df = df.groupby('ID', group_keys=False).apply(
        lambda row: separate_title_abstract(row)
    ).reset_index(drop=True)

    # Split the full-abstract CSV into a CSV containing individual sentences instead
    df_sentences = df.groupby(['ID', 'Type'], group_keys=False).apply(
        lambda row: sentence_tokenize(row)
    ).reset_index(drop=True)
    
    # Export whitespace-based split data
    whitespace_output = df_sentences.groupby(df_sentences.index, group_keys=False).apply(
        lambda group: split_term_whitespace(group, search_word, anchor_synonyms)).reset_index(drop=True)

    # Export coreference-based split data
    if try_rerun or (not os.path.isfile(f'outputs/{search_word}/partial-coreference.pkl')):
        output = df.apply(
            lambda row: coref_predictor.predict(row['Text']), axis=1, result_type='expand')
        df_merged = df.join(output)

        output = df_merged.apply(
            lambda row: reinterpret_coref_clusters(row, search_word, anchor_synonyms, df_sentences), 
            axis=1, result_type='expand')
        df_merged = df_merged.join(output)

        df_merged.to_pickle(f'outputs/{search_word}/partial-coreference.pkl')

    df_merged = pd.read_pickle(f'outputs/{search_word}/partial-coreference.pkl')
    coreference_output = df_sentences.groupby(df_sentences.index, group_keys=False).apply(
        lambda group: split_term_coreference(group, search_word, anchor_synonyms, df_merged, split_term_whitespace)
    ).reset_index(drop=True)
    
    # Whitespace outputs
    generate_outputs(search_word, 'whitespace', try_rerun, whitespace_output, df_sentences)
    
    # Coreference outputs
    generate_outputs(search_word, 'coreference', try_rerun, coreference_output, df_sentences)

In [None]:
all_targets = [
    ('BERT', ['BERT']),
    ('SQuAD', ['SQuAD']),
    ('DROP', ['DROP']),
    ('GPT-2', ['GPT', 'GPT-2', 'GPT-3']),
    ('Transformers', ['Transformers', 'Transformer', 'transfer learning', 'transfer'])
]

for dataset in all_targets:
    pipeline(dataset[0], dataset[1], try_rerun=False)
    print(dataset[0], 'done running')

# Misc utility functions

In [None]:
# p = dependency_predictor.predict(
#     sentence='the quick red fox jumped over the lazy dog.'
# )
# p

In [None]:
# Modified version of the pipeline to use EBM dataset
# and temporary-generated exports of PIO extraction for the dataset

import pandas as pd

def generate_outputs(search_word, anchor_type, try_rerun, split_data, df_sentences):
    # these are EBM-specific join headers
    join_headers = ['ID', 'Type', 'Index', 'ann_section', 'ann_source']
    
    # EBM uses dep and const parse on each sentence individually NOT per-split, to save memory
    # Since the tokenization is exactly the same for each split
    # (unlike whitespace, const parsing splits .... TODO, check if this is true)
    
    # Retrieve dependency parse
    if try_rerun or (not os.path.isfile(f'outputs/{search_word}/partial-dependency-{anchor_type}.pkl')):
        df_sentences_dep = df_sentences.apply(
            lambda row: parse_dependency(row), 
            axis=1, result_type='expand')
        df_sentences_dep.to_pickle(f'outputs/{search_word}/partial-dependency-{anchor_type}.pkl')
    df_sentences_dep = pd.read_pickle(f'outputs/{search_word}/partial-dependency-{anchor_type}.pkl')
    
    # Retrieve constituency parse
    if try_rerun or (not os.path.isfile(f'outputs/{search_word}/partial-constituency-{anchor_type}.pkl')):
        df_sentences_con = df_sentences.apply(
            lambda row: parse_constituency(row), 
            axis=1, result_type='expand')
        df_sentences_con.to_pickle(f'outputs/{search_word}/partial-constituency-{anchor_type}.pkl')
    df_sentences_con = pd.read_pickle(f'outputs/{search_word}/partial-constituency-{anchor_type}.pkl')
    
    # Attach const and dep parses to df_sentences
    df_sentences = df_sentences.join(df_sentences_dep)
    df_sentences = df_sentences.join(df_sentences_con)
    
    splitted_sentences = df_sentences.merge(
        split_data,
        how='outer',
        left_on=join_headers,
        right_on=join_headers,
        suffixes=[None, '_y'])
    
    # Generate individual outputs
    if not os.path.isfile(f'outputs/{search_word}/{anchor_type}_firstword.csv'):
        output = splitted_sentences.apply(
            lambda row: group_first_word(row), 
            axis=1, result_type='expand')
        splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_firstword.csv')
    
    if not os.path.isfile(f'outputs/{search_word}/{anchor_type}_firstverb.csv'):
        output = splitted_sentences.apply(
            lambda row: group_first_verb(row), 
            axis=1, result_type='expand')
        splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_firstverb.csv')

    if not os.path.isfile(f'outputs/{search_word}/{anchor_type}_sentstruct_con.csv'):
        output = splitted_sentences.apply(
            lambda row: group_type_sentence_con(row), 
            axis=1, result_type='expand')
        splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_sentstruct_con.csv')

    if not os.path.isfile(f'outputs/{search_word}/{anchor_type}_mainverb_svo_con.csv'):
        output = splitted_sentences.apply(
            lambda row: group_main_verb_con(row), 
            axis=1, result_type='expand')
        splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_mainverb_svo_con.csv')

    if not os.path.isfile(f'outputs/{search_word}/{anchor_type}_mainverb_dep.csv'):
        output = splitted_sentences.apply(
            lambda row: group_main_verb_dep(row), 
            axis=1, result_type='expand')
        splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_mainverb_dep.csv')

    if not os.path.isfile(f'outputs/{search_word}/{anchor_type}_mainverb_svo_dep.csv'):
        output = splitted_sentences.apply(
            lambda row: group_main_verb_expanded_dep(row), 
            axis=1, result_type='expand')
        splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_mainverb_svo_dep.csv')

    if not os.path.isfile(f'outputs/{search_word}/{anchor_type}_anchorpos_dep.csv'):
        output = splitted_sentences.apply(
            lambda row: group_anchor_pos_dep(row, context=3), 
            axis=1, result_type='expand')
        splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_anchorpos_dep.csv')

    if not os.path.isfile(f'outputs/{search_word}/{anchor_type}_anchorverb_dep.csv'):
        output = splitted_sentences.apply(
            lambda row: group_anchor_verb_dep(row), 
            axis=1, result_type='expand')
        splitted_sentences.join(output).to_csv(f'outputs/{search_word}/{anchor_type}_anchorverb_dep.csv')
    
    # Restructure the outputs into something that can be neatly imported into a spreadsheet to play with
    group_anchorverb_dep = pd.read_csv(f'outputs/{search_word}/{anchor_type}_anchorverb_dep.csv')
    group_anchorverb_dep = group_anchorverb_dep.rename(
        columns={
            "group": "d_averb", "group2": "d_averb_s", "group3": "d_averb_o", "group4": "d_averb_relation", 
            "group5": "d_averb_split", "group6": "d_averb_span0", "group7": "d_averb_span1", 
            "group8": "d_averb_cspan0", "group9": "d_averb_cspan1"})
    group_mainverb_dep = pd.read_csv(f'outputs/{search_word}/{anchor_type}_mainverb_dep.csv')
    group_mainverb_dep = group_mainverb_dep[
        ['group']]
    group_mainverb_dep = group_mainverb_dep.rename(
        columns={
            "group": "d_root"})
    group_mainverb_svo_dep = pd.read_csv(f'outputs/{search_word}/{anchor_type}_mainverb_svo_dep.csv')
    group_mainverb_svo_dep = group_mainverb_svo_dep[
        ['group', 'group2', 'group3', 'group4', 'group5', 
         'group6', 'group7', 'group8']]
    group_mainverb_svo_dep = group_mainverb_svo_dep.rename(
        columns={
            "group": "d_root_full", "group2": "d_root_s", "group3": "d_root_o", 
            "group4": "d_root_split", "group5": "d_root_span0", "group6": "d_root_span1", 
            "group7": "d_root_cspan0", "group8": "d_root_cspan1"})
    group_firstverb = pd.read_csv(f'outputs/{search_word}/{anchor_type}_firstverb.csv')
    group_firstverb = group_firstverb[
        ['group']]
    group_firstverb = group_firstverb.rename(
        columns={
            "group": "fverb"})
    group_anchorpos_dep = pd.read_csv(f'outputs/{search_word}/{anchor_type}_anchorpos_dep.csv')
    group_anchorpos_dep = group_anchorpos_dep[
        ['group', 'group2']]
    group_anchorpos_dep = group_anchorpos_dep.rename(
        columns={
            "group": "d_apos", "group2": "d_apos_w"})
    group_firstword = pd.read_csv(f'outputs/{search_word}/{anchor_type}_firstword.csv')
    group_firstword = group_firstword[
        ['group']]
    group_firstword = group_firstword.rename(
        columns={
            "group": "fword"})
    group_sentstruct_con = pd.read_csv(f'outputs/{search_word}/{anchor_type}_sentstruct_con.csv')
    group_sentstruct_con = group_sentstruct_con[
        ['group', 'group2']]
    group_sentstruct_con = group_sentstruct_con.rename(
        columns={
            "group": "c_senttype", "group2": "c_sentparts"})
    group_mainverb_svo_con = pd.read_csv(f'outputs/{search_word}/{anchor_type}_mainverb_svo_con.csv')
    group_mainverb_svo_con = group_mainverb_svo_con[
        ['group', 'group2', 'group3', 'group4', 'group5', 
         'group6', 'group7', 'group8', 'group9', 'group10', 
         'group11', 'group12', 'group13', 'group14', 'group15', 
         'group16']]
    group_mainverb_svo_con = group_mainverb_svo_con.rename(
        columns={
            "group": "c_subj_full", "group2": "c_subj_split", "group3": "c_subj_span0", 
            "group4": "c_subj_span1", "group5": "c_subj_cspan0", "group6": "c_subj_cspan1", 
            "group7": "c_subj_allspans", "group8": "c_verb_full", "group9": "c_verb_split", 
            "group10": "c_verb_span0", "group11": "c_verb_span1", "group12": "c_verb_cspan0", 
            "group13": "c_verb_cspan1", "group14": "c_verb_allspans", "group15": "c_subj_list", 
            "group16": "c_verb_list"})

    outputs_merged = group_anchorverb_dep.join(
        group_mainverb_dep
    ).join(
        group_mainverb_svo_dep
    ).join(
        group_firstverb
    ).join(
        group_anchorpos_dep
    ).join(
        group_firstword
    ).join(
        group_sentstruct_con
    ).join(
        group_mainverb_svo_con
    )
    outputs_merged = outputs_merged[[
        'split_0', 'split_1', 'split_2', 
        'c_senttype', 'c_sentparts',
        'c_subj_full', 'c_subj_split', 
        'c_subj_span0', 'c_subj_span1', 'c_subj_cspan0', 'c_subj_cspan1', 'c_subj_allspans',
        'c_verb_full', 'c_verb_split', 
        'c_verb_span0', 'c_verb_span1', 'c_verb_cspan0', 'c_verb_cspan1', 'c_verb_allspans',
        'c_subj_list', 'c_verb_list',
        'd_averb', 'd_averb_s', 'd_averb_o', 'd_averb_relation', 
            'd_averb_split', 'd_averb_span0', 'd_averb_span1', 'd_averb_cspan0', 'd_averb_cspan1',
        'd_root', 'd_root_full', 'd_root_s', 'd_root_o', 
            'd_root_split', 'd_root_span0', 'd_root_span1', 'd_root_cspan0', 'd_root_cspan1',
        'fverb', 'fword', 'd_apos', 'd_apos_w',
        'URL', 'ID', 'Type', 'Index', 'Text', 
        'split_tokens', 'split_anchor_span', 'split_anchor_indices', 
        'within_anchor_index', 'split_anchor_type'
    ]]

    outputs_merged.to_csv(f'outputs/{search_word}/{anchor_type}.csv')
    return outputs_merged

search_word = 'ebm'

# Retrieve EBM sentence data...
pio_df_sentences = pd.read_pickle(f'temp/ebm-df_sentences.pkl')

# Add tokenization to the EBM sentence data
pio_df_tokenized = pd.read_pickle(f'temp/ebm-df_tokenized.pkl')
    
# Retrieve PIO data
pio_output = pd.read_pickle(f'temp/ebm-pio_output.pkl')

# Generate outputs for grouping...
# TODO refine the join column lists into something thats actually reusable ??? not a super high priority
outputs_merged = generate_outputs(
    search_word, 
    'pio', 
    False, 
    pio_output, 
    pio_df_sentences.join(
        pio_df_tokenized.set_index(['ID', 'Type', 'Index', 'ann_section', 'ann_source']), 
        on=['ID', 'Type', 'Index', 'ann_section', 'ann_source']))
outputs_merged.iloc[0:100]

# Individual element alignment / analysis

In [None]:
import gensim

# Load Google's pre-trained Word2Vec model.
# model source: https://code.google.com/archive/p/word2vec/
word2vec = gensim.models.KeyedVectors.load_word2vec_format(
    'model/GoogleNews-vectors-negative300.bin', 
    binary=True)

In [None]:
# Get the word2vec embedding of a phrase
def get_phrase_embed_word2vec(word2vec, phrase):
    try:
        phraseS = phrase.split()
    except:
        return pd.DataFrame()
    unknowns = []
    emb = []
    for w in phraseS:
        try:
            emb.append(word2vec[w])
        except:
            unknowns.append(w)
    if len(emb) == 0:
        return pd.DataFrame()
    emb_sum = pd.DataFrame(emb).sum() / len(phraseS)
    emb_sum['word'] = phrase
    return pd.DataFrame([emb_sum])

v = get_phrase_embed_word2vec(
    word2vec, 
    'test sentence')
sent_v = get_phrase_embed_word2vec(
    word2vec, 
    'This is a test sentence !')
v

In [None]:
pio_extracted = pd.read_hdf(f'temp/ebm-pio_extracted.hdf', 'mydata')
# pio_extracted.iloc[0:200]

In [None]:
# dependency_predictor.predict(
#     sentence='hypertensive patients receiving drug treatment'
# )

In [None]:
# Do dependency parsing once for the entire input to save processing time
# for groupings that require dependency parsing later
depparse = pd.DataFrame()
depparse['dependency_parse'] = pio_extracted.apply(
    lambda row: dependency_predictor.predict(
        sentence=row['text'].strip()
    ), axis=1)
depparse.to_hdf(f'temp/ebm-pio_extracted_depparsed.hdf', 'mydata', mode='w')

In [None]:
depparse = pd.read_hdf(f'temp/ebm-pio_extracted_depparsed.hdf', 'mydata')
# depparse.iloc[0:200]

In [None]:
# # Grab only the hierplane_tree segment to save space
# # TODO rework previous code to also use hierplane_tree ... ? ????
# # ... this is not a priority, it's just a code cleanliness thing
# # do we even need this? / can we even use this?
# # TODO also this needs to be updated to use hdf reading instead of pkl, no pkl export anymore
# pio_extracted = pd.read_pickle(f'temp/ebm-pio_extracted_parsed.pkl')
# pio_extracted['dependency_parse'] = pio_extracted.apply(
#     lambda row: row['dependency_parse']['hierplane_tree'], axis=1)
# pio_extracted.to_pickle(f'temp/ebm-pio_extracted_parsed.pkl')

In [None]:
# tcon = constituency_predictor.predict(
#     sentence='hypertensive patients receiving drug treatment'
# )
# {i:tcon[i] for i in tcon if i!='class_probabilities'}

In [None]:
def inplace_constparse(row):
    parse = constituency_predictor.predict(
        sentence=row['text'].strip()
    )
    return {i:parse[i] for i in parse if (i not in ['class_probabilities'])}

# Do constituency parsing once for the entire input to save processing time
# for groupings that require constituency parsing later
conparse = pd.DataFrame()
conparse['constituency_parse'] = pio_extracted.apply(
    lambda row: inplace_constparse(row), axis=1)
conparse.to_hdf(f'temp/ebm-pio_extracted_conparsed.hdf', 'mydata', mode='w')

In [None]:
conparse = pd.read_hdf(f'temp/ebm-pio_extracted_conparsed.hdf', 'mydata')
# conparse.iloc[0:200]

In [None]:
depparse = pd.read_hdf(f'temp/ebm-pio_extracted_depparsed.hdf','mydata')
conparse = pd.read_hdf(f'temp/ebm-pio_extracted_conparsed.hdf','mydata')
pio_extracted['dependency_parse'] = depparse['dependency_parse']
pio_extracted['constituency_parse'] = conparse['constituency_parse']
pio_extracted.to_hdf(f'temp/ebm-pio_extracted_parsed.hdf', 'mydata', mode='w')

In [None]:
# Start running here if already generated dep and con parses!
pio_ex_depparse = pd.read_hdf(f'temp/ebm-pio_extracted_parsed.hdf','mydata')
# pio_ex_depparse.loc[pio_ex_depparse['class'] == 'p']

In [None]:
# p_input = pio_ex_depparse.loc[pio_ex_depparse['class'] == 'p']
# p_input

In [None]:
p_input = pio_ex_depparse.loc[pio_ex_depparse['class'] == 'p']

dep_root = pd.DataFrame()
dep_root['word'] = p_input.apply(
    lambda row: row['dependency_parse']['hierplane_tree']['root']['word'],
    axis=1
)
dep_root['attr'] = p_input.apply(
    lambda row: row['dependency_parse']['hierplane_tree']['root']['attributes'][0],
    axis=1
)

# dep_root

In [None]:
pio_ex_conparse = pd.read_hdf(f'temp/ebm-pio_extracted_parsed.hdf','mydata')
# pio_ex_conparse.loc[pio_ex_conparse['class'] == 'p']

In [None]:
p_input = pio_ex_conparse.loc[pio_ex_conparse['class'] == 'p']
# p_input

In [None]:
p_input = pio_ex_conparse.loc[pio_ex_conparse['class'] == 'p']

# pick out the phrase type of the entire P text chunk
con_root = pd.DataFrame()
con_root['src'] = p_input['text']
con_root['attr'] = p_input.apply(
    lambda row: row['constituency_parse']['hierplane_tree']['root']['nodeType'],
    axis=1
)

# con_root

In [None]:
p_input = pio_ex_conparse.loc[pio_ex_conparse['class'] == 'p']

# pick out the "main subject" using constituency parsing
def ebmnlp_const_main_subj(row):
    output = []
    output_types = []
    output_levels = []
    active_nodes = [row['constituency_parse']['hierplane_tree']['root']]
    active_child_attrs = [[n['nodeType'] for n in active_nodes[0]['children']]]
    level = 0
    while len(active_nodes) > 0:
        new_active_nodes = []
        new_active_child_attrs = []
        for node_i in range(len(active_nodes)):
            node = active_nodes[node_i]
            child_attr = active_child_attrs[node_i]
            child_np_i = [i for i, x in enumerate(child_attr) if x == 'NP']
            if len(child_np_i) == 0:
                # we've landed on a node with no child NPs
                endpoint_children = [n for n in node['children'] if ('children' not in n)]
                output.append([n['word'] for n in endpoint_children])
                output_types.append([n['nodeType'] for n in endpoint_children])
                output_levels.append(level)
            else:
                # this node has child NPs, add those to the list of nodes to investigate
                for i in child_np_i:
                    child_node = node['children'][i]
                    new_active_nodes.append(child_node)
                    new_active_child_attrs.append([n['nodeType'] for n in child_node['children']])
        active_nodes = new_active_nodes
        active_child_attrs = new_active_child_attrs
        level += 1
    output = [' '.join(phrase) for phrase in output]
    return {'mainsubj':output, 'mainsubj_type':output_types, 'mainsubj_levels':output_levels}

con_mainsubj = pd.DataFrame()
con_mainsubj['src'] = p_input['text']
con_mainsubj = con_mainsubj.join(
    p_input.apply(
        lambda row: ebmnlp_const_main_subj(row),
        axis=1, result_type='expand'
    )
)

# con_mainsubj

In [None]:
p_input = pio_ex_conparse.loc[pio_ex_conparse['class'] == 'p']

# pick out the "segments" of the phrase using constituency parsing
def const_node_align_segments(node):
    punct_types = [] #['.', ':', 'HYPH', '-LRB-', '-RRB-']
    # ... ignore punct splits for now, these are mostly visual decoration anyways
    output = []
    output_types = []
    output_child_types = []
    children = node['children']
    active_curr_text = None
    active_curr_types = None
    active_curr_i = None
    for c_i in range(len(children)):
        if 'children' in children[c_i]:
            # if this child node has children
            # first save any previous non-children blobs
            if active_curr_text is not None:
                output.append(active_curr_text)
                output_types.append(node['nodeType'])
                output_child_types.append(active_curr_types)
                active_curr_text = None
            # break child node down and get its segments
            child_output = const_node_align_segments(children[c_i])
            output += child_output[0]
            output_types += child_output[1]
            output_child_types += child_output[2]
        else:
            # otherwise, this child node has no children
            # so we put it into the direct output of this node
            if children[c_i]['nodeType'] in punct_types:
                # if the text we're adding is punctuation, split it into a separate punctuation group!
                # put the previous group in output
                if active_curr_text is not None:
                    output.append(active_curr_text)
                    output_types.append(node['nodeType'])
                    output_child_types.append(active_curr_types)
                # put the punctuation into output
                output.append(children[c_i]['word'])
                output_types.append(children[c_i]['nodeType'])
                output_child_types.append([children[c_i]['nodeType']])
                # reset groups
                active_curr_text = None
            elif active_curr_text is None:
                # empty blob, initialize it
                active_curr_text = children[c_i]['word']
                active_curr_types = [children[c_i]['nodeType']]
                active_curr_i = c_i
            else:
                # we extend the already-existing blob
                active_curr_text += ' ' + children[c_i]['word']
                active_curr_types.append(children[c_i]['nodeType'])
                active_curr_i = c_i
    # and finish off the last blob we started
    if active_curr_text is not None:
        output.append(active_curr_text)
        output_types.append(node['nodeType'])
        output_child_types.append(active_curr_types)
    return output, output_types, output_child_types

# print(const_node_align_segments(tcon['hierplane_tree']['root']))

con_segments = pd.DataFrame()
con_segments['src'] = p_input['text']
con_segments['aligntup'] = p_input.apply(
    lambda row: const_node_align_segments(row['constituency_parse']['hierplane_tree']['root']),
    axis=1
)
con_segments['alignsegments'] = con_segments.apply(
    lambda row: row['aligntup'][0],
    axis=1
)
con_segments['aligntypes'] = con_segments.apply(
    lambda row: row['aligntup'][1],
    axis=1
)
con_segments['alignctypes'] = con_segments.apply(
    lambda row: row['aligntup'][2],
    axis=1
)

# con_segments.to_csv(f'temp/ebm-pio_consegments.csv')
con_segments.to_hdf(f'temp/ebm-pio_consegments.hdf', 'mydata', mode='w')
# con_segments

In [None]:
# con_segments = pd.read_hdf(f'temp/ebm-pio_consegments.hdf','mydata')
# con_segments

In [None]:
con_segments.apply(
    lambda row: list(zip(row['alignsegments'], row['aligntypes'])),
    axis=1
)

# Simple 'main subject' alignment

In [None]:
temp_df = con_segments.join(con_mainsubj['mainsubj'])

# get simple align spacing information
# for NOW, only use the FIRST mainsubj identified that's not an empty string as the anchor...
num_before = 0
num_after = 0
for index, row in temp_df.iterrows():
    anchors = [e for e in row['mainsubj'] if e != '']
    # find the index of the first 'main subj element'
    if len(anchors) > 0:
        anchorfound = 0
        while anchorfound == 0:
            try:
                temp_df.at[index,'align_mainsubjsimple'] = row['alignsegments'].index(anchors[0])
                anchorfound = 1
            except:
                anchors[0] = ' '.join(anchors[0].split(' ')[:-1])
    else:
        temp_df.at[index,'align_mainsubjsimple'] = -1
    # calcuate how much space we need before and after the 'main subj element'
    num_before = max(num_before, temp_df.at[index,'align_mainsubjsimple'])
    num_after = max(num_after, len(row['alignsegments']) - temp_df.at[index,'align_mainsubjsimple'] - 1)
    
def alignRowMainSubjSimple(row, num_before, num_after):
    output = ['']*int(num_before-row['align_mainsubjsimple'])
    output += row['alignsegments']
    output += ['']*int(num_after-len(row['alignsegments'])+row['align_mainsubjsimple']+1)
    return dict(zip(range(len(output)), output))

temp_df = temp_df.apply(
    lambda row: alignRowMainSubjSimple(row, num_before, num_after), 
    axis=1, result_type='expand')
temp_df.to_csv(f'temp/ebm-pio_alignRowMainSubjSimple.csv')
temp_df