# AttnGAN
## Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks

http://openaccess.thecvf.com/content_cvpr_2018/papers/Xu_AttnGAN_Fine-Grained_Text_CVPR_2018_paper.pdf  
https://github.com/taoxugit/AttnGAN

---

## TODO

- run le code en debug dans IntelliJ
- indiquer les shapes dans le code de ce notebook
- copier et analyser la fonction train DAMSM

---

![](images/AttnGAN-framework.png)

- **z** - a noise vector usually sampled from a standard normal distribution.  
- **F**<sup>*ca*</sup> - represents the Conditioning Augmentation that converts the sentence vector **E** to the conditioning vector.
  - **E** is a global sentence vector, and **e** is the matrix of word vectors.
- **F**<sub>i</sub><sup>*attn*</sup> is the proposed attention model at the i<sup>*th*</sup> stage of the AttnGAN.


- **F**<sup>*ca*</sup>, **F**<sub>i</sub><sup>*attn*</sup>, **F**<sub>i</sub> , and **G**<sub>i</sub> are neural networks.

- The attention model **F**<sup>*attn*</sup> has two inputs:  
the word features **e** and the image features from the previous hidden layer h.

The word features are first converted into the common semantic space of the image features by adding a new perceptron layer.  

Then, a word-context vector is computed for each sub-region of the image based on its hidden features **h** (query). Each column of h is a feature vector of a sub-region of the image. For the j<sup>th</sup> sub-region, its word-context vector is a dynamic representation of word vectors relevant to h<sub>j</sub>.  

Finally, image features and the corresponding word-context features are combined to generate images at the next stage.

# Deep Attentional Multimodal Similarity Model (DAMSM)

The DAMSM **learns two neural networks** that map sub-regions of the image and words of the sentence to a common semantic space, thus measures the **image-text similarity** at the word level to compute a fine-grained **loss for image generation**.

## Text Encoder 

A **bi-directional LSTM** that extracts semantic vectors from the text description.

![](images/text_encoder.png)

**word features**: The feature matrix of words indicated by **e**. Its i<sup>*th*</sup> column **e**<sub>i</sub> is feature vector for the i<sup>*th*</sup> word.  

**sentence feature**: The last hidden states of the bi-directional LSTM are concatenated to be the global sentence vector, denoted by **E**.

