<a class="anchor" id="top"></a>
# Code Playground

A quick way to explore the main functionality of the package either locally or in PAWS.

[Click here](#playground) to skip ahead and see the code in action.

In [1]:
!pip install wget



## Code

### Contents of /src/mwsql/utils.py

In [4]:
'''Helper functions used in src/mwsql.py'''

import csv
import gzip
import re
import sys

from typing import List

# Allow long field names
csv.field_size_limit(sys.maxsize)


def head(file_path, n_lines=10):
    '''Display top of compressed file, similar to `zcat | head` UNIX utility'''

    with gzip.open(file_path, 'rt', encoding='utf-8') as infile:
        for line in infile:
            print(line.strip())
            n_lines -= 1
            if n_lines == 0:
                break


# mwsql helper functions
def is_insert_statement(line: str) -> bool:
    '''Check whether a string is an SQL `insert into` statement.'''

    return line.startswith('INSERT INTO')


def is_create_statement(line: str) -> bool:
    '''Check whether a string is an SQL `create table` statement.'''

    return line.startswith('CREATE TABLE')


def get_table_name(line: str) -> str:
    '''Extract SQL table name from string'''

    table_name_pattern = r'`([\S]*)`'
    table_name = re.search(table_name_pattern, line).group(1)
    return table_name


def has_col_name(line: str) -> bool:
    '''Check whether a string contains an SQL column name'''

    return line.strip().startswith('`')


def get_col_name(line: str) -> str:
    '''Extract SQL column names and data types from string'''

    col_name_pattern = r'`([\S]*)`'
    col_name = re.search(col_name_pattern, line).group(1)

    col_dtype_pattern = r'` ((.)*),'
    col_dtype = re.search(col_dtype_pattern, line).group(1)

    return col_name, col_dtype


def has_primary_key(line: str) -> str:
    '''Check whether a string contains an SQL primary key'''

    return line.strip().startswith('PRIMARY KEY')


def get_primary_key(line: str) -> str:
    '''Extract SQL table primary key from string'''

    pattern = r'`([\S]*)`'
    primary_key = re.search(pattern, line).group(1).replace('`', '').split(',')

    return primary_key


def parse_records(records: List[str]):
    '''Parse an SQL `insert into` statement into separate records
    (also called values, rows, entries...) and return as a csv.reader object.
    '''

    reader = csv.reader(records, delimiter=',',
                                 doublequote=False,
                                 escapechar='\\',
                                 quotechar="'",
                                 strict=True
                        )
    return reader


def get_records(line: str) -> List[str]:
    '''Split a string containing multiple SQL value tuples into a list
    where each element is a csv reader object representing the tuple.
    '''

    values = line.partition(' VALUES ')[-1].strip().replace('NULL', "''")
    # Remove `;` at the end of the last `insert into` statement
    if values[-1] == ';':
        values = values[:-1]
    records = re.split(r'\),\(', values[1:-1])  # Strip `(` and `)`

    return parse_records(records)


### Contents of /src/mwsql/load.py

In [35]:
'''Utilities for easy loading of dump files from
PAWS or from the web'''

import os
# import sys
import wget

# from mwsql import Dump
from urllib.error import HTTPError


def progress_bar(current, total, width=60):
    '''Custom progress bar for wget downloads'''

    unit = 'bytes'

    # Show file size in MB for large files
    if total >= 100000:
        MB = 1024 * 1024
        current = current / MB
        total = total / MB
        unit = 'MB'

    progress = current / total
    progress_message = f"Progress: {progress:.0%} [{current:.1f} / {total:.1f}] {unit}"
    sys.stdout.write('\r' + progress_message)
    sys.stdout.flush()


def get_source(db, filename):
    '''Determine where to get the dump files from depending
    on if the user's environment is PAWS or local.
    '''

    # If in PAWS, set dir
    if os.path.exists('/public/dumps/public/'):
        prefix = '/public/dumps/public/'
        download = False

    # If in other environment, set url
    else:
        prefix = 'https://dumps.wikimedia.org/'
        download = True

    source = f'{prefix}{db}/latest/{db}-latest-{filename}.sql.gz'
    return source, download


def load(db, filename):
    '''Load dump file from public dir if in PAWS, else download
    from the web if the file doesn't already exist in the user's
    current working directory
    '''

    source, download = get_source(db, filename)

    if download:
        try:
            print(f'Downloading {source}')
            cwd = os.getcwd()
            file = wget.download(source, bar=progress_bar)
            file_path = os.path.join(cwd, file)

        except HTTPError:
            print('File not found')
            return None
    
    else:
        file_path = source

    try:
        dump = Dump.from_file(file_path)
        return dump

    except:
        print("Couldn't create dump")
        return None


### Contents of /src/mwsql/mwsql.py

In [43]:
'''A set of utilities for processing MediaWiki SQL dump data'''

# __version__ = '0.1.0'

# import os
# import gzip

# from mwsql import utils


class Dump:
    '''Class for parsing an SQL dump file and processing its contents'''

    def __init__(self, table_name, col_names, col_dtypes, primary_key, source_file):

        self.name = table_name
        self.col_names = col_names
        self.dtypes = col_dtypes
        self.primary_key = primary_key
        self.size = os.path.getsize(source_file)
        self._source_file = source_file

    def __iter__(self):
        return self.rows

    @property
    def rows(self):
        '''Create a generator object from the rows'''

        with gzip.open(self._source_file, 'rt', encoding='ISO-8859-1') as infile:
            for line in infile:
                if is_insert_statement(line):
                    rows = get_records(line)
                    for row in rows:
                        yield row

    @classmethod
    def from_file(cls, file_path):
        '''Initialize mwsql object from dump file'''

        source_file = file_path
        table_name = None
        primary_key = None
        col_names = []
        col_dtypes = {}

        with gzip.open(file_path, 'rt', encoding='ISO-8859-1') as infile:
            for line in infile:
                if is_insert_statement(line):
                    # All metadata is extracted so we return it
                    return cls(table_name, col_names, col_dtypes, primary_key, source_file)
                if is_create_statement(line):
                    table_name = get_table_name(line)
                elif has_col_name(line):
                    col_name, dtype = get_col_name(line)
                    col_names.append(col_name)
                    col_dtypes[col_name] = dtype
                elif has_primary_key(line):
                    primary_key = get_primary_key(line)
        return None

    def to_csv(self, file_path):
        '''Convert mwsql object into CSV file'''
        # creates the specified outfile if it doesn't exist
        # raises an error if the outfile already exist to avoid overwriting

        raise NotImplementedError

    def head(self, n_lines=5):
        '''Display first n rows'''

        rows = self.rows
        print(self.col_names)
        return [next(rows) for _ in range(n_lines)]


<a class="anchor" id="playground"></a>
## Playground
[Return to top](#top)

In [44]:
dump = load('simplewiki', 'change_tag_def')

Downloading https://dumps.wikimedia.org/simplewiki/latest/simplewiki-latest-change_tag_def.sql.gz
Progress: 100% [2122.0 / 2122.0] bytes

In [46]:
dump.name

'change_tag_def'

In [47]:
# display dump size in bytes
dump.size

2122

In [48]:
dump.col_names

['ctd_id', 'ctd_name', 'ctd_user_defined', 'ctd_count']

In [49]:
# original SQL data types
dump.dtypes

{'ctd_id': 'int(10) unsigned NOT NULL AUTO_INCREMENT',
 'ctd_name': 'varbinary(255) NOT NULL',
 'ctd_user_defined': 'tinyint(1) NOT NULL',
 'ctd_count': 'bigint(20) unsigned NOT NULL DEFAULT 0'}

In [51]:
dump.primary_key

['ctd_id']

In [50]:
dump.head()

['ctd_id', 'ctd_name', 'ctd_user_defined', 'ctd_count']


[['1', 'mw-replace', '0', '9840'],
 ['2', 'visualeditor', '0', '299284'],
 ['3', 'mw-undo', '0', '55938'],
 ['4', 'mw-rollback', '0', '69231'],
 ['5', 'mobile edit', '0', '222801']]

In [56]:
# create generator
rows = dump.rows

In [57]:
for _ in range(10):
    print(next(rows))

['1', 'mw-replace', '0', '9840']
['2', 'visualeditor', '0', '299284']
['3', 'mw-undo', '0', '55938']
['4', 'mw-rollback', '0', '69231']
['5', 'mobile edit', '0', '222801']
['6', 'mobile web edit', '0', '215421']
['7', 'very short new article', '0', '28126']
['8', 'visualeditor-wikitext', '0', '19553']
['9', 'mw-new-redirect', '0', '28885']
['10', 'visualeditor-switched', '0', '17278']
