In [3]:
"""
In this notebook i tried to make a simple version of LSTM visualization.
https://github.com/Praneet9/Visualising-LSTM-Activations

Above sample is complicated because works on  text, we cannot be sure of distribution
of data at book.

I will use keras document lstm addition.
What we are searching is :
 x         y
1+3        4
8+6        15

Idea is which lstm cell gets activated for output values.
output is very simple 2 digit.
Also while visualizing i don't generate random input
i just generate all set of inputs to see when and how they are activated
"""
from random import seed
from random import randint
from numpy import array
from math import ceil
from math import log10
from math import sqrt
from numpy import argmax
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import TimeDistributed
from keras.layers import RepeatVector
from itertools import combinations 


# generate lists of random integers and their sum
def random_sum_pairs(n_examples, n_numbers, largest):
	X, y = list(), list()
	for i in range(n_examples):
		in_pattern = [randint(1,largest) for _ in range(n_numbers)]
		out_pattern = sum(in_pattern)
		X.append(in_pattern)
		y.append(out_pattern)
	return X, y

def possible_sum_pairs( n_numbers, largest):
	X, y = list(), list()
	comb = combinations([1, 2, 3,4,5,6,7,8,9,10], 2) 
	for in_pattern in comb:
		out_pattern = in_pattern[0]+in_pattern[1]
		X.append([in_pattern[0],in_pattern[1]])
		y.append(out_pattern)
	return X, y

# convert data to strings
def to_string(X, y, n_numbers, largest):
	max_length = n_numbers * ceil(log10(largest+1)) + n_numbers - 1
	Xstr = list()
	for pattern in X:
		strp = '+'.join([str(n) for n in pattern])
		strp = ''.join([' ' for _ in range(max_length-len(strp))]) + strp
		Xstr.append(strp)
	max_length = ceil(log10(n_numbers * (largest+1)))
	ystr = list()
	for pattern in y:
		strp = str(pattern)
		strp = ''.join([' ' for _ in range(max_length-len(strp))]) + strp
		ystr.append(strp)
	return Xstr, ystr

# integer encode strings
def integer_encode(X, y, alphabet):
	char_to_int = dict((c, i) for i, c in enumerate(alphabet))
	Xenc = list()
	for pattern in X:
		integer_encoded = [char_to_int[char] for char in pattern]
		Xenc.append(integer_encoded)
	yenc = list()
	for pattern in y:
		integer_encoded = [char_to_int[char] for char in pattern]
		yenc.append(integer_encoded)
	return Xenc, yenc

# one hot encode
def one_hot_encode(X, y, max_int):
	Xenc = list()
	for seq in X:
		pattern = list()
		for index in seq:
			vector = [0 for _ in range(max_int)]
			vector[index] = 1
			pattern.append(vector)
		Xenc.append(pattern)
	yenc = list()
	for seq in y:
		pattern = list()
		for index in seq:
			vector = [0 for _ in range(max_int)]
			vector[index] = 1
			pattern.append(vector)
		yenc.append(pattern)
	return Xenc, yenc

# generate an encoded dataset
def generate_data(n_samples, n_numbers, largest, alphabet,is_random=True):
	# generate pairs
	X, y = random_sum_pairs(n_samples, n_numbers, largest) if is_random else possible_sum_pairs( n_numbers, largest)
	# convert to strings
	X, y = to_string(X, y, n_numbers, largest)
	# integer encode
	X, y = integer_encode(X, y, alphabet)
	# one hot encode
	X, y = one_hot_encode(X, y, len(alphabet))
	# return as numpy arrays
	X, y = array(X), array(y)
	return X, y

# invert encoding
def invert(seq, alphabet):
	int_to_char = dict((i, c) for i, c in enumerate(alphabet))
	strings = list()
	for pattern in seq:
		string = int_to_char[argmax(pattern)]
		strings.append(string)
	return ''.join(strings)

