In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.column import Column
from pyspark.sql.types import StructType, ArrayType, StringType, DataType

# CAREFUL: pandas must be imported AFTER pyspark!
import pandas as pd

import itertools
from trec_car.read_data import iter_annotations
from tools.dataloaders import WikipediaCBOR, PageFormat


spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

spark.createDataFrame(sc.parallelize([('Alice', 1)]))
#sqlContext = SQLContext(sc)

DataFrame[_1: string, _2: bigint]

In [2]:
wikipedia_cbor = WikipediaCBOR("wikipedia/car-wiki2020-01-01/enwiki2020.cbor",
                                "wikipedia/car-wiki2020-01-01/partitions",
                                # repreprocess=True, page_lim=1000
                                )

Loaded from cache


In [58]:
from typing import cast, Any, Optional, Tuple, Union, NamedTuple

from trec_car.read_data import (AnnotationsFile, Page, Para,
                                Section, List as ParaList, ParaLink, ParaText, ParaBody)

from tokenizers.pre_tokenizers import Whitespace
from collections import defaultdict

from io import StringIO

def preprocess_page_new(enumerated_page: Tuple[int, Page]):
    """
    Transform a list of pages into a flattened representation that can
    then be easily (de)serialized.
    """

    id, page = enumerated_page

    # For the sake of easy link spans they are byte oriented to make
    # it easier for the rust std
    # page_content = StringIO()
    split_content = []
    orig_page_content = StringIO()
    links = []

    splitter = Whitespace()
    # splitter = 0

    # Encode a link. Cast to padding if the link was not "common".
    # Call this method only after preprocessing has been done!
    def encode_link(link):
        #return self.key_encoder.get(link, 0)
        return wikipedia_cbor.key_encoder.get(link, 0)

    # People say pattern matching is overrated.
    # I beg to differ.
    # (It's also true that a tree structure for tokenization makes
    # absolutely no sense - but I don't get to decide things apparently).
    def handle_section(skel: Section):
        for subskel in skel.children:
            visit_section(subskel)

    def handle_list(skel: ParaList):
        visit_section(skel.body)

    def handle_para(skel: Para):
        paragraph = skel.paragraph
        bodies = paragraph.bodies

        for body in bodies:
            visit_section(body)

    def handle_paratext(body: ParaBody):
        split_body = splitter.pre_tokenize_str(body.get_text())
        #print('from handle_paratext: "' + body.get_text() + '"')

        # take care of the space...
        running_prefix = 0
        if len(split_content) > 0:
            running_prefix = split_content[-1][1][1]
        
        while running_prefix < len(orig_page_content.getvalue()) and orig_page_content.getvalue()[running_prefix].isspace():
            running_prefix += 1

        orig_page_content.write(body.get_text())
        
        # there may be a case where the tokenizer skipped a newline. In this case we need to remove one byte:
        
        # No I have no idea on how to make this more readable...
        """
        if len(split_body):
            tok, (begin, end) = split_body[0]
            begin += running_prefix
            end += running_prefix

            if orig_page_content.getvalue()[begin:end] != tok:
                running_prefix -= 1
        """

        #print(f"After skipping: {orig_page_content.getvalue()[running_prefix:]}")

        split_body_shifted = [(text, (begin_offset + running_prefix,
                                end_offset + running_prefix)) for text, (begin_offset, end_offset) in split_body]

        # print(split_body)

        for tok, (begin, end) in split_body_shifted:
            # assert tok == orig_page_content.getvalue()[begin:end], f"generated {orig_page_content.getvalue()[begin:end]} but expected {tok}"
            assert tok == orig_page_content.getvalue()[begin:end], f"generated {orig_page_content.getvalue()[begin:end]} but expected {tok}. Context: {orig_page_content.getvalue()[-100:]}, body: {body.get_text()}. split_body[0]: {split_body[0]}"
        split_content.extend(split_body_shifted)

    def handle_paralink(body: ParaLink):
        encoded_link = encode_link(body.page)
        split_body = splitter.pre_tokenize_str(body.get_text())

        #print('from handle_paralink: "' + body.get_text() + '"')

        #running_prefix = running_length_count# + int(running_length_count != 0)
        running_prefix = 0
        if len(split_content) > 0:
            running_prefix = split_content[-1][1][1]
        
        while running_prefix < len(orig_page_content.getvalue()) and orig_page_content.getvalue()[running_prefix].isspace():
            running_prefix += 1

        orig_page_content.write(body.get_text())
        

        split_body = [(text, (begin_offset + running_prefix,
                                end_offset + running_prefix)) for text, (begin_offset, end_offset) in split_body]

        split_content.extend(split_body)

        if len(split_body) > 0:
            end_byte_span = split_body[-1][1][1] - 1
            start_mention_idx = len(split_content) - len(split_body)
            links.append((encoded_link, start_mention_idx, len(split_content)))

        for tok, (begin, end) in split_body:
            assert tok == orig_page_content.getvalue()[begin:end], f"generated {orig_page_content.getvalue()[begin:end]} but expected {tok}. Context: {orig_page_content.getvalue()[-100:]}, body: {body.get_text()}"


    def nothing():
        return lambda body: None

    handler = defaultdict(nothing, {Section: handle_section,
                                    Para: handle_para,
                                    ParaList: handle_list,
                                    ParaLink: handle_paralink,
                                    ParaText: handle_paratext})

    def visit_section(skel):
        # Recur on the sections
        handler[type(skel)](skel)

    for skel in page.skeleton:
        visit_section(skel)

    return id, page.page_name, orig_page_content.getvalue(), split_content, links

    # return PageFormat(id, page.page_name, page_content.getvalue(), links)

