## Counting mentions of Wikipedia articles in Reddit submission titles

### Prerequisites

1. Download a dump of Wikipedia's articles, named `enwiki-{date_string}-pages-articles-multistream.xml.bz2`
2. Download the `enwiki-{date_string}-pages-articles-multistream-index.txt.bz2` file
3. Move those files into the same folder, removing the `enwiki-{date_string}` prefix
4. Process the `xml.bz2` file into a Parquet file using `wikiplain.load_bz2`
5. Run `PageRank.ipynb`
6. Download some of the `RS_{yyyy-mm}.zst` files from https://files.pushshift.io/reddit/submissions/
    - I only use 2015-present, and I cut each download off at 10%

In [121]:
import asyncio
import glob
import gzip
import io
import itertools
import json
import math
import operator
import os
import pickle
import random
import re
import shutil
import socket
import struct
import subprocess
import sys
import tarfile
import time
from collections import ChainMap, defaultdict, deque
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto
from functools import lru_cache, partial
from urllib.parse import urlencode, urlsplit, quote as urlquote, unquote as urlunquote
from xml.sax.saxutils import unescape as xml_unescape
from typing import Any, Awaitable, Callable, Literal, TypeVar

import httpx
import ijson
import mmh3
import numpy as np
import pyarrow.parquet as pq
import polars as pl
import sqlalchemy as sa
import scipy.sparse
import toolz
from dotenv import load_dotenv
from ipywidgets import interact
from spacy.lang.en import English
from sqlalchemy import create_engine
from sqlalchemy.sql import column as sqlcolumn, select, text as sqltext
from tqdm.auto import tqdm
from zstandard import ZstdDecompressor, ZstdDecompressionReader
from arsenal.datastructures.unionfind import UnionFind
from arsenal.datastructures.heap import MinMaxHeap

import wikiplain
from special_cases import SECOND_LEVEL_DOMAINS
from nbhelpers.polars import pager, searcher
from umbc_web.process_possf2 import PENN_TAGS, PENN_TAGS_BY_ID

In [2]:
load_dotenv()

True

In [3]:
pl.Config.set_fmt_str_lengths(160)

polars.config.Config

In [4]:
class RedditRankFiles:
    def __init__(self, date_string):
        self.date_string = date_string
        self.enwiki_dir = f"{os.environ['ENWIKI_DIR']}/{date_string}"
        self.parquet_dir = os.environ.get('ENWIKI_PARQUET_DIR', self.enwiki_dir)
        self.reddit_dir = f"{os.environ['REDDIT_DIR']}"
        try:
            os.mkdir(f"{self.enwiki_dir}/pagerank")
        except FileExistsError:
            pass
    
    @property
    def enwiki_parquet_filename(self):
        return f"{self.parquet_dir}/enwiki_{self.date_string}.parquet"
    
    @property
    def pagerank_parquet_filename(self):
        return f"{self.parquet_dir}/enwiki_{self.date_string}_pagerank.parquet"
    
    @property
    def enwiki_tokenized_database_uri(self):
        return f"sqlite:///{self.parquet_dir}/enwiki_tokenized_{self.date_string}.sqlite"

    @property
    def nub_filename(self):
        return f"{self.enwiki_dir}/pagerank/nub.pkl"
    
    @property
    def id_maps_filename(self):
        return f"{self.enwiki_dir}/pagerank/id_maps.pkl"
    
    @property
    def dense_id_arr_filename(self):
        return f"{self.enwiki_dir}/pagerank/dense_id_arr.pkl"
    
    @property
    def disambig_arr_filename(self):
        return f"{self.enwiki_dir}/pagerank/disambig_arr.pkl"
    
    @property
    def top_cite_domains_filename(self):
        return f"{self.enwiki_dir}/pagerank/top_cite_domains.pkl"
    
    @property
    def in_degree_filename(self):
        return f"{self.enwiki_dir}/pagerank/in_degree.pkl"
    
    @property
    def out_degree_filename(self):
        return f"{self.enwiki_dir}/pagerank/out_degree.pkl"

    def edge_filenames(self, num_partitions):
        return [
            f"{self.enwiki_dir}/pagerank/edges_{i}.pkl"
            for i in range(num_partitions)
        ]