# define dataset
seed(1)
n_samples = 1000
n_numbers = 2
largest = 10
alphabet = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', ' ']
n_chars = len(alphabet)
n_in_seq_length = n_numbers * ceil(log10(largest+1)) + n_numbers - 1
n_out_seq_length = ceil(log10(n_numbers * (largest+1)))
# define LSTM configuration
n_batch = 10
n_epoch = 30
# create LSTM
model = Sequential()
model.add(LSTM(100, input_shape=(n_in_seq_length, n_chars)))
model.add(RepeatVector(n_out_seq_length))
model.add(LSTM(50, return_sequences=True))
model.add(TimeDistributed(Dense(n_chars, activation='softmax')))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_3 (LSTM)                (None, 100)               45200     
_________________________________________________________________
repeat_vector_2 (RepeatVecto (None, 2, 100)            0         
_________________________________________________________________
lstm_4 (LSTM)                (None, 2, 50)             30200     
_________________________________________________________________
time_distributed_2 (TimeDist (None, 2, 12)             612       
Total params: 76,012
Trainable params: 76,012
Non-trainable params: 0
_________________________________________________________________
None


In [4]:
# train LSTM
for i in range(n_epoch):
	X, y = generate_data(n_samples, n_numbers, largest, alphabet)
	model.fit(X, y, epochs=1, batch_size=n_batch)

Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1


In [5]:

# evaluate on some new patterns
X, y = generate_data(n_samples, n_numbers, largest, alphabet)
result = model.predict(X, batch_size=n_batch, verbose=0)
# calculate error
expected = [invert(x, alphabet) for x in y]
predicted = [invert(x, alphabet) for x in result]
# show some examples
for i in range(20):
	print('Expected=%s, Predicted=%s' % (expected[i], predicted[i]))

Expected=13, Predicted=13
Expected=13, Predicted=13
Expected=13, Predicted=13
Expected= 9, Predicted= 9
Expected=11, Predicted=11
Expected=18, Predicted=18
Expected=15, Predicted=15
Expected=14, Predicted=14
Expected= 6, Predicted= 6
Expected=15, Predicted=15
Expected= 9, Predicted= 9
Expected=10, Predicted=10
Expected= 8, Predicted= 8
Expected=14, Predicted=14
Expected=14, Predicted=14
Expected=19, Predicted=19
Expected= 4, Predicted= 4
Expected=13, Predicted=13
Expected= 9, Predicted= 9
Expected=12, Predicted=12


In [6]:
import re

# Imports for visualisations
from IPython.display import HTML as html_print
from IPython.display import display
import keras.backend as K

#here we get a layer it could be any layer,
#the layer chosen here is a layer near the end of pipeline so 
#here we assume near layers to output will capture high level feature
lstm = model.layers[2]
attn_func = K.function(inputs = [model.get_input_at(0), K.learning_phase()],
           outputs = [lstm.output]
          )

In [7]:
# get html element
def cstr(s, color='black'):
	if s == ' ':
		return "<text style=color:#000;padding-left:10px;background-color:{}> </text>".format(color, s)
	else:
		return "<text style=color:#000;background-color:{}>{} </text>".format(color, s)
	
# print html
def print_color(t):
	display(html_print('&nbsp;&nbsp;&nbsp;'.join([str(test)+"="+cstr(ti0, color=ci0)+cstr(ti1, color=ci1) for ti0,ti1,ci0,ci1,test in t])))

def print_color2(t):
	display(html_print(''.join([cstr(ti0, color=ci0) for ti0,ci0 in t])))

# get appropriate color for value
def get_clr(value):
	colors = ['#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8'
		'#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
		'#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
		'#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e']
	value = int((value * 100) / 5)
	return colors[value]

# sigmoid function
def sigmoid(x):
	z = 1/(1 + np.exp(-x)) 
	return z

def visualize(output_values, result_list, test_list, cell_no):
	print("\nCell Number:", cell_no, "\n")
	text_colours = []
	for i in range(len(output_values)):
		#print(i,cell_no)
		res0 = get_clr(output_values[i][0][cell_no])
		res1 = get_clr(output_values[i][1][cell_no])        
		text = (result_list[i][0],result_list[i][1],res0, res1,invert_input(i,test_list) )
		text_colours.append(text)
	print_color(text_colours)

In [8]:
#let's check color importance
print_color2([(i,get_clr(i/100)) for i in range(0,95,5) ])

In [13]:
def get_predictions():
	
	result_list, output_values,test_list = [], [], []
	xs,ys = generate_data(1, n_numbers, largest, alphabet,False)
	for i in range(len(xs)): 
		x = xs[i].reshape(1,xs[i].shape[0],xs[i].shape[1])
		y = ys[i]      
		test_list.append(x)
		prediction = model.predict(x, batch_size=1, verbose=0)
		output = attn_func([x,1])[0][0]
		output = sigmoid(output)
		output_values.append(output)
		index = np.argmax(prediction)
		result = invert(prediction[0], alphabet)
		#print(result1)
		result_list.append(result)
	return output_values, result_list ,test_list

In [14]:
import numpy as np
output_values, result_list,test_list = get_predictions()


In [17]:
def invert_input(index,arr):
    return "".join([invert(x, alphabet) for x in arr[index]])
#dump all cell states and find which seems more imformative
for cell_no in range(50):
	visualize(output_values, result_list,test_list, cell_no)


Cell Number: 0 




Cell Number: 1 




Cell Number: 2 




Cell Number: 3 




Cell Number: 4 




Cell Number: 5 




Cell Number: 6 




Cell Number: 7 




Cell Number: 8 




Cell Number: 9 




Cell Number: 10 




Cell Number: 11 




Cell Number: 12 




Cell Number: 13 




Cell Number: 14 




Cell Number: 15 




Cell Number: 16 




Cell Number: 17 




Cell Number: 18 




Cell Number: 19 




Cell Number: 20 




Cell Number: 21 




Cell Number: 22 




Cell Number: 23 




Cell Number: 24 




Cell Number: 25 




Cell Number: 26 




Cell Number: 27 




Cell Number: 28 




Cell Number: 29 




Cell Number: 30 




Cell Number: 31 




Cell Number: 32 




Cell Number: 33 




Cell Number: 34 




Cell Number: 35 




Cell Number: 36 




Cell Number: 37 




Cell Number: 38 




Cell Number: 39 




Cell Number: 40 




Cell Number: 41 




Cell Number: 42 




Cell Number: 43 




Cell Number: 44 




Cell Number: 45 




Cell Number: 46 




Cell Number: 47 




Cell Number: 48 




Cell Number: 49 



In [18]:
visualize(output_values, result_list,test_list, 22)
#seems activating on summation less than 10 and also more activated at 2nd digit


Cell Number: 22 



In [19]:
visualize(output_values, result_list,test_list, 23)
#seems activating on all 1st digit


Cell Number: 23 



In [21]:
visualize(output_values, result_list,test_list, 38)
#seems activating on all 2nd digit when summ less than 9


Cell Number: 38 



In [23]:
visualize(output_values, result_list,test_list, 43)
#seems activating on  2nd digit when summ is more  than 14


Cell Number: 43 

