In [1]:
WORKING_ON_COLAB = True

if WORKING_ON_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    DATASET_FILE_PATH = '/content/drive/My Drive/Projects/IRBoardGameComplexity/dataset.csv'
    !pip install fastcoref==2.0.*
else:
    DATASET_FILE_PATH = 'data/dataset.csv'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:

from fastcoref import spacy_component
import spacy

nlp = spacy.load("en_core_web_sm", exclude=["parser", "lemmatizer", "ner", "textcat"])
nlp.add_pipe("fastcoref")

Downloading:   0%|          | 0.00/819 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/393 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/362M [00:00<?, ?B/s]

<fastcoref.spacy_component.spacy_component.FastCorefResolver at 0x7ff155423f10>

In [3]:
import logging

logger = logging.getLogger('bgg_predict')
logger.handlers.clear()
handler = logging.StreamHandler()
formatter = logging.Formatter(
        '%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)

logger.debug('test')

2022-11-07 22:02:44,545 bgg_predict  DEBUG    test
DEBUG:bgg_predict:test


In [4]:
import re
from typing import List
from dataclasses import dataclass

regex_mail = re.compile(r'\w+(?:\.\w+)*?@\w+(?:\.\w+)+')

@dataclass
class Sentence:
    content: str
    start: int
    end: int
    
    def does_include_pos(self, pos: int) -> bool:
        return self.start <= pos <= self.end

def clean_text(text: str) -> str:
    return regex_mail.sub('', text)

def get_sentences_from_text(text: str) -> List[Sentence]:
    # assert no continuous dots because of text cleared while building the dataset
    sentences = text.split('.')
    
    res = []
    char_accumulator = 0
    for sentence in sentences:
        res.append(Sentence(sentence, char_accumulator, char_accumulator + len(sentence) - 1))
        char_accumulator += len(sentence) + 1
        
    if res[-1].content == '':
        res.pop()
        
    return res

In [8]:
if WORKING_ON_COLAB:
    # from https://github.com/python/cpython/blob/main/Lib/bisect.py#L68
    def bisect_left(a, x, lo=0, hi=None, *, key=None):
        if hi is None:
            hi = len(a)
        # Note, the comparison uses "<" to match the
        # __lt__() logic in list.sort() and in heapq.
        if key is None:
            while lo < hi:
                mid = (lo + hi) // 2
                if a[mid] < x:
                    lo = mid + 1
                else:
                    hi = mid
        else:
            while lo < hi:
                mid = (lo + hi) // 2
                if key(a[mid]) < x:
                    lo = mid + 1
                else:
                    hi = mid
        return lo
else:
    from bisect import bisect_left
from typing import List, Tuple
from dataclasses import dataclass

# necessary to use bisect_left with ranges
@dataclass
class Interval:
    start: int
    end: int
    
    def __lt__(self, other) -> bool:
       return self.start < self.end < other.start
    
    def __eq__(self, other) -> bool:
       return self.start <= other.start <= self.end

def get_sentences_from_clusters(clusters: List[List[Tuple[int, int]]], sentences: List[Sentence]) -> List[List[int]]:
    '''find the sentence each cluster belongs to'''
    sentence_clusters = []
    for cluster in clusters:
        sentence_clusters.append([bisect_left(sentences, Interval(entity[0], entity[1]), key=lambda x: Interval(x.start, x.end)) for entity in cluster])

    return sentence_clusters

text = 'Alice goes down the rabbit hole. Where she would discover a new reality beyond her expectations.'
sentences = get_sentences_from_text(text)
clusters = [[(0, 5), (39, 42), (79, 82)]]
get_sentences_from_clusters(clusters, sentences)

[[0, 1, 1]]

In [9]:
from typing import List, Set
from itertools import groupby
from operator import itemgetter
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

def get_rule_groups_from_sentence_clusters(sentences: List[Sentence], sentence_clusters: List[List[int]]) -> List[List[int]]:
    def normalize_group(group: Set[int]) -> List[List[int]]:
        '''each group could contain multiple consecutive sublists. this method split these sublists'''
        res = []

        # https://stackoverflow.com/a/23861347/5587393
        for k, g in groupby(enumerate(sorted(list(group))), lambda x: x[0] - x[1]):
            res.append(list(map(itemgetter(1), g)))

        return res
    # the graph is built as a directed sparse graph where the first element of each cluster
    # is connected to the other elements in the same cluster
    graph = [[0 for _ in range(len(sentences))] for __ in range(len(sentences))]
    for cluster in sentence_clusters:
        for sentence in cluster[1:]:
            graph[cluster[0]][sentence] = 1

    # find the connected components of the graph created from the clusters returned after coref     
    graph = csr_matrix(graph)
    n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True)
    groups = [set() for _ in range(n_components)]
    for i, label in enumerate(labels):
        groups[label].add(i)

    return [norm_group for group in groups for norm_group in normalize_group(group)]

