🚨**NOTE** 🚨: to run the notebooks move them to the main dir. Simply

```bash
cp notebook_name.ipynd ../
```

Once the data is prepared (see Notebook 01), let's have a look to the HAN implementation.

Let's start by bringing back the figure of the network architecture and the mathematical expressions corresponding to the attention mechanism used by the authors

<p align="center">
  <img width="400" src="figures/HAN_arch.png">
</p>

**Word attention:**

$
u_{it} = \text{tanh}(W_wh_{it} + b_w)
$

$
\alpha_{it} = \frac{\exp(u_{it}u_w^{\mathsf{T}})}{\sum_{t}\exp(u_{it}u_w^{\mathsf{T}})}
$

$
s_i = \sum_{t}\alpha_{it}h_{it}
$

Where $u_{it}$ can be seen as a hidden representation of $h_{it}$ (the GRU ouput). The importance of a word is then measured as the similarity of $u_{it}$ with a context vector $u_{w}$, which is then normalized through a softmax function resulting in  $\alpha_{it}$, the so called normalized importance weights. The sentence vector $s_i$ is the weighted sum of the word annotations based on the weights $\alpha_{it}$. For more details please, have a look to the paper [Zichao Yang et al., 2016](https://www.cs.cmu.edu/~./hovy/papers/16HLT-hierarchical-attention-networks.pdf). 

The same as before applies to the sentence attention mechanism but at sentence level.

**Sentence attention:**

$
u_{i} = \text{tanh}(W_sh_i + b_s)
$

$
\alpha_{i} = \frac{\exp(u_iu_s^{\mathsf{T}})}{\sum_{i}\exp(u_{i}u_s^{\mathsf{T}})}
$

$
v = \sum_{i}\alpha_{i}h_{i}
$

In a simplified way, the flow of the data is: Word Embeddings $\rightarrow$ GRU $\rightarrow$ Word Attention $\rightarrow$ GRU $\rightarrow$ Sentence Attention $\rightarrow$ FC + Softmax

Let's build the pieces one by one, starting with the attention mechanism. I have kept the names of the variables as close as possible to the notation of the paper.  

In [1]:
import numpy as np
import torch
import torch.nn.functional as F

from torch import nn

In [31]:
class AttentionWithContext(nn.Module):
    def __init__(self, hidden_dim):
        super(AttentionWithContext, self).__init__()

        self.attn = nn.Linear(hidden_dim, hidden_dim)
        self.contx = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, inp):
        # The first expression in the attention mechanism is simply a linear layer that receives 
        # the output of the Word-GRU referred here as 'inp' and h_{it} in the paper
        u = torch.tanh_(self.attn(inp))
        # The second expression is...the same but without bias, wrapped up in a Softmax function
        a = F.softmax(self.contx(u), dim=1)
        # And finally, an element-wise multiplication taking advantage of Pytorch's broadcasting abilities 
        s = (a * inp).sum(1)
        # we will also return the normalized importance weights
        return a.permute(0, 2, 1), s

Note that one could easily implement Attention without the context vector $u_w$ as:

In [2]:
class Attention(nn.Module):
    def __init__(self, hidden_dim, seq_len):
        super(Attention, self).__init__()

        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.weight = nn.Parameter(nn.init.kaiming_normal_(torch.Tensor(hidden_dim, 1)))
        self.bias = nn.Parameter(torch.zeros(seq_len))

    def forward(self, inp):
        # 1. Matrix Multiplication
        x = inp.contiguous().view(-1, self.hidden_dim)
        u = torch.tanh_(torch.mm(x, self.weight).view(-1, self.seq_len) + self.bias)
        # 2. Softmax on 'u_{it}' directly
        a = F.softmax(u, dim=1)
        # 3. Braodcasting and out
        s = (inp * torch.unsqueeze(a, 2)).sum(1)
        return a, s

Ok, so Attention is as easy as a few lines of code. Let's now use these functions to implement the **Word Attention Net** and the **Sentence Attention Net**. 

Word Attention Net looks like this:

In [32]:
class WordAttnNet(nn.Module):
    def __init__(
        self,
        vocab_size,
        hidden_dim=32,
        padding_idx=1,
        embed_dim=50,
        embedding_matrix=None,
    ):
        super(WordAttnNet, self).__init__()

        if isinstance(embedding_matrix, np.ndarray):
            self.word_embed = nn.Embedding(
                vocab_size, embedding_matrix.shape[1], padding_idx=padding_idx
            )
            self.word_embed.weight = nn.Parameter(torch.Tensor(embedding_matrix))
            embed_dim = embedding_matrix.shape[1]
        else:
            self.word_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)

        self.rnn = nn.GRU(embed_dim, hidden_dim, bidirectional=True, batch_first=True)

        self.word_attn = AttentionWithContext(hidden_dim * 2)

    def forward(self, X, h_n):
        embed = self.word_embed(X.long())
        h_t, h_n = self.rnn(embed, h_n)
        a, s = self.word_attn(h_t)
        return a, s.unsqueeze(1), h_n

Note that the `WordAttnNet` class at the `models` module has a few more rings and bells that I will comment in a separate notebook. However, the main parts are all contained in the cell above, so let's comment on them: 

1. Word Embeddings: we allow the user to pass some pre-trained word embeddings. If not, then we simply initialize them with Pytorch defaults (random). 
2. Word GRU: `batch_first=True`. I am a maniac and I want my batches first
3. Word Attention (with context)

Let's just manually run the forward pass to perhaps understand better what is going on:

In [11]:
bsz = 16
maxlen_sent = 20 
hidden_dim  = 32    
embed_dim   = 100    
vocab_size  = 1000
padding_idx = 1

# net
word_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
rnn = nn.GRU(embed_dim, hidden_dim, bidirectional=True, batch_first=True)
attn = nn.Linear(hidden_dim*2, hidden_dim*2)
contx = nn.Linear(hidden_dim*2, 1, bias=False)

# inputs
X = torch.from_numpy(np.random.choice(vocab_size, (bsz, maxlen_sent)))
h_n = torch.zeros((2, bsz, hidden_dim))

`X` will be, for each review in the batch, one sentence at a time. This is because `word_attn` will run in a loop so that we can initialize the hidden state of the next sentence with the last hidden state of the previous one. We could say that the model is *"sentence stateful"*. 

In [7]:
# 1. Word Embeddings
# (bsz, maxlen_sent, embed_dim)
embed = word_embed(X)
embed.shape

torch.Size([16, 20, 100])

The first sentence of each review in the batch is now represented by a sequence of 20 words with 100 dim embeddings. Now we pass these embeddings to the GRU which will return a new encoding of dim `hidden_dim * 2`, since the GRU is bidirectional

In [8]:
# 2. GRU
h_t, h_n = rnn(embed, h_n)
# (bsz, seq_len, hidden_dim*2)
h_t.shape

torch.Size([16, 20, 64])

The attention mechanism will *"collapse"* each sentence into a `hidden_dim * 2` tensor that is the result of a weighted average using the "importance weights" $\alpha_{it}$. Therefore, the resulting output will be of dim `(bzs, hidden_dim * 2)`

In [14]:
# 3. Attention
u = torch.tanh_(attn(h_t))
a = F.softmax(contx(u), dim=1)
print(h_t.shape, a.shape)

torch.Size([16, 20, 64]) torch.Size([16, 20, 1])


And not we are going to multiply broadcasting along the last dim of `h_t`

In [18]:
# RNN outputs scaled by their importance weights
s = (a * h_t)
print(s.shape)
# Sum along the seq dim so we end up with a representation per document/review
s = s.sum(1)
print(s.shape)
# Because this will be stack for all sentences, we do the `.unsqueeze(1)`
print(s.unsqueeze(1).shape)

torch.Size([16, 20, 64])
torch.Size([16, 64])


Ok, now we have, per review, a 64 dim tensor representation. Let's have a look to what is an implementation of the Sentence Attention Net:

In [33]:
class SentAttnNet(nn.Module):
    def __init__(
        self, word_hidden_dim=32, sent_hidden_dim=32, padding_idx=1
    ):
        super(SentAttnNet, self).__init__()

        self.rnn = nn.GRU(
            word_hidden_dim * 2, sent_hidden_dim, bidirectional=True, batch_first=True
        )

        self.sent_attn = AttentionWithContext(sent_hidden_dim * 2)

    def forward(self, X):
        h_t, h_n = self.rnn(X)
        a, v = self.sent_attn(h_t)
        return a.permute(0,2,1), v

As you can see is rather easy. The same comment I made before for `WordAttnNet` applies here. The class at the `models` module has a some additional functionalities that I will comment in a separate notebook, but for now, these will be enough. I don't think this needs much explanation, does it? 

The forward pass will receive a tensor of dim `(bsz, review_len, word_hidden_dim*2)` and it will return a tensor of dim `(bsz, sent_hidden_dim*2)`. 

