In [1]:
run visdial/encoders/text_encoder.py

In [2]:
hist_encoder = HistEncoder(DynamicRNN(torch.nn.LSTM(300, 512, num_layers=2)), hidden_size=512)
hist_encoder

HistEncoder(
  (encoder): DynamicRNN(
    (rnn_module): LSTM(300, 512, num_layers=2)
  )
)

In [3]:
hist = torch.randn(16, 10, 10, 20, 300)
hist_len = torch.randint(1, 20, size=(16, 10, 10))

In [4]:
out = hist_encoder(hist, hist_len)

In [5]:
for o in out: 
    print(o.shape)

torch.Size([160, 10, 512])
torch.Size([160, 10])


In [6]:
ques_encoder = QuesEncoder(DynamicRNN(torch.nn.LSTM(300, 512, num_layers=2)), hidden_size=512)
ques_encoder

QuesEncoder(
  (encoder): DynamicRNN(
    (rnn_module): LSTM(300, 512, num_layers=2)
  )
)

In [7]:
ques = torch.randn(16, 10, 20, 300)
ques_len = torch.randint(1, 20, size=(16, 10))

In [8]:
ques_out = ques_encoder(ques, ques_len)

In [9]:
for q in ques_out:
    print(q.shape)

torch.Size([160, 20, 512])
torch.Size([160, 20])


In [10]:
text_embed = TextEmbeddings(vocab_size=11322, embedding_size=300, hidden_size=512, has_position=True, has_hidden_layer=False)

In [12]:
tokens = torch.randint(0, 11322, size=(16, 10, 25))

In [13]:
text_embed(tokens).size()

torch.Size([16, 10, 25, 300])

In [14]:
batch = {
    'ques_tokens' : torch.randint(0, 11322, size=(16, 10, 25)),
    'ques_len' : torch.randint(1, 25, size=(16, 10)),
    'hist_tokens' : torch.randint(0, 11322, size=(16, 10, 10, 25)),
    'hist_len' : torch.randint(1, 25, size=(16, 10, 10)), 
    'img_feat' : torch.randn(16, 36, 2048)
}

In [15]:
text_encoder = TextEncoder(text_embed, hist_encoder, ques_encoder)

In [16]:
text_encoder

TextEncoder(
  (text_embeddings): TextEmbeddings(
    (tok_embedding): Embedding(11322, 300, padding_idx=0)
    (pos_embedding): PositionalEmbedding()
  )
  (hist_encoder): HistEncoder(
    (encoder): DynamicRNN(
      (rnn_module): LSTM(300, 512, num_layers=2)
    )
  )
  (ques_encoder): QuesEncoder(
    (encoder): DynamicRNN(
      (rnn_module): LSTM(300, 512, num_layers=2)
    )
  )
)

In [17]:
out = text_encoder(batch)

In [18]:
for o in out:
    print(o.shape)

torch.Size([160, 10, 512])
torch.Size([160, 25, 512])
torch.Size([160, 10])
torch.Size([160, 25])


# Run model 

In [8]:
run visdial/model.py

In [14]:
run debug.py

HOME_PATH /home/quanguet
DATA_PATH /home/quanguet/datasets/visdial


[nltk_data] Downloading package punkt to /home/quanguet/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
  2%|▏         | 1072/45238 [00:00<00:04, 10713.14it/s]

[val2018] Tokenizing questions...


100%|██████████| 45238/45238 [00:03<00:00, 12647.35it/s]
  4%|▍         | 1310/34822 [00:00<00:02, 13098.98it/s]

[val2018] Tokenizing answers...


100%|██████████| 34822/34822 [00:02<00:00, 13413.98it/s]
100%|██████████| 2064/2064 [00:00<00:00, 10963.35it/s]


[val2018] Tokenizing captions...
img_ids         torch.Size([2])
num_rounds      torch.Size([2])
opts            torch.Size([2, 10, 100, 25])
opts_in         torch.Size([2, 10, 100, 25])
opts_out        torch.Size([2, 10, 100, 25])
opts_len        torch.Size([2, 10, 100])
opts_in_len     torch.Size([2, 10, 100])
opts_out_len    torch.Size([2, 10, 100])
ans             torch.Size([2, 10, 25])
ans_in          torch.Size([2, 10, 25])
ans_out         torch.Size([2, 10, 25])
ans_len         torch.Size([2, 10])
ans_in_len      torch.Size([2, 10])
ans_out_len     torch.Size([2, 10])
ans_ind         torch.Size([2, 10])
gt_relevance    torch.Size([2, 100])
round_id        torch.Size([2])
img_feat        torch.Size([2, 36, 2048])
ques_tokens     torch.Size([2, 10, 25])
hist_tokens     torch.Size([2, 10, 10, 50])
ques_len        torch.Size([2, 10])
hist_len        torch.Size([2, 10, 10])


