In [13]:
import json 
import pandas as pd 

with open("causenet-precision.jsonl", "r") as f:
    json_lines = f.readlines()

In [126]:
from tqdm.notebook import tqdm
extracted_examples = []

for line in tqdm(json_lines, total = len(json_lines)):
    line_json = json.loads(line)
    cause = line_json["causal_relation"]["cause"]["concept"].replace("_'", "'").replace("_"," ")
    effect = line_json["causal_relation"]["effect"]["concept"].replace("_'", "'").replace("_"," ")

    example_sents = []
    for ex in line_json["sources"]:
        if "sentence" in ex["payload"]:
            example_sents.append(ex["payload"]["sentence"]) 
            break

    for es in example_sents:
        extracted_examples.append({"cause": cause, "effect": effect, "text": es.strip()})

HBox(children=(FloatProgress(value=0.0, max=197806.0), HTML(value='')))




In [127]:
df = pd.DataFrame(extracted_examples)

In [128]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(df, test_size=.15, shuffle=True)
train, val = train_test_split(train, test_size=.1, shuffle=True)

In [129]:
import spacy 

nlp = spacy.load("en_core_web_sm")

In [173]:
from difflib import SequenceMatcher

def matcher(string, pattern):
    '''
    Return the start and end index of any pattern present in the text.
    '''
    match_list = []
    pattern = pattern.strip()
    seqMatch = SequenceMatcher(None, string, pattern, autojunk=False)
    match = seqMatch.find_longest_match(0, len(string), 0, len(pattern))
    if (match.size == len(pattern)):
        start = match.a
        end = match.a + match.size
        match_tup = (start, end)
        string = string.replace(pattern, "X" * len(pattern), 1)
        match_list.append(match_tup)
        
    return match_list, string

def mark_sentence(s, match_list, word_dict):
    '''
    Marks all the entities in the sentence as per the BIO scheme. 
    '''
        
    for start, end, e_type in match_list:
        temp_str = s[start:end]
        tmp_list = temp_str.split()
        if len(tmp_list) > 1:
            word_dict[tmp_list[0]] = 'B-' + e_type
            for w in tmp_list[1:]:
                word_dict[w] = 'I-' + e_type
        else:
            word_dict[temp_str] = 'B-' + e_type
    return word_dict

def clean(text):
    '''
    Just a helper fuction to add a space before the punctuations for better tokenization
    '''
    filters = ["!", "#", "$", "%", "&", "(", ")", "/", "*", ".", 
              ":", ";", "<", "=", ">", "?", "@", "[",
               "\\", "]", "_", "`", "{", "}", "~", "'"]
    for i in text:
        if i in filters:
            text = text.replace(i, " " + i)
            
    return text

def tagged_text(text, cause, effect):
    text = clean(r.text.lower())
    word_dict = {}
    tag_dict = {}

    toks = word_tokenize(text)
    pos = nltk.pos_tag(toks)

    for i, tok in enumerate(toks):
        word_dict[tok] = "O"
        tag_dict[tok] = pos[i][1]
  
    match_list = []
    annotations = [(r.cause, "CAUSE"), [r.effect, "EFFECT"]]
    for k in annotations:
        a, text_ = matcher(text, k[0])
        match_list.append((a[0][0], a[0][1], k[1]))
    tagged_seq = mark_sentence(text, match_list, word_dict)
    return tagged_seq, tag_dict 


In [178]:
from tqdm.notebook import tqdm
pairs = [(train, "train.txt"), (val, "val.txt"), (test, "test.txt")]

bad_ctr = 0
for pair in pairs:
    print(f"working on {pair[1]}")
    data = pair[0]
    with open(pair[1], "w") as f:
        for i,r in tqdm(data.iterrows(), total=len(data)):
            try:
                tagged_seq, pos = tagged_text(r.text, r.cause, r.effect)
                for tag in tagged_seq.keys():
                    f.writelines(tag + ' ' + pos[tag]  +' '+ tagged_seq[tag] + "\n")
                f.writelines("\n")
            except:
                bad_ctr+=1
                continue
    print("ignored: ", bad_ctr)

working on train.txt


HBox(children=(FloatProgress(value=0.0, max=138376.0), HTML(value='')))


ignored:  0
working on val.txt


HBox(children=(FloatProgress(value=0.0, max=15376.0), HTML(value='')))


ignored:  0
working on test.txt


HBox(children=(FloatProgress(value=0.0, max=27133.0), HTML(value='')))


ignored:  0