In [None]:
class RNN_ENCODER(nn.Module):
    def __init__(self, ntoken=2932, ninput=300, drop_prob=0.5, nhidden=256, nlayers=1, bidirectional=True):
        '''
        ntoken  -- size of the dictionary computed from the dataset captions; 2932 tokens in FashionGen2
        nhidden -- TEXT.EMBEDDING_DIM = 256 for FashionGen2       
        ninput  -- size of each embedding vector (300 by default)
        
        nlayers -- Number of recurrent layers
        '''
        self.n_steps = cfg.TEXT.WORDS_NUM  # 10 in FashionGen2 (caption max number of words)
        # ...
        if bidirectional:
            self.num_directions = 2

        # number of features in the hidden state (hidden nodes in the LSTM layer)
        self.nhidden = nhidden // self.num_directions  # 128 = 256 / 2  (1 Bi-LSTM layer of 128 nodes)

    def define_module(self):
        # ...
        self.encoder = nn.Embedding(self.ntoken, self.ninput)  # nn.Embedding(2932, 300)

        if self.rnn_type == 'LSTM':
            self.rnn = nn.LSTM(self.ninput, self.nhidden,
                               self.nlayers, batch_first=True,
                               dropout=self.drop_prob,
                               bidirectional=self.bidirectional)

    def forward(self, captions, cap_lens, hidden, mask=None):
        # input: torch.LongTensor of size batch x n_steps --> emb: batch x n_steps x ninput
        # input (bs, 10) --> emb (bs, 10, 300)
        emb = self.drop(self.encoder(captions))
        
        # Returns: a PackedSequence object
        cap_lens = cap_lens.data.tolist()
        # cap_lens -> <type 'list'>: [10, 10, 10, 10, 10, 10, 10, 9, 9, 8, 8, 8, 7, 7, 7, 
        #                              7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5]
        
        emb = pack_padded_sequence(emb, cap_lens, batch_first=True)
        
        # emb -> PackedSequence(tensor([[-0.0112,  0.0000, -0.0002,  ...,  0.5088,  0.0000, -0.1240],
        #                               [-0.0000,  0.0782,  0.3802,  ..., -0.3164,  0.4351,  0.0748],
        #                               [-0.1120, -0.0000, -0.1069,  ..., -0.0000,  0.0000, -0.0000],
        #                               ...,
        #                               [ 0.0000, -0.1407, -0.6452,  ...,  0.0000, -0.0000,  0.4360],
        #                               [-0.0000, -0.0074,  0.0000,  ...,  0.1719,  0.0000,  0.0082],
        #                               [ 0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000,  2.9084]]
    
        
        # #hidden and memory (num_layers * num_directions, batch, hidden_size):
        #     tensor containing the initial hidden state for each element in batch.
        # #output (batch, seq_len, hidden_size * num_directions) or a PackedSequence object:
        #     tensor containing output features (h_t) from the last layer of RNN
        output, hidden = self.rnn(emb, hidden)
        
        # PackedSequence object --> (batch, seq_len, hidden_size * num_directions)
        output = pad_packed_sequence(output, batch_first=True)[0]
        
        # output = self.drop(output) --> batch x hidden_size * num_directions x seq_len
        words_emb = output.transpose(1, 2)
        
        # words_emb.shape --> torch.Size([32, 256, 10])
        
        # --> batch x num_directions * hidden_size
        if self.rnn_type == 'LSTM':
            sentence_emb = hidden[0].transpose(0, 1).contiguous()
            
        sentence_emb = sentence_emb.view(-1, self.nhidden * self.num_directions)  # (-1, 128*2)

        # words_emb.shape --> torch.Size([32, 256, 10])
        # sentence_emb    --> torch.Size([32, 256])
        return words_emb, sentence_emb

In [None]:
RNN_ENCODER(
  (encoder): Embedding(2932, 300)
  (drop): Dropout(p=0.5)
  (rnn): LSTM(300, 128, batch_first=True, dropout=0.5, bidirectional=True)
)

---

# Image Encoder

A pretrained **Inception v3** CNN (input image of 299×299) that maps images to semantic vectors.

![](images/inceptionv3.png)

The **intermediate layers** of the CNN learn **local features** of different **sub-regions of the image**, while the later layers learn global features of the image.    

We extract the **local feature** matrix **f** ∈ R768⇥289 (reshaped from 768×17×17, 17x17=289) from the **“mixed 6e” layer** of Inception-v3.  
Each column of **f** is the **feature vector** of a **sub-region of the image**.  

**f** shape is (768, 289):  
768 is the dimension of the local feature vector, and  
289 is the number of sub-regions in the image. 

