<a href="https://colab.research.google.com/github/jmerizia/city-circuits/blob/main/city_circuits_prep_wikipedia_cities_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# define encoder

In [2]:
'''
Modified from https://github.com/graykode/gpt-2-Pytorch/blob/master/GPT2/encoder.py
See above for original license.
'''

"""Byte pair encoding utilities"""

import os
import json
import regex as re
from functools import lru_cache

@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

def get_encoder():
    with open('./encoder.json', 'r') as f:
        encoder = json.load(f)
    with open('./vocab.bpe', 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )

In [3]:
!curl -o encoder.json https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/master/GPT2/encoder.json
!curl -o vocab.bpe https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/master/GPT2/vocab.bpe

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1017k  100 1017k    0     0  3497k      0 --:--:-- --:--:-- --:--:-- 3497k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  445k  100  445k    0     0  1912k      0 --:--:-- --:--:-- --:--:-- 1912k


# build the dataset

In [None]:
import tensorflow_datasets as tfds

dataset, dataset_info = tfds.load(
    name='wikipedia/20190301.en',
    data_dir='./wikipedia',
    with_info=True,
    split=tfds.Split.TRAIN,
)

dataset_info.splits['train'].num_examples

In [None]:
from tqdm import tqdm
import pickle

enc = get_encoder()

def process(example):
    text = example['text'].numpy().decode('utf-8')
    original = text

    # only take articles with a considerable number of words
    if len(text) < 300:
        return

    # skip disambiguation pages
    if 'may refer to' in text.split('\n')[0]:
        return

    if text.startswith('Paris'):
        print(text)
        return 'a'

    # extract categories
    categories = []
    for line in text.split('\n'):
        line = line.strip()
        if line.startswith('Category:'):
            categories.append(line[len('Category:'):].strip())

    # try to skip titles as much as possible
    text = '\n'.join(line for line in text.split('\n') if len(line) > 50)

    # remove parenthesized portions
    k = 0
    new_text = ''
    for i in range(len(text)):
        if text[i] == '(':
            k += 1
        elif text[i] == ')':
            k -= 1
        else:
            if k == 0:
                new_text += text[i]
    text = new_text

    # put everything on one line
    text = ' '.join(text.split('\n'))

    # clean up white space
    text = ' '.join(text.split()).strip()

    # possible degenerate cases
    if len(text) < 5:
        return

    # only take the first few tokens
    num_tokens = 30
    text = text[:num_tokens * 10]
    tokens = enc.encode(text)
    tokens = tokens[:num_tokens]
    text = enc.decode(tokens)
    return { 'text': text, 'categories': categories }


bucket = []
bucket_idx = 0
for example in tqdm(dataset, miniters=1000):
    res = process(example)
    if res == 'a':
        break
    bucket.append(res)
    # if len(bucket) > 200*1000:
    #     print('saving')
    #     fn = f'examples-{bucket_idx:05}.pickle'
    #     with open(fn, 'wb') as f:
    #         pickle.dump(bucket, f)
    #     bucket = []
    #     bucket_idx += 1

# fn = f'examples-{bucket_idx:05}.pickle'
# with open(fn, 'wb') as f:
#     pickle.dump(bucket, f)

In [112]:
import requests
import urllib
from bs4 import BeautifulSoup


city_list_urls = [
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_A-B',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_C-D-E-F',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_G-H-I-J-K',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_L-M-N-O',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_P-Q-R-S',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_T-U-V-W-Y-Z',
]

