In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import imodelsx.process_results
import sys
import datasets
import numpy as np
from copy import deepcopy

from collections import defaultdict
import openai
from typing import List, Tuple
import os
import os.path
from os.path import join
import string
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from typing import List
from IPython.display import display, HTML
import clin.viz
import joblib
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'

# download and extract data
# !wget https://github.com/bepnye/EBM-NLP/raw/master/ebm_nlp_2_00.tar.gz
# !tar -xvf ebm_nlp_2_00.tar.gz

DATA_DIR = 'ebm_nlp_2_00'
DOC_DIR = join(DATA_DIR, 'documents')
ANNOT_DIR = join(DATA_DIR, 'annotations', 'aggregated', 'starting_spans', 'interventions', 'test', 'gold')

annot_fnames = os.listdir(ANNOT_DIR)
doc_ids_gold = sorted([fname.split('.')[0] for fname in annot_fnames])
doc_names = sorted(os.listdir(join(DATA_DIR, 'documents')))
doc_ids = sorted(list(set([doc_name.split('.')[0] for doc_name in doc_names])))
# find doc_ids that are in doc_ids_gold
doc_ids = [doc_id for doc_id in doc_ids if doc_id in doc_ids_gold]

In [None]:
def get_doc_and_interventions(doc_id):
    doc = open(join(DOC_DIR, doc_id + '.txt'), 'r').read()
    toks = open(join(DOC_DIR, doc_id + '.tokens'), 'r').read()
    annot = open(join(ANNOT_DIR, doc_id + '.AGGREGATED.ann'), 'r').read()
    
    toks_list = toks.split()
    annot_list = np.array([int(i) for i in annot.split()]).astype(int)
    annot_arr = (annot_list > 0).astype(int)

    def find_continguous_sequences(annot_arr: np.ndarray) -> List[Tuple]:
        # return a list of (start, stop) tuples identifying continguous sequences where annot_list is 1
        # e.g. annot_arr = [0, 1, 1, 0, 1, 1, 1, 0, 0, 1] -> [(1, 3), (4, 7), (9, 11)]
        annot_arr = np.concatenate([[0], annot_arr, [0]])
        diffs = np.diff(annot_arr)
        starts = np.where(diffs == 1)[0]
        stops = np.where(diffs == -1)[0]
        return list(zip(starts, stops))

    contiguous_seqs = find_continguous_sequences(annot_arr)

    # get all interventions
    interventions = []
    unique_interventions = []
    for start, stop in contiguous_seqs:
        intervention = toks_list[start:stop]
        if intervention[-1] in string.punctuation:
            intervention = intervention[:-1]
        if intervention[0] in string.punctuation:
            intervention = intervention[1:]
        interventions.append(' '.join(intervention))
        iv_lower = ' '.join(intervention).lower()
        iv_lower_s = iv_lower + 's'
        iv_lower_without_s = iv_lower[:-1] if iv_lower.endswith('s') else iv_lower
        unique_interventions_lower = [x.lower() for x in unique_interventions]
        if not iv_lower in unique_interventions_lower \
            and not iv_lower_s in unique_interventions_lower\
                and not iv_lower_without_s in unique_interventions_lower:
            unique_interventions.append(' '.join(intervention))

    return doc, unique_interventions

docs_and_interventions = [get_doc_and_interventions(doc_id) for doc_id in tqdm(doc_ids)]
docs = [doc for doc, _ in docs_and_interventions]
interventions = [intervention for _, intervention in docs_and_interventions]
df = pd.DataFrame.from_dict({'doc_id': doc_ids, 'interventions': interventions, 'doc': docs})
df.to_csv('ebm_interventions_gold_raw.csv', index=False)

# Clean up annotations

In [15]:
ebm_interventions_spans = defaultdict(list)
for i in range(125):
    doc_id = df.iloc[i]['doc_id']

    doc = open(join(DOC_DIR, doc_id + '.txt'), 'r').read()
    toks = open(join(DOC_DIR, doc_id + '.tokens'), 'r').read()
    annot = open(join(ANNOT_DIR, doc_id + '.AGGREGATED.ann'), 'r').read()

    toks_list = toks.split()
    annot_list = np.array([int(i) for i in annot.split()]).astype(float)/2

    if i >= 110:
        color_str = clin.viz.colorize(toks_list, annot_list, char_width_max=60, title=str(i) + " " + doc_id)
        display(HTML(color_str))

    ebm_interventions_spans['doc_id'].append(doc_id)
    ebm_interventions_spans['doc'].append(doc)
    ebm_interventions_spans['toks_list'].append(toks_list)
    ebm_interventions_spans['annot_list'].append(annot_list)
pd.DataFrame(ebm_interventions_spans).to_pickle('ebm_interventions_spans.pkl')

In [None]:
# paste this dict to start filling in the annotations
d = {doc_ids[i]: interventions[i] for i in range(110, 125)}
d

## Read annots and save as pkl

In [None]:
from ebm_interventions_labels_cleaned import ANNOTS
annot_doc_ids = sorted(list(ANNOTS.keys()))
n_clean = len(ANNOTS)
df_cleaned_rows = defaultdict(list)
for i in range(n_clean):
    doc_id = annot_doc_ids[i]
    row = df.iloc[i]
    assert row['doc_id'] == doc_id, f'{row["doc_id"]} != {doc_id}'
    df_cleaned_rows['doc_id'].append(doc_id)
    df_cleaned_rows['doc'].append(row['doc'].strip())
    df_cleaned_rows['interventions'].append(ANNOTS[doc_id])

In [None]:
df_cleaned = pd.DataFrame.from_dict(df_cleaned_rows)
joblib.dump(df_cleaned, 'ebm_interventions_cleaned.pkl')