<a href="https://colab.research.google.com/github/jmerizia/city-circuits/blob/main/city_circuits_prep_wikipedia_cities_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# define encoder

In [2]:
'''
Modified from https://github.com/graykode/gpt-2-Pytorch/blob/master/GPT2/encoder.py
See above for original license.
'''

"""Byte pair encoding utilities"""

import os
import json
import regex as re
from functools import lru_cache

@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

def get_encoder():
    with open('./encoder.json', 'r') as f:
        encoder = json.load(f)
    with open('./vocab.bpe', 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )

In [3]:
!curl -o encoder.json https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/master/GPT2/encoder.json
!curl -o vocab.bpe https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/master/GPT2/vocab.bpe

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1017k  100 1017k    0     0  3497k      0 --:--:-- --:--:-- --:--:-- 3497k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  445k  100  445k    0     0  1912k      0 --:--:-- --:--:-- --:--:-- 1912k


# build the dataset

In [112]:
import requests
import urllib
from bs4 import BeautifulSoup


city_list_urls = [
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_A-B',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_C-D-E-F',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_G-H-I-J-K',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_L-M-N-O',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_P-Q-R-S',
    'https://en.wikipedia.org/wiki/List_of_towns_and_cities_with_100,000_or_more_inhabitants/country:_T-U-V-W-Y-Z',
]

links = []
for url in city_list_urls:
    r = requests.get(url)
    soup = BeautifulSoup(r.text, 'html.parser')
    content_divs = soup.find_all('div', { 'class': 'mw-parser-output' } )
    assert len(content_divs) == 1
    content_div = content_divs[0]
    for table in content_div.find_all('table'):
        if 'box-More_citations_needed' in table['class'] or 'box-Cleanup' in table['class']:
            continue
        for tr in table.find_all('tr')[1:]:
            td = tr.find('td')
            if td:
                city_page_url = td.find('a')['href']
                links.append(urllib.parse.unquote(city_page_url))
    # some lists are in <ul> or <ol> tags
    cnt = 0
    for ul in content_div.find_all('ul'):
        for li in ul.find_all('li'):
            a = li.find('a')
            if a:
                link_text = a.text
                link_url = a['href']
                if link_url.startswith('/wiki/') and link_text != 'World largest cities':
                    city_page_url = urllib.parse.unquote(link_url)
                    # print(city_page_url)
                    links.append(city_page_url)
                    cnt += 1
    cnt = 0
    for ol in content_div.find_all('ol'):
        for li in ol.find_all('li'):
            a = li.find('a')
            if a:
                link_text = a.text
                link_url = a['href']
                if link_url.startswith('/wiki/') and link_text != 'World largest cities':
                    city_page_url = urllib.parse.unquote(link_url)
                    # print(city_page_url)
                    links.append(city_page_url)
                    cnt += 1

print('Collected', len(links), 'links to cities from various lists on Wikipedia.')

Collected 4394 links to cities from various lists on Wikipedia.


In [None]:
import time
import requests

if 'data' in globals():
    del data

def download(url):
    for _ in range(2):
        r = requests.get(url)
        if r.status_code == 200:
            return r
        print(f'failure for url {url}! trying again...')

urls = ['https://wikipedia.com' + urllib.parse.quote(link) for link in links]

st = time.time()
data = []
for idx, url in enumerate(urls):
    r = download(url)
    data.append(r)
    print(idx, '/', len(urls), r.status_code)
en = time.time()
print(en-st)

In [None]:
def fix_typos(text, page_name):
    # fix typos in Wikipedia (we should fix these in Wikipedia shortly)
    if page_name == 'Djelfa':
        # missing paren
        text = text.replace('(Arabic: الجلفة\u200e, romanized:\xa0al-Ǧilfah ', '')
    return text

def clean_text(text):
    original = text

    # we probably won't need more than 1k characters
    text = text[:1000]

    # remove parenthesized portions
    k1 = 0
    k2 = 0
    k3 = 0
    new_text = ''
    for i in range(len(text)):
        if text[i] == '(':
            k1 += 1
        elif text[i] == ')':
            k1 -= 1
        elif text[i] == '[':
            k2 += 1
        elif text[i] == ']':
            k2 -= 1
        elif text[i] == '{':
            k3 += 1
        elif text[i] == '}':
            k3 -= 1
        else:
            if k1 == 0 and k2 == 0 and k3 == 0:
                new_text += text[i]
    text = new_text

    # fix strange punctuation
    text = text.replace(' .', '.')
    text = text.replace(' ,', ',')
    text = text.replace(' ; ', '')

    # put everything on one line
    text = ' '.join(text.split('\n'))

    # clean up white space
    text = ' '.join(text.split()).strip()

    # possible degenerate cases
    if len(text) < 5:
        return

    # only take the first few tokens
    num_tokens = 30
    text = text[:num_tokens * 10]
    tokens = enc.encode(text)
    tokens = tokens[:num_tokens]
    text = enc.decode(tokens)
    return text


results = []
for idx, request in enumerate(data):
    # if idx != 1294:
    #     continue
    # print(request.url)
    soup = BeautifulSoup(request.text, 'html.parser')
    content_divs = soup.find_all('div', { 'class': 'mw-parser-output' } )
    assert len(content_divs) == 1
    content_div = content_divs[0]
    # remove coodinates section
    for span in content_div.find_all('span'):
        if span.get('id') == 'coordinates':
            span.extract()
    paragraphs = content_div.find_all('p')
    text = ''
    for paragraph in paragraphs:
        line = paragraph.text.strip()
        if len(line) > 0:
            text += ' ' + line
    # print(text)
    text = fix_typos(text, urllib.parse.unquote(request.url).split('/')[-1])
    text = clean_text(text)
    results.append({ 'url': request.url, 'text': text })
    print(idx, '/', len(data), request.url, ':', text)

In [236]:
# save the results to disk
import json
with open('dataset.json', 'w') as f:
    f.write(json.dumps(results))