In [1]:
from fastcoref import spacy_component
import spacy

nlp = spacy.load("en_core_web_sm", exclude=["parser", "lemmatizer", "ner", "textcat"])
nlp.add_pipe("fastcoref")

11/07/2022 19:13:20 - INFO - 	 missing_keys: []
11/07/2022 19:13:20 - INFO - 	 unexpected_keys: []
11/07/2022 19:13:20 - INFO - 	 mismatched_keys: []
11/07/2022 19:13:20 - INFO - 	 error_msgs: []
11/07/2022 19:13:20 - INFO - 	 Model Parameters: 90.5M, Transformer: 82.1M, Coref head: 8.4M


<fastcoref.spacy_component.spacy_component.FastCorefResolver at 0x7f39189af8e0>

In [2]:
import logging

logger = logging.getLogger('bgg_predict')
logger.handlers.clear()
handler = logging.StreamHandler()
formatter = logging.Formatter(
        '%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)

logger.debug('test')

2022-11-07 19:13:43,597 bgg_predict  DEBUG    test
11/07/2022 19:13:43 - DEBUG - 	 test


In [3]:
import re
from typing import List
from dataclasses import dataclass

regex_mail = re.compile(r'\w+(?:\.\w+)*?@\w+(?:\.\w+)+')

@dataclass
class Sentence:
    content: str
    start: int
    end: int
    
    def does_include_pos(self, pos: int) -> bool:
        return self.start <= pos <= self.end

def clean_text(text: str) -> str:
    return regex_mail.sub('', text)

def get_sentences_from_text(text: str) -> List[Sentence]:
    # assert no continuous dots because of text cleared while building the dataset
    sentences = text.split('.')
    
    res = []
    char_accumulator = 0
    for sentence in sentences:
        res.append(Sentence(sentence, char_accumulator, char_accumulator + len(sentence) - 1))
        char_accumulator += len(sentence) + 1
        
    if res[-1].content == '':
        res.pop()
        
    return res

In [4]:
from bisect import bisect_left
from typing import List, Tuple
from dataclasses import dataclass

# necessary to use bisect_left with ranges
@dataclass
class Interval:
    start: int
    end: int
    
    def __lt__(self, other) -> bool:
       return self.start < self.end < other.start
    
    def __eq__(self, other) -> bool:
       return self.start <= other.start <= self.end

def get_sentences_from_clusters(clusters: List[List[Tuple[int, int]]], sentences: List[Sentence]) -> List[List[int]]:
    '''find the sentence each cluster belongs to'''
    sentence_clusters = []
    for cluster in clusters:
        sentence_clusters.append([bisect_left(sentences, Interval(entity[0], entity[1]), key=lambda x: Interval(x.start, x.end)) for entity in cluster])

    return sentence_clusters

text = 'Alice goes down the rabbit hole. Where she would discover a new reality beyond her expectations.'
sentences = get_sentences_from_text(text)
clusters = [[(0, 5), (39, 42), (79, 82)]]
get_sentences_from_clusters(clusters, sentences)

[[0, 1, 1]]

In [5]:
from typing import List, Set
from itertools import groupby
from operator import itemgetter
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

def get_rule_groups_from_sentence_clusters(sentences: List[Sentence], sentence_clusters: List[List[int]]) -> List[List[int]]:
    def normalize_group(group: Set[int]) -> List[List[int]]:
        '''each group could contain multiple consecutive sublists. this method split these sublists'''
        res = []

        # https://stackoverflow.com/a/23861347/5587393
        for k, g in groupby(enumerate(sorted(list(group))), lambda x: x[0] - x[1]):
            res.append(list(map(itemgetter(1), g)))

        return res
    # the graph is built as a directed sparse graph where the first element of each cluster
    # is connected to the other elements in the same cluster
    graph = [[0 for _ in range(len(sentences))] for __ in range(len(sentences))]
    for cluster in sentence_clusters:
        for sentence in cluster[1:]:
            graph[cluster[0]][sentence] = 1

    # find the connected components of the graph created from the clusters returned after coref     
    graph = csr_matrix(graph)
    n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True)
    groups = [set() for _ in range(n_components)]
    for i, label in enumerate(labels):
        groups[label].add(i)

    return [norm_group for group in groups for norm_group in normalize_group(group)]