In [5]:
files = RedditRankFiles("20230301")

### Re-use outputs computed by PageRank.ipynb

1. Pages with the same title
2. `id_map` from non-redirecting article titles to node number, and `id_map2` from redirecting article titles to node number

In [6]:
pqf = pq.ParquetFile(files.enwiki_parquet_filename)

In [7]:
with open(files.nub_filename, "rb") as fp:
    overwritten, pqf_size = pickle.load(fp)
with open(files.id_maps_filename, "rb") as fp:
    id_map, id_map2 = pickle.load(fp)

### Build representation of articles/links as a graph

1. Create `id_map` from non-redirecting article titles to node number, and `id_map2` from redirecting article titles to node number
2. Use `wikiplain` to extract link titles, and use above maps to convert to (src_id, dest_id) pairs

In [8]:
LOG_PARTITION_SIZE = 16
PARTITION_SIZE = 1 << LOG_PARTITION_SIZE
N = len(id_map)
NUM_PARTITIONS = math.ceil(N / PARTITION_SIZE)

In [9]:
class Vec:
    def __init__(self, dtype):
        self.array = np.ndarray((1024,), dtype=dtype)
        self.length = 0
    
    @property
    def capacity(self):
        return self.array.shape[0]

    def append(self, v):
        idx = self.length
        if idx >= self.capacity:
            addsz = max(2, self.capacity)
            self.array = np.hstack((self.array, np.zeros((addsz,), dtype=self.array.dtype)))
        self.array[idx] = v
        self.length += 1

In [10]:
def get_disambig_arr():
    iterator = tqdm(pqf.iter_batches(batch_size=100), total=math.ceil(pqf_size / 100))
    iterator = map(
        lambda b: zip(
            b["id"].to_numpy(),
            b["ns"].to_numpy(),
            map(operator.attrgetter("is_valid"), b["redirect"]),
            b["text"].to_pylist()
        ),
        iterator
    )
    iterator = itertools.chain.from_iterable(iterator)
    iterator = filter(lambda e: not e[2] and e[1] == 0 and e[0] not in overwritten, iterator)
    iterator = enumerate(map(operator.itemgetter(3), iterator))
    disambig_arr = np.zeros(N, dtype=np.bool_)
    for node_id, text in iterator:
        disambig_arr[node_id] = wikiplain.is_disambiguation_page(text)
    return disambig_arr

try:
    with open(files.disambig_arr_filename, "rb") as fp:
        disambig_arr = pickle.load(fp)
except Exception:
    disambig_arr = get_disambig_arr()
    with open(files.disambig_arr_filename, "wb") as fp:
        pickle.dump(disambig_arr, fp)

In [11]:
!curl https://raw.githubusercontent.com/timvieira/arsenal/master/arsenal/datastructures/heap/heap.pyx

# cython: language_level=3, boundscheck=False, infer_types=True, nonecheck=False
# cython: overflowcheck=False, initializedcheck=False, wraparound=False, cdivision=True

"""
Heap data structures with optional
 - Locators
 - Top-k (Bounded heap)

"""
import numpy as np

Vt = np.double
cdef double NaN = np.nan

