In [None]:
# export
from fastai2.basics import *

from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer

from fastai2_utils.pytorch.transformer import *
from fastai_transformers_utils.tokenizers import GPT2DecoderTokenizer
from fastai_transformers_utils.generated_lm import GeneratedLM, GenerateArgs

In [None]:
# default_exp models.bert2gpt2

In [None]:
enc_model_name = 'hfl/chinese-bert-wwm-ext'
dec_model_name = 'distilgpt2'

In [None]:
enc_tokenizer = AutoTokenizer.from_pretrained(enc_model_name)
dec_tokenizer = GPT2DecoderTokenizer.from_pretrained(dec_model_name)

# Models Bert2GPT2
> 

## Helper functions

In [None]:
# export
def gen_attention_mask(inp_ids, pad_id):
    '''
        Returns Tensor where 0 are positions that contain pad_id, others 1.
        input_ids: (bs, seq_len) returns: (bs, seq_len)
    '''
    key_padding_mask = gen_key_padding_mask(inp_ids, pad_id)
    return (~key_padding_mask).long()

In [None]:
input_ids = torch.tensor([[12, 11, 0, 0], 
                          [9, 1, 5, 0]])
attention_mask = gen_attention_mask(input_ids, 0)
test_eq(attention_mask, torch.tensor([[1, 1, 0, 0],
                                      [1, 1, 1, 0]]))

## BertEncoder

In [None]:
# export
class BertEncoder(nn.Module):
    def __init__(self, model_name):
        ''' model_name: pretrained bert model name from huggingface '''
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.layer_groups = [self.bert.embeddings, *self.bert.encoder.layer, self.bert.pooler]
    def forward(self, src_input_ids, src_attention_mask):
        '''
        src_input_ids: (bs, enc_seq_len)
        src_attention_mask: (bs, enc_seq_len)
        returns: (bs, enc_seq_len, embed_size)
        '''
        return self.bert(src_input_ids, attention_mask=src_attention_mask)[0]

In [None]:
enc_seq_len = 10
src_strs = ['測試', '你好嗎', '早安']
src_input_ids = torch.tensor([enc_tokenizer.encode(src_str, max_length=enc_seq_len, pad_to_max_length=True) for src_str in src_strs])
src_attention_mask = gen_attention_mask(src_input_ids, enc_tokenizer.pad_token_id)
src_input_ids, src_attention_mask

(tensor([[ 101, 3947, 6275,  102,    0,    0,    0,    0,    0,    0],
         [ 101,  872, 1962, 1621,  102,    0,    0,    0,    0,    0],
         [ 101, 3193, 2128,  102,    0,    0,    0,    0,    0,    0]]),
 tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]))

In [None]:
encoder = BertEncoder(enc_model_name)
test_eq(encoder(src_input_ids, src_attention_mask).shape, (3, enc_seq_len, encoder.bert.config.hidden_size))

## GPT2Decoder

In [None]:
# export
def _adujsted_gpt2wte(gpt2):
    ''' Adjust pretrained gpt2 wte layer to adapt the GPT2DecoderTokenizer.
    Add bos_token and pad_token at the last of gpt2.wte.
    Use GPT2DecoderTokenizer or make sure the pad token is at the last of your tokenizer and the bos token is at the second-last.
    '''
    old_wte = gpt2.wte
    old_weight = old_wte.weight
    num_embeddings = old_wte.num_embeddings+2
    embedding_dim = old_wte.embedding_dim
    
    bos_weight = old_weight.mean(dim=0)[None] # (1, embedding_dim)
    pad_weight = torch.zeros((1, embedding_dim))
    
    new_weight = torch.cat([old_weight, bos_weight, pad_weight], dim=0) # (num_embeddings, embedding_dim)
    new_wte = nn.Embedding(num_embeddings, embedding_dim, padding_idx=num_embeddings-1)
    new_wte.weight.data = new_weight
    
    return new_wte

