# Language Modelling Sandbox - Lilian

In [1]:
# Imports
from typing import List, Tuple, Union

from torch import nn
import torch

from transformers import GPT2Tokenizer, GPT2Model

In [2]:
class LanguageModel(nn.Module):
    def __init__(self, pretrained: str = "gpt2"):
        super(LanguageModel, self).__init__()
        self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = GPT2Model.from_pretrained(pretrained)
    
    def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor,
                past_key_values: Union[None, Tuple[torch.Tensor, torch.Tensor]] = None,
                past_attention_mask: Union[None, torch.LongTensor] = None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        :param input_ids: (batch_size, input_id_len) token indices
        :param attention_mask: (batch_size, input_id_len) 1 for tokens that should be attended to, 0 for padding, etc. that shouldn't
        :param past_key_values: Optional, num_layers-len tuple of tensors (2, batch_size, num_heads, past_tokens_len, embed_size_per_head)
                                the cache output of a previous call to the language model that stores all the hidden states
                                Note: tokens represented in past_key_values should NOT be included in input_ids
        :param past_attention_mask: Optional, (batch_size, past_tokens_len) attention mask
                                    corresponding to past key values inputs
        :return: (last_hidden_state, new_attention_mask, key_values)
                 last_hidden_state: (batch_size, input_id_len, hidden_size) final hidden state associated with each input_id
                 new_attention_mask: (batch_size, past_tokens_len + input_id_len) new past attention mask
                 key_values: updated past_key_values
        """
        if past_key_values is not None:
            new_attention_mask = torch.cat([past_attention_mask, attention_mask], dim=1)
        else:
            new_attention_mask = attention_mask
        last_hidden_state, key_values = self.model(
            input_ids=input_ids,
            attention_mask=new_attention_mask,
            past_key_values=past_key_values
        )
        return last_hidden_state, new_attention_mask, key_values
    
    def tokenize(self, inputs: List[str], add_space: bool = False) -> Tuple[torch.LongTensor, torch.LongTensor]:
        """
        :param inputs: list of strings to tokenize
        :param add_space: True if a space should be added to the strings so that the first token is considered a new word,
                          else False
        :return: (input_ids, attention_mask)
                 each item has shape (batch_size=len(inputs), seq_len) where seq_len is the max # of tokens in any input
        """
        if add_space:
            inputs = [" " + x for x in inputs]
        
        tokens = self.tokenizer.batch_encode_plus(
            inputs,
            padding="longest",
            return_attention_mask=True,
            return_tensors="pt"
        )
        
        return tokens["input_ids"], tokens["attention_mask"]
    
    def token_embedding(self) -> torch.Tensor:
        """
        Returns the token embedding tensor of shape (vocab_size, hidden_size).
        """
        return self.model.wte.weight.data
    
    def position_embedding(self) -> torch.Tensor:
        """
        Returns the position embedding tensor of shape (max_seq_len, hidden_size).
        """
        return self.model.wpe.weight.data
    

#### Important Notes
A few things re: how GPT2Model/Tokenizer works that aren't immediately clear from the documentation:

1) Let past_key_values be the tuple associated with some `past_tokens_len` words. Let the number of input ids per batch be `input_id_len`. Then the shape of the attention mask passed into the model should be `past_tokens_len + input_id_len` - in other words, we need to keep the attention of the past tokens as well, and concatenate it in front of the new attention.

2) The way the GPT2Tokenizer works is that words succeeding spaces are prefixed by "Ġ", i.e. "Ġword". So, the token at the very beginning of a text sample with no preceding space, or subword parts that aren't first, do not have the Ġ. This is important to consider if we want to append a new word to a sequence, for example in the following case:

We are interested in decoding the sentence "I am happy.", which would be represented as:
    ['<|endoftext|>' (implicitly), 'I', 'Ġam', 'Ġhappy', '.']

We tokenize and feed in '<|endoftext|>', the EOS token. Some feature is produced which can help predict "I".
We tokenize and feed in 'I', using the past_key_values associated with the output of the previous step.
We find that the next word is 'am'. We tokenize the "am" string to get the id associated with 'am', which is NOT EQUAL to 'Ġam'!

This is just something to take note of in case we explicitly pass in strings. Ideally, the token we output from our overall predictor is directly 'Ġam' instead of 'am', which would bypass the problem.

In [3]:
m = LanguageModel("gpt2-large")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=764.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3247202234.0, style=ProgressStyle(descr…




Some weights of GPT2Model were not initialized from the model checkpoint at gpt2-large and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'h.12.attn.masked_bias', 'h.13.attn.masked_bias', 'h.14.attn.masked_bias', 'h.15.attn.masked_bias', 'h.16.attn.masked_bias', 'h.17.attn.masked_bias', 'h.18.attn.masked_bias', 'h.19.attn.masked_bias', 'h.20.attn.masked_bias', 'h.21.attn.masked_bias', 'h.22.attn.masked_bias', 'h.23.attn.masked_bias', 'h.24.attn.masked_bias', 'h.25.attn.masked_bias', 'h.26.attn.masked_bias', 'h.27.attn.masked_bias', 'h.28.attn.masked_bias', 'h.29.attn.masked_bias', 'h.30.attn.masked_bias', 'h.31.attn.masked_bias', 'h.32.attn.masked_bias', 'h.33.attn.masked_bias', 'h.34.attn.masked_bias', 'h.35.attn.masked_bi

In [9]:
m.model.wte(torch.FloatTensor([[m.tokenizer.eos_token_id]]).long()).shape

torch.Size([1, 1, 1280])