# Create Wiki Dataset

## Download the wiki dataset

In [None]:
# This code was taken from https://github.com/noanabeshima/wikipedia-downloader
# The dataset that is downloaded is the same as in the Pile: https://github.com/EleutherAI/the-pile?tab=readme-ov-file

import os
import json
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
from joblib import Parallel, delayed
import fire

def process_article(article):
    # Converts an article to a single text file
    title = article['title'].numpy().decode('UTF-8')
    text = article['text'].numpy().decode('UTF-8')
    return title+"\n\n"+text

def main(n_jobs: int = 1):
    # Downloads wikipedia dataset using tensorflow_datasets into 10 json files
    try:
        os.mkdir('output')
    except:
        pass

    for interval in range(10):
        if f'wikipedia-en-{interval}.json' not in os.listdir('./output'):
            ds = tfds.load('wikipedia/20200301.en', split=f'train[{str(interval)}0%:{str(interval+1)}0%]')

            result = Parallel(n_jobs=n_jobs)(delayed(process_article)(article) for article in tqdm(ds))

            result = json.dumps(result)

            file = open(f"output/wikipedia-en-{interval}.json", "w")
            file.write(result)
            file.close()

if __name__ == '__main__':
    fire.Fire(main)

# tensorflow==2.2.0
# tfds-nightly==3.1.0.dev202007060105
# fire==0.3.1
# tqdm==4.47.0
# joblib==0.15.1
# apache-beam==2.22.0 

## Select random subset of the dataset

In [21]:
import numpy as np
import pandas as pd
import os
def wiki_get_n_random_pages(n, file, num_chars): 
    with open(file, 'r') as raw_text:
        json_file = json.loads(raw_text.read())
        random_integers = np.random.choice(np.arange(0, len(json_file)), size=n, replace=False)
        collection = [] 
        for i in random_integers:
            subsection = json_file[i][0:min(len(json_file[i]), num_chars)]
            title = subsection.split('\n')[0]
            text = subsection[len(title):]
            collection.append([title, text])
    return collection

def collect_from_all_wiki(dir, n, num_chars, outfile):
    dfs = []
    count = 1
    for filename in os.listdir(dir):
        path = f"{dir}/{filename}"
        collection = wiki_get_n_random_pages(n=n, file=path, num_chars=num_chars)
        dfs.append(pd.DataFrame(collection, columns=["title", "text"]))
        print(f'finished file: {count}')
        count += 1
    df = pd.concat(dfs)
    df.to_csv(outfile)
    return df

In [30]:
df = collect_from_all_wiki('output', 200, 2000, 'wiki_2000.csv').sample(frac=1).reindex()


finished file: 1
finished file: 2
finished file: 3
finished file: 4
finished file: 5
finished file: 6
finished file: 7
finished file: 8
finished file: 9
finished file: 10
