Train
---

In this notebook, Will train the CNN-RNN model.  

- `batch_size` - the batch size of each training batch.  It is the number of image-caption pairs used to amend the model weights in each training step. 
- `vocab_threshold` - the minimum word count threshold.  Note that a larger threshold will result in a smaller vocabulary, whereas a smaller threshold will include rarer words and result in a larger vocabulary.  
- `vocab_from_file` - a Boolean that decides whether to load the vocabulary from file. 
- `embed_size` - the dimensionality of the image and word embeddings.  
- `hidden_size` - the number of features in the hidden state of the RNN decoder.  
- `num_epochs` - the number of epochs to train the model.  
- `save_every` - determines how often to save the model weights.  We recommend that you set `save_every=1`, to save the model weights after each epoch.  This way, after the `i`th epoch, the encoder and decoder weights will be saved in the `models/` folder as `encoder-i.pkl` and `decoder-i.pkl`, respectively.
- `print_every` - determines how often to print the batch loss to the Jupyter notebook while training
- `log_file` - the name of the text file containing - for every step - how the loss and perplexity evolved during training.

In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
import sys
sys.path.append('/opt/cocoapi/PythonAPI')
from pycocotools.coco import COCO
from data_loader import get_loader
from model import EncoderCNN, DecoderRNN
import math

batch_size = 32          # batch size
vocab_threshold = 6        # minimum word count threshold
vocab_from_file = True    # if True, load existing vocab file
embed_size = 512           # dimensionality of image and word embeddings
hidden_size = 512          # number of features in hidden state of the RNN decoder
num_epochs = 10            # number of training epochs
save_every = 1             # determines frequency of saving model weights
print_every = 100          # determines window for printing average loss
log_file = 'training_log.txt'       # name of file with saved training loss and perplexity

transform_train = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

# Build data loader.
data_loader = get_loader(transform=transform_train,
                         mode='train',
                         batch_size=batch_size,
                         vocab_threshold=vocab_threshold,
                         vocab_from_file=vocab_from_file)

# The size of the vocabulary.
vocab_size = len(data_loader.dataset.vocab)

# Initialize the encoder and decoder. 
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Move models to GPU if CUDA is available. 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

# Define the loss function. 
criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

params = list(decoder.parameters()) + list(encoder.embed.parameters()) 

optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08)

total_step = math.ceil(len(data_loader.dataset.caption_lengths) / data_loader.batch_sampler.batch_size)

Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...
Done (t=0.69s)
creating index...


  0%|          | 663/414113 [00:00<01:02, 6626.12it/s]

index created!
Obtaining caption lengths...


100%|██████████| 414113/414113 [00:52<00:00, 7939.96it/s]


## Step 2: Train the  Model

```python
# Load pre-trained weights before resuming training.
encoder.load_state_dict(torch.load(os.path.join('./models', encoder_file)))
decoder.load_state_dict(torch.load(os.path.join('./models', decoder_file)))
```

In [2]:
import torch.utils.data as data
import numpy as np
import os
import time

# Open the training log file.
f = open(log_file, 'w')

old_time = time.time()
                            headers={"Metadata-Flavor":"Google"})

for epoch in range(1, num_epochs+1):
    
    for i_step in range(1, total_step+1):
        
        if time.time() - old_time > 60:
            old_time = time.time()
        
        # Randomly sample a caption length, and sample indices with that length.
        indices = data_loader.dataset.get_train_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices.
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler
        
        # Obtain the batch.
        images, captions = next(iter(data_loader))

        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        captions = captions.to(device)
        
        # Zero the gradients.
        decoder.zero_grad()
        encoder.zero_grad()
        
        # Pass the inputs through the CNN-RNN model.
        features = encoder(images)
        outputs = decoder(features, captions)
        
        # Calculate the batch loss.
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
        
        # Backward pass.
        loss.backward()
        
        # Update the parameters in the optimizer.
        optimizer.step()
            
        # Get training statistics.
        stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, num_epochs, i_step, total_step, loss.item(), np.exp(loss.item()))
        
        # Print training statistics (on same line).
        print('\r' + stats, end="")
        sys.stdout.flush()
        
        # Print training statistics to file.
        f.write(stats + '\n')
        f.flush()
        
        # Print training statistics (on different line).
        if i_step % print_every == 0:
            print('\r' + stats)
            
    # Save the weights.
    if epoch % save_every == 0:
        torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-%d.pkl' % epoch))
        torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-%d.pkl' % epoch))

