In [35]:
import matplotlib.pyplot as plt
import numpy as np
import pickle
from numpy import newaxis as na
import codecs
from IPython.display import display, HTML

In [36]:
def rescale_score_by_abs (score, max_score, min_score):
    # CASE 1: positive AND negative scores occur --------------------
    if max_score>0 and min_score<0:
    
        if max_score >= abs(min_score):   # deepest color is positive
            if score>=0:
                return 0.5 + 0.5*(score/max_score)
            else:
                return 0.5 - 0.5*(abs(score)/max_score)

        else:                             # deepest color is negative
            if score>=0:
                return 0.5 + 0.5*(score/abs(min_score))
            else:
                return 0.5 - 0.5*(score/min_score)   
    
    # CASE 2: ONLY positive scores occur -----------------------------       
    elif max_score>0 and min_score>=0: 
        if max_score == min_score:
            return 1.0
        else:
            return 0.5 + 0.5*(score/max_score)
    
    # CASE 3: ONLY negative scores occur -----------------------------
    elif max_score<=0 and min_score<0: 
        if max_score == min_score:
            return 0.0
        else:
            return 0.5 - 0.5*(score/min_score)

In [37]:
def getRGB (c_tuple):
    return "#%02x%02x%02x"%(int(c_tuple[0]*255), int(c_tuple[1]*255), int(c_tuple[2]*255))

In [38]:
def span_word (word, score, colormap):
    return "<span style=\"background-color:"+getRGB(colormap(score))+"\">"+word+"</span>"

In [39]:
def html_heatmap (words, scores, cmap_name="bwr"):
    colormap  = plt.get_cmap(cmap_name)
     
    assert len(words)==len(scores)
    max_s     = max(scores)
    min_s     = min(scores)
    
    output_text = ""
    
    for idx, w in enumerate(words):
        score       = rescale_score_by_abs(scores[idx], max_s, min_s)
        output_text = output_text + span_word(w, score, colormap) + " "
    
    return output_text + "\n"

In [40]:
def lrp_linear(hin, w, b, hout, Rout, bias_nb_units, eps, bias_factor=0.0, debug=False):
    sign_out = np.where(hout[na,:]>=0, 1., -1.) # shape (1, M)
    
    numer    = (w * hin[:,na]) + ( bias_factor * (b[na,:]*1. + eps*sign_out*1.) / bias_nb_units ) # shape (D, M)
    # Note: here we multiply the bias_factor with both the bias b and the stabilizer eps since in fact
    # using the term (b[na,:]*1. + eps*sign_out*1.) / bias_nb_units in the numerator is only useful for sanity check
    # (in the initial paper version we were using (bias_factor*b[na,:]*1. + eps*sign_out*1.) / bias_nb_units instead)
    
    denom    = hout[na,:] + (eps*sign_out*1.)   # shape (1, M)
    
    message  = (numer/denom) * Rout[na,:]       # shape (D, M)
    
    Rin      = message.sum(axis=1)              # shape (D,)
    
    if debug:
        print("local diff: ", Rout.sum() - Rin.sum())
    # Note: 
    # - local  layer   relevance conservation if bias_factor==1.0 and bias_nb_units==D (i.e. when only one incoming layer)
    # - global network relevance conservation if bias_factor==1.0 and bias_nb_units set accordingly to the total number of lower-layer connections 
    # -> can be used for sanity check
    
    return Rin

