In [None]:
"""
Refs:
    http://karpathy.github.io/2015/05/21/rnn-effectiveness/
    https://www.youtube.com/watch?v=1ZbLA7ofasY
    https://towardsdatascience.com/visualising-lstm-activations-in-keras-b50206da96ff
"""

from pathlib import Path

from IPython.display import HTML as html_print
from IPython.display import display


import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from lstm import WordLSTM, LSTM, LSTMCell
from data import WordTokenizer


__author__ = '__Girish_Hegde__'

In [None]:
def neuron_firing_hook(module, input, output):
    """ Forward hook

    Refs:
        https://www.youtube.com/watch?v=1ZbLA7ofasY
    """
    h_t, c_t = output
    acts = h_t.detach()
    module.net.activations = acts
    # acts =  F.softmax(h_t.detach(), dim=-1)
    # module.activations = acts


In [None]:
def attach_hook(net, layers=[0, ], type=LSTMCell):
    layers = set(layers)
    net.activations = None
    i = 0
    for name, module in net.named_modules():
        if isinstance(module, type):
            if i in layers:
                module.net = net
                module.firing_hook = module.register_forward_hook(neuron_firing_hook)
            i += 1


In [None]:
def get_hexcolor(rgb):
    r, g, b = rgb
    if isinstance(r, float):
        r, g, b = [int(v) for v in [r*255, g*255, b*255]]
    clr = [hex(v).replace('0x', '').zfill(2) for v in (r, g, b)]
    clr = '#' + ''.join(clr)
    return clr


# get html element
def cstr(s, color='black'):
	"""
	Refs:
		https://towardsdatascience.com/visualising-lstm-activations-in-keras-b50206da96ff
	"""
	if s == ' ':
		return "<text style=color:#000;padding-left:10px;background-color:{}> </text>".format(color, s)
	else:
		return "<text style=color:#000;background-color:{}>{} </text>".format(color, s)


# print html
def print_color(t):
	"""
	Refs:
		https://towardsdatascience.com/visualising-lstm-activations-in-keras-b50206da96ff
	"""
	display(html_print(''.join([cstr(ti, color=ci) for ti, ci in t])))


# get appropriate color for value
def get_clr(value):
	"""
	Refs:
		https://towardsdatascience.com/visualising-lstm-activations-in-keras-b50206da96ff
	"""
	r, b = value, 1 - value
	clr = get_hexcolor((r, 0, b))
	return clr


In [None]:

def visualize(tokens, activations, neuron):
	"""
	Refs:
		https://towardsdatascience.com/visualising-lstm-activations-in-keras-b50206da96ff
	"""
	text_colours = []
	for i, (tk, act) in enumerate(zip(tokens, activations)):
		value = act[neuron].item()
		text = (tk, get_clr(value))
		text_colours.append(text)
	print_color(text_colours)

In [None]:
CKPT = Path('./data/runs/best.pt')
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
TEXT = './data/datasets/corpus.txt'
LAYER = 2
N_WORDS = 10000

ckpt = torch.load(CKPT, map_location=DEVICE)
net = WordLSTM(ckpt['VOCAB_SIZE'], ckpt['EMBEDDING_DIM'], ckpt['HIDDEN_SIZE'], ckpt['NUM_LAYERS'])
net.load_state_dict(ckpt['state_dict'])
net = net.to(DEVICE)
net.eval()
init_states = net.init_hidden(1, DEVICE)

int2token = ckpt['int2token']
token2int = {tk: i for i, tk in int2token.items()}
firing = []
text = WordTokenizer.read(TEXT, encoding='utf-8')
_, words, _ = WordTokenizer.tokenize(text, lowercase=True)
words = [w for w in words if w in token2int]

attach_hook(net, layers=[LAYER, ], type=LSTMCell)
print('Color Palette')
print_color([[' ', get_clr(v.item())] for v in torch.arange(255)/255])

In [None]:
firing = []
init_states = net.init_hidden(1, DEVICE)
for i, token in enumerate(words[:N_WORDS]):
    enc = torch.tensor([[token2int[token]]], dtype=torch.int64, device=DEVICE)
    pred, init_states = net(enc, init_states)
    act = net.activations.clone()[0]
    act = (act - act.min())/(act.max() - act.min())
    firing.append(act)
    print(i)


In [None]:
# std = torch.std(torch.stack(firing), dim=0)
# vs, idx = std.sort()
visualize(words, firing, neuron=247)