In [None]:
class CNN_ENCODER(nn.Module):
    def __init__(self, nef):
        ''' 
        nef <-- TEXT.EMBEDDING_DIM = 256 (does 'nef' stands for 'Number Embedding Features'?)
        '''
        super(CNN_ENCODER, self).__init__()
        if cfg.TRAIN.FLAG:
            self.nef = nef
        else:
            self.nef = 256  # define a uniform ranker

        model = models.inception_v3()
        url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'
        model.load_state_dict(model_zoo.load_url(url))
        for param in model.parameters():
            param.requires_grad = False

        self.define_module(model)
        self.init_trainable_weights()

    def define_module(self, model):
        self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3
        # ...
        self.Mixed_5d = model.Mixed_5d
        self.Mixed_6a = model.Mixed_6a
        self.Mixed_6b = model.Mixed_6b
        # ...
        self.emb_features = conv1x1(768, self.nef)
        self.emb_cnn_code = nn.Linear(2048, self.nef)

    def init_trainable_weights(self):
        initrange = 0.1
        self.emb_features.weight.data.uniform_(-initrange, initrange)
        self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        features = None
        # --> fixed-size input: batch x 3 x 299 x 299
        x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
        # 299 x 299 x 3
        x = self.Conv2d_1a_3x3(x)
        # 149 x 149 x 32
        x = self.Conv2d_2a_3x3(x)

        # ...

        x = self.Mixed_6e(x)
        # 17 x 17 x 768

        # --- image region features ---
        features = x
        # 17 x 17 x 768

        x = self.Mixed_7a(x)
        # 8 x 8 x 1280
        x = self.Mixed_7b(x)
        # 8 x 8 x 2048
        x = self.Mixed_7c(x)
        # 8 x 8 x 2048
        x = F.avg_pool2d(x, kernel_size=8)
        # 1 x 1 x 2048
        x = x.view(x.size(0), -1)
        # 2048

        # --- global image features ---
        cnn_code = self.emb_cnn_code(x)             # self.emb_cnn_code = nn.Linear(2048, self.nef)
        # 512
        if features is not None:
            features = self.emb_features(features)  # self.emb_features = conv1x1(768, self.nef)

        return features, cnn_code

In [None]:
CNN_ENCODER(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Mixed_5b): InceptionA(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch5x5_1): BasicConv2d(
      (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch5x5_2): BasicConv2d(
      (conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_2): BasicConv2d(
      (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3): BasicConv2d(
      (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )

...
    
  (Mixed_6e): InceptionC(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_2): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_3): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_2): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_3): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_4): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_5): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
    
...
    
  (Mixed_7c): InceptionE(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(2048, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_1): BasicConv2d(
      (conv): Conv2d(2048, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_2a): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_2b): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_1): BasicConv2d(
      (conv): Conv2d(2048, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(448, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_2): BasicConv2d(
      (conv): Conv2d(448, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3a): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3b): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(2048, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (emb_features): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (emb_cnn_code): Linear(in_features=2048, out_features=256, bias=True)
)

# Training

In [None]:
cfg.TEXT.EMBEDDING_DIM = 256
dataset.n_words = 2932   # for FashionGen subset. Computed by dataset.load_text_data(): parsing all captions

def build_models():
    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)

---

In [None]:
    batch_size = 32

    for step, data in enumerate(dataloader, 0):
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

        #
        # imgs -- list of 1 image tensor
        # imgs[0].shape --> torch.Size([32, 3, 299, 299])
        #
        # cap_lens.shape(32) -->     tensor([10, 10, 10, 10, 10, 10, 10,  9,  9,  8,  8,  8,  7,  7,  7,  7,  
        #                                     7,  7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  6,  5,  5,  5,  5]
        # 
        #
        # captions.shape(32, 10) --> tensor([[1151,   16,  526,  241, 1240, 1944, 1443,  303,  526, 1147],
        #                                    [2331,  195, 1624, 1151, 1078, 2859, 1837,   16,  526, 1147],
        #                                     ...
        #                                    [2153, 1837, 2538,  526, 1147,    0,    0,    0,    0,    0]]
        #
        # nef -- cfg.TEXT.EMBEDDING_DIM = 256 (for FashionGen)

        # words_features: batch_size x nef x 17 x 17
        # sentence_feature: batch_size x nef
        words_features, sentence_feature = cnn_model(imgs[-1])
        
        # words_features.shape   --> torch.Size([32, 256, 17, 17])
        # sentence_feature.shape --> torch.Size([32, 256])
        
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        
        # nef -> 256
        # att_sze -> 17
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        
        # hidden -> list of size 2
        # hidden[0].shape -> torch.Size([2, 32, 128]) -- all zeros
        # hidden[1].shape -> torch.Size([2, 32, 128]) -- all zeros
        
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sentence_feature, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data

        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()