In [None]:
gpt2 = AutoModel.from_pretrained(dec_model_name)
old_wte = gpt2.wte
new_wte = _adujsted_gpt2wte(gpt2)

test_eq(old_wte.weight, new_wte.weight[:-2])
test_eq(new_wte.weight[-1], torch.zeros((old_wte.embedding_dim))) # zero
test_eq(new_wte.weight[-2], old_wte.weight.mean(dim=0)) # mean of old

In [None]:
# Export
class GPT2Decoder(nn.Module):
    def __init__(
        self, 
        model_name, pad_id, # for GPT2
        vocab_size, # for classifier
        num_heads=1, drop_p=0, num_layers=1, # for CrossAttention
    ):
        ''' model_name: pretrained gpt2 model name from huggingface '''
        super().__init__()
        self.gpt2 = AutoModel.from_pretrained(model_name)
        self.gpt2.wte = _adujsted_gpt2wte(self.gpt2)
        self.cross_attn = CrossAttention(self.gpt2.config.n_embd, num_heads, drop_p, num_layers)
        self.classifier = nn.Linear(self.gpt2.config.n_embd, vocab_size)
        
        self.pad_id = pad_id
        self.layer_groups = [
            self.gpt2.wte, self.gpt2.wpe, *self.gpt2.h, self.gpt2.ln_f, *self.cross_attn.cross_attn_layers, self.classifier
        ]
    def forward(self, tgt_input_ids, memory, memory_key_padding_mask):
        '''
            tgt_input_ids: (bs, dec_seq_len)
            memory: (bs, enc_seq_len, embed_size)
            memory_key_padding_mask: (bs, enc_seq_len)
            returns: output, attn_weight
                output: (bs, dec_seq_len, dec_vocab_size)
                attn_weight: (bs, dec_seq_len, enc_seq_len)
        '''
        tgt_attention_mask = gen_attention_mask(tgt_input_ids, self.pad_id) # (bs, dec_seq_len)
        gpt2_out = self.gpt2(tgt_input_ids, attention_mask=tgt_attention_mask)[0] # (bs, dec_seq_len, 768)
        attn_output, attn_weight = self.cross_attn(gpt2_out, memory, src_key_padding_mask=memory_key_padding_mask) # (bs, dec_seq_len, 768), (bs, dec_seq_len, enc_seq_len)
        
        output = self.classifier(attn_output) # (bs, dec_seq_len, dec_vocab_size)
        
        return output, attn_weight

In [None]:
decoder = GPT2Decoder(
    dec_model_name, dec_tokenizer.pad_token_id,
    vocab_size=len(dec_tokenizer),
    num_heads=2, drop_p=0, num_layers=2,
)

In [None]:
# test embedding
test_eq((decoder.gpt2.wte.num_embeddings, decoder.gpt2.wte.embedding_dim), (gpt2.wte.num_embeddings+2, gpt2.wte.embedding_dim))

In [None]:
# test forward shape
tgt_input_ids = torch.randint(0, 50259, (3, 40))
memory = torch.randn((3, 50, 768))
memory_key_padding_mask = torch.zeros((3, 50)).bool()

output, attn_weight = decoder(tgt_input_ids, memory, memory_key_padding_mask)
test_eq(output.shape, (3, 40, len(dec_tokenizer)))
test_eq(attn_weight.shape, (3, 40, 50))

## Bert2Gpt2

