In [18]:
import pandas
import numpy as np
import zipfile
import tensorflow as tf
import collections
import random

In [19]:
text8_filepath = "text8.zip"

def loadFile(filename):
    with zipfile.ZipFile(filename) as _zipfile:
        with _zipfile.open(_zipfile.namelist()[0]) as data_file:
            data = tf.compat.as_str(data_file.read()).split()
    return data


In [37]:
def build_dataset(words, n_words):
    count = [["UNK", -1]]
    count.extend(collections.Counter(words).most_common(n_words - 1))
    dictionary = dict()
    
    for word, _ in count:
        dictionary[word] = len(dictionary)
        
    data = list()
    unk_count = 0
    for word in words:
        index = dictionary.get(word, 0)
        if index == 0:
            unk_count += 1
        data.append(index)
    count[0][1] = unk_count
    
    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
    return data, count, dictionary, reversed_dictionary

In [44]:
def generate_batch(n, data, batch_size=10, num_skips=10, skip_window=5):
#     batch_size = n + batch_size < len(data)? n : len(data) - n
    
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
    
    span = 2 * skip_window + 1
    buffer = collections.deque(maxlen=span)
    
    if n + span > len(data):
        n = 0
    
    buffer.extend(data[n: n + span])
    n += span
    for i in range(batch_size // num_skips):
        context_words = [w for w in range(span) if w != skip_window]
        words_to_use = random.sample(context_words, num_skips)
        for j, context_word in enumerate(words_to_use):
            batch[i * num_skips + j] = buffer[skip_window]
            labels[i * num_skips + j, 0] = buffer[context_word]
        
        if n == len(data):
            buffer[:] = data[:span]
            n = span
        else:
            buffer.append(data[n])
            n += 1
        
    n =  (n + len(data) - span) % len(data)

    return batch, labels, n


    

In [45]:
raw_text = loadFile(text8_filepath)
data, count, dictionary, reverse_dictionary = build_dataset(raw_text, 50000)

In [47]:
n = 0
batch, labels, n = generate_batch(n, data, batch_size=8, num_skips=2, skip_window=1)
# print(labels[1: 10])
for i in range(8):
    try:
        print(batch[i], reverse_dictionary[batch[i]],
            '->', labels[i, 0], reverse_dictionary[labels[i, 0]])
    except KeyError:
        print("Key not found: {}".format(labels[i, 0]))

3081 originated -> 12 as
3081 originated -> 5234 anarchism
12 as -> 3081 originated
12 as -> 6 a
6 a -> 12 as
6 a -> 195 term
195 term -> 6 a
195 term -> 2 of
