In [176]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.autograd import Variable
%matplotlib inline

In [177]:
class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many 
    other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        x = self.src_embed(x)
        x = self.encoder(x)
        x = self.tgt_embed(x)
        x = self.decoder(x)
        return self.generator(x)
    
    def encode(self, src, src_mask):
        _src = self.src_embed(src)
        return self.encoder(_src, src_mask)
    
    def decode(self, memory, tgt, src_mask, tgt_mask):
        _tgt = self.tgt_embed(tgt)
        return self.decoder(memory, _tgt, src_mask, tgt_mask)

In [178]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super().__init__()
        self.linear = nn.Linear(d_model, vocab)
        
    def forward(self, x):
        return F.log_softmax(self.linear(x), dim=-1)

In [179]:
from copy import deepcopy

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([deepcopy(module) for _ in range(N)])

class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, size, eps=1e-6):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(size))
        self.bias = nn.Parameter(torch.zeros(size))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.scale * (x - mean) / (std + self.eps) + self.bias

In [180]:
class Encoder(nn.Module):
    "Core encoder is a stack of N layers"
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        size = layer.size
        self.norm = LayerNorm(size)
        
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [181]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, drop_prob):
        super().__init__()
        self.size = size
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(drop_prob)
        
    def forward(self, x, layer):
        next_x = layer(self.norm(x))
        return x + self.dropout(next_x)

In [182]:
class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, drop_prob):
        super().__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.norm = LayerNorm(size)
        self.sublayers = clones(SublayerConnection(size, drop_prob), 2)
        
    def forward(self, x, mask):
        x = self.sublayers[0](x, lambda x: self.self_attn(x, x, x, mask))
        x = self.sublayers[1](x, self.feed_forward)
        return self.norm(x)

In [183]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        size = layer.size
        self.norm = LayerNorm(size)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [184]:
class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self, size, self_attn, src_attn, feed_forward, drop_prob):
        super().__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn  = src_attn
        self.feed_forward = feed_forward
        self.sublayers = clones(SublayerConnection(size, drop_prob), 3)
    
    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayers[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayers[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        x = self.sublayers[2](x, self.feed_forward)
        return x

In [185]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask)

In [186]:
plt.figure(figsize=(5, 5))
mask = subsequent_mask(5)
mask[0]

tensor([[ 0,  1,  1,  1,  1],
        [ 0,  0,  1,  1,  1],
        [ 0,  0,  0,  1,  1],
        [ 0,  0,  0,  0,  1],
        [ 0,  0,  0,  0,  0]], dtype=torch.uint8)

<matplotlib.figure.Figure at 0x7f2e67fd7940>

In [187]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    # qurey.size == (n_batch, input_length, n_head, key_dim)
    # key.size == (n_batch, memory_length, n_head, key_dim)
    # value.size == (n_batch, memory_length, n_head, value_dim)
    q_k = torch.matmul(query, key.transpose(-1, -2))
    key_dim = query.size()[-1]
    scores = q_k / np.sqrt(key_dim)
    if mask is not None:
        scores = scores.masked_fill(mask, -1e9)
    # attention.size == (n_batch, input_length, memory_length)
    attention = F.softmax(scores)
    if dropout is not None:
        attention = dropout(attention)
    # output.size == (n_batch, input_length, value_dim)
    output = torch.matmul(attention, value)
    return output, attention

In [188]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, drop_prob=0.1):
        super().__init__()
        self.h = h
        self.d_model = d_model
        self.h_dim = d_model // h
        self.linears = clones(nn.Linear(d_model, d_model), 3)
        self.o_linear = nn.Linear(h * self.h_dim, d_model)
        self.dropout = nn.Dropout(drop_prob)
        
    def forward(self, query, key, value, mask):
        outputs = list()
        n_batch = query.size[0]
        inputs = (query, key, value)
        projs = list()
        for x, lin in zip(inputs, self.linears):
            x = lin(x)
            # Split for multi head
            x = torch.cat(torch.chunk(x, dim=-1), dim=0)
            projs.append(x)
        proj_q, proj_k, proj_v = projs
        output, self.attn = attention(proj_q, proj_k, proj_v,
                                      mask=mask, dropout=self.dropout)
        self.attn = torch.cat(torch.chunk(self.attn, dim=0), dim=-1)
        output = torch.cat(torch.chunk(output, dim=0), dim=-1)
        return self.o_linear(output)

In [189]:
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, drop_prob=0.1):
        super().__init__()
        self.lin1 = nn.Linear(d_model, d_ff)
        self.lin2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(drop_prob)
        
    def feed_forward(self, x):
        x = F.relu(self.lin1(x))
        x = self.dropout(x)
        return self.lin2(x)

In [190]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.embeds = nn.Embedding(vocab, d_model)
        self.d_model = d_model
        
    def feed_forward(self, x):
        return self.embeds(x) * np.sqrt(self.d_model)

In [191]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, drop_prob, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(drop_prob)
        # Positional Encoding
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        const = 1e4
        div_term = torch.exp(torch.arange(0, d_model, 2) *\
                             -(np.log(const) / d_model))
        div_term = div_term.unsqueeze(0)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, x.size(1)],
                         requires_grad=False)
        return self.dropout(x)