In [None]:
# export
class Bert2GPT2(nn.Module):
    def __init__(
        self, 
        encoder: BertEncoder, decoder: GPT2Decoder,
        enc_pad_id, # for src_key_padding_mask and memory_key_padding_mask
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.enc_pad_id = enc_pad_id
    def forward(self, src_input_ids, tgt_input_ids):
        '''
            src_input_ids: (bs, enc_seq_len)
            tgt_input_ids: (bs, dec_seq_len)
        '''
        src_attention_mask = gen_attention_mask(src_input_ids, self.enc_pad_id) # (bs, enc_seq_len)
        memory = self.encoder(src_input_ids, src_attention_mask) # (bs, enc_seq_len, embed_size)
        memory_key_padding_mask = (1-src_attention_mask).bool()
        output, _ = self.decoder(tgt_input_ids, memory, memory_key_padding_mask=memory_key_padding_mask) # (bs, dec_seq_len, embeded_size)
        
        return output

In [None]:
bert2gpt2 = Bert2GPT2(encoder, decoder, enc_tokenizer.pad_token_id)
src_input_ids = torch.randint(0, len(enc_tokenizer), (3, 50)) # (bs, enc_seq_len)
tgt_input_ids = torch.randint(0, len(dec_tokenizer), (3, 40)) # (bs, dec_seq_len)

output = bert2gpt2(src_input_ids, tgt_input_ids)
test_eq(output.shape, (3, 40, len(dec_tokenizer)))

## GeneratedBert2GPT2

In [None]:
# export
class GeneratedBert2GPT2():
    def __init__(
        self, 
        seq2seq: Bert2GPT2, 
        enc_tokenizer: PreTrainedTokenizer,
        dec_tokenizer: PreTrainedTokenizer,
    ):
        self.seq2seq = seq2seq
        self.enc_tokenizer = enc_tokenizer
        self.dec_tokenizer = dec_tokenizer
        self.generatedLM = GeneratedLM(seq2seq.decoder, len(dec_tokenizer), dec_tokenizer.pad_token_id, [dec_tokenizer.eos_token_id], support_past=False)

### generate_from_ids

In [None]:
# export
@patch
@torch.no_grad()
def generate_from_ids(self: GeneratedBert2GPT2, src_input_ids, generate_args: GenerateArgs):
    ''' src_input_ids: (bs, enc_seq_len), returns: (bs, max_length)'''
    self.seq2seq.eval()
    
    device = src_input_ids.device
    bs = src_input_ids.shape[0]
    tgt_input_ids = torch.zeros((bs, 1), dtype=torch.long, device=device).fill_(self.dec_tokenizer.bos_token_id) # (bs, 1)

    src_attention_mask = gen_attention_mask(src_input_ids, self.enc_tokenizer.pad_token_id) # (bs, enc_seq_len)
    memory = self.seq2seq.encoder(src_input_ids, src_attention_mask) # (bs, enc_seq_len, embed_size)
    memory_key_padding_mask = (1-src_attention_mask).bool()
    model_otherargs = self.generatedLM.build_model_otherargs_for_beam([memory, memory_key_padding_mask], generate_args.num_beams)

    result = self.generatedLM.generate(tgt_input_ids, generate_args, [model_otherargs[0]], dict(memory_key_padding_mask=model_otherargs[1]))

    return result

In [None]:
generated_bert2gpt2 = GeneratedBert2GPT2(bert2gpt2, enc_tokenizer, dec_tokenizer)

generate_args = GenerateArgs(max_length=10, num_beams=2)
src_input_ids = torch.randint(0, len(enc_tokenizer), (3, 50)) # (bs, enc_seq_len)
result = generated_bert2gpt2.generate_from_ids(src_input_ids, generate_args)
test_eq(result.shape, (3, 10))

## Export -

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 02_data.news_commentary.ipynb.
Converted 02_data.tatoeba.ipynb.
Converted 03a_models.patch.ipynb.
Converted 03c_models.bert2gpt2.ipynb.
Converted 03c_models.gru2gru.ipynb.
Converted 03c_models.qrnn2qrnn.ipynb.
Converted 03c_models.tran2tran.ipynb.
Converted 04_metrics.ipynb.
Converted 99_fulltest_bert2gpt2.ipynb.
Converted 99_fulltest_gru2gru.ipynb.
Converted 99_fulltest_qrnn2qrnn.ipynb.
Converted 99_fulltest_tran2tran.ipynb.
Converted index.ipynb.
