In [None]:
# export
from fastai2.basics import *

from transformers import PreTrainedTokenizer, AutoTokenizer

from fastai_transformers_utils.generated_lm import GeneratedLM, GenerateArgs
from fastai_transformers_utils.tokenizers import GPT2DecoderTokenizer

In [None]:
# default_exp models.gru2gru

In [None]:
bs = 3
enc_seq_len = 50
dec_seq_len = 40

#  Models GRU2GRU
> 

In [None]:
enc_vocab_size = 21128
enc_pad_id = 0

dec_vocab_size = 50259
dec_pad_id = 50258

embeded_size = 512
num_encoder_layers = 2
num_decoder_layers = 2
drop_p = 0.1

## Encoder

In [None]:
# export
class GRUEncoder(nn.Module):
    def __init__(
        self,
        vocab_size, embeded_size, pad_id,
        num_layers=1, drop_p=0
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embeded_size, padding_idx=pad_id)
        self.encoder = nn.GRU(embeded_size, embeded_size, num_layers=num_layers, dropout=drop_p, batch_first=True, bidirectional=True)
        
    def forward(self, src_inp_ids):
        '''
            src_inp_ids: (bs, enc_seq_len)
            returns: output, h
                output: (bs, seq_len, 2*embeded_size)
                h: (2*num_layers, bs, embeded_size)
        '''
        embeded = self.embedding(src_inp_ids) # (bs, enc_seq_len, embeded_size)
        output, h = self.encoder(embeded) # (bs, enc_seq_len, 2*embeded_size), (2*num_encoder_layers, bs, embeded_size)
        return output, h

In [None]:
src_input_ids = torch.randint(0, enc_vocab_size, (bs, enc_seq_len)) # (bs, enc_seq_len)
encoder = GRUEncoder(enc_vocab_size, embeded_size, enc_pad_id, num_encoder_layers, drop_p)
output, h = encoder(src_input_ids)
test_eq(output.shape, (bs, enc_seq_len, 2*embeded_size))
test_eq(h.shape, (2*num_encoder_layers, bs, embeded_size))

## Decoder

In [None]:
# export
class GRUDecoder(nn.Module):
    def __init__(
        self,
        vocab_size, embeded_size, pad_id,
        num_layers=1, drop_p=0,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embeded_size, padding_idx=pad_id)
        self.decoder = nn.GRU(embeded_size, embeded_size, num_layers=num_layers, dropout=drop_p, batch_first=True)
        self.classifier = nn.Linear(embeded_size, vocab_size)
        
    def forward(self, tgt_inp_ids, h):
        '''
            tgt_inp_ids: (bs, dec_seq_len)
            h: (num_decoder_layers, bs, embeded_size)
            returns: output, h
                output: (bs, dec_seq_len, dec_vocab_size)
                h: (num_decoder_layers, bs, embeded_size)
        '''
        embeded = self.embedding(tgt_inp_ids) # (bs, dec_seq_len, embeded_size)
        output, h = self.decoder(embeded, h) # (bs, dec_seq_len, embeded_size), (num_decoder_layers, bs, embeded_size)
        output = self.classifier(output) # (bs, dec_seq_len, dec_vocab_size)
        return output, h # (bs, dec_seq_len, dec_vocab_size), (num_decoder_layers, bs, embeded_size)

In [None]:
decoder = GRUDecoder(dec_vocab_size, embeded_size, dec_pad_id, num_decoder_layers, drop_p)

tgt_input_ids = torch.randint(0, dec_vocab_size, (bs, dec_seq_len)) # (bs, dec_seq_len)
h = torch.randn((num_decoder_layers, bs, embeded_size)) # (num_decoder_layers, bs, embeded_size)

output, h = decoder(tgt_input_ids, h)
test_eq(output.shape, (bs, dec_seq_len, dec_vocab_size))
test_eq(h.shape, (num_decoder_layers, bs, embeded_size))

## GRU2GRU

