In [16]:
import torch
import torch.nn as nn
import json
import numpy as np
from text_processing import tokenize_text, untokenize, pad_text, Toks

In [2]:
gpu_id = 0

device = torch.device('cuda:{}'.format(gpu_id)) \
    if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda', index=0)

In [22]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        assert hidden_size % 2 == 0

        self.hidden_size = hidden_size
        self.input_size = input_size
        
        self.hidden_init_tensor = torch.zeros(2, 1, int(self.hidden_size/2), requires_grad=True)
        nn.init.normal_(self.hidden_init_tensor, mean=0, std=0.05)
        self.hidden_init = torch.nn.Parameter(self.hidden_init_tensor, requires_grad=True)
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.emb_drop = nn.Dropout(0.2)
        self.gru = nn.GRU(hidden_size, int(hidden_size/2), batch_first=True, bidirectional=True)
        self.gru_out_drop = nn.Dropout(0.2)
        self.gru_hid_drop = nn.Dropout(0.3)
        
    def forward(self, input, hidden, lengths):
        emb = self.emb_drop(self.embedding(input))
        pp = torch.nn.utils.rnn.pack_padded_sequence(emb, lengths, batch_first=True)
        out, hidden = self.gru(pp, hidden)
        out = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)[0]
        out = self.gru_out_drop(out)
        hidden = self.gru_hid_drop(hidden)
        return out, hidden
    
    def initHidden(self, bs):
        return self.hidden_init.expand(2, bs, int(self.hidden_size/2)).contiguous()

