<h1>Text data demo</h1>
This notebook implements an lstm network trained character by character on an input text (much like Andrej Karpathy's char-rnn).

In [1]:
from ntb.datasets import textdata
import ntb
import matplotlib.pyplot as pp
import numpy as np

%matplotlib inline

In [2]:
seq_length = 30
text_data,encode,decode = textdata.load("us_constitution.txt",seq_length=seq_length,stride=10)
num_train = len(text_data['X_train'])

In [3]:
vocab_length = len(text_data['charmap'])
embedding_dim = 32
hidden_dim = 64
batch_size = 25

mask_ones = np.ones([batch_size,seq_length])/seq_length
graph = ntb.ComputationGraph()

with ntb.default_graph(graph):
    x = ntb.Placeholder(shape=[-1,-1])
    y = ntb.Placeholder(shape=[-1,-1])
    lr = ntb.Placeholder()
    mask = ntb.Placeholder(shape=[-1,-1])
    sample_temperature = ntb.Placeholder()
    
    W_embed = ntb.Variable(initializer=ntb.xavier_init(shape=(vocab_length,embedding_dim)))
    x_emb = ntb.Embed(x,W_embed)
    
    lstm_1 = ntb.Lstm(num_units=hidden_dim)
    init_state_1 = lstm_1.get_zero_state(batch_size,as_node=True)
    h1,out_state_1 = lstm_1(x_emb,init_state_1)
    
    lstm_2 = ntb.Lstm(num_units=hidden_dim)
    init_state_2 = lstm_2.get_zero_state(batch_size,as_node=True)
    h2,out_state_2 = lstm_2(h1,init_state_2)
    
    Wscores = ntb.Variable(shape=[hidden_dim,vocab_length])
    bscores = ntb.Variable(value=np.zeros(vocab_length))
    scores = ntb.TemporalAffine(h2,Wscores,bscores)
    loss = ntb.TemporalCE(scores,y,mask)
    
    sample_prob = ntb.Softmax(scores/sample_temperature)
    sample = ntb.Sample(sample_prob)
    
    optim = ntb.Optim(loss_node=loss,lr=lr,update_rule='adam')

In [4]:
def get_batch():
    idx = np.random.choice(num_train,batch_size)
    return text_data['X_train'][idx],text_data['y_train'][idx]

def print_sample(length,temp=1,prime_text=" "):
    sample_out = []
    prime_text_seq = encode(prime_text).reshape(1,-1)
    state_1,state_2 = lstm_1.get_zero_state(1),lstm_2.get_zero_state(1)
    state_1,state_2,smp = graph.run([out_state_1,out_state_2,sample],
                                    assign_dict={x:prime_text_seq,
                                                 init_state_1:state_1,
                                                 init_state_2:state_2,
                                                 sample_temperature:temp})
    smp = smp[:,-1:]
    sample_out.append(smp[0,0])
    for i in range(length):
        state_1,state_2,smp = graph.run([out_state_1,out_state_2,sample],
                                        assign_dict={x:smp,init_state_1:state_1,
                                                     init_state_2:state_2,
                                                     sample_temperature:temp})
        sample_out.append(smp[0,0])
    sample_out = decode(np.array(sample_out))
    print(sample_out)

In [9]:
num_iterations = 15000
print_every = 1000
def train(num_iterations=num_iterations,print_every=print_every,learning_rate=3e-4):
    cumuloss_tr = 0
    for i in range(num_iterations+1):
        X,Y = get_batch()
        loss_tr,_ = graph.run([loss,optim],assign_dict={x:X,y:Y,lr:learning_rate,mask:mask_ones})
        cumuloss_tr+=loss_tr
        if i == 0:
            cumuloss_tr*=print_every
        if i % print_every == 0:
            print("Train loss:",cumuloss_tr/print_every,"\n")
            print("Sample:\n----------------------------------------")
            print_sample(200,prime_text=" ",temp=.7)
            print("----------------------------------------\n")
            cumuloss_tr = 0

In [5]:
train()

Loss: 4.33766992884 

Sample:
----------------------------------------
gWtxwVuWW7fD)HRyHFzlIIf,7gO)WAzmWHBWJyVHnd51IUW;aD9RgmG08i7:DEkRq,ML ynK(B;:Ub2K5W
5Kq3dDE.qkSnOt1BW4Jgdmx
WQFmulQD4"PYb9S"(2DqMH6ldu:,(nSO;P33QN;bdy))Py3H,ByBdN5llWGDG(3bJEqFcQDDne)YV(yFjc)f.Q)foj))J4
----------------------------------------

Loss: 3.03453152001 

Sample:
----------------------------------------
edt Thl an .arsif usacslhan the of be tevy the of on the  on un fhe leceteift SarLo qed als ucer e, gol ofnanes or teat ord ubes onotinf ters on an an atiby  ens Thavhe itase Pos the af te fhans otecte
----------------------------------------

Loss: 2.31354856243 

Sample:
----------------------------------------
on he erectatont; and be
the Erent the te beerss the so the the cocege on ole and
of there Erece sos innt voor thal of he pontenr of the Ceureseses of the in on thall fofent of Reint be Are the PunonSs
----------------------------------------

Loss: 2.01018536052 

Sample:
-------------------------

In [8]:
train()

Loss: 1.23290025043 

Sample:
----------------------------------------
the President of the
United States, the Torom of the United States, other such Senator of Senators, of the United States, in one for the Senates shall have such Part, and all the Congress shall have be
----------------------------------------

Loss: 1.13735143121 

Sample:
----------------------------------------
the Legislature at the cimes shall exceed to the United States
shall be present.

Amendment 10
Whe President of the same other State; he shall recessed yithort the Congress shall not be decurmed of the
----------------------------------------

Loss: 1.1092698446 

Sample:
----------------------------------------
temy shall be a Members for the Congress shall shall be proose the sight, with the State of the United States.

Evengerions by Law be appointed to any person have devolveed in the Vice Pother constitut
----------------------------------------

Loss: 1.08937795578 

Sample:
--------------------------

In [49]:
#let's generate a longer sample and prime the RNN with "The se"

print_sample(length=1000,prime_text="The se",temp=.5)

veral States shall be appointed by all such Senators and Representatives shall have been a President or Importation thereof, institution of the United States, and post the United States, and such Bills of the suple States, and with he shall not act all State in the President of the such House of Representatives in the several States, and the several States shall be presence of the United States, without the Consent of the United States; and in a majority of the United States, shall be made on the United States or of his a President of the Congress may by Law appoint and person shall be appropriate legislation.

Amendment 11
The executive thereof the two-thirds of the United States of account of the United States, and be entitled to all grants of the United States or Tribe, and a Manner shall be a Representatives, shall be appointed to any such Measess oncess of the Senate and House of Males of the United States, and which such Consist of the President or Citizens of the United States, 

In [10]:
import pickle

In [11]:
#run this to save the current model
with open('textdata_model.pickle','wb') as f:
    pickle.dump(graph.save(),f)

In [8]:
#run this to load the saved model
with open('textdata_model.pickle','rb') as f:
    graph.load(pickle.load(f))