In [118]:
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as np
import numpy as np

import os
import codecs
import nltk.data
import collections
from glob import glob
from nltk.tokenize import RegexpTokenizer
from nltk.corpus import stopwords as nltk_stopwords

punctuation_remover = RegexpTokenizer(r'\w+')
stopwords = nltk_stopwords.words('english')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

In [121]:
sentences = []
for filename in glob(os.path.join('./data', "*.txt")):
    if 'name.txt' in filename:
        names = read_name_file(filename)
    with codecs.open(filename, 'r', 'utf-8') as f:
        sentences.extend([sentence for sentence in tokenizer.tokenize(f.read())
                          if any(word in nameword2idx.keys() for word in sentence)])

In [131]:
with codecs.open(filename, 'r', 'utf-8') as f:
    sentences = tokenizer.tokenize(f.read())

sentences[1]
nameword2idx.keys() 

[]

In [138]:
with open(os.path.join('./data', 'name.txt')) as f:
    name_lists = f.readlines()

names = [name.strip().split('\t') for name in name_lists]

#name_counter = collections.Counter([word for name_set in names for name in name_set for word in name.split()])
#print(name_counter.most_common(100))

name2idx = {}
nameword2idx = {}

for idx, name_set in enumerate(names):
    for name in name_set:
        word_in_name = name.split()
        name_without_punctuation = " ".join(punctuation_remover.tokenize(name))

        for name in [name, name_without_punctuation]:
            try:
                if idx not in name_dict[name]:
                    name2idx[name].append(idx)
            except:
                name2idx[name] = []
                name2idx[name].append(idx)
        
        for name in word_in_name:
            try:
                if idx not in name_dict[name]:
                    nameword2idx[name].append(idx)
            except:
                nameword2idx[name] = []
                nameword2idx[name].append(idx)

idx2name = dict(zip(name2idx.values(), name2idx.keys()))

In [145]:
word_in_name

['Zacharias', 'Smith']

In [110]:
def stopword_filter(text):
    return " ".join([word for word in text.split() if word not in stopwords])

In [37]:
vocab_size = 50000

In [129]:
def read_name_data(data_dir):
    with open(os.path.join(data_dir, 'name.txt')) as f:
        name_lists = f.readlines()
        
    names = [name.strip().split('\t') for name in name_lists]

    #name_counter = collections.Counter([word for name_set in names for name in name_set for word in name.split()])
    #print(name_counter.most_common(100))

    name2idx = {}
    nameword2idx = {}

    for idx, name_set in enumerate(names):
        for name in name_set:
            name2idx[name] = idx
            word_in_name = name.split()
            name_without_punctuation = " ".join(punctuation_remover.tokenize(name))

            for name in [name, name_without_punctuation] + word_in_name:
                try:
                    if idx not in name_dict[name]:
                        name2idx[name].append(idx)
                except:
                    nameword2idx[name] = []
                    nameword2idx[name].append(idx)

    idx2name = dict(zip(name2idx.values(), name2idx.keys()))
    
    return names, name2idx, idx2name, nameword2idx

def read_data(data_dir):
    tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
    
    text = ""
    for filename in glob(os.path.join(data_dir, "*.txt")):
        if 'name.txt' not in filename:
            continue
        with open(filename) as f:
            text += f.read()
    tokenizer.tokenize(text)
    return text.split()

def build_dataset(words):
    count = [['UNK', -1]]
    count.extend(collections.Counter(words).most_common(vocab_size - 1))

    dictionary = dict()
    for word, _ in count:
        dictionary[word] = len(dictionary)

    data = list()
    unk_count = 0
    for word in words:
        if word in dictionary:
            index = dictionary[word]
        else:
            index = 0
            unk_count = unk_count + 1
        data.append(index)

    count[0][1] = unk_count
    reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))

    return dictionary, reverse_dictionary, data, count

def generate_batch(batch_size, num_skips, skip_window):
    global data_index
    assert batch_size % num_skips == 0
    assert num_skips <= 2 * skip_window
    
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
    span = 2 * skip_window + 1 # [ skip_window target skip_window ]
    buffer = collections.deque(maxlen=span)
    
    for _ in range(span):
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)

    for i in range(batch_size // num_skips):
        target = skip_window    # target label at the center of the buffer
        targets_to_avoid = [ skip_window ]
        for j in range(num_skips):
            while target in targets_to_avoid:
                target = random.randint(0, span - 1)
            targets_to_avoid.append(target)
            batch[i * num_skips + j] = buffer[skip_window]
            labels[i * num_skips + j, 0] = buffer[target]
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)

    return batch, labels

In [130]:
words = read_data('./data/')
print('Data size :', len(words))

names, name2idx, idx2name, nameword2idx = read_name_data("./data/")
print('# of names :', len(names))

dictionary, reverse_dictionary, data, count = build_dataset(words)
print('Most common words (+UNK) :', count[:5])

data_index = 0

Data size : 396
# of names : 189
Most common words (+UNK) : [['UNK', 0], ('Weasley', 9), ('Potter', 6), ('Longbottom', 4), ('Dursley', 4)]


In [None]:
def generate_batch(batch_)

In [None]:
flags = tf.app.flags

flags.DEFINE_string("data_dir", './data/', "Directory which contains data files")

FLAGS = flags.FLAGS

class Options(object):
    def __init__(self):
        self.data_dir = FLAGS.data_dir

def main():
    if not FLAGS.data_dir:
        print("--data_dir must be specified")
        sys.exit(1)

    opts = Options()
    read_data(opts.data_dir)