links = []
for url in city_list_urls:
    r = requests.get(url)
    soup = BeautifulSoup(r.text, 'html.parser')
    content_divs = soup.find_all('div', { 'class': 'mw-parser-output' } )
    assert len(content_divs) == 1
    content_div = content_divs[0]
    for table in content_div.find_all('table'):
        if 'box-More_citations_needed' in table['class'] or 'box-Cleanup' in table['class']:
            continue
        for tr in table.find_all('tr')[1:]:
            td = tr.find('td')
            if td:
                city_page_url = td.find('a')['href']
                links.append(urllib.parse.unquote(city_page_url))
    # some lists are in <ul> or <ol> tags
    cnt = 0
    for ul in content_div.find_all('ul'):
        for li in ul.find_all('li'):
            a = li.find('a')
            if a:
                link_text = a.text
                link_url = a['href']
                if link_url.startswith('/wiki/') and link_text != 'World largest cities':
                    city_page_url = urllib.parse.unquote(link_url)
                    # print(city_page_url)
                    links.append(city_page_url)
                    cnt += 1
    cnt = 0
    for ol in content_div.find_all('ol'):
        for li in ol.find_all('li'):
            a = li.find('a')
            if a:
                link_text = a.text
                link_url = a['href']
                if link_url.startswith('/wiki/') and link_text != 'World largest cities':
                    city_page_url = urllib.parse.unquote(link_url)
                    # print(city_page_url)
                    links.append(city_page_url)
                    cnt += 1

print('Collected', len(links), 'links to cities from various lists on Wikipedia.')

Collected 4394 links to cities from various lists on Wikipedia.


In [160]:
!pip install aiohttp

Collecting aiohttp
  Downloading aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3 MB)