In [207]:
def make_model(src_vocab, tgt_vocab, N=6,
               d_model=512, d_ff=2048, h=8, drop_prob=0.1):
    attn = MultiHeadedAttention(h, d_model, drop_prob)
    ff = PositionwiseFeedForward(d_model, d_ff, drop_prob)
    pe = PositionalEncoding(d_model, drop_prob)
    c = deepcopy
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), drop_prob), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), drop_prob), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(pe)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(pe)),
        Generator(d_model, tgt_vocab))
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

In [208]:
tmp_model = make_model(10, 10, 2)

In [209]:
class Batch:
    def __init__(self, src, tgt=None, pad=0):
        self.src = src
        self.src_mask = (src == pad).unsqueeze(2)
        if tgt is not None:
            self.tgt = tgt[:, :-1]
            self.tgt_y = tgt[:, 1:]
            tgt_mask = \
                self.make_std_mask(self.tgt, pad)
            self.ntokens = (self.tgt_y == pad).data.sum()
            
    @staticmethod
    def make_std_mask(tgt, pad):
        tgt_mask = (tgt == pad).unsqueeze(-2)
        ttgt_mask = tgt_mask & Variable(
            subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        return tgt_mask

In [195]:
x.transpose(1, 2).size()

torch.Size([3, 5, 4])

In [196]:
x.transpose(2, 1).size()

torch.Size([3, 5, 4])

In [211]:
def run_epoch(data_iter, model, loss_compute):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(batch.src, batch.trg, 
                            batch.src_mask, batch.trg_mask)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 50 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens

In [210]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [199]:
mask

tensor([[ 0,  3,  3],
        [ 0,  0,  3],
        [ 0,  0,  0]], dtype=torch.uint8)

In [200]:
x.masked_fill(mask / 3, 1e-9)

tensor([[-4.7528e-01,  1.0000e-09,  1.0000e-09],
        [-5.0348e-01, -3.4601e-01,  1.0000e-09],
        [ 2.5800e+00,  7.9626e-01, -5.0441e-01]])

In [201]:
x = np.ones((5, 5))
torch.from_numpy(np.triu(x, k=1)) == 0

tensor([[ 1,  0,  0,  0,  0],
        [ 1,  1,  0,  0,  0],
        [ 1,  1,  1,  0,  0],
        [ 1,  1,  1,  1,  0],
        [ 1,  1,  1,  1,  1]], dtype=torch.uint8)

In [202]:
help(torch.chunk)

Help on built-in function chunk:

chunk(...)
    chunk(tensor, chunks, dim=0) -> List of Tensors
    
    Splits a tensor into a specific number of chunks.
    
    Last chunk will be smaller if the tensor size along the given dimension
    :attr:`dim` is not divisible by :attr:`chunks`.
    
    Arguments:
        tensor (Tensor): the tensor to split
        chunks (int): number of chunks to return
        dim (int): dimension along which to split the tensor



In [203]:
lin = nn.Linear(10, 30)
x = torch.randn(20, 15, 10)
lin(x).size()

torch.Size([20, 15, 30])

In [204]:
help(torch.unsqueeze)

Help on built-in function unsqueeze:

unsqueeze(...)
    unsqueeze(input, dim, out=None) -> Tensor
    
    Returns a new tensor with a dimension of size one inserted at the
    specified position.
    
    The returned tensor shares the same underlying data with this tensor.
    
    A negative `dim` value within the range
    [-:attr:`input.dim()`, :attr:`input.dim()`) can be used and
    will correspond to :meth:`unsqueeze` applied at :attr:`dim` = :attr:`dim + input.dim() + 1`
    
    Args:
        input (Tensor): the input tensor
        dim (int): the index at which to insert the singleton dimension
        out (Tensor, optional): the output tensor
    
    Example::
    
        >>> x = torch.tensor([1, 2, 3, 4])
        >>> torch.unsqueeze(x, 0)
        tensor([[ 1,  2,  3,  4]])
        >>> torch.unsqueeze(x, 1)
        tensor([[ 1],
                [ 2],
                [ 3],
                [ 4]])



In [205]:
x.unsqueeze(0)

tensor([[[[ 0.3807,  0.0561,  0.2692,  ...,  0.6745,  0.7825,  0.5816],
          [-0.5833,  1.4313, -0.0649,  ...,  0.5097,  1.1076,  0.4495],
          [ 1.4976,  1.3701, -0.3923,  ..., -2.7077,  0.8734,  0.4760],
          ...,
          [-0.9734,  1.5111,  0.5513,  ...,  1.5626,  0.0551, -0.3939],
          [-0.7921,  2.0839,  0.8382,  ...,  2.0039, -1.2792, -0.5473],
          [-0.7719, -1.2014, -2.1645,  ..., -0.5173,  0.0699,  0.4990]],

         [[ 0.5323,  0.2479,  0.1900,  ..., -0.2317, -1.0883,  0.4674],
          [-1.1096,  0.9642, -0.2402,  ...,  0.3732, -0.3986,  0.0554],
          [ 0.2246, -1.5680, -1.5758,  ..., -0.0553, -0.1274,  0.3998],
          ...,
          [-0.4647,  1.9455,  1.3348,  ...,  0.9216,  0.4980, -0.1353],
          [ 0.8853,  0.5545,  1.1105,  ...,  0.6940, -0.8288, -1.3541],
          [ 0.8021,  1.9052,  0.4606,  ...,  0.1953,  0.9813, -2.7392]],

         [[-1.6473, -0.1218,  0.3845,  ...,  0.3449,  1.7841,  0.5223],
          [-0.9522,  0.5693,  

In [206]:
torch.arange(0, 10, dtype=torch.float)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.])