# Close the training log file.
f.close()

Epoch [1/10], Step [100/12942], Loss: 4.5089, Perplexity: 90.8187
Epoch [1/10], Step [200/12942], Loss: 3.7137, Perplexity: 41.00571
Epoch [1/10], Step [300/12942], Loss: 3.4036, Perplexity: 30.0713
Epoch [1/10], Step [400/12942], Loss: 3.7827, Perplexity: 43.9331
Epoch [1/10], Step [500/12942], Loss: 3.3950, Perplexity: 29.8136
Epoch [1/10], Step [600/12942], Loss: 3.7504, Perplexity: 42.5375
Epoch [1/10], Step [700/12942], Loss: 3.4433, Perplexity: 31.2898
Epoch [1/10], Step [800/12942], Loss: 3.7652, Perplexity: 43.1715
Epoch [1/10], Step [900/12942], Loss: 3.4139, Perplexity: 30.3829
Epoch [1/10], Step [1000/12942], Loss: 3.7258, Perplexity: 41.5041
Epoch [1/10], Step [1100/12942], Loss: 2.8036, Perplexity: 16.5033
Epoch [1/10], Step [1200/12942], Loss: 3.1838, Perplexity: 24.1387
Epoch [1/10], Step [1300/12942], Loss: 3.6778, Perplexity: 39.5576
Epoch [1/10], Step [1400/12942], Loss: 2.9622, Perplexity: 19.3411
Epoch [1/10], Step [1500/12942], Loss: 3.2180, Perplexity: 24.9791
Epo

Epoch [2/10], Step [11500/12942], Loss: 2.1744, Perplexity: 8.79675
Epoch [2/10], Step [11600/12942], Loss: 2.1598, Perplexity: 8.66960
Epoch [2/10], Step [11700/12942], Loss: 2.3013, Perplexity: 9.98719
Epoch [2/10], Step [11800/12942], Loss: 2.1057, Perplexity: 8.21259
Epoch [2/10], Step [11900/12942], Loss: 2.3473, Perplexity: 10.4569
Epoch [2/10], Step [12000/12942], Loss: 2.6425, Perplexity: 14.0482
Epoch [2/10], Step [12100/12942], Loss: 2.2353, Perplexity: 9.34917
Epoch [2/10], Step [12200/12942], Loss: 2.2098, Perplexity: 9.11355
Epoch [2/10], Step [12300/12942], Loss: 2.2745, Perplexity: 9.72296
Epoch [2/10], Step [12400/12942], Loss: 1.8171, Perplexity: 6.15377
Epoch [2/10], Step [12500/12942], Loss: 2.0060, Perplexity: 7.43340
Epoch [2/10], Step [12600/12942], Loss: 2.2640, Perplexity: 9.62149
Epoch [2/10], Step [12700/12942], Loss: 2.0468, Perplexity: 7.74305
Epoch [2/10], Step [12800/12942], Loss: 2.6293, Perplexity: 13.8639
Epoch [2/10], Step [12900/12942], Loss: 2.2310, 

Epoch [4/10], Step [10100/12942], Loss: 2.0121, Perplexity: 7.47937
Epoch [4/10], Step [10200/12942], Loss: 1.9931, Perplexity: 7.33856
Epoch [4/10], Step [10300/12942], Loss: 1.9858, Perplexity: 7.28495
Epoch [4/10], Step [10400/12942], Loss: 3.1465, Perplexity: 23.2546
Epoch [4/10], Step [10500/12942], Loss: 1.8773, Perplexity: 6.53579
Epoch [4/10], Step [10600/12942], Loss: 2.0325, Perplexity: 7.63292
Epoch [4/10], Step [10700/12942], Loss: 2.1606, Perplexity: 8.67632
Epoch [4/10], Step [10800/12942], Loss: 1.9516, Perplexity: 7.04002
Epoch [4/10], Step [10900/12942], Loss: 1.9791, Perplexity: 7.23644
Epoch [4/10], Step [11000/12942], Loss: 2.6141, Perplexity: 13.6549
Epoch [4/10], Step [11100/12942], Loss: 1.8595, Perplexity: 6.42072
Epoch [4/10], Step [11200/12942], Loss: 1.9440, Perplexity: 6.98677
Epoch [4/10], Step [11300/12942], Loss: 1.9298, Perplexity: 6.88839
Epoch [4/10], Step [11400/12942], Loss: 2.0280, Perplexity: 7.59899
Epoch [4/10], Step [11500/12942], Loss: 2.0533, 

