In [1]:
import collections

import datasets

import spacy
from spacy import displacy
from spacy.tokens import Span

In [2]:
data = datasets.load_dataset("coref-data/preco_indiscrim", split="validation")

In [3]:
data.features

{'id': Value(dtype='string', id=None),
 'sentences': [{'id': Value(dtype='int64', id=None),
   'speaker': Value(dtype='null', id=None),
   'text': Value(dtype='string', id=None),
   'tokens': [{'id': Value(dtype='int64', id=None),
     'text': Value(dtype='string', id=None)}]}],
 'text': Value(dtype='string', id=None),
 'coref_chains': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
 'genre': Value(dtype='string', id=None),
 'meta_data': {'comment': Value(dtype='string', id=None)}}

In [4]:
## https://stackoverflow.com/questions/470690/how-to-automatically-generate-n-distinct-colors

from typing import Iterable, Tuple
import colorsys
import itertools
from fractions import Fraction
from pprint import pprint

def zenos_dichotomy() -> Iterable[Fraction]:
    """
    http://en.wikipedia.org/wiki/1/2_%2B_1/4_%2B_1/8_%2B_1/16_%2B_%C2%B7_%C2%B7_%C2%B7
    """
    for k in itertools.count():
        yield Fraction(1,2**k)

def fracs() -> Iterable[Fraction]:
    """
    [Fraction(0, 1), Fraction(1, 2), Fraction(1, 4), Fraction(3, 4), Fraction(1, 8), Fraction(3, 8), Fraction(5, 8), Fraction(7, 8), Fraction(1, 16), Fraction(3, 16), ...]
    [0.0, 0.5, 0.25, 0.75, 0.125, 0.375, 0.625, 0.875, 0.0625, 0.1875, ...]
    """
    yield Fraction(0)
    for k in zenos_dichotomy():
        i = k.denominator # [1,2,4,8,16,...]
        for j in range(1,i,2):
            yield Fraction(j,i)

# can be used for the v in hsv to map linear values 0..1 to something that looks equidistant
# bias = lambda x: (math.sqrt(x/3)/Fraction(2,3)+Fraction(1,3))/Fraction(6,5)

HSVTuple = Tuple[Fraction, Fraction, Fraction]
RGBTuple = Tuple[float, float, float]

def hue_to_tones(h: Fraction) -> Iterable[HSVTuple]:
    for s in [Fraction(6,10)]: # optionally use range
        for v in [Fraction(8,10),Fraction(5,10)]: # could use range too
            yield (h, s, v) # use bias for v here if you use range

def hsv_to_rgb(x: HSVTuple) -> RGBTuple:
    return colorsys.hsv_to_rgb(*map(float, x))

flatten = itertools.chain.from_iterable

def hsvs() -> Iterable[HSVTuple]:
    return flatten(map(hue_to_tones, fracs()))

def rgbs() -> Iterable[RGBTuple]:
    return map(hsv_to_rgb, hsvs())

def rgb_to_css(x: RGBTuple) -> str:
    uint8tuple = map(lambda y: int(y*255), x)
    return "rgb({},{},{})".format(*uint8tuple)

def css_colors() -> Iterable[str]:
    return map(rgb_to_css, rgbs())

def rgb_to_hex(x: RGBTuple) -> str:
    uint8tuple = map(lambda y: int(y*255), x)
    return '#{:02x}{:02x}{:02x}'.format(*uint8tuple)

def hex_colors() -> Iterable[str]:
    return map(rgb_to_hex, rgbs())

In [5]:
list(itertools.islice(hex_colors(), 3))

['#cc5151', '#7f3333', '#51cccc']

In [30]:
from random import randint
from spacy.tokens import Doc

def visualize_document(sentences, coref_chains):
    i = 0
    local_to_global_idx = {}
    for sent_i, s in enumerate(sentences):
        for tok_i in range(len(s["tokens"])):
            local_to_global_idx[(sent_i, tok_i)] = i
            i += 1
    
    words = [t["text"] for s in sentences for t in s["tokens"]]
    # text = " ".join(words)

    nlp = spacy.blank("en")
    doc = Doc(nlp.vocab, words)

    spans = []
    for i, chain in enumerate(coref_chains):
        if len(chain) < 2:
            continue
        for ment in chain:
            sent, start, end = ment
            span = Span(doc,
                        local_to_global_idx[(sent, start)], 
                        local_to_global_idx[(sent, end)] + 1,
                        f"e{i}")
            spans.append(span)

    colors = list(itertools.islice(hex_colors(), len(coref_chains)))
    colors = {f"e{i}": colors[i] for i in range(len(coref_chains))}
    doc.spans["sc"] = spans

    displacy.render(doc, style="span", options={"colors": colors}, jupyter=True)
    # displacy.serve(doc, style="span", options={"colors": colors}, port=1001)

In [7]:
examples = data.to_list()

In [8]:
def head(toks, ment):
    head = None
    for tok in toks[ment[1]:ment[2] + 1]:
        if not head or tok["head"] < head:
            if (tok["head"] - 1 >= ment[1] and tok["head"] - 1 <= ment[2]):
                head = tok["head"]
    assert head is not None or ment[2] - ment[1] == 0
    if head is None:
        head = ment[1] + 1
    return head

def nested_chain(sentences, coref_chains):
    for chain in coref_chains:
        for m1 in chain:
            for m2 in chain:
                if not m1[0] == m2[0] or m1 == m2:
                    continue
                # if not (m1[2] < m2[1] or m1[1] > m2[2]): # and len(chain) == 2:
                # if same head
                toks = sentences[m1[0]]["tokens"]
                if head(toks, m1) == head(toks, m2):
                    print(m1, m2)
                    print(head(toks, m1))
                    return True
    return False

In [216]:
def bare_plurals(sentences, coref_chains):
    for chain in coref_chains:
        for m1 in chain:
            for m2 in chain:
                if m1 == m2:
                    continue
                if sentences[m1[0]]["tokens"][m1[-1]]["text"][-1] == "s" and sentences[m2[0]]["tokens"][m2[-1]]["text"][-1] == "s":
                    return True
    return False

SyntaxError: unterminated string literal (detected at line 7) (2298165345.py, line 7)

In [15]:
ex_i = -1

In [71]:
ex_i -= 3

In [215]:
ex_i += 1
while True:
    if bare_plurals(examples[ex_i]["sentences"], examples[ex_i]["coref_chains"]):
        print(examples[ex_i]["id"])
        visualize_document(examples[ex_i]["sentences"], examples[ex_i]["coref_chains"])
        ex_i += 1
        break #36393

train_36499