In [None]:
import itertools
from typing import List, Tuple
import pandas as pd

DATASET_FILE_PATH = 'data/dataset.csv'

def get_rules(text: str) -> List[str]:
    text = clean_text(text)
    sentences = get_sentences_from_text(text)
    
    doc = nlp(text, component_cfg = { "fastcoref": {'resolve_text': True} })
    coref_clusters = doc._.coref_clusters
    logger.debug(coref_clusters)
    
    sentence_clusters = get_sentences_from_clusters(coref_clusters, sentences)
    rule_groups = get_rule_groups_from_sentence_clusters(sentences, sentence_clusters)
    
    return ['. '.join([sentences[s_index].content for s_index in group]) for group in rule_groups]

def get_rules_features(text: str) -> Tuple[int, float]:
    rules = get_rules(text)
    rule_count = len(rules)
    return rule_count, len(text) / rule_count

def remove_columns_prefix(df: pd.core.frame.DataFrame) -> None:
    '''remove prefix 'info.' from the columns of df'''
    df.rename(columns=lambda c: c.rsplit('.', 1)[-1], inplace=True)
    
df_features = pd.DataFrame()
with pd.read_csv(DATASET_FILE_PATH, chunksize=1) as reader:
    for df in reader:
        remove_columns_prefix(df)
        df_rules_features = df.apply(lambda x: pd.Series(get_rules_features(x.rulebook), 
                                     index=['rule_count', 'avg_rule_len']), axis='columns')
        df_features = pd.concat([df_features, df[['numweights', 'averageweight', 'playingtime', 'family']].join(df_rules_features)])
        
display(df_features)

11/07/2022 19:14:03 - INFO - 	 Tokenize 1 texts...


  0%|          | 0/1 [00:00<?, ?ba/s]

11/07/2022 19:14:05 - INFO - 	 ***** Running Inference on 1 texts *****


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 19:14:07,924 bgg_predict  DEBUG    [[(200, 211), (287, 296), (929, 938), (1520, 1529)], [(439, 453), (463, 467), (530, 546)], [(632, 648), (649, 651)], [(714, 741), (754, 756)], [(689, 700), (777, 781), (840, 851)], [(785, 799), (819, 821)], [(917, 938), (986, 998)], [(1217, 1228), (1235, 1238), (1267, 1269)], [(392, 406), (1295, 1311)], [(1484, 1492), (1499, 1502), (1533, 1535)], [(468, 489), (1975, 1996)], [(2027, 2033), (2069, 2082), (2108, 2110)], [(2219, 2227), (2237, 2245), (2569, 2577), (2681, 2689), (2708, 2716)], [(2278, 2286), (2306, 2309), (2350, 2391)], [(2406, 2451), (2443, 2446)], [(2511, 2522), (2591, 2603), (2645, 2649)], [(2581, 2590), (2664, 2666)], [(2336, 2349), (2723, 2736)], [(2999, 3011), (3100, 3112)], [(2912, 2921), (3123, 3132)]]
11/07/2022 19:14:07 - DEBUG - 	 [[(200, 211), (287, 296), (929, 938), (1520, 1529)], [(439, 453), (463, 467), (530, 546)], [(632, 648), (649, 651)], [(714, 741), (754, 756)], [(689, 700), (777, 781), (840, 851)], [(785, 799

  0%|          | 0/1 [00:00<?, ?ba/s]

11/07/2022 19:14:11 - INFO - 	 ***** Running Inference on 1 texts *****


Inference:   0%|          | 0/1 [00:00<?, ?it/s]