In [59]:

from pyspark.sql.types import *
# a = wikipedia_cbor[0]
k = 1000
with open(wikipedia_cbor.cbor_path, "rb") as cbor_file:
    topk = itertools.islice(iter_annotations(cbor_file), k)
    topk = map(preprocess_page_new, enumerate(topk))

    df = pd.DataFrame(topk, columns=["id", "title", "text", "tokenized_text", "links"])
    #df = spark.createDataFrame(df)
    # page = next(iter_annotations(cbor_file))
    # dataloader_page = wikipedia_cbor.preprocess_page(page)
    # spacy_collect_mentions(dataloader_page)

In [10]:
import spacy

nlp = spacy.load("xx_ent_wiki_sm")

In [11]:
from tools.dataloaders import TokenizedText
from trie_search import RecordTrieSearch

class MyRecordTrie(RecordTrieSearch):
    def __init__(self, records):
        super().__init__("<Q", records)

    def search_all_patterns(self, tokens: TokenizedText):
        words = [tok[0] for tok in tokens]

        for i, (word, span) in enumerate(tokens):
            for pattern in self._TrieSearch__search_prefix_patterns(word, words[i+1:]):
                weight = self[pattern][0][0]
                yield pattern, span[0], weight, i # exact token position
    
    def search_longest_patterns(self, tokens):
        # avoid overlapping mentions
        all_patterns = self.search_all_patterns(tokens)
        check_field = [0] * len(tokens)
        for pattern, start_idx, weight, idx in sorted(
                all_patterns, key=lambda x: len(x[0]), reverse=True):
            pattern_length = pattern.count(" ") + 1
            target_field = check_field[idx:idx + pattern_length]
            check_sum = sum(target_field)
            if check_sum == 0:
                for i in range(pattern_length):
                    check_field[idx + i] = 1
                yield pattern, start_idx, weight, idx

In [14]:
from typing import List, Tuple
#from pyspark.sql.functions import pandas_udf
# from spacy.attrs import LOWER, ENT_TYPE
import numpy as np
import pandas as pd

from trie_search import RecordTrieSearch
from heapq import merge
from tools.dataloaders import TokenizedText, Link

"""
def remap_links(
    text: str,
    links: List[Link]
) -> List[Link]:
    remapped_links = []
    for link in links:
        # revert the tokenization algorithm
        start_byte = tokenized_text[link[1]][1][0]
        end_byte = tokenized_text[link[2]-1][1][1]

        #print(text[start_byte:end_byte])

        exact_mentions[text[start_byte:end_byte]] = link[0]
        remapped_links.append((link[0], start_byte, end_byte))
    
    return remapped_links

def apply_ner(
    page_id: int,
    title: str,
    text: str
    links: List[Link]
) -> List[Link]:
    pass #nlp(text)
    
"""

def autolink(
    page_id: int,
    title: str,
    text: str,
    tokenized_text: List[Tuple[str, Tuple[int, int]]],
    remapped_links: List[Tuple[int, int, int]]
) -> List[Link]:
    # title_toks = nlp(title)
    # title_toks = Whitespace().pre_tokenize_str(title)

    link_idx = 0

    exact_mentions = {}
    mention_builder = []

    # TODO: deal with ambiguities...
    #exact_mentions[" ".join(str(tok[0]) for tok in title_toks)] = page_id
    exact_mentions[title] = page_id

    #print(exact_mentions)

    remapped_links = []
    for link in links:
        # revert the tokenization algorithm
        start_byte = tokenized_text[link[1]][1][0]
        end_byte = tokenized_text[link[2]-1][1][1]

        #print(text[start_byte:end_byte])

        exact_mentions[text[start_byte:end_byte]] = link[0]
        remapped_links.append((link[0], start_byte, end_byte))
    

    #print(list(map(lambda x: (x[0], (x[1],)), exact_mentions.items())))

    trie = MyRecordTrie(map(lambda x: (x[0], (x[1],)), exact_mentions.items()))

    #print(trie.keys())

    # print(trie.items())
    patterns = sorted(trie.search_longest_patterns(tokenized_text), key=lambda x: x[1]) # sort by apparition


    link = None
    link_idx = 0
    new_links = []
    
    for title, idx, new_link_id, token_position in patterns:
        
        new_link = (new_link_id, idx, idx + len(title))
        # print(title, new_link)

        if link_idx < len(remapped_links):
            link = remapped_links[link_idx]
        
        while link_idx < len(remapped_links) - 1 and idx > link[2]:
            link_idx += 1
            link = remapped_links[link_idx]
            #print("Increasing")

        if link_idx >= len(remapped_links):
            link = None
        
        if link is None or idx < link[1]:
            new_links.append(new_link)
            #print(title, new_link)
        else:
            #print("Found existing link",  link)
            #print("Compare with: ", new_link)
            pass
    
    # print(f"The page has {len(remapped_links)} links by default")
    # print("Added new ", len(new_links), "links")
    return list(merge(remapped_links, new_links, key=lambda x: x[1]))