In [4]:
class DecoderAttn(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DecoderAttn, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.emb_drop = nn.Dropout(0.2)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.gru_drop = nn.Dropout(0.2)
        self.mlp = nn.Linear(hidden_size*2, output_size)
        self.logsoftmax = nn.LogSoftmax(dim=2)
        
        self.att_mlp = nn.Linear(hidden_size, hidden_size, bias=False)
        self.attn_softmax = nn.Softmax(dim=2)
    
    def forward(self, input, hidden, encoder_outs):
        emb = self.embedding(input)
        out, hidden = self.gru(self.emb_drop(emb), hidden)
        
        out_proj = self.att_mlp(out)
        enc_out_perm = encoder_outs.permute(0, 2, 1)
        e_exp = torch.bmm(out_proj, enc_out_perm)
        attn = self.attn_softmax(e_exp)
        
        ctx = torch.bmm(attn, encoder_outs)
        
        full_ctx = torch.cat([self.gru_drop(out), ctx], dim=2)
        
        out = self.mlp(full_ctx)
        out = self.logsoftmax(out)
        return out, hidden, attn

In [23]:
def build_model(enc_vocab_size, dec_vocab_size, 
                hid_size=512, loaded_state = None):
    
    enc = Encoder(enc_vocab_size, hid_size)
    dec = DecoderAttn(dec_vocab_size, hid_size, dec_vocab_size)
    
    if loaded_state is not None:
        enc.load_state_dict(loaded_state['enc'])
        dec.load_state_dict(loaded_state['dec'])
        
    enc = enc.to(device)
    dec = dec.to(device)
    
    return enc, dec

In [24]:
def setup_test():
    
    if torch.cuda.is_available():
        loaded_state = torch.load(model_path + seq_to_seq_test_model_fname,
                                  map_location=device)
    else:
        loaded_state = torch.load(model_path + seq_to_seq_test_model_fname,
                                  map_location='cpu')

    enc_idx_to_word = loaded_state['enc_idx_to_word']
    enc_word_to_idx = loaded_state['enc_word_to_idx']
    enc_vocab_size = len(enc_idx_to_word)

    dec_idx_to_word = loaded_state['dec_idx_to_word']
    dec_word_to_idx = loaded_state['dec_word_to_idx']
    dec_vocab_size = len(dec_idx_to_word)
    
    enc, dec = build_model(enc_vocab_size, 
                           dec_vocab_size, 
                           loaded_state = loaded_state)
    
    return {'enc': enc, 'dec': dec,
            'enc_idx_to_word': enc_idx_to_word,
            'enc_word_to_idx': enc_word_to_idx,
            'enc_vocab_size': enc_vocab_size,
            'dec_idx_to_word': dec_idx_to_word,
            'dec_word_to_idx': dec_word_to_idx,
            'dec_vocab_size': dec_vocab_size}

In [35]:
def make_packpadded(s, e, enc_padded_text):

    text = enc_padded_text[s:e]
    lengths = np.count_nonzero(text, axis=1)
    order = np.argsort(-lengths)
    new_text = text[order]
    new_enc = torch.tensor(new_text)
    new_enc = new_enc.to(device)
    
    leng = torch.tensor(lengths[order])
    leng.to(device)
    return order, new_enc, leng

def generate(enc, dec, enc_padded_text, L=20):
    enc.eval()
    dec.eval()
    with torch.no_grad():
        # run the encoder
        order, enc_pp, enc_lengths = make_packpadded(0, 
                                                     enc_padded_text.shape[0], 
                                                     enc_padded_text)
        hid = enc.initHidden(enc_padded_text.shape[0])
        out_enc, hid_enc = enc(enc_pp, hid, enc_lengths)
        
        hid_enc = torch.cat([hid_enc[0,:, :], hid_enc[1,:,:]], dim=1).unsqueeze(0)

        # run the decoder step by step
        dec_tensor = torch.ones((enc_padded_text.shape[0]), 
                                L + 1, 
                                dtype=torch.long) * Toks.SOS
        dec_tensor = dec_tensor.to(device)
        last_enc = hid_enc
        for i in range(L):
            out_dec, hid_dec, attn = dec.forward(dec_tensor[:,i].unsqueeze(1), 
                                                 last_enc, 
                                                 out_enc)
            out_dec[:, 0, Toks.UNK] = -np.inf # ignore unknowns
            #out_dec[torch.arange(dec_tensor.shape[0], dtype=torch.long), 0, dec_tensor[:, i]] = -np.inf
            chosen = torch.argmax(out_dec[:,0],dim=1)
            dec_tensor[:, i+1] = chosen
            last_enc = hid_dec
    
    return dec_tensor.data.cpu().numpy()[np.argsort(order)]

In [42]:
def test(setup_data, input_seqs, test_style=ROM_STYLE):
    input_rems_text = input_seqs
    slen = len(input_seqs)
    for i in range(slen):
        input_rems_text[i].append(COCO_STYLE)

    _, _, enc_tok_text, _ = tokenize_text(input_rems_text, 
                                          idx_to_word=setup_data['enc_idx_to_word'],
                                          word_to_idx = setup_data['enc_word_to_idx'])
    enc_padded_text = pad_text(enc_tok_text)

    dlen = enc_padded_text.shape[0]
    num_batch = int(dlen/BATCH_SIZE)
    
    print(num_batch)
    if dlen % BATCH_SIZE != 0:
        num_batch+=1
    res = []
    for i in range(num_batch):
        dec_tensor = generate(setup_data['enc'], 
                              setup_data['dec'], 
                              enc_padded_text[i*BATCH_SIZE:(i+1)*BATCH_SIZE])
        res.append(dec_tensor)

    all_text = []
    res = np.concatenate(res, axis=0)
    for row in res:
        utok = untokenize(row, setup_data['dec_idx_to_word'], to_text=True)
        all_text.append(utok)
    
    return all_text

In [43]:
model_path = "./models/"
seq_to_seq_test_model_fname = "seq_to_txt_state.tar"
BATCH_SIZE = 3
ROM_STYLE = "ROMANCETOKEN"
COCO_STYLE = "MSCOCOTOKEN"

setup_data = setup_test()
input_seqs = [['manNOUNNOUNNOUN', 'FRAMENETPosture', 'tennisNOUNNOUNNOUN', 
               'courtNOUNNOUNNOUN', 'FRAMENETContaining', 'racquetNOUNNOUNNOUN'], 
              ['manNOUNNOUNNOUN', 'FRAMENETPosture', 'fieldNOUNNOUNNOUN'], 
              ['bearNOUNNOUNNOUN', 'FRAMENETPlacing', 'topNOUNNOUNNOUN']]   # generate from img_to_text

all_text = test(setup_data, input_seqs)

1


In [45]:
len(input_seqs)

3

In [41]:
all_text

['a man standing on a tennis court holding a racquet',
 'a man standing in a field',
 'a teddy bear sitting on top of it']

In [37]:
all_text

['a man standing on a tennis court holding a racquet',
 'a man standing in a field of a very tall',
 'a teddy bear sitting on top of a wooden top']

In [32]:
all_text

['the man who was standing on the tennis court holding his racquet .',
 'the man standing in the field .',
 'i stuffed the bear on top of it .']