In [41]:
class LSTM_bidi:
    
    def __init__(self):
   
        # vocabulary
        f_voc     = open("vocab", 'rb')
        self.voc  = pickle.load(f_voc)
        f_voc.close()
        
        # word embeddings
        self.E    = np.load('embeddings.npy', mmap_mode='r') # shape V*e
        
        # model weights
        f_model   = open('model', 'rb')
        model     = pickle.load(f_model)
        f_model.close()
        # LSTM left encoder
        self.Wxh_Left  = model["Wxh_Left"]  # shape 4d*e
        self.bxh_Left  = model["bxh_Left"]  # shape 4d 
        self.Whh_Left  = model["Whh_Left"]  # shape 4d*d
        self.bhh_Left  = model["bhh_Left"]  # shape 4d  
        # LSTM right encoder
        self.Wxh_Right = model["Wxh_Right"]
        self.bxh_Right = model["bxh_Right"]
        self.Whh_Right = model["Whh_Right"]
        self.bhh_Right = model["bhh_Right"]   
        # linear output layer
        self.Why_Left  = model["Why_Left"]  # shape C*d
        self.Why_Right = model["Why_Right"] # shape C*d
    

    def set_input(self, w, delete_pos=None):
        T      = len(w)                         # sequence length
        d      = int(self.Wxh_Left.shape[0]/4)  # hidden layer dimension
        e      = self.E.shape[1]                # word embedding dimension
        x      = np.zeros((T, e))
        x[:,:] = self.E[w,:]
        if delete_pos is not None:
            x[delete_pos, :] = np.zeros((len(delete_pos), e))
        
        self.w              = w
        self.x              = x
        self.x_rev          = x[::-1,:].copy()
        
        self.h_Left         = np.zeros((T+1, d))
        self.c_Left         = np.zeros((T+1, d))
        self.h_Right        = np.zeros((T+1, d))
        self.c_Right        = np.zeros((T+1, d))
     
   
    def forward(self):

        T      = len(self.w)                         
        d      = int(self.Wxh_Left.shape[0]/4) 
        # gate indices (assuming the gate ordering in the LSTM weights is i,g,f,o):     
        idx    = np.hstack((np.arange(0,d), np.arange(2*d,4*d))).astype(int) # indices of gates i,f,o together
        idx_i, idx_g, idx_f, idx_o = np.arange(0,d), np.arange(d,2*d), np.arange(2*d,3*d), np.arange(3*d,4*d) # indices of gates i,g,f,o separately
          
        # initialize
        self.gates_xh_Left  = np.zeros((T, 4*d))  
        self.gates_hh_Left  = np.zeros((T, 4*d)) 
        self.gates_pre_Left = np.zeros((T, 4*d))  # gates pre-activation
        self.gates_Left     = np.zeros((T, 4*d))  # gates activation
        
        self.gates_xh_Right = np.zeros((T, 4*d))  
        self.gates_hh_Right = np.zeros((T, 4*d)) 
        self.gates_pre_Right= np.zeros((T, 4*d))
        self.gates_Right    = np.zeros((T, 4*d)) 
             
        for t in range(T): 
            self.gates_xh_Left[t]     = np.dot(self.Wxh_Left, self.x[t])        
            self.gates_hh_Left[t]     = np.dot(self.Whh_Left, self.h_Left[t-1]) 
            self.gates_pre_Left[t]    = self.gates_xh_Left[t] + self.gates_hh_Left[t] + self.bxh_Left + self.bhh_Left
            self.gates_Left[t,idx]    = 1.0/(1.0 + np.exp(- self.gates_pre_Left[t,idx]))
            self.gates_Left[t,idx_g]  = np.tanh(self.gates_pre_Left[t,idx_g]) 
            self.c_Left[t]            = self.gates_Left[t,idx_f]*self.c_Left[t-1] + self.gates_Left[t,idx_i]*self.gates_Left[t,idx_g]
            self.h_Left[t]            = self.gates_Left[t,idx_o]*np.tanh(self.c_Left[t])
            
            self.gates_xh_Right[t]    = np.dot(self.Wxh_Right, self.x_rev[t])     
            self.gates_hh_Right[t]    = np.dot(self.Whh_Right, self.h_Right[t-1])
            self.gates_pre_Right[t]   = self.gates_xh_Right[t] + self.gates_hh_Right[t] + self.bxh_Right + self.bhh_Right
            self.gates_Right[t,idx]   = 1.0/(1.0 + np.exp(- self.gates_pre_Right[t,idx]))
            self.gates_Right[t,idx_g] = np.tanh(self.gates_pre_Right[t,idx_g])                 
            self.c_Right[t]           = self.gates_Right[t,idx_f]*self.c_Right[t-1] + self.gates_Right[t,idx_i]*self.gates_Right[t,idx_g]
            self.h_Right[t]           = self.gates_Right[t,idx_o]*np.tanh(self.c_Right[t])
            
        self.y_Left  = np.dot(self.Why_Left,  self.h_Left[T-1])
        self.y_Right = np.dot(self.Why_Right, self.h_Right[T-1])
        self.s       = self.y_Left + self.y_Right
        
        return self.s.copy() # prediction scores
     
              
    def backward(self, w, sensitivity_class):

        # forward pass
        self.set_input(w)
        self.forward() 
        
        T      = len(self.w)
        d      = int(self.Wxh_Left.shape[0]/4)
        C      = self.Why_Left.shape[0]   # number of classes
        idx    = np.hstack((np.arange(0,d), np.arange(2*d,4*d))).astype(int) # indices of gates i,f,o together
        idx_i, idx_g, idx_f, idx_o = np.arange(0,d), np.arange(d,2*d), np.arange(2*d,3*d), np.arange(3*d,4*d) # indices of gates i,g,f,o separately
        
        # initialize
        self.dx               = np.zeros(self.x.shape)
        self.dx_rev           = np.zeros(self.x.shape)
        
        self.dh_Left          = np.zeros((T+1, d))
        self.dc_Left          = np.zeros((T+1, d))
        self.dgates_pre_Left  = np.zeros((T, 4*d))  # gates pre-activation
        self.dgates_Left      = np.zeros((T, 4*d))  # gates activation
        
        self.dh_Right         = np.zeros((T+1, d))
        self.dc_Right         = np.zeros((T+1, d))
        self.dgates_pre_Right = np.zeros((T, 4*d)) 
        self.dgates_Right     = np.zeros((T, 4*d))  
               
        ds                    = np.zeros((C))
        ds[sensitivity_class] = 1.0
        dy_Left               = ds.copy()
        dy_Right              = ds.copy()
        
        self.dh_Left[T-1]     = np.dot(self.Why_Left.T,  dy_Left)
        self.dh_Right[T-1]    = np.dot(self.Why_Right.T, dy_Right)
        
        for t in reversed(range(T)): 
            self.dgates_Left[t,idx_o]    = self.dh_Left[t] * np.tanh(self.c_Left[t])  # do[t]
            self.dc_Left[t]             += self.dh_Left[t] * self.gates_Left[t,idx_o] * (1.-(np.tanh(self.c_Left[t]))**2) # dc[t]
            self.dgates_Left[t,idx_f]    = self.dc_Left[t] * self.c_Left[t-1]         # df[t]
            self.dc_Left[t-1]            = self.dc_Left[t] * self.gates_Left[t,idx_f] # dc[t-1]
            self.dgates_Left[t,idx_i]    = self.dc_Left[t] * self.gates_Left[t,idx_g] # di[t]
            self.dgates_Left[t,idx_g]    = self.dc_Left[t] * self.gates_Left[t,idx_i] # dg[t]
            self.dgates_pre_Left[t,idx]  = self.dgates_Left[t,idx] * self.gates_Left[t,idx] * (1.0 - self.gates_Left[t,idx]) # d ifo pre[t]
            self.dgates_pre_Left[t,idx_g]= self.dgates_Left[t,idx_g] *  (1.-(self.gates_Left[t,idx_g])**2) # d g pre[t]
            self.dh_Left[t-1]            = np.dot(self.Whh_Left.T, self.dgates_pre_Left[t])
            self.dx[t]                   = np.dot(self.Wxh_Left.T, self.dgates_pre_Left[t])
            
            self.dgates_Right[t,idx_o]    = self.dh_Right[t] * np.tanh(self.c_Right[t])         
            self.dc_Right[t]             += self.dh_Right[t] * self.gates_Right[t,idx_o] * (1.-(np.tanh(self.c_Right[t]))**2) 
            self.dgates_Right[t,idx_f]    = self.dc_Right[t] * self.c_Right[t-1]            
            self.dc_Right[t-1]            = self.dc_Right[t] * self.gates_Right[t,idx_f] 
            self.dgates_Right[t,idx_i]    = self.dc_Right[t] * self.gates_Right[t,idx_g]    
            self.dgates_Right[t,idx_g]    = self.dc_Right[t] * self.gates_Right[t,idx_i]      
            self.dgates_pre_Right[t,idx]  = self.dgates_Right[t,idx] * self.gates_Right[t,idx] * (1.0 - self.gates_Right[t,idx]) 
            self.dgates_pre_Right[t,idx_g]= self.dgates_Right[t,idx_g] *  (1.-(self.gates_Right[t,idx_g])**2) 
            self.dh_Right[t-1]            = np.dot(self.Whh_Right.T, self.dgates_pre_Right[t])
            self.dx_rev[t]                = np.dot(self.Wxh_Right.T, self.dgates_pre_Right[t])
                    
        return self.dx.copy(), self.dx_rev[::-1,:].copy()     
    
                   
    def lrp(self, w, LRP_class, eps=0.001, bias_factor=0.0):

        # forward pass
        self.set_input(w)
        self.forward() 
        
        T      = len(self.w)
        d      = int(self.Wxh_Left.shape[0]/4)
        e      = self.E.shape[1] 
        C      = self.Why_Left.shape[0]  # number of classes
        idx    = np.hstack((np.arange(0,d), np.arange(2*d,4*d))).astype(int) # indices of gates i,f,o together
        idx_i, idx_g, idx_f, idx_o = np.arange(0,d), np.arange(d,2*d), np.arange(2*d,3*d), np.arange(3*d,4*d) # indices of gates i,g,f,o separately
        
        # initialize
        Rx       = np.zeros(self.x.shape)
        Rx_rev   = np.zeros(self.x.shape)
        
        Rh_Left  = np.zeros((T+1, d))
        Rc_Left  = np.zeros((T+1, d))
        Rg_Left  = np.zeros((T,   d)) # gate g only
        Rh_Right = np.zeros((T+1, d))
        Rc_Right = np.zeros((T+1, d))
        Rg_Right = np.zeros((T,   d)) # gate g only
        
        Rout_mask            = np.zeros((C))
        Rout_mask[LRP_class] = 1.0  
        
        # format reminder: lrp_linear(hin, w, b, hout, Rout, bias_nb_units, eps, bias_factor)
        Rh_Left[T-1]  = lrp_linear(self.h_Left[T-1],  self.Why_Left.T , np.zeros((C)), self.s, self.s*Rout_mask, 2*d, eps, bias_factor, debug=False)
        Rh_Right[T-1] = lrp_linear(self.h_Right[T-1], self.Why_Right.T, np.zeros((C)), self.s, self.s*Rout_mask, 2*d, eps, bias_factor, debug=False)
        
        for t in reversed(range(T)):
            Rc_Left[t]   += Rh_Left[t]
            Rc_Left[t-1]  = lrp_linear(self.gates_Left[t,idx_f]*self.c_Left[t-1],         np.identity(d), np.zeros((d)), self.c_Left[t], Rc_Left[t], 2*d, eps, bias_factor, debug=False)
            Rg_Left[t]    = lrp_linear(self.gates_Left[t,idx_i]*self.gates_Left[t,idx_g], np.identity(d), np.zeros((d)), self.c_Left[t], Rc_Left[t], 2*d, eps, bias_factor, debug=False)
            Rx[t]         = lrp_linear(self.x[t],        self.Wxh_Left[idx_g].T, self.bxh_Left[idx_g]+self.bhh_Left[idx_g], self.gates_pre_Left[t,idx_g], Rg_Left[t], d+e, eps, bias_factor, debug=False)
            Rh_Left[t-1]  = lrp_linear(self.h_Left[t-1], self.Whh_Left[idx_g].T, self.bxh_Left[idx_g]+self.bhh_Left[idx_g], self.gates_pre_Left[t,idx_g], Rg_Left[t], d+e, eps, bias_factor, debug=False)
            
            Rc_Right[t]  += Rh_Right[t]
            Rc_Right[t-1] = lrp_linear(self.gates_Right[t,idx_f]*self.c_Right[t-1],         np.identity(d), np.zeros((d)), self.c_Right[t], Rc_Right[t], 2*d, eps, bias_factor, debug=False)
            Rg_Right[t]   = lrp_linear(self.gates_Right[t,idx_i]*self.gates_Right[t,idx_g], np.identity(d), np.zeros((d)), self.c_Right[t], Rc_Right[t], 2*d, eps, bias_factor, debug=False)
            Rx_rev[t]     = lrp_linear(self.x_rev[t],     self.Wxh_Right[idx_g].T, self.bxh_Right[idx_g]+self.bhh_Right[idx_g], self.gates_pre_Right[t,idx_g], Rg_Right[t], d+e, eps, bias_factor, debug=False)
            Rh_Right[t-1] = lrp_linear(self.h_Right[t-1], self.Whh_Right[idx_g].T, self.bxh_Right[idx_g]+self.bhh_Right[idx_g], self.gates_pre_Right[t,idx_g], Rg_Right[t], d+e, eps, bias_factor, debug=False)
                   
        return Rx, Rx_rev[::-1,:], Rh_Left[-1].sum()+Rc_Left[-1].sum()+Rh_Right[-1].sum()+Rc_Right[-1].sum()

