In [1]:
!pip install pytorch-pretrained-bert pytorch-ignite ipdb

Collecting pytorch-pretrained-bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |████████████████████████████████| 133kB 2.8MB/s 
[?25hCollecting pytorch-ignite
[?25l  Downloading https://files.pythonhosted.org/packages/35/55/41e8a995876fd2ade29bdba0c3efefa38e7d605cb353c70f3173c04928b5/pytorch_ignite-0.3.0-py2.py3-none-any.whl (103kB)
[K     |████████████████████████████████| 112kB 7.8MB/s 
[?25hCollecting ipdb
  Downloading https://files.pythonhosted.org/packages/2c/bb/a3e1a441719ebd75c6dac8170d3ddba884b7ee8a5c0f9aefa7297386627a/ipdb-0.13.2.tar.gz
Building wheels for collected packages: ipdb
  Building wheel for ipdb (setup.py) ... [?25l[?25hdone
  Created wheel for ipdb: filename=ipdb-0.13.2-cp36-none-any.whl size=10522 sha256=20649f49704e805b5b25bdd11a4a60c9efd8d9716d031c44a8127e2f1f9bb0b5
  Stored in directory: /root/.cache/pip/wheel

In [4]:
""" Creating  GPT-2 architech (with the use of ReLu activation instead of geLu from original paper.
You can find the original paper @ Languages Models are Unsupervised Multitask Learners (Radford et al., 2018) """

' Creating  GPT-2 architech (with the use of ReLu activation instead of geLu from original paper.\nYou can find the original paper @ Languages Models are Unsupervised Multitask Learners (Radford et al., 2018) '

In [0]:
import torch
import torch.nn as nn

class Transformer(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout, causal):
      #Args :
              #-embed_dim: embedding dimensions
              #-hidden_dim: hidden dimensions of feed_forward 
              #-num_embeddings: vocabulary size
              #-num_max_positions: max_length of a sequence
              #-num_heads: number of heads of self.attention layer
              #-num_layers: number of transformer blocks
              #-dropout: dropout probability
              #-causal: None/True. Whether or not applying causal LMs with attention masks 
        super().__init__()
        self.causal = causal
        self.tokens_embeddings = nn.Embedding(num_embeddings, embed_dim)
        self.position_embeddings = nn.Embedding(num_max_positions, embed_dim)
        self.dropout = nn.Dropout(dropout)

        self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList()
        self.layer_norms_1, self.layer_norms_2 = nn.ModuleList(), nn.ModuleList()
        for _ in range(num_layers):
            self.attentions.append(nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout))
            self.feed_forwards.append(nn.Sequential(nn.Linear(embed_dim, hidden_dim),
                                                    nn.ReLU(),
                                                    nn.Linear(hidden_dim, embed_dim)))
            self.layer_norms_1.append(nn.LayerNorm(embed_dim, eps=1e-12))
            self.layer_norms_2.append(nn.LayerNorm(embed_dim, eps=1e-12))

    def forward(self, x, padding_mask=None):
        """ x has shape [seq length, batch], padding_mask has shape [batch, seq length] """
        positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
        h = self.tokens_embeddings(x)
        h = h + self.position_embeddings(positions).expand_as(h)
        h = self.dropout(h)

        attn_mask = None
        if self.causal:
            attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
            attn_mask = torch.triu(attn_mask, diagonal=1)

        for layer_norm_1, attention, layer_norm_2, feed_forward in zip(self.layer_norms_1, self.attentions,
                                                                       self.layer_norms_2, self.feed_forwards):
            h = layer_norm_1(h)
            x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False, key_padding_mask=padding_mask)
            x = self.dropout(x)
            h = x + h

            h = layer_norm_2(h)
            x = feed_forward(h)
            x = self.dropout(x)
            h = x + h
            print(h)
        return h
      





In [15]:
""" Adding Language Modelling layer on top of Transformer
"""

' Adding Language Modelling layer on top of Transformer\n'

In [0]:
class TransformerWithLMHead(nn.Module):
    def __init__(self, config):
        """ Transformer with a language modeling head on top (tied weights) """
        super().__init__()
        self.config = config
        self.transformer = Transformer(config.embed_dim, config.hidden_dim, config.num_embeddings,
                                       config.num_max_positions, config.num_heads, config.num_layers,
                                       config.dropout, causal=not config.mlm)

        self.lm_head = nn.Linear(config.embed_dim, config.num_embeddings, bias=False)
        self.apply(self.init_weights)
        self.tie_weights()

    def tie_weights(self):
        self.lm_head.weight = self.transformer.tokens_embeddings.weight

    def init_weights(self, module):
        """ initialize weights - nn.MultiheadAttention is already initalized by PyTorch (xavier) """
        if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x, labels=None, padding_mask=None):
        """ x has shape [seq length, batch], padding_mask has shape [batch, seq length] """
        hidden_states = self.transformer(x, padding_mask)
        logits = self.lm_head(hidden_states)

        if labels is not None:
            shift_logits = logits[:-1] if self.transformer.causal else logits
            shift_labels = labels[1:] if self.transformer.causal else labels
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            return logits, loss

        return logits