In [2]:
import torch
import torch.nn as nn
from torchvision import transforms
import sys
sys.path.append('./cocoapi/PythonAPI')
from pycocotools.coco import COCO
from data_loader import get_loader
from model import EncoderCNN, DecoderRNN
import math



batch_size = 10          # batch size
vocab_threshold = 0.4       # minimum word count threshold
vocab_from_file = True    # if True, load existing vocab file
embed_size = 256           # dimensionality of image and word embeddings
hidden_size = 512          # number of features in hidden state of the RNN decoder
num_epochs = 3             # number of training epochs
save_every = 1             # determines frequency of saving model weights
print_every = 100          # determines window for printing average loss
log_file = 'training_log.txt'       # name of file with saved training loss and perplexity


transform_train = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

# Build data loader.
data_loader = get_loader(transform=transform_train,
                         mode='train',
                         batch_size=batch_size,
                         vocab_threshold=vocab_threshold,
                         vocab_from_file=vocab_from_file)

# The size of the vocabulary.
vocab_size = len(data_loader.dataset.vocab)

# Initialize the encoder and decoder. 
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Move models to GPU if CUDA is available. 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

# Define the loss function. 
criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()


params = list(decoder.parameters()) + list(encoder.embed.parameters()) 

optimizer = torch.optim.Adam(params, lr=0.001)


total_step = math.ceil(len(data_loader.dataset.caption_lengths) / data_loader.batch_sampler.batch_size)

Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...


  0%|          | 1306/414113 [00:00<01:07, 6143.08it/s]

Done (t=0.80s)
creating index...
index created!
Obtaining caption lengths...


100%|██████████| 414113/414113 [01:01<00:00, 6757.60it/s]


In [3]:
import torch.utils.data as data
import numpy as np
import os
import requests
import time

# Open the training log file.
f = open(log_file, 'w')

old_time = time.time()

for epoch in range(1, num_epochs+1):
    
    for i_step in range(1, total_step+1):
        
        if time.time() - old_time > 60:
            old_time = time.time()
            
        # Randomly sample a caption length, and sample indices with that length.
        indices = data_loader.dataset.get_train_indices()
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler
        
        # Obtain the batch.
        images, captions = next(iter(data_loader))

        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        captions = captions.to(device)
        
        # Zero the gradients.
        decoder.zero_grad()
        encoder.zero_grad()
        
        # Pass the inputs through the CNN-RNN model.
        features = encoder(images)
        outputs = decoder(features, captions)
        
        # Calculate the batch loss.
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
        
        # Backward pass.
        loss.backward()
        
        # Update the parameters in the optimizer.
        optimizer.step()
            
        # Get training statistics.
        stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, num_epochs, i_step, total_step, loss.item(), np.exp(loss.item()))
        
        # Print training statistics (on same line).
        print('\r' + stats, end="")
        sys.stdout.flush()
        
        # Print training statistics to file.
        f.write(stats + '\n')
        f.flush()
        
        # Print training statistics (on different line).
        if i_step % print_every == 0:
            print('\r' + stats)
            
    # Save the weights.
    if epoch % save_every == 0:
        torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-%d.pkl' % epoch))
        torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-%d.pkl' % epoch))

# Close the training log file.
f.close()

Epoch [1/3], Step [100/41412], Loss: 4.0002, Perplexity: 54.6077
Epoch [1/3], Step [200/41412], Loss: 4.0060, Perplexity: 54.92536
Epoch [1/3], Step [300/41412], Loss: 3.8952, Perplexity: 49.1673
Epoch [1/3], Step [400/41412], Loss: 3.7748, Perplexity: 43.58790
Epoch [1/3], Step [500/41412], Loss: 3.4968, Perplexity: 33.0094
Epoch [1/3], Step [600/41412], Loss: 3.3878, Perplexity: 29.59977
Epoch [1/3], Step [700/41412], Loss: 3.3744, Perplexity: 29.20778
Epoch [1/3], Step [800/41412], Loss: 3.3292, Perplexity: 27.9165
Epoch [1/3], Step [900/41412], Loss: 3.0967, Perplexity: 22.1246
Epoch [1/3], Step [1000/41412], Loss: 3.2101, Perplexity: 24.7805
Epoch [1/3], Step [1100/41412], Loss: 3.1116, Perplexity: 22.4580
Epoch [1/3], Step [1200/41412], Loss: 3.7805, Perplexity: 43.8390
Epoch [1/3], Step [1300/41412], Loss: 2.9828, Perplexity: 19.7428
Epoch [1/3], Step [1400/41412], Loss: 3.0570, Perplexity: 21.2630
Epoch [1/3], Step [1500/41412], Loss: 2.7453, Perplexity: 15.5687
Epoch [1/3], St

