In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf

In [None]:
## helper function

def sample(preds, temperature=1.0):
    # helper function to sample an index from a probability array
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

In [None]:
processed_df = pd.read_pickle('../data/processed_df.pkl')

In [None]:
temp_bioguide_id = 'H001055'
temp_df = processed_df[processed_df.bioguide_id==temp_bioguide_id]

In [None]:
temp_df.iloc[0].speech

In [None]:
# add BOS and EOS tokens
X_train = list(map(lambda x: ['[BOS]'] + list(x.lower()) + ['[EOS]'], temp_df.speech.values))

# merge all texts end to end
merged_X_train = np.hstack(X_train)

In [None]:
lookup_layer = tf.keras.layers.StringLookup(num_oov_indices=0,output_mode='one_hot')
lookup_layer.adapt(merged_X_train)

In [None]:
def split_seq(seq):
    input_seq = seq[:,:-1,:] # from char 0 to char [EOS] - 1
    target_seq = seq[:,1:,:] # from char 1 to EOS

    return input_seq,target_seq

In [None]:
X_train_idx = lookup_layer(merged_X_train)

train_ds = tf.data.Dataset.from_tensor_slices(X_train_idx)
train_ds = train_ds.batch(101,drop_remainder=True).batch(32).map(lambda x: split_seq(x)).prefetch(tf.data.AUTOTUNE)

In [None]:
class RNN_LM(tf.keras.Model):
    def __init__(self):
        super().__init__()
        #self.lstm1 = tf.keras.layers.LSTM(units=512,return_sequences=True)
        self.lstm2 = tf.keras.layers.LSTM(units=256,return_sequences=True)
        #self.dense1 = tf.keras.layers.Dense(units=256,activation='relu')
        self.dense2 = tf.keras.layers.Dense(units=128,activation='relu')
        self.dense3 = tf.keras.layers.Dense(units=len(lookup_layer.get_vocabulary()))

    def call(self,x):
        
        #hiddens = self.lstm2(self.lstm1(x))
        #outputs = self.dense3(self.dense2(self.dense1(hiddens)))
        hiddens = self.lstm2(x)
        outputs = self.dense3(hiddens)
        
        return outputs

    def train_step(self, data):
        
        input_seq,target_seq = data

        with tf.GradientTape() as tape:
            predicted_seq = self(input_seq)
            
            loss = self.compiled_loss(target_seq, predicted_seq, regularization_losses=self.losses)

        
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(target_seq, predicted_seq)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

In [None]:
model = RNN_LM()

optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001)

loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.2, patience=4, verbose=1)

early_stopper = tf.keras.callbacks.EarlyStopping(monitor='loss',min_delta=0, patience=5, verbose=2)

model.compile(optimizer=optimizer, loss=loss)

model.fit(x=train_ds, epochs=500, verbose=2, callbacks=[early_stopper,lr_scheduler], shuffle=True)

model.summary()

In [None]:
# prediction
initial_text = ''

initial_tokens = ['[BOS]'] + list(initial_text)
preds_tensor = np.zeros(shape=(1,len(initial_tokens),len(lookup_layer.get_vocabulary())))
preds_tensor[0,:,:] = lookup_layer(initial_tokens)
preds_tensor = tf.cast(tf.constant(preds_tensor),dtype=tf.float32)
decoded_str = initial_text

while len(decoded_str) < 100:
    last_of_preds = model(preds_tensor)[:,-1:,:]
    decoded_ch = lookup_layer.get_vocabulary()[sample(tf.squeeze(tf.math.softmax(last_of_preds)))]
    #decoded_ch = lookup_layer.get_vocabulary()[np.argmax(last_of_preds)]
    decoded_str += decoded_ch
    
    preds_tensor = tf.concat([preds_tensor,tf.expand_dims(tf.expand_dims(lookup_layer(decoded_ch),axis=0),axis=0)],axis=1)
    
print(decoded_str)