In [None]:
# export
class GRU2GRU(nn.Module):
    def __init__(
        self, 
        encoder: GRUEncoder, decoder: GRUDecoder, 
        num_encoder_layers, num_decoder_layers,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.proj = nn.Linear(2*num_encoder_layers, num_decoder_layers)
        
    def forward(self, src_input_ids, tgt_input_ids):
        '''
            src_input_ids: (bs, enc_seq_len)
            tgt_input_ids: (bs, dec_seq_len)
            returns: (bs, dec_seq_len, dec_vocab_size)
        '''
        _, enc_h = self.encoder(src_input_ids) # (2*num_encoder_layers, bs, embeded_size)
        enc_h = enc_h.permute(1, 2, 0) # (bs, embeded_size, 2*num_encoder_layers)
        enc_h_proj = self.proj(enc_h) # (bs, embeded_size, num_decoder_layers)
        enc_h_proj = enc_h_proj.permute(2, 0, 1) # (num_decoder_layers, bs, embeded_size)
        enc_h_proj = enc_h_proj.contiguous()
        
        output, dec_h = self.decoder(tgt_input_ids, enc_h_proj) # (bs, dec_seq_len, dec_vocab_size), (num_decoder_layers, bs, embeded_size)
        return output

In [None]:
gru2gru = GRU2GRU(encoder, decoder, num_encoder_layers, num_decoder_layers)

src_input_ids = torch.randint(0, enc_vocab_size, (bs, enc_seq_len)) # (bs, enc_seq_len)
tgt_input_ids = torch.randint(0, dec_vocab_size, (bs, dec_seq_len)) # (bs, dec_seq_len)
output = gru2gru(src_input_ids, tgt_input_ids) # (bs, dec_seq_len, dec_vocab_size)
test_eq(output.shape, (bs, dec_seq_len, dec_vocab_size))

## GeneratedGRU2GRU

In [None]:
# export
class GeneratedGRU2GRU():
    '''
        device is for created tensors
    '''
    def __init__(
        self,
        seq2seq: GRU2GRU, 
        enc_tokenizer: PreTrainedTokenizer,
        dec_tokenizer: PreTrainedTokenizer,
    ):
        self.seq2seq = seq2seq
        self.enc_tokenizer = enc_tokenizer
        self.dec_tokenizer = dec_tokenizer
        self.generatedLM = GeneratedLM(seq2seq.decoder, len(dec_tokenizer), dec_tokenizer.pad_token_id, [dec_tokenizer.eos_token_id], support_past=False)

In [None]:
enc_tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-bert-wwm-ext')
dec_tokenizer = GPT2DecoderTokenizer.from_pretrained('distilgpt2')

### generate_from_ids

In [None]:
# export
@patch
@torch.no_grad()
def generate_from_ids(self: GeneratedGRU2GRU, src_input_ids, generate_args: GenerateArgs):
    ''' src_input_ids: (bs, enc_seq_len) '''
    self.seq2seq.eval()
    
    device = src_input_ids.device
    bs = src_input_ids.shape[0]
    tgt_input_ids = torch.zeros((bs, 1), dtype=torch.long).fill_(self.dec_tokenizer.bos_token_id).to(device) # (bs, 1)

    _, enc_h = self.seq2seq.encoder(src_input_ids) # (2*num_encoder_layers, bs, embeded_size)
    enc_h = enc_h.permute(1, 2, 0) # (bs, embeded_size, 2*num_encoder_layers)
    enc_h_proj =  self.seq2seq.proj(enc_h) # (bs, embeded_size, num_decoder_layers)
    model_otherargs = self.generatedLM.build_model_otherargs_for_beam([enc_h_proj], generate_args.num_beams) # (bs*num_beams, embeded_size, num_decoder_layers)
    enc_h_proj = model_otherargs[0].permute(2, 0, 1) # (num_decoder_layers, bs*num_beams, embeded_size)
    enc_h_proj = enc_h_proj.contiguous()

    result = self.generatedLM.generate(tgt_input_ids, generate_args, [enc_h_proj])

    return result

In [None]:
generated_gru2gru = GeneratedGRU2GRU(gru2gru, enc_tokenizer, dec_tokenizer)

generate_args = GenerateArgs(max_length=20, num_beams=2)
src_input_ids = torch.randint(0, enc_vocab_size, (bs, enc_seq_len)) # (bs, enc_seq_len)
result = generated_gru2gru.generate_from_ids(src_input_ids, generate_args)
test_eq(result.shape, (bs, 20))

## Export -

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 02_data.tatoeba.ipynb.
Converted 03a_models.core.ipynb.
Converted 03b_models.tran2tran.ipynb.
Converted 03c_models.gru2gru.ipynb.
Converted index.ipynb.