In [68]:
page_id = 0
title = "Anarchism"
tokenized_text = df.iloc[0]["tokenized_text"]

text = df.iloc[0]["text"]
links = df.iloc[0]["links"]

from pyspark.sql.functions import udf
from pyspark.sql.types import *

autolink_udf = udf(autolink, )
autolink(page_id, title, text, tokenized_text, links)
#df.apply(lambda x: autolink(*x), axis=1)

The page has 444 links by default
Added new  212 links


[(0, 0, 9),
 (4986805, 16, 34),
 (7226817, 35, 44),
 (6089297, 49, 66),
 (5245450, 80, 91),
 (2345670, 143, 155),
 (2229207, 157, 170),
 (1969097, 201, 212),
 (5863601, 269, 288),
 (2628402, 405, 422),
 (0, 424, 433),
 (2967692, 500, 505),
 (0, 550, 559),
 (5831329, 585, 593),
 (2587732, 601, 619),
 (5294848, 637, 646),
 (6954572, 651, 667),
 (6227561, 676, 710),
 (7794129, 714, 723),
 (2638933, 725, 737),
 (5622587, 739, 750),
 (6051218, 752, 761),
 (6565959, 766, 789),
 (1755882, 885, 915),
 (4654925, 963, 991),
 (2299482, 1051, 1064),
 (7184765, 1077, 1089),
 (3190375, 1159, 1165),
 (656667, 1170, 1193),
 (1626157, 1444, 1448),
 (3454222, 1625, 1642),
 (5463212, 1799, 1813),
 (6791215, 1830, 1846),
 (7226817, 2032, 2041),
 (6409622, 2090, 2112),
 (7396139, 2239, 2253),
 (32917, 2397, 2411),
 (656667, 2424, 2462),
 (4756746, 2482, 2503),
 (2980391, 2535, 2558),
 (7226817, 2825, 2834),
 (5585653, 2857, 2866),
 (1505556, 2871, 2896),
 (6942452, 2912, 2917),
 (3032856, 2920, 2930),
 (41

In [20]:
import pyarrow

In [21]:
pyarrow.Table

['Array',
 'ArrayValue',
 'ArrowException',
 'ArrowIOError',
 'ArrowInvalid',
 'ArrowKeyError',
 'ArrowMemoryError',
 'ArrowNotImplementedError',
 'ArrowSerializationError',
 'ArrowTypeError',
 'BaseExtensionType',
 'BinaryArray',
 'BinaryValue',
 'BooleanArray',
 'BooleanValue',
 'Buffer',
 'BufferOutputStream',
 'BufferReader',
 'BufferedInputStream',
 'BufferedOutputStream',
 'ChunkedArray',
 'Codec',
 'CompressedInputStream',
 'CompressedOutputStream',
 'DataType',
 'Date32Array',
 'Date32Value',
 'Date64Array',
 'Date64Value',
 'Decimal128Array',
 'Decimal128Type',
 'DecimalValue',
 'DeserializationCallbackError',
 'DictionaryArray',
 'DictionaryMemo',
 'DictionaryType',
 'DictionaryValue',
 'DoubleValue',
 'DurationArray',
 'DurationType',
 'DurationValue',
 'ExtensionArray',
 'ExtensionType',
 'Field',
 'FileSystem',
 'FixedSizeBinaryArray',
 'FixedSizeBinaryType',
 'FixedSizeBinaryValue',
 'FixedSizeBufferWriter',
 'FixedSizeListArray',
 'FixedSizeListType',
 'FixedSizeListValu

In [None]:
df.pipe()

In [72]:
docs = nlp("Natasha went to the beach with her friend Monica")


Natasha
Monica


In [91]:
ent.label_

'PER'

In [10]:
from trie_search import RecordTrieSearch

patterns = [("Albert Einstein", (1,)), ("Einstein", (1,)), ("Albert Hitchcock", (2,))]
t = RecordTrieSearch("<Q", patterns)

[(1,)]

In [87]:

splitter = Whitespace()



for p in search_all_patterns("This is a a little lesson in trickery, this is going down in history!."):
    print(p)

['This', 'is', 'a', 'a', 'little', 'lesson', 'in', 'trickery', ',', 'this', 'is', 'going', 'down', 'in', 'history', '!.']
('trickery', 29, 3)
('history', 61, 2)


In [54]:
trie.splitter

'\\w+|[^\\w\\s]+'