# Prepare Dataset

In [4]:
from functools import reduce
import glob
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import os
import os.path
import pprint
import random
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tarfile
import time
import unicodedata
import urllib.request
import zipfile

In [10]:
def reporthook(count, block_size, total_size):
    global start_time
    if count == 0:
        start_time = time.time()
        return
    if count * block_size >= total_size:
        print('')
        return
    if count % 100 != 0:
        return
    print('.', end='')

def maybe_download(url, destination_path):
    if os.path.exists(destination_path):
        return
    
    dirpath, filename = os.path.split(destination_path)
    os.makedirs(dirpath, exist_ok=True)
    urllib.request.urlretrieve(url, destination_path, reporthook)

def reduce_folder_depth(folder_path):
    parent_path, folder_name = os.path.split(folder_path)
    files = os.listdir(folder_path)
    if len(files) == 1 and files[0] == folder_name:
        tmp_dir_path = os.path.join(folder_path, '.__tmp_dir__')
        os.rename(os.path.join(folder_path, files[0]), tmp_dir_path)
        innerfiles = os.listdir(tmp_dir_path)
        for f in innerfiles:
            os.rename(os.path.join(tmp_dir_path, f), os.path.join(folder_path, f))
        os.rmdir(tmp_dir_path)
        

def maybe_unzip(source_zip_file_path):
    dirpath, filename = os.path.split(source_zip_file_path)
    name, ext = os.path.splitext(filename)
    destination_folder = os.path.join(dirpath, name)

    if os.path.exists(destination_folder):
        return
    if destination_folder and not os.path.exists(destination_folder):
        os.makedirs(destination_folder)
    with zipfile.ZipFile(source_zip_file_path) as z:
        zip_ref.extractall(destination_folder)
    reduce_folder_depth(destination_folder)
    return destination_folder
        
def maybe_extract_tar(source_tar_file_path):
    dirpath, filename = os.path.split(source_tar_file_path)
    name, ext = os.path.splitext(filename)
    destination_folder = os.path.join(dirpath, name)

    if os.path.exists(destination_folder):
        return destination_folder
    if destination_folder and not os.path.exists(destination_folder):
        os.makedirs(destination_folder)
    with tarfile.open(source_tar_file_path) as t:
        t.extractall(destination_folder)
    reduce_folder_depth(destination_folder)
    return destination_folder

In [6]:
def bundle_dataset(list):
    files = [
        {
            'name': f,
            'path': os.path.join(data_path, f),
            'size': os.stat(os.path.join(data_path, f)).st_size,
            'handle': None
        }
        for f in list
    ]
    total_size = reduce((lambda x, y: x + y), map((lambda x: x['size']), files))
    return {
        'files': files,
        'total_size': total_size
    }

In [7]:
DATA_URL = 'http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz'
DATA_SOURCE_PATH = os.path.join('data', 'CBTest.tgz')

maybe_download(DATA_URL, DATA_SOURCE_PATH)
extracted_dir = maybe_extract_tar(DATA_SOURCE_PATH)

data_path = os.path.join(extracted_dir, 'data')
input_files =  os.listdir(data_path)

train_dataset = bundle_dataset([f for f in input_files if '_train' in f])
valid_dataset = bundle_dataset([f for f in input_files if '_valid' in f])
test_dataset = bundle_dataset([f for f in input_files if '_test' in f])

pp = pprint.PrettyPrinter(indent=2)
pp.pprint(train_dataset)
pp.pprint(valid_dataset)
pp.pprint(test_dataset)