In [10]:
import itertools
from typing import List, Tuple
import pandas as pd

def get_rules(text: str) -> List[str]:
    text = clean_text(text)
    sentences = get_sentences_from_text(text)
    
    doc = nlp(text, component_cfg = { "fastcoref": {'resolve_text': True} })
    coref_clusters = doc._.coref_clusters
    logger.debug(coref_clusters)
    
    sentence_clusters = get_sentences_from_clusters(coref_clusters, sentences)
    rule_groups = get_rule_groups_from_sentence_clusters(sentences, sentence_clusters)
    
    return ['. '.join([sentences[s_index].content for s_index in group]) for group in rule_groups]

def get_rules_features(text: str) -> Tuple[int, float]:
    rules = get_rules(text)
    rule_count = len(rules)
    return rule_count, len(text) / rule_count

def remove_columns_prefix(df: pd.core.frame.DataFrame) -> None:
    '''remove prefix 'info.' from the columns of df'''
    df.rename(columns=lambda c: c.rsplit('.', 1)[-1], inplace=True)
    
df_features = pd.DataFrame()
with pd.read_csv(DATASET_FILE_PATH, chunksize=5) as reader:
    for df in reader:
        remove_columns_prefix(df)
        df_rules_features = df.apply(lambda x: pd.Series(get_rules_features(x.rulebook), 
                                     index=['rule_count', 'avg_rule_len']), axis='columns')
        df_features = pd.concat([df_features, df[['numweights', 'averageweight', 'playingtime', 'family']].join(df_rules_features)])
        
display(df_features)

  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:07:50,180 bgg_predict  DEBUG    [[(200, 211), (287, 296), (929, 938), (1520, 1529)], [(439, 453), (463, 467), (530, 546)], [(632, 648), (649, 651)], [(714, 741), (754, 756)], [(689, 700), (777, 781), (840, 851)], [(785, 799), (819, 821)], [(917, 938), (986, 998)], [(1217, 1228), (1235, 1238), (1267, 1269)], [(392, 406), (1295, 1311)], [(1484, 1492), (1499, 1502), (1533, 1535)], [(468, 489), (1975, 1996)], [(2027, 2033), (2069, 2082), (2108, 2110)], [(2219, 2227), (2237, 2245), (2569, 2577), (2681, 2689), (2708, 2716)], [(2278, 2286), (2306, 2309), (2350, 2391)], [(2406, 2451), (2443, 2446)], [(2511, 2522), (2591, 2603), (2645, 2649)], [(2581, 2590), (2664, 2666)], [(2336, 2349), (2723, 2736)], [(2999, 3011), (3100, 3112)], [(2912, 2921), (3123, 3132)]]