Ok, so, we are ready to implement **Hierarchical Attention Networks** as:

In [34]:
class HierAttnNet(nn.Module):
    def __init__(
        self,
        vocab_size,
        maxlen_sent,
        maxlen_doc,
        word_hidden_dim=32,
        sent_hidden_dim=32,
        padding_idx=1,
        embed_dim=50,
        embedding_matrix=None,
        num_class=4,
    ):
        super(HierAttnNet, self).__init__()

        self.word_hidden_dim = word_hidden_dim

        self.wordattnnet = WordAttnNet(
            vocab_size=vocab_size,
            hidden_dim=word_hidden_dim,
            padding_idx=padding_idx,
            embed_dim=embed_dim,
            embedding_matrix=embedding_matrix,
        )

        self.sentattnnet = SentAttnNet(
            word_hidden_dim=word_hidden_dim,
            sent_hidden_dim=sent_hidden_dim,
            padding_idx=padding_idx,
        )

        self.fc = nn.Linear(sent_hidden_dim * 2, num_class)

    def forward(self, X):
        x = X.permute(1, 0, 2)
        word_h_n = nn.init.zeros_(torch.Tensor(2, X.shape[0], self.word_hidden_dim))
        if use_cuda:
            word_h_n = word_h_n.cuda()
        # alpha and s Tensor Lists
        word_a_list, word_s_list = [], []
        for sent in x:
            word_a, word_s, word_h_n = self.wordattnnet(sent, word_h_n)
            word_a_list.append(word_a)
            word_s_list.append(word_s)
        # Importance attention weights per word in sentence
        self.sent_a = torch.cat(word_a_list, 1)
        # Sentences representation
        sent_s = torch.cat(word_s_list, 1)
        # Importance attention weights per sentence in doc and document representation
        self.doc_a, doc_s = self.sentattnnet(sent_s)
        return self.fc(doc_s)

Let's again see what happens in the forward pass step by step: 

In [35]:
maxlen_sent = 20
maxlen_doc = 5
num_class = 4
word_hidden_dim = 32
sent_hidden_dim = 32

wordattnnet = WordAttnNet(vocab_size, hidden_dim, padding_idx, embed_dim, embedding_matrix=None)
sentattnnet = SentAttnNet(word_hidden_dim, sent_hidden_dim, padding_idx)
fc = nn.Linear(sent_hidden_dim * 2, num_class)

In [36]:
X = torch.from_numpy(np.random.choice(vocab_size, (bsz, maxlen_doc, maxlen_sent)))

We first permute/transpose axis so we have the input in the form `(maxlen_doc, bsz, maxlen_sent)`. This is because we are going to loop through sentences per document in each batch. Note that if you don't care about the stateful nature of the sentences in the document (although you should), you could just apply attention *"a la TimeDistributed"*. This is, reshaping the input along the sequence dimension, apply `wordattnnet` to the resulting tensor and reshape back to the original form before applying `sentattnnet`. However, you should care about the fact that sentences naturally follow each other, so let's do it right.

In [37]:
x = X.permute(1, 0, 2)
x.shape

torch.Size([5, 16, 20])

In [38]:
# Initial Word RNN hidden state
word_h_n = nn.init.zeros_(torch.Tensor(2, X.shape[0], word_hidden_dim))

In [41]:
# Loop through sentences:
word_a_list, word_s_list = [], []
for sent in x:
    word_a, word_s, word_h_n = wordattnnet(sent, word_h_n)
    word_a_list.append(word_a)
    word_s_list.append(word_s)
# Importance attention weights per word in sentence
sent_a = torch.cat(word_a_list, 1)
# Sentences representation
sent_s = torch.cat(word_s_list, 1)
# (bsz, maxlen_doc, maxlen_sent)
print(sent_a.shape)
# (bsz, maxlen_doc, hidden_dim*2)
print(sent_s.shape)

torch.Size([16, 5, 20])
torch.Size([16, 5, 64])


In [45]:
doc_a, doc_s = sentattnnet(sent_s)
# (bsz, maxlen_doc, 1). One could .squeeze(2)
print(doc_a.shape)
# (bsz, hidden_dim*2)
print(doc_s.shape)

torch.Size([16, 5, 1])
torch.Size([16, 64])


and predictions (without softmax, since the loss we will be using, `F.cross_entropy` already applies the logSoftmax function

In [47]:
out = fc(doc_s)
out.shape

torch.Size([16, 4])