In [1]:
import pickle
import string
from collections import Counter, defaultdict

import pandas as pd
from IPython.display import display, HTML
from nltk.parse import CoreNLPParser
from nltk.tree import ParentedTree
from tqdm.notebook import tqdm_notebook as tqdm

from notebook_utils.constants import PROJ_ROOT, GENRES
from notebook_utils.data_loader import load_all_books

In [2]:
def get_rules(tree, rules, seen: set):
    if not isinstance(tree, ParentedTree):
        ptree = ParentedTree.convert(tree)
    else:
        ptree = tree

    for subtree in ptree:
        if isinstance(subtree, ParentedTree) and subtree.height() > 2:
            get_rules(subtree, rules, seen)

    productions = ptree.productions()
    for rule in productions:
        if rule.is_lexical() and rule not in seen and str(rule._lhs) not in string.punctuation:
            rule_str = f"{rule._lhs} -> {' '.join([rhs for rhs in rule._rhs])}"
            rules["G"][rule_str] += 1
            seen.add(rule)
            try:
                gnode = ptree.parent().parent().label() + " -> "
                rules["GG"][gnode + rule_str] += 1
            except AttributeError:
                rules["gG"]["ROOT -> " + rule_str] += 1
        elif rule.is_nonlexical() and rule not in seen and str(rule._lhs) not in string.punctuation:
            rules["g"][str(rule)] += 1
            seen.add(rule)
            try:
                gnode = ptree.parent().parent().label() + " -> "
                rules["gG"][gnode + str(rule)] += 1
            except AttributeError:
                rules["gG"]["ROOT -> " + str(rule)] += 1


all_books = load_all_books()
all_rules = {genre: {"G": [], "GG": [], "g": [], "gG": []} for genre in GENRES}

In [3]:
bar_length = sum(len(all_books[genre]) for genre in GENRES) * 1000

parser = CoreNLPParser()

with tqdm(total=bar_length) as pbar:
    for genre in GENRES:
        for i, book in enumerate(all_books[genre]):
            pbar.set_postfix_str(f" -- {genre} -- [{i + 1}/{len(all_books[genre])}] ")
            if book.book_number == "19513" or book.book_number == "19640" or book.book_number == "19678" \
                    or book.book_number == "19782" or book.book_number == "19836" or book.book_number == "22326" \
                    or book.book_number == "1322":
                pbar.update(1000)
                continue

            try:
                sentences = all_books[genre][i].first_1k_sentences
                book_rules = []

                for sentence in sentences:
                    results = [r for r in parser.raw_parse(sentence, properties={"annotators": "tokenize,ssplit,pos,parse"})]
                    sent_rules = {"G": defaultdict(lambda: 0), "GG": defaultdict(lambda: 0),
                                  "g": defaultdict(lambda: 0), "gG": defaultdict(lambda: 0)}
                    get_rules(results[0], sent_rules, set())

                    book_rules.append(sent_rules)

                    pbar.update(1)

                if len(all_books[genre][i].first_1k_sentences) < 1000:
                    pbar.update(1000 - len(all_books[genre][i].first_1k_sentences))

                counts = {"G": sum([Counter(book_rules[j]["G"]) for j in range(len(book_rules))], Counter()),
                          "GG": sum([Counter(book_rules[j]["GG"]) for j in range(len(book_rules))], Counter()),
                          "g": sum([Counter(book_rules[j]["g"]) for j in range(len(book_rules))], Counter()),
                          "gG": sum([Counter(book_rules[j]["gG"]) for j in range(len(book_rules))], Counter())}

                full_book_data = {"G": {"Book #": all_books[genre][i].book_number, "@Genre": genre},
                                  "GG": {"Book #": all_books[genre][i].book_number, "@Genre": genre},
                                  "g": {"Book #": all_books[genre][i].book_number, "@Genre": genre},
                                  "gG": {"Book #": all_books[genre][i].book_number, "@Genre": genre}}

                full_book_data["G"].update({k: v for k, v in counts["G"].items() if k != "''" and k != "``" and k not in string.punctuation})
                full_book_data["G"]["@Outcome"] = all_books[genre][i].success

                full_book_data["GG"].update({k: v for k, v in counts["GG"].items() if k != "''" and k != "``" and k not in string.punctuation})
                full_book_data["GG"]["@Outcome"] = all_books[genre][i].success

                full_book_data["g"].update({k: v for k, v in counts["g"].items() if k != "''" and k != "``" and k not in string.punctuation})
                full_book_data["g"]["@Outcome"] = all_books[genre][i].success

                full_book_data["gG"].update({k: v for k, v in counts["gG"].items() if k != "''" and k != "``" and k not in string.punctuation})
                full_book_data["gG"]["@Outcome"] = all_books[genre][i].success

                for tag_type in ["G", "GG", "g", "gG"]:
                    all_rules[genre][tag_type].append(full_book_data[tag_type])

            except (AssertionError, RuntimeError) as e:
                print(f"{genre}, {book.success}, {book.book_number}")
                pbar.update(1000)
                continue

        display(HTML(f"<b>Dumping {genre} CFG Rules</b>"))
        for tag_type in ["G", "GG", "g", "gG"]:
            with open(str(PROJ_ROOT.joinpath("data", f"{genre}_{tag_type}_data")), "wb+") as f:
                try:
                    pickle.dump(pd.DataFrame(all_rules[genre][tag_type]).fillna(0), f)
                except MemoryError:
                    print(f"There was a MemoryError when dumping {genre}_{tag_type}_data")


HBox(children=(FloatProgress(value=0.0, max=799000.0), HTML(value='')))


