In [1]:
import numpy as np

In [15]:
class DataGen:
    
    def __init__(self, percent=100):
        self.n_chars = 0
        self.max_word_len = 34
        self.corpus = open('data/text8.txt').read()
        self.words = self.corpus.split(' ')
        width = int(len(self.words) * percent/100)
        self.words = self.words[:width]
        self.char2int, self.int2char = self.load_charset()
        self.vocab, self.word2int, self.int2word = self.get_vocab(self.words)
        del self.corpus
    
    def get_vocab(self, words):
        word2int = {}
        int2word = {}
        vocab = list(sorted(set(words)))
        for word in vocab:
            int2word[len(word2int)] = word
            word2int[word] = len(word2int)
            if len(word) > self.max_word_len:
                self.max_word_len = len(word)
#             if len(word) == 18:
#                 print(word)
                print(word, self.max_word_len)
        return vocab, word2int, int2word
    
    def load_charset(self):
        charset = list(sorted(set(self.corpus)))
        self.n_chars = len(charset)
        char2int, int2char = {}, {}
        j = 0
        for k in range(len(charset)):
            char = charset[k]
            int2char[k] = char
            char2int[char] = k
            
        return char2int, int2char
        
    
    def word2vec(self, word):
        cons = np.zeros((self.max_word_len, self.n_consonant), dtype=np.float32)
        vowel = np.zeros((self.max_word_len, self.n_vowel), dtype=np.float32)
        for i in range(len(word)):
            char = word[i]
            t = self.char2tup[char]
            cons[i][t[0]] = 1
            vowel[i][t[1]] = 1
        con, vow = self.char2tup['#']
        cons[i+1:, con] = 1
        vowel[i+1:, vow] = 1
        vec = np.concatenate([cons, vowel], axis=1)
#         np.random.shuffle(vec)
        return vec
    
    def one_hot(self, n, size):
        v = np.zeros((size,))
        v[n] = 1
        return v
    
    def one_hot_decode(self, vec):
        indexes = np.argmax(vec, axis=1)
        words = []
        for i in indexes:
            words.append(self.int2word[i])
        return words
            
        
    
    def sentense_to_vec(self, words):
        vecs = []
        for w in words:
            vecs.append(self.word2vec(w))
        vec = np.concatenate(vecs)
        return vec
    

    def gen(self, batch_size=100, n_batches=-1, windows_size=4):
        batch = 0
        n_words = len(self.words)
        if n_batches > 0:
            n_words = batch_size * n_batches
        c_word = windows_size // 2
        while True:
            x = []
            y = []
            for i in range(batch_size):
                j = c_word - windows_size // 2
                k = c_word + windows_size // 2 + 1
                context = self.words[j:k]
                target = context.pop(windows_size//2)
                vec = self.sentense_to_vec(context)
                x.append(vec)
                y.append(self.one_hot(self.word2int[target], len(self.vocab)))
                c_word += 1
            batch += 1
            if c_word > n_words - windows_size // 2:
                print("word ", c_word)
                c_word = windows_size // 2
            rand = np.random.choice(batch_size, size=batch_size, replace=False)
            x = np.stack(x)
            x = x.reshape((x.shape[0], x.shape[1], x.shape[2],1))
            y = np.stack(y)
#             y = y[rand]
#             x = x[rand]
            yield x, y
            
            
                
        
        


In [16]:
dg = DataGen()

a 1
aa 2
aaa 3
aaaa 4
aaaaaacceglllnorst 18
aaupaukunukunumuhumuh 21
abcdefghijklmnopqrstuvwxyz 26
accusativeillumillamilludill 28
antidisestablishmentarianistically 34
bababadalgharaghtakamminarronnkonnbronntonnerronntuonnthunntrovarrhounawnskawntoohoohoordenenthurnuk 100


In [7]:
print(dg.max_word_len)
dg.char2int

100


{' ': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26}

In [5]:
# dg.char2tup

In [6]:
# gen = dg.gen(n_batches=40, batch_size=10)
# for i in range(120):
#     x, y = next(gen)