In [None]:
# export
import json
from fastai2.basics import *
from fastai2.text.all import *
from fastai2.vision.gan import *

In [None]:
# default_exp model

# Model
>

## Encoder

In [None]:
# export
class Encoder(nn.Module):
    def __init__(self, vocab_size: int, out_size: int):
        super().__init__()
        self.awd_lstm = AWD_LSTM(
            vocab_size, 
            emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, bidir=False, hidden_p=0.15, input_p=0.25, embed_p=0.02, weight_p=0.2,
        )
        self.hid_proj = nn.Linear(self.awd_lstm.emb_sz, out_size)

    def forward(self, inp_ids):
        "inp_ids: (bs, seq_len), summary: (bs, out_size))"
        self.reset()
        awd_lstm_out = self.awd_lstm(inp_ids) # awd_lstm_out: (bs, seq_len, emb_sz=400)
        awd_lstm_hid = self.awd_lstm.hidden 
        # hid[0], hid[1]: ( (1, bs, 1152), (1, bs, 1152) )
        # hid[2]: ( (1, bs, 400), (1, bs, 400) )
        
#         summary = awd_lstm_hid[2][0].squeeze(dim=0) # (bs, 400)  
        summary = awd_lstm_out.mean(dim=1) # (bs, 400)
        summary = self.hid_proj(summary) # (bs, out_size)

        return summary
    def reset(self):
        self.awd_lstm.reset()
    @classmethod
    def from_pretrained(cls, awd_lstm_path: Path, vocab_path: Path, 
                        out_size):
        vocab = json.loads(vocab_path.read_text())
        ret = cls(len(vocab), out_size)
        ret.awd_lstm.load_state_dict(torch.load(awd_lstm_path))
        return ret

In [None]:
encoder = Encoder(vocab_size=8472, out_size=100)
inp_ids = torch.randint(0, 100, (16, 20))
summary = encoder(inp_ids)
test_eq(summary.shape, (16, 100))

In [None]:
# skip
encoder = Encoder.from_pretrained(Path('./coco_small/awd_lstm-1.pt'), Path('./coco_small/vocab.json'), 100)
inp_ids = torch.randint(0, 100, (16, 20))
summary = encoder(inp_ids)
test_eq(summary.shape, (16, 100))

## Decoder