In [42]:
def predict(words):
    net                 = LSTM_bidi()                                   # load trained LSTM model
    w_indices           = [net.voc.index(w) for w in words]             # convert input sentence to word IDs
    net.set_input(w_indices)                                            # set LSTM input sequence
    scores              = net.forward()                                 # classification prediction scores
    return np.argmax(scores)            

In [43]:
# Alternatively, uncomment one of the following sentences, or define your own sequence (only words contained in the vocabulary are supported!)
#words = ['this','movie','was','actually','neither','that','funny',',','nor','super','witty','.']
#words = ['this', 'film', 'does', 'n\'t', 'care', 'about', 'cleverness', ',', 'wit', 'or', 'any', 'other', 'kind', 'of', 'intelligent', 'humor', '.']
words = ['i','hate','the','movie','though','the','plot','is','interesting','.']
#words = ['used', 'to', 'be', 'my', 'favorite']
#words = ['not', 'worth', 'the', 'time']
#words = ['is', 'n\'t', 'a', 'bad', 'film'] # Note: misclassified sample!
#words = ['is', 'n\'t', 'very', 'interesting'] 
#words = ['it', '\'s', 'easy' ,'to' ,'love' ,'robin' ,'tunney' ,'--' ,'she' ,'\'s' ,'pretty' ,'and' ,'she' ,'can' ,'act' ,'--' ,'but' ,'it' ,'gets' ,'harder' ,'and' ,'harder' ,'to' ,'understand' ,'her' ,'choices', '.']