{ 'files': [ { 'handle': None,
               'name': 'cbt_train.txt',
               'path': 'data/CBTest/data/cbt_train.txt',
               'size': 25742364},
             { 'handle': None,
               'name': 'cbtest_CN_train.txt',
               'path': 'data/CBTest/data/cbtest_CN_train.txt',
               'size': 295933246},
             { 'handle': None,
               'name': 'cbtest_NE_train.txt',
               'path': 'data/CBTest/data/cbtest_NE_train.txt',
               'size': 248333387},
             { 'handle': None,
               'name': 'cbtest_P_train.txt',
               'path': 'data/CBTest/data/cbtest_P_train.txt',
               'size': 836819208},
             { 'handle': None,
               'name': 'cbtest_V_train.txt',
               'path': 'data/CBTest/data/cbtest_V_train.txt',
               'size': 247098043}],
  'total_size': 1653926248}
{ 'files': [ { 'handle': None,
               'name': 'cbt_valid.txt',
               'path': 'data/CBTest/data/c

In [16]:
DATA_URL = 'http://archives.textfiles.com/stories.tar.gz'
DATA_SOURCE_PATH = os.path.join('data', 'stories.tgz')

maybe_download(DATA_URL, DATA_SOURCE_PATH)
extracted_dir = maybe_extract_tar(DATA_SOURCE_PATH)

data_path = os.path.join(extracted_dir)
input_files =  os.listdir(data_path)

train_dataset = bundle_dataset([f for f in input_files if os.path.isfile(os.path.join(data_path, f))])

pp = pprint.PrettyPrinter(indent=2)
pp.pprint(train_dataset)
pp.pprint(valid_dataset)
pp.pprint(test_dataset)

{ 'files': [ { 'handle': None,
               'name': '100west.txt',
               'path': 'data/stories/100west.txt',
               'size': 20839},
             { 'handle': None,
               'name': '13chil.txt',
               'path': 'data/stories/13chil.txt',
               'size': 8457},
             { 'handle': None,
               'name': '14.lws',
               'path': 'data/stories/14.lws',
               'size': 5261},
             { 'handle': None,
               'name': '16.lws',
               'path': 'data/stories/16.lws',
               'size': 15294},
             { 'handle': None,
               'name': '17.lws',
               'path': 'data/stories/17.lws',
               'size': 10853},
             { 'handle': None,
               'name': '18.lws',
               'path': 'data/stories/18.lws',
               'size': 26624},
             { 'handle': None,
               'name': '19.lws',
               'path': 'data/stories/19.lws',
               'size': 17902

               'name': 'gold3ber.txt',
               'path': 'data/stories/gold3ber.txt',
               'size': 5558},
             { 'handle': None,
               'name': 'goldbug.poe',
               'path': 'data/stories/goldbug.poe',
               'size': 78372},
             { 'handle': None,
               'name': 'goldenp.txt',
               'path': 'data/stories/goldenp.txt',
               'size': 49790},
             { 'handle': None,
               'name': 'goldfish.txt',
               'path': 'data/stories/goldfish.txt',
               'size': 4948},
             { 'handle': None,
               'name': 'goldgoos.txt',
               'path': 'data/stories/goldgoos.txt',
               'size': 5003},
             { 'handle': None,
               'name': 'grav',
               'path': 'data/stories/grav',
               'size': 22515},
             { 'handle': None,
               'name': 'graymare.txt',
               'path': 'data/stories/graymare.txt',
              

               'size': 47601},
             { 'handle': None,
               'name': 'tcoa.txt',
               'path': 'data/stories/tcoa.txt',
               'size': 13216},
             { 'handle': None,
               'name': 'tctac.txt',
               'path': 'data/stories/tctac.txt',
               'size': 2301},
             { 'handle': None,
               'name': 'tearglas.txt',
               'path': 'data/stories/tearglas.txt',
               'size': 31407},
             { 'handle': None,
               'name': 'telefone.txt',
               'path': 'data/stories/telefone.txt',
               'size': 3405},
             { 'handle': None,
               'name': 'terrorbears.txt',
               'path': 'data/stories/terrorbears.txt',
               'size': 4816},
             { 'handle': None,
               'name': 'testpilo.hum',
               'path': 'data/stories/testpilo.hum',
               'size': 16609},
             { 'handle': None,
               'name': 'textfil