In [None]:
# export
class Decoder(nn.Module):
    def __init__(self, out_size: int, inp_size: int, num_layers=3):
        super().__init__()
        store_attr(self, 'out_size,inp_size,num_layers')
        layers = self.gen_layers()
        self.layers = nn.Sequential(*layers)
    def gen_layers(self):
        up_sample = nn.Upsample(self.out_size//2**self.num_layers)
        num_c = 64*2**self.num_layers
        first_conv = ConvLayer(self.inp_size, num_c, 3, 1)
        other_conv = [ConvLayer(num_c//2**i, num_c//2**(i+1), 4, 2, 1, transpose=True) for i in range(self.num_layers)]
        last_conv = nn.Conv2d(64, 3, 3, 1, 1, bias=False)
        last_act = nn.Tanh()
        return [AddChannels(2), up_sample, first_conv] + other_conv + [last_conv, last_act]
    def change_out_size(self, out_size: int):
        self.out_size = out_size
        self.layers[1] = nn.Upsample(out_size//2**self.num_layers)
    def forward(self, enc_summary):
        "enc_summary: (bs, inp_size), out: (bs, 3, out_size, out_size)"
        return self.layers(enc_summary)

In [None]:
decoder = Decoder(64, 100)
enc_summary = torch.randn(16, 100)
out = decoder(enc_summary)
test_eq(out.shape, (16, 3, 64, 64))
decoder.summary(enc_summary)

Decoder (Input shape: ['16 x 100'])
Layer (type)         Output Shape         Param #    Trainable 
AddChannels          16 x 100 x 1 x 1     0          False     
________________________________________________________________
Upsample             16 x 100 x 8 x 8     0          False     
________________________________________________________________
Conv2d               16 x 512 x 8 x 8     460,800    True      
________________________________________________________________
BatchNorm2d          16 x 512 x 8 x 8     1,024      True      
________________________________________________________________
ReLU                 16 x 512 x 8 x 8     0          False     
________________________________________________________________
ConvTranspose2d      16 x 256 x 16 x 16   2,097,152  True      
________________________________________________________________
BatchNorm2d          16 x 256 x 16 x 16   512        True      
______________________________________________________________

In [None]:
decoder2 = Decoder(96, 100)
decoder2.load_state_dict(decoder.state_dict())
enc_summary = torch.randn(16, 100)
out = decoder2(enc_summary)
test_eq(out.shape, (16, 3, 96, 96))
decoder2.summary(enc_summary)

Decoder (Input shape: ['16 x 100'])
Layer (type)         Output Shape         Param #    Trainable 
AddChannels          16 x 100 x 1 x 1     0          False     
________________________________________________________________
Upsample             16 x 100 x 12 x 12   0          False     
________________________________________________________________
Conv2d               16 x 512 x 12 x 12   460,800    True      
________________________________________________________________
BatchNorm2d          16 x 512 x 12 x 12   1,024      True      
________________________________________________________________
ReLU                 16 x 512 x 12 x 12   0          False     
________________________________________________________________
ConvTranspose2d      16 x 256 x 24 x 24   2,097,152  True      
________________________________________________________________
BatchNorm2d          16 x 256 x 24 x 24   512        True      
______________________________________________________________

In [None]:
decoder2.change_out_size(128)
enc_summary = torch.randn(16, 100)
out = decoder2(enc_summary)
test_eq(out.shape, (16, 3, 128, 128))
decoder2.summary(enc_summary)

Decoder (Input shape: ['16 x 100'])
Layer (type)         Output Shape         Param #    Trainable 
AddChannels          16 x 100 x 1 x 1     0          False     
________________________________________________________________
Upsample             16 x 100 x 16 x 16   0          False     
________________________________________________________________
Conv2d               16 x 512 x 16 x 16   460,800    True      
________________________________________________________________
BatchNorm2d          16 x 512 x 16 x 16   1,024      True      
________________________________________________________________
ReLU                 16 x 512 x 16 x 16   0          False     
________________________________________________________________
ConvTranspose2d      16 x 256 x 32 x 32   2,097,152  True      
________________________________________________________________
BatchNorm2d          16 x 256 x 32 x 32   512        True      
______________________________________________________________

## Generator

In [None]:
# export
class MGenerator(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, inp_ids):
        summary = self.encoder(inp_ids)
        out = self.decoder(summary)
        return out
    @classmethod
    def from_pretrained(cls, model_path: Path, vocab_path: Path,
                       enc_out_size, dec_out_size, num_dec_layers):
        vocab = json.loads(vocab_path.read_text())
        encoder = Encoder(len(vocab), enc_out_size)
        decoder = Decoder(dec_out_size, enc_out_size, num_dec_layers)
        ret = cls(encoder, decoder)
        ret.load_state_dict(torch.load(model_path))
        return ret

In [None]:
generator = MGenerator(encoder, decoder)
inp_ids = torch.randint(0, 100, (16, 20))
out = generator(inp_ids)
test_eq(out.shape, (16, 3, 64, 64))

## Export -

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 01_eda[script].ipynb.
Converted 01_gen_coco_tiny_data[script].ipynb.
Converted 02_data_coco.ipynb.
Converted 03_model.ipynb.
Converted 04_loss.ipynb.
Converted 05_leaner.ipynb.
Converted 90a_fulltest_train_lm.ipynb.
Converted 90b_fulltest_train_generator.ipynb.
Converted 95a_train_lm[script].ipynb.
Converted 95b_train_generator[script].ipynb.
Converted index.ipynb.