In [44]:
predicted_class = predict(words)                                        # get predicted class
target_class    = predicted_class                                       # define relevance target class 

In [45]:
print (words)
print ("\npredicted class:          ",   predicted_class)

['i', 'hate', 'the', 'movie', 'though', 'the', 'plot', 'is', 'interesting', '.']

predicted class:           0


In [46]:
# LRP hyperparameters:
eps                 = 0.001                                             # small positive number
bias_factor         = 0.0                                               # recommended value
 
net                 = LSTM_bidi()                                       # load trained LSTM model

w_indices           = [net.voc.index(w) for w in words]                 # convert input sentence to word IDs
Rx, Rx_rev, R_rest  = net.lrp(w_indices, target_class, eps, bias_factor)# perform LRP
R_words             = np.sum(Rx + Rx_rev, axis=1)                       # compute word-level LRP relevances

scores              = net.s.copy()                                      # classification prediction scores

In [47]:
print ("prediction scores:        ",   scores)
print ("\nLRP target class:         ", target_class)
print ("\nLRP relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words[idx]) + "\t" + w)
print ("\nLRP heatmap:")    
display(HTML(html_heatmap(words, R_words)))

prediction scores:         [ 0.83815775  0.55615975 -0.23874953  0.06664621 -0.48527585]

LRP target class:          0

LRP relevances:
			   -0.07	i
			    2.78	hate
			   -0.06	the
			    0.23	movie
			    0.75	though
			   -0.01	the
			    2.10	plot
			   -1.16	is
			   -6.09	interesting
			   -0.51	.

LRP heatmap:


In [48]:
# How to sanity check global relevance conservation:
bias_factor        = 1.0                                             # value to use for sanity check
Rx, Rx_rev, R_rest = net.lrp(w_indices, target_class, eps, bias_factor)
R_tot              = Rx.sum() + Rx_rev.sum() + R_rest.sum()          # sum of all "input" relevances

print(R_tot)       ;    print("Sanity check passed? ", np.allclose(R_tot, net.s[target_class]))

0.8381577526347607
Sanity check passed?  True


In [49]:
net              = LSTM_bidi()                                       # load trained LSTM model

w_indices        = [net.voc.index(w) for w in words]                 # convert input sentence to word IDs
Gx, Gx_rev       = net.backward(w_indices, target_class)             # perform gradient backpropagation
R_words_SA       = (np.linalg.norm(Gx + Gx_rev, ord=2, axis=1))**2   # compute word-level Sensitivity Analysis relevances
R_words_GI       = ((Gx + Gx_rev)*net.x).sum(axis=1)                 # compute word-level GradientxInput relevances

scores           = net.s.copy()                                      # classification prediction scores 

In [50]:
print ("prediction scores:       ",   scores)
print ("\nSA/GI target class:      ", target_class)
print ("\nSA relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words_SA[idx]) + "\t" + w)
print ("\nSA heatmap:")    
display(HTML(html_heatmap(words, R_words_SA)))
print ("\nGI relevances:")
for idx, w in enumerate(words):
    print ("\t\t\t" + "{:8.2f}".format(R_words_GI[idx]) + "\t" + w)
print ("\nGI heatmap:")    
display(HTML(html_heatmap(words, R_words_GI)))

prediction scores:        [ 0.83815775  0.55615975 -0.23874953  0.06664621 -0.48527585]

SA/GI target class:       0

SA relevances:
			    0.92	i
			   50.03	hate
			    0.57	the
			    6.77	movie
			   14.29	though
			    0.60	the
			   23.59	plot
			   11.33	is
			   35.48	interesting
			    1.80	.

SA heatmap:



GI relevances:
			   -0.05	i
			    2.09	hate
			    0.47	the
			   -0.12	movie
			    0.50	though
			    0.58	the
			    0.98	plot
			    0.54	is
			   -2.68	interesting
			    0.43	.

GI heatmap:


In [51]:
# using Recursive Neural Tensor Network
# prediction scores:  [ 8.90618621e-01  1.01122260e+00  2.26404237e-04 -3.12216870e-01 -1.37174135e+00]