Epoch [1/3], Step [24600/41412], Loss: 2.8468, Perplexity: 17.2331
Epoch [1/3], Step [24700/41412], Loss: 2.1604, Perplexity: 8.67463
Epoch [1/3], Step [24800/41412], Loss: 2.3503, Perplexity: 10.4885
Epoch [1/3], Step [24900/41412], Loss: 2.6070, Perplexity: 13.5580
Epoch [1/3], Step [25000/41412], Loss: 2.4647, Perplexity: 11.7599
Epoch [1/3], Step [25100/41412], Loss: 1.7195, Perplexity: 5.58186
Epoch [1/3], Step [25200/41412], Loss: 3.0046, Perplexity: 20.1775
Epoch [1/3], Step [25300/41412], Loss: 2.8998, Perplexity: 18.1701
Epoch [1/3], Step [25400/41412], Loss: 1.7098, Perplexity: 5.528078
Epoch [1/3], Step [25500/41412], Loss: 2.9788, Perplexity: 19.6639
Epoch [1/3], Step [25600/41412], Loss: 1.8994, Perplexity: 6.68169
Epoch [1/3], Step [25700/41412], Loss: 1.8971, Perplexity: 6.66664
Epoch [1/3], Step [25800/41412], Loss: 2.6620, Perplexity: 14.3254
Epoch [1/3], Step [25900/41412], Loss: 2.5661, Perplexity: 13.0148
Epoch [1/3], Step [26000/41412], Loss: 2.4726, Perplexity: 11

Epoch [2/3], Step [7700/41412], Loss: 2.1660, Perplexity: 8.72350
Epoch [2/3], Step [7800/41412], Loss: 1.8076, Perplexity: 6.09617
Epoch [2/3], Step [7900/41412], Loss: 2.1157, Perplexity: 8.29583
Epoch [2/3], Step [8000/41412], Loss: 2.3574, Perplexity: 10.5637
Epoch [2/3], Step [8100/41412], Loss: 1.8706, Perplexity: 6.49192
Epoch [2/3], Step [8200/41412], Loss: 2.2369, Perplexity: 9.36410
Epoch [2/3], Step [8300/41412], Loss: 2.4201, Perplexity: 11.24693
Epoch [2/3], Step [8400/41412], Loss: 2.3064, Perplexity: 10.0381
Epoch [2/3], Step [8500/41412], Loss: 2.7628, Perplexity: 15.8436
Epoch [2/3], Step [8600/41412], Loss: 2.4618, Perplexity: 11.7253
Epoch [2/3], Step [8700/41412], Loss: 2.2161, Perplexity: 9.17196
Epoch [2/3], Step [8800/41412], Loss: 2.2620, Perplexity: 9.60276
Epoch [2/3], Step [8900/41412], Loss: 2.2665, Perplexity: 9.64558
Epoch [2/3], Step [9000/41412], Loss: 2.5345, Perplexity: 12.6107
Epoch [2/3], Step [9100/41412], Loss: 1.7856, Perplexity: 5.96347
Epoch [2/

Epoch [2/3], Step [32100/41412], Loss: 2.3916, Perplexity: 10.9314
Epoch [2/3], Step [32200/41412], Loss: 2.0705, Perplexity: 7.92914
Epoch [2/3], Step [32300/41412], Loss: 2.0191, Perplexity: 7.53161
Epoch [2/3], Step [32400/41412], Loss: 2.9614, Perplexity: 19.3253
Epoch [2/3], Step [32500/41412], Loss: 2.0338, Perplexity: 7.64346
Epoch [2/3], Step [32600/41412], Loss: 2.1434, Perplexity: 8.52850
Epoch [2/3], Step [32700/41412], Loss: 2.5693, Perplexity: 13.0573
Epoch [2/3], Step [32800/41412], Loss: 2.3739, Perplexity: 10.7395
Epoch [2/3], Step [32900/41412], Loss: 2.1707, Perplexity: 8.76429
Epoch [2/3], Step [33000/41412], Loss: 2.1805, Perplexity: 8.85068
Epoch [2/3], Step [33100/41412], Loss: 2.1485, Perplexity: 8.57217
Epoch [2/3], Step [33200/41412], Loss: 2.3559, Perplexity: 10.5477
Epoch [2/3], Step [33300/41412], Loss: 2.0779, Perplexity: 7.98782
Epoch [2/3], Step [33400/41412], Loss: 2.3659, Perplexity: 10.6538
Epoch [2/3], Step [33500/41412], Loss: 2.0117, Perplexity: 7.4

KeyboardInterrupt: 