In [5]:
import torch
import torch.nn as nn
import pytorch_transformers

<h1 align="center">1. Pretraining a LM (few days)</h1>

## Transformer

In [2]:
class Transformer(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout, causal):
        
        super().__init__()
                
        #### Initial input embedding
        self.tokens_embeddings   = nn.Embedding(num_embeddings,   embed_dim)
        self.position_embeddings = nn.Embedding(num_max_positions, embed_dim)
        
        #### Mask
        self.causal = causal
            
        #### Transformer block components
        self.attentions    = nn.ModuleList()
        self.feed_forwards = nn.ModuleList()
        self.layer_norms_1 = nn.ModuleList()
        self.layer_norms_2 = nn.ModuleList()
        self.dropout       = nn.Dropout(dropout)
        
        for _ in range(num_layers):
            self.attentions.append(nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout))
            
            self.feed_forwards.append(nn.Sequential(nn.Linear(embed_dim, hidden_dim),
                                                    nn.ReLU(),
                                                    nn.Linear(hidden_dim, embed_dim)))
            
            self.layer_norms_1.append(nn.LayerNorm(embed_dim, eps=1e-12))
            self.layer_norms_2.append(nn.LayerNorm(embed_dim, eps=1e-12))

            
            
    # x            has shape [seq length, batch]
    # padding_mask has shape [batch, seq length]
    def forward(self, x, padding_mask=None):
        
        #### Initial input embedding
        positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
        h = self.tokens_embeddings(x)
        h = h + self.position_embeddings(positions).expand_as(h)
        h = self.dropout(h)

        #### Mask
        attn_mask = None
        if self.causal:
            attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
            attn_mask = torch.triu(attn_mask, diagonal=1)

        #### Transformer block components
        for layer_norm_1, attention, layer_norm_2, feed_forward in zip(self.layer_norms_1, self.attentions,
                                                                       self.layer_norms_2, self.feed_forwards):
            h = layer_norm_1(h)
            x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False, key_padding_mask=padding_mask)
            x = self.dropout(x)
            h = x + h

            h = layer_norm_2(h)
            x = feed_forward(h)
            x = self.dropout(x)
            h = x + h
            print(h)
            
        return h

## Transformer with LM head

In [4]:
class TransformerWithLMHead(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = Transformer(config.embed_dim, config.hidden_dim, config.num_embeddings,
                                       config.num_max_positions, config.num_heads, config.num_layers,
                                       config.dropout, causal=not config.mlm)

        self.lm_head = nn.Linear(config.embed_dim, config.num_embeddings, bias=False)
        self.apply(self.init_weights)
        self.tie_weights()

    def tie_weights(self):
        self.lm_head.weight = self.transformer.tokens_embeddings.weight

    def init_weights(self, module):
        """ initialize weights - nn.MultiheadAttention is already initalized by PyTorch (xavier) """
        if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x, labels=None, padding_mask=None):
        """ x has shape [seq length, batch], padding_mask has shape [batch, seq length] """
        hidden_states = self.transformer(x, padding_mask)
        logits = self.lm_head(hidden_states)

        if labels is not None:
            shift_logits = logits[:-1] if self.transformer.causal else logits
            shift_labels = labels[1:] if self.transformer.causal else labels
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            return logits, loss

        return logits

## Tokenizer

In [6]:
tokenizer = pytorch_transformers.BertTokenizer.from_pretrained('bert-base-uncased')

## Configuration

In [None]:
config = {
    "embed_dim":         410,
    "hidden_dim":        2100,
    "num_max_positions": 256,
    "num_embeddings":    len(tokenizer.vocab), # 30522 tokens
    "num_heads":         10,
    "num_layers":        16,
    "dropout":           0.1,
    "initializer_range": 0.02,
    "batch_size":        16,
    "lr":                2.5e-4,
    "max_norm":          1.0,
    "n_epochs":          50,
    "n_warmup":          1000,
    "mlm":               False,
    "gradient_accumulation_steps": 4,
    "device":            "cuda" if torch.cuda.is_available() else "cpu",
    "log_dir":           "./",
    "dataset_cache":     "./dataset_cache.bin"
}