In [15]:
run configs/lstm_config.py

HOME_PATH /home/quanguet
DATA_PATH /home/quanguet/datasets/visdial


In [16]:
config = get_config()

In [17]:
text_embeddings = TextEmbeddings(**config['model']['text_embeddings'])

In [18]:
model = get_lstm_model(config)

In [19]:
model

VisdialModel(
  (encoder): Encoder(
    (text_encoder): TextEncoder(
      (text_embeddings): TextEmbeddings(
        (tok_embedding): Embedding(11322, 300, padding_idx=0)
        (pos_embedding): PositionalEmbedding()
      )
      (hist_encoder): HistEncoder(
        (encoder): DynamicRNN(
          (rnn_module): LSTM(300, 512, num_layers=2)
        )
      )
      (ques_encoder): QuesEncoder(
        (encoder): DynamicRNN(
          (rnn_module): LSTM(300, 512, num_layers=2)
        )
      )
    )
    (img_encoder): ImageEncoder(
      (img_linear): Sequential(
        (0): Linear(in_features=2048, out_features=512, bias=True)
        (1): Dropout(p=0.2)
      )
    )
    (attn_encoder): CrossAttentionEncoder(
      (cross_attn_encoder): Sequential(
        (0): CrossAttentionLayer(
          (attns): ModuleList(
            (0): MultiHeadAttention(
              (dropout): Dropout(p=0.2)
              (x_proj_linear): Linear(in_features=512, out_features=512, bias=False)
         

In [26]:
batch['hist_tokens'].shape

torch.Size([2, 10, 10, 50])

In [23]:
model.encoder(batch)

RuntimeError: Length of all samples has to be greater than 0, but found an element in 'lengths' that is <= 0

In [29]:
encoder_out.shape

torch.Size([160, 512])

In [31]:
encoder_out

tensor([[ 0.0466,  0.0301, -0.0358,  ...,  0.0406, -0.0049,  0.0362],
        [ 0.0468,  0.0303, -0.0372,  ...,  0.0392, -0.0043,  0.0379],
        [ 0.0456,  0.0303, -0.0382,  ...,  0.0417, -0.0047,  0.0363],
        ...,
        [ 0.0406,  0.0287, -0.0372,  ...,  0.0387, -0.0041,  0.0340],
        [ 0.0394,  0.0309, -0.0355,  ...,  0.0385, -0.0015,  0.0363],
        [ 0.0414,  0.0316, -0.0344,  ...,  0.0399, -0.0042,  0.0372]],
       grad_fn=<AddmmBackward>)

In [30]:
hist = torch.randint(1, 10, size=(1, 4, 1, 3))
print(hist)

tensor([[[[3, 7, 6]],

         [[1, 8, 8]],

         [[3, 6, 7]],

         [[3, 2, 1]]]])


In [34]:
hist.rep

In [39]:
out = hist.unsqueeze(1).repeat(1, 4, 1, 1, 1)
print(out.shape)

torch.Size([1, 4, 4, 1, 3])


In [54]:
out = out.cuda()

In [55]:
mask = torch.tensor([
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
    [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
], device=out.device)
mask = mask[None, :, :, None, None]

In [56]:
mask.shape

torch.Size([1, 10, 10, 1, 1])

In [57]:
mask

tensor([[[[[1]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]]],


         [[[1]],

          [[1]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]]],


         [[[1]],

          [[1]],

          [[1]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]]],


         [[[1]],

          [[1]],

          [[1]],

          [[1]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]]],


         [[[1]],

          [[1]],

          [[1]],

          [[1]],

          [[1]],

          [[0]],

          [[0]],

          [[0]],

          [[0]],

          [[0]]],


         [[[1]],

          [[1]],

          [[1]],

          [[1]],

          [[1]],



In [49]:
out.masked_fill(mask == 0, 0)

tensor([[[[[3, 7, 6]],

          [[0, 0, 0]],

          [[0, 0, 0]],

          [[0, 0, 0]]],


         [[[3, 7, 6]],

          [[1, 8, 8]],

          [[0, 0, 0]],

          [[0, 0, 0]]],


         [[[3, 7, 6]],

          [[1, 8, 8]],

          [[3, 6, 7]],

          [[0, 0, 0]]],


         [[[3, 7, 6]],

          [[1, 8, 8]],

          [[3, 6, 7]],

          [[3, 2, 1]]]]])