[?25l[K     |▎                               | 10 kB 15.8 MB/s eta 0:00:01[K     |▌                               | 20 kB 18.3 MB/s eta 0:00:01[K     |▊                               | 30 kB 21.7 MB/s eta 0:00:01[K     |█                               | 40 kB 23.7 MB/s eta 0:00:01[K     |█▎                              | 51 kB 25.2 MB/s eta 0:00:01[K     |█▌                              | 61 kB 13.5 MB/s eta 0:00:01[K     |█▊                              | 71 kB 12.5 MB/s eta 0:00:01[K     |██                              | 81 kB 13.6 MB/s eta 0:00:01[K     |██▏                             | 92 kB 14.6 MB/s eta 0:00:01[K     |██▌                             | 102 kB 15.7 MB/s eta 0:00:01[K     |██▊                             | 112 kB 15.7 MB/s eta 0:00:01[K     |███                             | 122 kB 15.7 MB/s eta 0:00:01[K     |███▏                           

In [230]:
# scrape and save the files
import time
import requests

if 'data' in globals():
    del data

def download(url):
    for _ in range(2):
        r = requests.get(url)
        if r.status_code == 200:
            return r
        print(f'failure for url {url}! trying again...')

urls = ['https://wikipedia.com' + urllib.parse.quote(link) for link in links]

st = time.time()
data = []
for idx, url in enumerate(urls):
    r = download(url)
    data.append(r)
    print(idx, '/', len(urls), r.status_code)
en = time.time()
print(en-st)

# fn = f'page-{idx:05}.txt'
# with open(fn, 'w') as f:
#     f.write(r.text)

0 / 4394 200
1 / 4394 200
2 / 4394 200
3 / 4394 200
4 / 4394 200
5 / 4394 200
6 / 4394 200
7 / 4394 200
8 / 4394 200
9 / 4394 200
10 / 4394 200
11 / 4394 200
12 / 4394 200
13 / 4394 200
14 / 4394 200
15 / 4394 200
16 / 4394 200
17 / 4394 200
18 / 4394 200
19 / 4394 200
20 / 4394 200
21 / 4394 200
22 / 4394 200
23 / 4394 200
24 / 4394 200
25 / 4394 200
26 / 4394 200
27 / 4394 200
28 / 4394 200
29 / 4394 200
30 / 4394 200
31 / 4394 200
32 / 4394 200
33 / 4394 200
34 / 4394 200
35 / 4394 200
36 / 4394 200
37 / 4394 200
38 / 4394 200
39 / 4394 200
40 / 4394 200
41 / 4394 200
42 / 4394 200
43 / 4394 200
44 / 4394 200
45 / 4394 200
46 / 4394 200
47 / 4394 200
48 / 4394 200
49 / 4394 200
50 / 4394 200
51 / 4394 200
52 / 4394 200
53 / 4394 200
54 / 4394 200
55 / 4394 200
56 / 4394 200
57 / 4394 200
58 / 4394 200
59 / 4394 200
60 / 4394 200
61 / 4394 200
62 / 4394 200
63 / 4394 200
64 / 4394 200
65 / 4394 200
66 / 4394 200
67 / 4394 200
68 / 4394 200
69 / 4394 200
70 / 4394 200
71 / 4394 200
72

In [234]:
def fix_typos(text, page_name):
    # fix typos in Wikipedia (we should fix these in Wikipedia shortly)
    if page_name == 'Djelfa':
        # missing paren
        text = text.replace('(Arabic: الجلفة\u200e, romanized:\xa0al-Ǧilfah ', '')
    return text

def clean_text(text):
    original = text

    # we probably won't need more than 1k characters
    text = text[:1000]

    # remove parenthesized portions
    k1 = 0
    k2 = 0
    k3 = 0
    new_text = ''
    for i in range(len(text)):
        if text[i] == '(':
            k1 += 1
        elif text[i] == ')':
            k1 -= 1
        elif text[i] == '[':
            k2 += 1
        elif text[i] == ']':
            k2 -= 1
        elif text[i] == '{':
            k3 += 1
        elif text[i] == '}':
            k3 -= 1
        else:
            if k1 == 0 and k2 == 0 and k3 == 0:
                new_text += text[i]
    text = new_text

    # fix strange punctuation
    text = text.replace(' .', '.')
    text = text.replace(' ,', ',')
    text = text.replace(' ; ', '')

    # put everything on one line
    text = ' '.join(text.split('\n'))

    # clean up white space
    text = ' '.join(text.split()).strip()

    # possible degenerate cases
    if len(text) < 5:
        return

    # only take the first few tokens
    num_tokens = 30
    text = text[:num_tokens * 10]
    tokens = enc.encode(text)
    tokens = tokens[:num_tokens]
    text = enc.decode(tokens)
    return text


results = []
for idx, request in enumerate(data):
    # if idx != 1294:
    #     continue
    # print(request.url)
    soup = BeautifulSoup(request.text, 'html.parser')
    content_divs = soup.find_all('div', { 'class': 'mw-parser-output' } )
    assert len(content_divs) == 1
    content_div = content_divs[0]
    # remove coodinates section
    for span in content_div.find_all('span'):
        if span.get('id') == 'coordinates':
            span.extract()
    paragraphs = content_div.find_all('p')
    text = ''
    for paragraph in paragraphs:
        line = paragraph.text.strip()
        if len(line) > 0:
            text += ' ' + line
    # print(text)
    text = fix_typos(text, urllib.parse.unquote(request.url).split('/')[-1])
    text = clean_text(text)
    results.append({ 'url': request.url, 'text': text })
    print(idx, '/', len(data), request.url, ':', text)

    # TODO: also collect the tokenized version


0 / 4394 https://en.wikipedia.org/wiki/Kabul : Kabul is the capital and largest city of Afghanistan, located in the eastern section of the country. It is also a municipality, forming part of the
1 / 4394 https://en.wikipedia.org/wiki/Herat : Herāt is an oasis city and the third-largest city of Afghanistan. In 2020, it had an estimated population of 574,276
2 / 4394 https://en.wikipedia.org/wiki/Kandahar : Kandahar is a city in Afghanistan, located in the south of the country on the Arghandab River, at an elevation of 1,
3 / 4394 https://en.wikipedia.org/wiki/Mazar-i-Sharif : Mazār-i-Sharīf, also called Mazār-e Sharīf, or just Mazar, is the fourth
4 / 4394 https://en.wikipedia.org/wiki/Jalalabad : Jalalabad is the fifth-largest city of Afghanistan. It has a population of about 356,274, and serves as the capital of N
5 / 4394 https://en.wikipedia.org/wiki/Kunduz : Kunduz is a city in northern Afghanistan, the capital of Kunduz Province. The city has a population of about 374,746, making


In [236]:
# save the results to disk
import json
with open('dataset.json', 'w') as f:
    f.write(json.dumps(results))