In [153]:
import numpy as np
import tensorflow as tf
import collections

### Reading data and converting to bigrams

In [154]:
def read_data(file):
    with open(file, 'r') as f:
        text = f.read()
        text = text.replace('\n', '')
        start_idx = text.find('VIKRAM AND THE VAMPIRE')
        end_idx = text.find('FOOTNOTES')
        text = text[start_idx:end_idx]
        text = text.lower().strip()
        
        bigram_text = [''.join(text[char:char + 2]) for char in range(0, len(text) - 2, 2)]
        
    return bigram_text

In [155]:
bigram_text = read_data('2400-0.txt')
print('no. of bigrams:', len(bigram_text))
bigram_text[0:10]

no. of bigrams: 203402


['vi', 'kr', 'am', ' a', 'nd', ' t', 'he', ' v', 'am', 'pi']

### Creating Dictionary

In [156]:
def create_dict(bigrams):
    
    count = []
    count.extend(collections.Counter(bigrams).most_common())
    
    dictionary = dict({'UNK':0})
    for char, freq in count:
    
        dictionary[char] = len(dictionary)    
            
    rev_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
    
    return dictionary, rev_dictionary, count


In [157]:
dictionary, rev_dictionary, count = create_dict(bigram_text)
vocab_size = len(dictionary)

print('dictionary', list(dictionary)[:10])
print('reverse dictionary', list(rev_dictionary)[:10])
print('most common words:', count[0:5])
print('len of dictionary:', len(dictionary))

dictionary ['ul', 'nz', 'nh', 'r ', 'y.', '93', 'ha', '24', 'er', '14']
reverse dictionary [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
most common words: [('e ', 5889), (' t', 5669), ('th', 5586), ('he', 5425), (' a', 3910)]
len of dictionary: 1126


### Converting from bigrams to int

In [158]:
def bigram_to_int(bigram_text, dictionary):
    
    bigram_int = []
    
    for bigram in bigram_text:
        
        if bigram in dictionary:
            bigram_int.append(dictionary[bigram])
        else:
            bigram_int.append(dictionary['UNK'])

    return bigram_int

In [159]:
bigram_int = bigram_to_int(bigram_text, dictionary)
print(list(bigram_int[0:10]))

[147, 288, 109, 5, 15, 2, 4, 155, 109, 196]


### Generating batches

In [117]:
def next_batch(bigrams, batch_size, batch_num):
    global vocab_size
    
    segments = len(bigrams) // batch_size
    
    seg_idx = [offset * segments for offset in range(batch_size)]

    batch_data = np.zeros((batch_size,vocab_size),dtype=np.float32)
    batch_labels = np.zeros((batch_size,vocab_size),dtype=np.float32)
    
    # + 1 for extra label
    #bigram_data = bigrams[batch_num*batch_size : batch_num*batch_size + batch_size + 1]
    
    for bi in range(batch_size):
        batch_data[bi, dictionary[bigram_data[bi]]] = 1
        batch_labels[bi, dictionary[bigram_data[bi]] + 1] = 1
        
    return batch_data, batch_labels

In [118]:
batch_data, batch_labels = next_batch(bigram_text, 64, 0)
#np.set_printoptions(threshold=np.nan)
#batch_labels.shape
#batch_data.shape

### Unrolling batches

In [162]:
def unroll_batches(num_unroll, batch_size):
    
    unroll_data = []
    unroll_labels = []
    
    for u_idx in range(num_unroll):
        
        batch_data, batch_labels = next_batch(bigram_text, batch_size, u_idx)
        
        unroll_data.append(batch_data)
        unroll_labels.append(batch_labels)
        
    return unroll_data, unroll_labels

In [163]:
unroll_data, unroll_labels = unroll_batches(5, 5)
print('unroll data shape:', np.array(unroll_data).shape)
print('unroll labels shape:', np.array(unroll_labels).shape)

for ui,(dat,lbl) in enumerate(zip(unroll_data,unroll_labels)):   
    print('\n\nUnrolled index %d'%ui)
    dat_ind = np.argmax(dat,axis=1)
    lbl_ind = np.argmax(lbl,axis=1)
    print('\tInputs:')
    for sing_dat in dat_ind:
        print('\t%s (%d)'%(rev_dictionary[sing_dat],sing_dat),end=", ")
    print('\n\tOutput:')
    for sing_lbl in lbl_ind:        
        print('\t%s (%d)'%(rev_dictionary[sing_lbl],sing_lbl),end=", ")

unroll data shape: (5, 5, 1126)
unroll labels shape: (5, 5, 1126)


Unrolled index 0
	Inputs:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 
	Output:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 

Unrolled index 1
	Inputs:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 
	Output:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 

Unrolled index 2
	Inputs:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 
	Output:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 

Unrolled index 3
	Inputs:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 
	Output:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 

Unrolled index 4
	Inputs:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 
	Output:
	vi (147), 	ne (64), 	he (4), 	re (16), 	 s (17), 

### Recurrent Neural Network

In [160]:
def next_batch(bigrams, batch_size):
    global vocab_size
    
    segments = len(bigrams) // batch_size
    
    seg_idx = [offset * segments for offset in range(batch_size)]

    batch_data = np.zeros((batch_size,vocab_size),dtype=np.float32)
    batch_labels = np.zeros((batch_size,vocab_size),dtype=np.float32)
    
    # + 1 for extra label
    #bigram_data = bigrams[batch_num*batch_size : batch_num*batch_size + batch_size + 1]
    
    for bi in range(batch_size):
        
        if seg_idx[bi]+1 >= len(bigrams):
            self.seg_idx[bi] = bi * segments
        
        batch_data[bi, dictionary[bigrams[seg_idx[bi]]]] = 1
        batch_labels[bi, dictionary[bigrams[seg_idx[bi]]]] = 1
        
        seg_idx[bi] = (seg_idx[bi]+1) % len(bigrams)
        
    return batch_data, batch_labels

In [161]:
batch_data, batch_labels = next_batch(bigram_text, 64, 0)

In [164]:
batch_size = 64
segments = len(bigram_text) // batch_size
seg_idx = [offset * segments for offset in range(batch_size)]
seg_idx

[0,
 3178,
 6356,
 9534,
 12712,
 15890,
 19068,
 22246,
 25424,
 28602,
 31780,
 34958,
 38136,
 41314,
 44492,
 47670,
 50848,
 54026,
 57204,
 60382,
 63560,
 66738,
 69916,
 73094,
 76272,
 79450,
 82628,
 85806,
 88984,
 92162,
 95340,
 98518,
 101696,
 104874,
 108052,
 111230,
 114408,
 117586,
 120764,
 123942,
 127120,
 130298,
 133476,
 136654,
 139832,
 143010,
 146188,
 149366,
 152544,
 155722,
 158900,
 162078,
 165256,
 168434,
 171612,
 174790,
 177968,
 181146,
 184324,
 187502,
 190680,
 193858,
 197036,
 200214]