In [1]:
# default_exp trainingUtils

In [1]:
# append our Aeye package
# TODO: fix this import issue when using the package.
import sys 
sys.path.append('../aeye')

from preprocessing import Lang, tensorForImageCaption, get_preprocessed_data, SOS_token, EOS_token
from models import EncoderRNN, DecoderRNN

In [96]:
import random
import time
import math

import torch
from torch import nn
from torch import optim

device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')

In [3]:
feature_dict, sentence_list, lang = get_preprocessed_data('train')
sentence_list = random.sample(sentence_list, len(sentence_list))

In [106]:
img, sent = tensorForImageCaption(feature_dict, sentence_list[3], lang)

In [107]:
img.shape, sent.shape

(torch.Size([1, 512]), torch.Size([11, 1]))

In [108]:
img

tensor([[1.9842e-01, 5.9648e+00, 2.3226e+00, 1.2697e+00, 7.8906e-01, 7.9215e-02,
         1.1086e-01, 1.9074e-02, 1.8796e+00, 2.5521e-03, 3.8115e-01, 8.5905e-01,
         8.9893e-01, 6.7055e-02, 1.4348e+00, 2.1541e+00, 2.4470e+00, 3.4091e+00,
         1.6080e+00, 2.7844e+00, 1.9986e+00, 2.2561e+00, 2.9838e+00, 1.3658e+00,
         2.4977e-01, 1.4338e+00, 5.7418e+00, 6.8935e-01, 1.6719e+00, 3.0016e-01,
         2.0613e+00, 1.8420e+00, 2.2487e-01, 1.2603e+00, 1.1628e+00, 6.0581e+00,
         2.1356e+00, 2.4543e+00, 3.6082e+00, 1.5615e+00, 3.5274e-01, 4.0782e-01,
         3.7784e+00, 8.5897e-01, 3.6654e+00, 9.5126e-01, 1.6369e+00, 7.4365e-01,
         4.3271e+00, 3.8370e+00, 1.6275e+00, 2.5214e+00, 1.8694e+00, 4.4382e-01,
         4.9051e-02, 2.0352e+00, 2.1343e+00, 1.0767e+00, 3.0240e-01, 6.5968e-01,
         7.5922e-01, 3.5785e+00, 1.5097e+00, 4.0003e-01, 4.5587e+00, 3.2672e-01,
         3.3936e-01, 4.8536e-01, 1.4820e-01, 1.9782e+00, 1.6055e+00, 1.9909e+00,
         2.6264e+00, 5.3603e

In [109]:
sent

tensor([[ 110],
        [ 105],
        [   9],
        [ 665],
        [1592],
        [  28],
        [   8],
        [  47],
        [ 144],
        [2815],
        [   1]])

In [33]:
SOS_tensor = torch.tensor([[SOS_token]], device=device)
EOS_tensor = torch.tensor([[EOS_token]], device=device)

EOS_tensor.dtype

torch.int64

In [121]:
teacher_forcing_ratio = 0.5
MAX_LENGTH = 30

def train(img_tensor, sent_tensor, decoder, criterion, 
          decoder_optimizer, max_length = MAX_LENGTH):
        
#     encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    length = sent_tensor.size(0)
    img_tensor = img_tensor.unsqueeze(0)
    sent_tensor = sent_tensor.unsqueeze(1)
    loss = 0
    
    decoder_input = SOS_tensor
    decoder_hidden = img_tensor
    use_teacher_forcing = True if random.random() \
                < teacher_forcing_ratio else True
    
    if use_teacher_forcing:
        for di in range(length):
            decoder_output, decoder_hidden = decoder(
                                        decoder_input, 
                                        decoder_hidden)
            
            print('\n\n', decoder_output)
            print('\n\n', decoder_hidden)
            loss += criterion(decoder_output[0], sent_tensor[di].squeeze(0))
            decoder_input = sent_tensor[di]
            
    else:
        for di in range(max_length):
            decoder_output, decoder_hidden = decoder(
                                        decoder_input,
                                        decoder_hidden)
            
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            loss += criterion(decoder_output[0], sent_tensor[0].squeeze(0))
    
    
    loss.backward()
    
    decoder_optimizer.step()
    
    return loss.item()/length

In [117]:
img.dtype

torch.float32

In [118]:
hidden_size = 512

decoder = DecoderRNN(hidden_size, lang.n_words)



In [59]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [122]:
learning_rate = 0.1

def trainIters(decoder, print_every=1000, plot_every=100,
               learning_rate=0.1):
    
    print_loss_total = 0
    start = time.time()
    
    # Get data
    feature_dict, sentence_list, lang = get_preprocessed_data('train')
    sentence_list = random.sample(sentence_list, len(sentence_list))
    n_iters = len(sentence_list)
    
    criterion = nn.NLLLoss()
    decoder_optimizer = optim.SGD(decoder.parameters(), lr = learning_rate)
    
    for iter, sentence in enumerate(sentence_list):
        img, sent = tensorForImageCaption(feature_dict, sentence, lang)
        loss = train(img, sent, decoder, criterion, decoder_optimizer)
        
        print(loss)
        break
        if (iter+1) % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
                                        iter, iter / n_iters * 100, print_loss_avg))
            


In [123]:
trainIters(decoder, print_every=10)



 tensor([[[0., 0., 0.,  ..., 0., 0., 0.]]], grad_fn=<LogSoftmaxBackward>)


 tensor([[[ 3.8382e-01,  4.1358e+00,  3.6584e-02,  7.9183e-01,  2.3832e-01,
          -4.3980e-01,  3.1525e-01, -1.7522e-01,  1.4950e+00,  2.1150e-01,
           1.5722e-01,  9.8403e-01,  9.4312e-01,  1.7420e-01,  1.2316e+00,
           2.4075e+00,  3.2609e+00,  2.9444e+00,  1.0421e+00,  1.8282e-01,
          -7.4717e-02,  7.3485e-01,  2.7300e+00,  2.6305e-01, -2.3797e-01,
           6.1523e-01,  3.7121e-01,  9.6492e-01,  1.7639e+00,  2.0498e-01,
           2.2484e+00,  1.8146e+00,  5.3975e-01,  9.8505e-01,  2.8041e-01,
           8.0842e+00,  2.0589e+00,  2.9190e+00,  3.9769e+00,  3.9014e-01,
           7.9187e-01,  8.0430e-01,  5.4187e+00,  1.4145e+00,  1.9304e+00,
           4.7413e-01,  1.7654e+00,  9.3459e-01,  1.3940e+00,  1.2964e+00,
           1.1734e+00,  1.0900e+00, -5.2843e-01,  7.9904e-01,  8.1107e-02,
           1.3953e+00,  1.3443e+00, -9.1299e-02,  5.2127e-01, -5.3464e-01,
           7.2838e-01