Epoch [6/10], Step [8500/12942], Loss: 1.9078, Perplexity: 6.73842
Epoch [6/10], Step [8600/12942], Loss: 2.5375, Perplexity: 12.6482
Epoch [6/10], Step [8700/12942], Loss: 2.0051, Perplexity: 7.42718
Epoch [6/10], Step [8800/12942], Loss: 1.9192, Perplexity: 6.81572
Epoch [6/10], Step [8900/12942], Loss: 2.5365, Perplexity: 12.6354
Epoch [6/10], Step [9000/12942], Loss: 2.1204, Perplexity: 8.33457
Epoch [6/10], Step [9100/12942], Loss: 2.0195, Perplexity: 7.53491
Epoch [6/10], Step [9200/12942], Loss: 1.8601, Perplexity: 6.42436
Epoch [6/10], Step [9300/12942], Loss: 2.3877, Perplexity: 10.8881
Epoch [6/10], Step [9400/12942], Loss: 2.0701, Perplexity: 7.92563
Epoch [6/10], Step [9500/12942], Loss: 2.0374, Perplexity: 7.67055
Epoch [6/10], Step [9600/12942], Loss: 2.1042, Perplexity: 8.20041
Epoch [6/10], Step [9700/12942], Loss: 2.4447, Perplexity: 11.5269
Epoch [6/10], Step [9800/12942], Loss: 2.0400, Perplexity: 7.69095
Epoch [6/10], Step [9900/12942], Loss: 2.2947, Perplexity: 9.9

Epoch [8/10], Step [6900/12942], Loss: 1.8613, Perplexity: 6.43220
Epoch [8/10], Step [7000/12942], Loss: 2.1885, Perplexity: 8.92155
Epoch [8/10], Step [7100/12942], Loss: 1.8373, Perplexity: 6.27933
Epoch [8/10], Step [7200/12942], Loss: 2.0605, Perplexity: 7.85029
Epoch [8/10], Step [7300/12942], Loss: 2.0627, Perplexity: 7.86758
Epoch [8/10], Step [7400/12942], Loss: 1.7088, Perplexity: 5.52252
Epoch [8/10], Step [7500/12942], Loss: 1.9174, Perplexity: 6.80309
Epoch [8/10], Step [7600/12942], Loss: 1.9285, Perplexity: 6.87891
Epoch [8/10], Step [7700/12942], Loss: 2.1947, Perplexity: 8.97773
Epoch [8/10], Step [7800/12942], Loss: 2.2011, Perplexity: 9.03533
Epoch [8/10], Step [7900/12942], Loss: 2.2119, Perplexity: 9.13295
Epoch [8/10], Step [8000/12942], Loss: 2.0483, Perplexity: 7.75512
Epoch [8/10], Step [8100/12942], Loss: 1.8893, Perplexity: 6.61517
Epoch [8/10], Step [8200/12942], Loss: 2.1489, Perplexity: 8.57508
Epoch [8/10], Step [8300/12942], Loss: 2.4882, Perplexity: 12.

Epoch [10/10], Step [5300/12942], Loss: 2.0209, Perplexity: 7.54513
Epoch [10/10], Step [5400/12942], Loss: 2.1191, Perplexity: 8.32346
Epoch [10/10], Step [5500/12942], Loss: 2.1146, Perplexity: 8.28653
Epoch [10/10], Step [5600/12942], Loss: 1.9655, Perplexity: 7.13859
Epoch [10/10], Step [5700/12942], Loss: 2.1341, Perplexity: 8.44980
Epoch [10/10], Step [5800/12942], Loss: 2.2284, Perplexity: 9.28530
Epoch [10/10], Step [5900/12942], Loss: 2.5382, Perplexity: 12.6569
Epoch [10/10], Step [6000/12942], Loss: 2.2623, Perplexity: 9.60566
Epoch [10/10], Step [6100/12942], Loss: 1.9360, Perplexity: 6.93086
Epoch [10/10], Step [6200/12942], Loss: 2.1702, Perplexity: 8.76047
Epoch [10/10], Step [6300/12942], Loss: 2.4349, Perplexity: 11.4142
Epoch [10/10], Step [6400/12942], Loss: 2.3649, Perplexity: 10.6432
Epoch [10/10], Step [6500/12942], Loss: 2.1212, Perplexity: 8.34112
Epoch [10/10], Step [6600/12942], Loss: 1.8315, Perplexity: 6.24318
Epoch [10/10], Step [6700/12942], Loss: 1.7112, 