In [1]:
from models import Decoder
from models import Encoder
from models import Joint_Representaion_Learner

import torch
import pickle
import numpy as np

# Setara Opt Dictionary

In [2]:
config = {
    'pos_attention' : False,
    'enhance_input' : 2,
    'watch' : 0,
    'num_hidden_layers_decoder' : 1,
    'decoding_type' : 'ARFormer',
    'decoder' : 'BertDecoder',
    'vocab_size' : 100,
    'dim_hidden' : 512,
    'max_len' : 30,
    'with_category' : True,
    'num_category' : 20,
    'layer_norm_eps' : 0.00001,
    'hidden_dropout_prob' : 0.5,
    'num_attention_heads' : 8,
    'attention_probs_dropout_prob' : 0.0,
    'with_layernorm' : False,
    'intermediate_size' : 2048,
    'hidden_act' : 'gelu_new'
}

opt = {
    'encoder' : 'Encoder_HighWay',
    'modality' : 'mio',
    'dim_m' : 2048,
    'dim_i' : 1536,
    'dim_o' : 1024,
    'dim_hidden' : 512,
    'no_encoder_bn' : False
}

batch_size = 8
frame_len = 6
num_objs = 5
seq_len = 16
feat_dim = 512

# Encoder Input

In [3]:
torch.manual_seed(1)

m = torch.randn(batch_size, frame_len, 2048)
i = torch.randn(batch_size, frame_len, 1536)
o = torch.randn(batch_size, frame_len, num_objs, 1024)

feats = [m, i, o]

# Deklarasi Encoder

In [4]:
encoder = getattr(Encoder, opt['encoder'], None)(opt)
encoder.eval()

Encoder_HighWay(
  (Encoder_M): Sequential(
    (0): Linear(in_features=2048, out_features=512, bias=True)
    (1): HighWay(
      (w1): Linear(in_features=512, out_features=512, bias=True)
      (w2): Linear(in_features=512, out_features=512, bias=True)
      (tanh): Tanh()
    )
    (2): Dropout(p=0.5, inplace=False)
  )
  (Encoder_I): Sequential(
    (0): Linear(in_features=1536, out_features=512, bias=True)
    (1): HighWay(
      (w1): Linear(in_features=512, out_features=512, bias=True)
      (w2): Linear(in_features=512, out_features=512, bias=True)
      (tanh): Tanh()
    )
    (2): Dropout(p=0.5, inplace=False)
  )
  (Encoder_O): ORG(
    (dropout): Dropout(p=0.3, inplace=False)
    (adjacency_dropout): Dropout(p=0.5, inplace=False)
    (object_projection): Linear(in_features=1024, out_features=512, bias=True)
    (sigma_r): Linear(in_features=1024, out_features=512, bias=True)
    (psi_r): Linear(in_features=1024, out_features=512, bias=True)
    (w_r): Linear(in_features=10

# Forward Pass Encoder

In [5]:
enc_output, enc_hidden = encoder(feats) ## AMAN BERJALAN

In [6]:
print("Shape untuk output enhanced object features {} \n\
Shape untuk output proyeksi object features {}".format(enc_output[-1][0].shape, enc_output[-1][1].shape))

Shape untuk output enhanced object features torch.Size([8, 6, 5, 512]) 
Shape untuk output proyeksi object features torch.Size([8, 6, 5, 512])


# Joint Representation

In [7]:
feats_size = [opt['dim_hidden']] * (len(opt['modality']))
join_representation_learner = Joint_Representaion_Learner(feats_size, opt)

In [8]:
enc_output, enc_hidden, enc_obj_output = join_representation_learner(enc_output, enc_hidden)

torch.Size([8, 512])

# Deklarasi Decoder

In [3]:
decoder = getattr(Decoder, config['decoder'], None)(config)
decoder.eval()

BertDecoder(
  (embedding): BertEmbeddings(
    (word_embeddings): Embedding(100, 512, padding_idx=0)
    (position_embeddings): Embedding(30, 512)
    (category_embeddings): Embedding(20, 512)
    (LayerNorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (layer): ModuleList(
    (0): BertLayer(
      (attention): BertAttention(
        (self): BertSelfAttention(
          (query): Linear(in_features=512, out_features=512, bias=True)
          (key): Linear(in_features=512, out_features=512, bias=True)
          (value): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (output): BertSelfOutput(
          (dense): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.5, inplace=False)
        )
      )
      (attend_to_enc_output): BertAttention(
        (self): BertSelfAttention(
          (query): Linear(in_features=512, o

# Deklarasi Data yang Dibutuhkan

In [21]:
torch.manual_seed(1)

tgt_tokens = torch.randint(0, config['vocab_size']-1, (batch_size, seq_len))
# enc_output = [torch.randn(batch_size, frame_len, feat_dim) if i < 2 else torch.randn(batch_size, frame_len, num_objs, feat_dim) for i in range(3)]
encoder_outputs = {'enc_output' : torch.randn(batch_size, 2 * frame_len, feat_dim), 
                   'enc_hidden' : torch.randn(batch_size, 1, feat_dim)}
category = torch.LongTensor([2])
decoding_type = config['decoding_type']

# Info Corpus

In [5]:
with open('MSRVTT/info_corpus.pkl', 'rb') as f:
    info_corpus = pickle.load(f)

# ambil index to word dictionary
i2w = info_corpus['info']['itow']

Preparation before feedigng

In [6]:
def prepare_inputs_for_decoder(encoder_outputs, category):
    input_keys_for_decoder = ['enc_output']

    inputs_for_decoder = {'category': category}
    for key in input_keys_for_decoder:
        inputs_for_decoder[key] = encoder_outputs[key] # di sini udah ambil tensor

    if isinstance(inputs_for_decoder['enc_output'], list):
        assert len(inputs_for_decoder['enc_output']) == 1
        inputs_for_decoder['enc_output'] = inputs_for_decoder['enc_output'][0]

    return inputs_for_decoder

# Rerun this cell if Data is Updated

In [23]:
inputs_for_decoder = prepare_inputs_for_decoder(encoder_outputs, category)
tgt_tokens = [item[:, :-1] for item in tgt_tokens] if isinstance(tgt_tokens, list) else tgt_tokens[:, :-1]

In [25]:
hidden_states, embs, *_ = decoder( 
    tgt_seq=tgt_tokens, 
    decoding_type=decoding_type,
    output_attentions=False,
    **inputs_for_decoder # difeed setiap key value pairs ke forward method decoder
)

In [29]:
hidden_states[0].shape

torch.Size([8, 15, 512])