DEBUG:bgg_predict:[[(200, 211), (287, 296), (929, 938), (1520, 1529)], [(439, 453), (463, 467), (530, 546)], [(632, 648), (649, 651)], [(714, 741), (754, 756)], [(689, 700), (777, 781), (840, 851)], [(785, 799), (819, 821)]

  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:09:02,071 bgg_predict  DEBUG    [[(521, 547), (566, 570), (72673, 72699), (72718, 72722)], [(20, 37), (708, 742), (749, 766), (1497, 1514), (3046, 3064), (5212, 5230), (6635, 6643), (6990, 7007), (8983, 9000), (9229, 9237), (9354, 9366), (9425, 9434), (10869, 10878), (12948, 12969), (16841, 16850), (25328, 25345), (25369, 25403), (27413, 27426), (28322, 28335), (28407, 28435), (30642, 30651), (31968, 31981), (33147, 33160), (36784, 36797), (37963, 37976), (42185, 42198), (43094, 43107), (45414, 45423), (46456, 46471), (46935, 46944), (49054, 49071), (49095, 49129), (52954, 52967), (53092, 53101), (53197, 53206), (53315, 53329), (53365, 53374), (53420, 53433), (53483, 53491), (53835, 53848), (53887, 53900), (53910, 53913), (53991, 54004), (57599, 57612), (57973, 57982), (61610, 61624), (62660, 62677), (63102, 63111), (66193, 66206), (66625, 66642), (66666, 66683), (67414, 67431), (67722, 67736), (68963, 68981), (70323, 70336), (71129, 71147), (72172, 72189), (72860, 72877)

  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:09:16,920 bgg_predict  DEBUG    [[(0, 10), (223, 237), (371, 385), (413, 427), (526, 540), (1050, 1060), (1098, 1107), (3873, 3882), (4043, 4057), (4487, 4501), (4555, 4569), (5364, 5378), (6750, 6764), (6816, 6820), (6859, 6864), (7408, 7422), (10059, 10073), (10791, 10805), (12411, 12426), (12699, 12709), (16121, 16135), (16433, 16442), (16833, 16842), (17095, 17104), (17358, 17367), (17617, 17631), (17657, 17667), (17723, 17732), (17905, 17914), (18422, 18431), (18528, 18542), (19082, 19096), (19645, 19654)], [(189, 215), (283, 285)], [(133, 163), (313, 325)], [(491, 507), (550, 574), (585, 589)], [(154, 163), (692, 701), (766, 775), (844, 853), (1772, 1781), (4159, 4168), (4384, 4393), (12198, 12207), (17677, 17686), (19535, 19544)], [(678, 701), (752, 775)], [(734, 748), (784, 786), (1895, 1909)], [(648, 674), (796, 815)], [(947, 960), (966, 971)], [(854, 866), (1001, 1013)], [(1247, 1255), (1349, 1359)], [(1389, 1401), (1433, 1438)], [(1453, 1476), (1537, 1547)], [(

  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:09:27,338 bgg_predict  DEBUG    [[(88, 90), (182, 185), (492, 495), (603, 620), (12153, 12155), (12247, 12250), (12557, 12560), (12668, 12685)], [(1067, 1087), (1109, 1124), (13174, 13189), (13220, 13235)], [(1101, 1103), (1243, 1245)], [(156, 163), (1354, 1361), (12221, 12228), (13419, 13426)], [(1405, 1415), (1675, 1685)], [(1761, 1777), (1817, 1827)], [(1992, 2005), (2045, 2047)], [(2360, 2366), (2408, 2416)], [(2749, 2757), (2831, 2842), (4993, 5004)], [(2849, 2858), (2932, 2941)], [(3127, 3135), (3180, 3194)], [(2563, 2577), (3322, 3336)], [(3437, 3456), (3536, 3545), (3669, 3692), (8360, 8379), (8459, 8468), (8592, 8615)], [(2321, 2330), (3618, 3627), (7070, 7079), (7244, 7253), (8541, 8550), (11993, 12002)], [(3889, 3897), (3912, 3920), (4167, 4175), (4186, 4194), (6384, 6392), (6800, 6808), (8835, 8843), (9090, 9098), (9109, 9117), (11307, 11315)], [(3937, 3945), (4007, 4018), (4115, 4126)], [(3972, 3985), (4041, 4064)], [(4237, 4243), (4250, 4254)], [(4307, 4317)

  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:09:31,974 bgg_predict  DEBUG    [[(103, 114), (142, 145)], [(255, 284), (294, 298)], [(544, 563), (583, 606)], [(657, 667), (669, 671)], [(765, 779), (906, 918)], [(1309, 1337), (1355, 1357)], [(1388, 1398), (1416, 1420)], [(1675, 1710), (1713, 1718)], [(1541, 1565), (1740, 1769)], [(2544, 2575), (2576, 2587)], [(3165, 3180), (3275, 3292)], [(1855, 1864), (3390, 3399), (3452, 3461)], [(3361, 3369), (3401, 3403)], [(3553, 3568), (3659, 3663)]]
DEBUG:bgg_predict:[[(103, 114), (142, 145)], [(255, 284), (294, 298)], [(544, 563), (583, 606)], [(657, 667), (669, 671)], [(765, 779), (906, 918)], [(1309, 1337), (1355, 1357)], [(1388, 1398), (1416, 1420)], [(1675, 1710), (1713, 1718)], [(1541, 1565), (1740, 1769)], [(2544, 2575), (2576, 2587)], [(3165, 3180), (3275, 3292)], [(1855, 1864), (3390, 3399), (3452, 3461)], [(3361, 3369), (3401, 3403)], [(3553, 3568), (3659, 3663)]]


  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:09:39,040 bgg_predict  DEBUG    [[(378, 395), (537, 546)], [(708, 711), (718, 729)], [(824, 830), (877, 879)], [(782, 792), (901, 913), (1002, 1014)], [(1188, 1205), (1314, 1318), (1324, 1326)], [(1468, 1520), (1545, 1550)], [(1449, 1520), (1565, 1568)], [(1160, 1169), (1626, 1635), (1850, 1859), (2507, 2516), (4428, 4437), (4455, 4464), (4626, 4635), (4778, 4787), (4853, 4862), (4937, 4946), (5045, 5054), (5097, 5106), (5155, 5164), (5387, 5396), (5534, 5543), (5700, 5709)], [(1595, 1613), (1663, 1671), (1732, 1740), (1751, 1753)], [(2128, 2136), (2154, 2156)], [(2281, 2288), (2321, 2325)], [(2174, 2187), (2490, 2498)], [(2676, 2682), (2787, 2797)], [(2907, 2916), (2957, 2961)], [(3447, 3458), (3474, 3476)], [(3625, 3637), (3683, 3685)], [(52, 66), (3811, 3825)], [(68, 72), (3827, 3831)], [(3969, 3982), (4221, 4234)], [(3969, 3985), (4221, 4237)], [(4001, 4011), (4253, 4263)], [(4027, 4041), (4286, 4300), (4337, 4351), (5027, 5041), (5078, 5092), (5669, 5683)], [(4304, 4

  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:09:42,756 bgg_predict  DEBUG    [[(153, 164), (201, 204)], [(363, 374), (467, 471)], [(391, 404), (672, 685)], [(650, 685), (693, 698)], [(783, 825), (826, 835)], [(335, 344), (868, 877)], [(878, 895), (901, 906)], [(928, 940), (953, 964), (1027, 1030)], [(408, 414), (1147, 1153)], [(1125, 1153), (1161, 1166)], [(1266, 1288), (1311, 1315)], [(1354, 1413), (1423, 1428)], [(1779, 1813), (1820, 1823), (1859, 1862)], [(2055, 2066), (2105, 2109)], [(2045, 2054), (2268, 2277)], [(2721, 2746), (2779, 2783)], [(2979, 3007), (3013, 3018)], [(3031, 3056), (3099, 3102)], [(3070, 3102), (3109, 3111), (3144, 3146), (3157, 3160)], [(3383, 3387), (3420, 3424)], [(3431, 3440), (3519, 3535)]]
DEBUG:bgg_predict:[[(153, 164), (201, 204)], [(363, 374), (467, 471)], [(391, 404), (672, 685)], [(650, 685), (693, 698)], [(783, 825), (826, 835)], [(335, 344), (868, 877)], [(878, 895), (901, 906)], [(928, 940), (953, 964), (1027, 1030)], [(408, 414), (1147, 1153)], [(1125, 1153), (1161, 1166)], [(

  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:10:27,000 bgg_predict  DEBUG    [[(302, 311), (326, 337), (338, 371)], [(432, 441), (499, 508), (560, 569), (593, 602), (777, 786), (1002, 1011), (1211, 1220)], [(810, 820), (830, 832)], [(746, 786), (868, 877)], [(362, 371), (1105, 1113), (3878, 3886), (6121, 6129), (6213, 6221), (6279, 6288), (6968, 6976), (8290, 8298), (8741, 8749), (9359, 9369), (10419, 10427), (10739, 10747), (10775, 10783), (11029, 11037), (11297, 11308), (19095, 19103), (31320, 31328), (31394, 31402), (33832, 33849), (36700, 36708), (36807, 36815), (36869, 36880), (38231, 38239), (38509, 38517), (39361, 39372), (39680, 39688), (41131, 41145), (42823, 42831)], [(1333, 1341), (1391, 1393)], [(1434, 1448), (1456, 1459)], [(1560, 1567), (1610, 1614), (1639, 1643), (1657, 1661), (1728, 1732), (1803, 1807), (1817, 1821), (1825, 1829)], [(1587, 1588), (1849, 1850), (43991, 43992), (44096, 44097), (44125, 44127), (44263, 44264), (44307, 44308), (44455, 44456)], [(2076, 2080), (2138, 2142)], [(2261, 2279), 

  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:10:30,668 bgg_predict  DEBUG    [[(6, 19), (508, 517)], [(444, 473), (608, 635)], [(652, 675), (704, 706)], [(681, 695), (765, 779)], [(734, 779), (801, 805)], [(839, 862), (928, 932)], [(1150, 1155), (1188, 1193)], [(1599, 1610), (1662, 1670), (1706, 1711), (1731, 1739)], [(1830, 1837), (1866, 1870)], [(1910, 1944), (1993, 2010)], [(2046, 2047), (2078, 2079)], [(2269, 2288), (2313, 2317)], [(2443, 2450), (2490, 2495)], [(2507, 2514), (2564, 2568)], [(2664, 2672), (2765, 2769), (2791, 2802)], [(2815, 2831), (2863, 2867)], [(2782, 2790), (2873, 2887)], [(2877, 2902), (2937, 2941)]]
DEBUG:bgg_predict:[[(6, 19), (508, 517)], [(444, 473), (608, 635)], [(652, 675), (704, 706)], [(681, 695), (765, 779)], [(734, 779), (801, 805)], [(839, 862), (928, 932)], [(1150, 1155), (1188, 1193)], [(1599, 1610), (1662, 1670), (1706, 1711), (1731, 1739)], [(1830, 1837), (1866, 1870)], [(1910, 1944), (1993, 2010)], [(2046, 2047), (2078, 2079)], [(2269, 2288), (2313, 2317)], [(2443, 2450), (24

  0%|          | 0/1 [00:00<?, ?ba/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

2022-11-07 23:10:33,701 bgg_predict  DEBUG    [[(88, 104), (292, 300)], [(28, 42), (365, 376), (626, 637), (665, 670), (708, 712), (799, 810), (1394, 1405), (1620, 1631)], [(391, 399), (446, 454), (459, 462), (484, 492)], [(419, 428), (523, 532), (883, 892), (1416, 1425)], [(550, 559), (563, 568)], [(588, 603), (615, 620)], [(253, 261), (789, 797)], [(1285, 1289), (1322, 1326), (1351, 1360)], [(1611, 1631), (1665, 1675)], [(1676, 1685), (1688, 1692)], [(1723, 1734), (1742, 1752), (1818, 1822)], [(1923, 1936), (1940, 1944)], [(2079, 2088), (2118, 2122), (2131, 2136)], [(2093, 2103), (2206, 2215), (2316, 2326)], [(2177, 2182), (2272, 2279), (2310, 2312), (2346, 2348)]]
DEBUG:bgg_predict:[[(88, 104), (292, 300)], [(28, 42), (365, 376), (626, 637), (665, 670), (708, 712), (799, 810), (1394, 1405), (1620, 1631)], [(391, 399), (446, 454), (459, 462), (484, 492)], [(419, 428), (523, 532), (883, 892), (1416, 1425)], [(550, 559), (563, 568)], [(588, 603), (615, 620)], [(253, 261), (789, 797)], 

Unnamed: 0,numweights,averageweight,playingtime,family,rule_count,avg_rule_len
0,703,2.1579,60,['familygames'],35.0,90.371429
1,62,3.1452,90,['strategygames'],438.0,166.388128
2,100,1.81,90,['thematic'],94.0,210.553191
3,3030,1.4858,30,['familygames'],119.0,120.487395
4,439,2.7813,120,['strategygames'],27.0,137.259259
5,259,2.2162,60,"['strategygames', 'familygames']",43.0,137.651163
6,546,1.8718,45,['familygames'],37.0,98.378378
7,148,2.9392,240,['strategygames'],293.0,192.754266
8,576,1.158,90,"['partygames', 'familygames']",50.0,64.54
9,656,1.1265,30,"['partygames', 'familygames']",16.0,147.375
