In [None]:
from os.path import join, expanduser, exists
from urllib.error import URLError
from urllib.request import urlopen

In [None]:
import numpy as np

In [None]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torchtext import vocab, data

In [None]:
PATH = expanduser(join('~', 'data', 'fastai', 'nietzsche', 'nietzsche.txt'))

In [None]:
def set_random_seed(state=1):
    gens = (np.random.seed, torch.manual_seed, torch.cuda.manual_seed)
    for set_state in gens:
        set_state(state)

In [None]:
RANDOM_STATE = 1
set_random_seed(RANDOM_STATE)

## Dataset Downloading

In [None]:
def download(url, download_path, expected_size):
    if exists(download_path):
        print('The file was already downloaded')
        return
    
    try:
        r = urlopen(url)
    except URLError as e:
        print(f'Cannot download the data. Error: {e}')
        return
    
    if r.status != 200:
        print(f'HTTP Error: {r.status}')
        return
    
    data = r.read()
    if len(data) != expected_size:
        print(f'Invalid downloaded array size: {len(data)}')
        return
    
    text = data.decode(encoding='utf-8')
    with open(download_path, 'w') as file:
        file.write(text)
        
    print(f'Downloaded: {download_path}')

In [None]:
URL = 'https://s3.amazonaws.com/text-datasets/nietzsche.txt'

In [None]:
download(URL, PATH, 600901)

In [None]:
def split(path, train_size=0.8):
    with open(path) as file:
        content = file.read()
    n = int(len(content) * train_size)
    return content[:n], content[n:]

In [None]:
train_text, valid_text = split(PATH)
print(len(train_text))
print(len(valid_text))

In [None]:
text = train_text + valid_text
chars = sorted(list(set(text)))
vocab_size = len(chars) + 1
print(f'Vocab size: {vocab_size}')

In [None]:
chars.insert(0, '\0')

In [None]:
char_to_index = {c: i for i, c in enumerate(chars)}
index_to_char = {i: c for i, c in enumerate(chars)}
indicies = [char_to_index[char] for char in text]

## Dataset Preparation

In [None]:
cs = 8