# TODO: Use the C++ standard library's implementation of a vector of doubles.
cdef class Vector:

    cdef public:
        int cap
        int end
        double[:] val

    def __init__(self, cap):
        self.cap = cap
        self.val = np.zeros(self.cap, dtype=Vt)
        self.end = 0

    cpdef int push(self, double x):
        i = self.end
        self.ensure_size(i)
        self.val[i] = x
        self.end += 1
        return i

    cpdef object pop(self):
        "pop from the end"
        assert 0 < self.end
        self.end -= 1
        v = self.val[self.end]
        self.val[self.end] = NaN
        return v

    cdef void grow(

In [13]:
def get_top_cite_domains():
    iterator = tqdm(pqf.iter_batches(batch_size=100), total=math.ceil(pqf_size / 100))
    iterator = map(
        lambda b: zip(
            b["id"].to_numpy(),
            b["ns"].to_numpy(),
            map(operator.attrgetter("is_valid"), b["redirect"]),
            b["text"].to_pylist()
        ),
        iterator
    )
    iterator = itertools.chain.from_iterable(iterator)
    iterator = filter(lambda e: not e[2] and e[1] == 0 and e[0] not in overwritten, iterator)
    iterator = enumerate(map(operator.itemgetter(3), iterator))
    heap = MinMaxHeap()
    heap_limit = 256 * 1024
    for node_id, text in iterator:
        if (node_id + 1) % 750000 == 0:
            heap_limit /= 2
            while len(heap) > heap_limit:
                heap.popmin()
        page = defaultdict(int)
        for url in wikiplain.get_cite_urls(text):
            full_domain = re.sub(r"[:/].*", "", url)
            parts = full_domain.split('.')
            if len(parts) >= 2:
                site_domain = parts[-2] + '.' + parts[-1]
                if site_domain in SECOND_LEVEL_DOMAINS:
                    if len(parts) >= 3:
                        site_domain = parts[-3] + '.' + site_domain
                    else:
                        continue
                page[site_domain] += 1
        for k, v in page.items():
            if k in heap:
                heap[k] = heap.max[k] + v
            elif len(heap) < heap_limit:
                heap[k] = v
            elif v > heap.peekmin()[1]:
                heap.popmin()
                heap[k] = v
    top_cite_domains = []
    while len(heap) > 0:
        top_cite_domains.append(heap.popmax())
    return top_cite_domains

try:
    with open(files.top_cite_domains_filename, "rb") as fp:
        top_cite_domains = pickle.load(fp)
except Exception:
    top_cite_domains = get_top_cite_domains()
    with open(files.top_cite_domains_filename, "wb") as fp:
        pickle.dump(top_cite_domains, fp)

In [14]:
pager(pl.DataFrame(top_cite_domains, schema=['domain', 'count']), 16)

interactive(children=(Dropdown(description='page', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, …

<function nbhelpers.polars.pager.<locals>.<lambda>(page)>

In [15]:
searcher(pl.DataFrame(top_cite_domains, schema=['domain', 'count'])
          .with_columns(pl.Series("rank", range(len(top_cite_domains)))),
         ["domain"],
         16)

interactive(children=(Text(value='', description='q'), Output()), _dom_classes=('widget-interact',))

<function nbhelpers.polars.searcher.<locals>.searcher_run(q)>

In [16]:
excluded_domains = {"imgur.com", "twitter.com", "youtube.com", "soundcloud.com",
                    "instagram.com", "amazon.com", "github.com", "vimeo.com",
                    "google.com"
                   }
top_cite_domain_set = {domain for domain, _ in top_cite_domains} - excluded_domains

In [17]:
localsizes = (pl.DataFrame([(id(value), name, sys.getsizeof(value)) for name, value in locals().items()],
                          schema=['id', 'name', 'size'])
              .groupby('id')
              .agg(pl.max("size"), pl.col("name").apply(lambda ser: ser.to_list()))
              .sort('size', descending=True)
              .head(50)
             )
localsizes

id,size,name
i64,i64,list[str]
139903491947072,335544408,"[""id_map2""]"
139903491968640,335544408,"[""id_map""]"
139899293787584,2378992,"[""batch""]"
139899299262336,65752,"[""top_cite_domain_set""]"
139903498434176,65752,"[""SECOND_LEVEL_DOMAINS""]"
139907559281856,9304,"[""top_cite_domains""]"
94605535217072,2043,"[""_i13""]"
94602671107920,1856,"[""_i4""]"
94602571469328,1383,"[""_i1""]"
94602560147520,1072,"[""auto""]"


In [18]:
pqf.close()

In [19]:
@lru_cache(maxsize=1)
def load_edges(partition):
    with open(files.edge_filenames(NUM_PARTITIONS)[partition], "rb") as fp:
        return pickle.load(fp)

### Edge format (copied from PageRank)

- `edges_{n}.pkl` stores the outgoing links from `PARITION_SIZE*n ..< PARTITION_SIZE*(n+1)`
- These are stored in a list where element `i` contains the links out to `PARITION_SIZE*i ..< PARTITION_SIZE*(i+1)`

In [64]:
nlp = English()

In [21]:
PR = pl.read_parquet(files.pagerank_parquet_filename)
PR_value = PR["value"].to_numpy()

In [92]:
engine = create_engine(files.enwiki_tokenized_database_uri)

In [95]:
max_span_map = {}
with engine.begin() as conn:
    conn.execute(sqltext("CREATE TABLE term_map (term TEXT NOT NULL, id INTEGER NOT NULL, weight FLOAT NOT NULL)"))
    stmt = sqltext("INSERT INTO term_map (term, id, weight) VALUES (:term, :id, :weight)")
    for title, node_id in tqdm(id_map.items(), total=len(id_map)):
        parts = [token.norm_ for token in nlp.tokenizer(title) if not (token.is_left_punct or token.is_right_punct)]
        if len(parts) == 0:
            continue
        if len(parts) > max_span_map.get(parts[0], -1):
            max_span_map[parts[0]] = len(parts)
        term = " ".join(parts)
        if disambig_arr[node_id]:
            partition = node_id >> LOG_PARTITION_SIZE
            destinations = []
            for pair_tbl in load_edges(partition):
                destinations.append(pair_tbl[pair_tbl[:, 0] == node_id][:, 1])
            destination_arr = np.unique(np.hstack(destinations))
            pr_value_total = PR_value[destination_arr].sum()
            for dest_id in destination_arr.tolist():
                conn.execute(stmt, {"term": term, "id": dest_id, "weight": PR_value[dest_id] / pr_value_total})
        else:
            term = " ".join(token.norm_ for token in nlp.tokenizer(title))
            conn.execute(stmt, {"term": term, "id": node_id, "weight": 2.0})
    for title, node_id in tqdm(id_map2.items(), total=len(id_map2)):
        if disambig_arr[node_id]:
            continue
        parts = [token.norm_ for token in nlp.tokenizer(title) if not (token.is_left_punct or token.is_right_punct)]
        if len(parts) == 0:
            continue
        if len(parts) > max_span_map.get(parts[0], -1):
            max_span_map[parts[0]] = len(parts)
        term = " ".join(parts)
        conn.execute(stmt, {"term": term, "id": node_id, "weight": 1.0})

  0%|          | 0/6625358 [00:00<?, ?it/s]

  0%|          | 0/10408919 [00:00<?, ?it/s]

In [100]:
with engine.begin() as conn:
    conn.execute(sqltext("CREATE TABLE term_map_2 AS SELECT term, id, MAX(weight) AS weight FROM term_map GROUP BY term, id"))
    conn.execute(sqltext("DROP TABLE term_map"))
    conn.execute(sqltext("ALTER TABLE term_map_2 RENAME TO term_map"))

In [101]:
with engine.begin() as conn:
    conn.execute(sqltext("CREATE INDEX ix_term_map_term ON term_map (term)"))
    conn.execute(sqltext("CREATE TABLE max_span_map (k TEXT NOT NULL, v INTEGER NOT NULL)"))
    stmt = sqltext("INSERT INTO max_span_map (k, v) VALUES (:k, :v)")
    for k, v in max_span_map.items():
        conn.execute(stmt, {"k": k, "v": v})

In [31]:
term_map_sql = sa.Table(
    "term_map",
    sa.MetaData(),
    sa.Column("term", sa.String),
    sa.Column("id", sa.Integer),
    sa.Column("weight", sa.Float),
)

In [66]:
renormalizations = {'|': '-'}

In [74]:
@contextmanager
def ensure_shutdown(sock):
    try:
        yield sock
    finally:
        sock.shutdown(socket.SHUT_RDWR)
        sock.close()

In [116]:
# def submission_titles(filename):
#     statinfo = os.stat(filename)
#     with tqdm.wrapattr(open(filename, "rb"), "read", total=statinfo.st_size) as compressed:
#         dctx = ZstdDecompressor()
#         reader = dctx.stream_reader(compressed)
#         dedup_q = deque()
#         dedup_set = set()
#         for submission in ijson.items(reader, "", multiple_values=True):
#             if "url" not in submission or "title" not in submission:
#                 continue
#             url = urlsplit(submission["url"])
#             if url.netloc not in top_cite_domain_set:
#                 continue
#             dedup_key = (url.netloc, url.path)
#             if dedup_key in dedup_set:
#                 continue
#             if len(dedup_q) > 1000:
#                 dedup_set.discard(dedup_q.popleft())
#             dedup_q.append(dedup_key)
#             dedup_set.add(dedup_key)
#             title = xml_unescape(submission["title"])
#             tokens = [
#                 renormalizations.get(token.norm_, token.norm_)
#                 for token in nlp.tokenizer(title)
#                 if not (token.is_left_punct or token.is_right_punct)
#             ]
#             spans = []
#             for i, w0 in enumerate(tokens):
#                 max_size = min(len(tokens) - i, max_span_map.get(w0, 0))
#                 for j in range(i + 1, i + max_size + 1):
#                     spans.append((i, j))
#             yield tokens, spans            

In [123]:
import io
from types import TracebackType
from typing import IO, Type, Iterator, Iterable, Optional, List

class ReadableIterator(IO[bytes]):
    __inner: Optional[Iterator[bytes]]
    __buffered: bytes

    def __init__(self, inner: Iterator[bytes]):
        self.__inner = inner
        self.__buffered = b""

    def __enter__(self) -> IO[bytes]:
        return self

    def __exit__(self,
                 __t: Optional[Type[BaseException]],
                 __value: Optional[BaseException],
                 __traceback: Optional[TracebackType]) -> None:
        self.close()

    def close(self) -> None:
        """
        Close the IO object.

        Attempting any further operation after the object is closed will raise an OSError. This method has no
        effect if the file is already closed.
        """
        self.__inner = None

    def fileno(self) -> int:
        """
        Returns the underlying file descriptor (an integer).
        """
        return 0

    def readable(self) -> bool:
        """
        Returns True if the IO object can be read.
        """
        return True

    def __require_inner(self) -> Iterator[bytes]:
        if self.__inner is None:
            raise OSError("Can't read a closed file")
        return self.__inner
    
    def read(self, size: int = -1) -> bytes:
        """
        Read at most size bytes, returned as a bytes object.

        If the size argument is negative, read until EOF is reached.
        Return an empty bytes object at EOF.
        """
        if size == 0:
            return b""
        result = self.__buffered
        while size < 0 or len(result) < size:
            try:
                result += next(self.__require_inner())
            except StopIteration:
                break
        if size > 0:
            self.__buffered = result[size:]
            return result[:size]
        self.__buffered = b""
        return result

    def readinto(self, buffer: bytes) -> int:
        """
        Read bytes into buffer.

        Returns number of bytes read (0 for EOF), or None if the object
        is set not to block and has no data to read.
        """
        content = self.read(len(buffer))
        buffer[:len(content)] = content
        return len(content)

    def readline(self, __limit: int = -1) -> bytes:
        raise ValueError("Line-based methods are not available on ReadableIterator")

    def readlines(self, __hint: int = -1) -> List[bytes]:
        raise ValueError("Line-based methods are not available on ReadableIterator")

    def seekable(self) -> bool:
        return False

    def seek(self, __offset: int, __whence: int = io.SEEK_CUR) -> int:
        raise ValueError("Cannot seek")

    def tell(self) -> int:
        raise ValueError("Cannot tell")

    def writable(self) -> bool:
        return False

    def flush(self) -> None:
        raise ValueError("Cannot write")

    def truncate(self, __size: Optional[int] = None) -> int:
        raise ValueError("Cannot write")

    def write(self, __s: bytes) -> int:
        raise ValueError("Cannot write")

    def writelines(self, __lines: Iterable[bytes]) -> None:
        raise ValueError("Cannot write")

    def isatty(self) -> bool:
        return False

    def __next__(self) -> bytes:
        raise ValueError("Iterable methods are not available on ReadableStreamWrapper")

    def __iter__(self) -> Iterator[bytes]:
        raise ValueError("Iterable methods are not available on ReadableStreamWrapper")

In [128]:
def submission_titles(pushshift_url):
    with httpx.stream("GET", pushshift_url) as response:
        fh = ReadableIterator(response.iter_bytes())
        total_s = response.headers.get("content-length")
        total = None
        if total_s is not None:
            total = int(total_s)
        with tqdm.wrapattr(fh, "read", total=total) as compressed:
            dctx = ZstdDecompressor()
            reader = dctx.stream_reader(compressed)
            dedup_q = deque()
            dedup_set = set()
            for submission in ijson.items(reader, "", multiple_values=True):
                if "url" not in submission or "title" not in submission:
                    continue
                url = urlsplit(submission["url"])
                if url.netloc not in top_cite_domain_set:
                    continue
                dedup_key = (url.netloc, url.path)
                if dedup_key in dedup_set:
                    continue
                if len(dedup_q) > 1000:
                    dedup_set.discard(dedup_q.popleft())
                dedup_q.append(dedup_key)
                dedup_set.add(dedup_key)
                title = xml_unescape(submission["title"])
                tokens = [
                    renormalizations.get(token.norm_, token.norm_)
                    for token in nlp.tokenizer(title)
                    if not (token.is_left_punct or token.is_right_punct)
                ]
                spans = []
                for i, w0 in enumerate(tokens):
                    max_size = min(len(tokens) - i, max_span_map.get(w0, 0))
                    for j in range(i + 1, i + max_size + 1):
                        spans.append((i, j))
                yield tokens, spans

In [129]:
relevance = np.zeros(N, dtype=np.float64)
with (engine.connect() as conn,
      ensure_shutdown(socket.socket()) as sock,
):
    sock.connect(("127.0.0.1", 31323))
    print(sock.getsockname())
    rfile = sock.makefile("r", encoding="utf-8")
    wfile = sock.makefile("wb", buffering=0)
    for group in toolz.partition_all(50, submission_titles("https://files.pushshift.io/reddit/submissions/RS_2023-01.zst")):
        sentences = [json.dumps(l).encode("utf-8") + b"\n" for l, _ in group]
        req = b"".join(sentences)
        wfile.write(struct.pack(">2I", len(req), 0) + req)
        queries = []
        query_rev = defaultdict(list)
        for sentence, spans in group:
            resp = json.loads(rfile.readline())
            scores = np.array([e["scores"] for e in resp])
            observations = np.array([e["observations"] for e in resp])
            priors = np.array([int(w.isalpha()) for w in sentence])
            scores -= scores.max(axis=1)[:, None]
            scores *= 0.5
            np.exp(scores, out=scores)
            scores /= np.sum(scores, axis=1)[:, None]
            score = np.max(scores[:, [PENN_TAGS["NNP"], PENN_TAGS["NNPS"]]], axis=1)
            score = (score * observations + priors * 20) / (observations + 20)
            span_dict = {}
            for i, j in spans:
                qid = len(queries)
                span_score = score[i:j].mean()
                if span_score >= 0.01:
                    span_dict[i, j] = qid
                    k = ' '.join(sentence[i:j])
                    queries.append((span_score, []))
                    query_rev[k].append(qid)
            for (i, j), qid in span_dict.items():
                if j - i > 1:
                    children = []
                    if (i, j - 1) in span_dict:
                        children.append(span_dict[i, j - 1])
                    if (i + 1, j) in span_dict:
                        children.append(span_dict[i + 1, j])
                    span_score, _ = queries[qid]
                    queries[qid] = (span_score, children)
        rs = conn.execute(
            select(term_map_sql.c.term, term_map_sql.c.id, term_map_sql.c.weight)
            .where(term_map_sql.c.weight >= 0.01)
            .where(term_map_sql.c.term.in_(list(query_rev.keys())))
        )
        rs = list(rs)
        inner_matches = set()
        for term, _, _ in rs:
            for qid in query_rev[term]:
                _, children = queries[qid]
                stack = deque(children)
                while len(stack):
                    curr = stack.pop()
                    inner_matches.add(curr)
                    _, grandchildren = queries[curr]
                    for c in grandchildren:
                        if c not in inner_matches:
                            stack.append(c)
        for term, node_id, weight in rs:
            for qid in query_rev[term]:
                if qid in inner_matches:
                    continue
                relevance[node_id] += weight * queries[qid][0]

('127.0.0.1', 59310)


  0%|          | 0/12445460428 [00:00<?, ?it/s]

In [130]:
pager(PR.with_columns(pl.Series("relevance", relevance)).sort("relevance", descending=True), 20)

interactive(children=(IntSlider(value=165634, description='page', max=331268), Output()), _dom_classes=('widge…

<function nbhelpers.polars.pager.<locals>.<lambda>(page)>