In [None]:
!pip install tensorflow_text

Collecting tensorflow_text
[?25l  Downloading https://files.pythonhosted.org/packages/c0/ed/bbb51e9eccca0c2bfdf9df66e54cdff563b6f32daed9255da9b9a541368f/tensorflow_text-2.5.0-cp37-cp37m-manylinux1_x86_64.whl (4.3MB)
[K     |████████████████████████████████| 4.3MB 6.5MB/s 
Installing collected packages: tensorflow-text
Successfully installed tensorflow-text-2.5.0


In [None]:
import tensorflow_datasets as tfds
import tensorflow_text as text
import tensorflow as tf

In [None]:
import torch
import torch.nn as nn

# Define Model

## Self-Attention

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dims = embed_size // heads

        assert (head_dims * heads == embed_size), "embed_size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dims, self.head_dims, bias=False)
        self.queries = nn.Linear(self.head_dims, self.head_dims, bias=False)
        self.keys = nn.Linear(self.head_dims, self.head_dims, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, queries, mask):
        batch_size = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        values = values.reshape(batch_size, value_len, self.heads, self.head_dims)
        keys = keys.reshape(batch_size, key_len, self.heads, self.head_dims)
        queries = queries.reshape(batch_size, query_len, self.heads, self.head_dims)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e28"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhqv,nvhd->nqhd", [attention, values]).reshape(
            batch_size, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out