In [1]:
# get the code for kaggle
!git clone https://github.com/moaaztaha/Image-Captioning
py_files_path = '/kaggle/working/Image-Captioning/'
import sys
sys.path.append(py_files_path)

Cloning into 'Image-Captioning'...
remote: Enumerating objects: 191, done.[K
remote: Counting objects: 100% (191/191), done.[K
remote: Compressing objects: 100% (145/145), done.[K
remote: Total 191 (delta 107), reused 125 (delta 43), pack-reused 0[K
Receiving objects: 100% (191/191), 10.93 MiB | 16.53 MiB/s, done.
Resolving deltas: 100% (107/107), done.


In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms


import pandas as pd
import numpy as np
import spacy

from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import time, math, random

from datasets import build_vocab, get_loaders
from models import EncoderCNN
from model2 import Img2Seq, DecoderRNN
from utils import train, evaluate, epoch_time, print_examples, predict_test, print_scores
from utils import get_test_data

%matplotlib inline

In [4]:
# making our results reproducable
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [5]:
#MODEL_PATH = 'models/splits.pth'
IMAGES_PATH = '../input/flickr8k/Images/'
DF_PATH = '/kaggle/working/Image-Captioning/data.csv'
TEST_DF_PATH = '/kaggle/working/Image-Captioning/test.csv'
TEST_EXAMPLES_PATH = '/kaggle/working/Image-Captioning/test_examples/'

# IMAGES_PATH = 'flickr/Images/'
# DF_PATH = 'data.csv'
# TEST_DF_PATH = 'test.csv'
# TEST_EXAMPLES_PATH = 'test_examples/'

In [6]:
vocab = build_vocab(DF_PATH)
pad_idx = vocab.stoi['<pad>']

In [7]:
HID_DIM = 256
EMB_DIM = 256
DROPOUT = .5
VOCAB_LENGTH = len(vocab)
TRAIN_CNN = False
bs = 256
lr = 3e-3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

enc = EncoderCNN(HID_DIM, DROPOUT)
dec = DecoderRNN(EMB_DIM, HID_DIM, VOCAB_LENGTH, DROPOUT)

model = Img2Seq(enc, dec, device, teacher_forcing_ratio=0).to(device)

Downloading: "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-1a9a5a14.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

In [8]:
# transforms 
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [9]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.Adam(model.parameters(), lr=lr)

# only finetune the CNN
for name, param in model.encoder.inception.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True
    else:
        param.requires_grad = TRAIN_CNN

In [10]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 5,688,429 trainable parameters


In [11]:
train_loader, valid_loader = get_loaders(bs, IMAGES_PATH, DF_PATH, transform, vocab)

Dataset split: train
Unique Image: 6000
Size: 30000
Dataset split: val
Unique Image: 1000
Size: 5000


In [12]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for idx, (imgs, captions) in tqdm(enumerate(iterator), total=len(iterator), position=0, leave=False, desc="training"):
        
        optimizer.zero_grad()
        
        imgs = imgs.to(model.device)
        captions = captions.to(model.device)
        
        outputs = model(imgs, captions)
        
        #print(outputs.shape, captions.shape)
        
        #output = [trg len, batch size, output dim]
        loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
            )
        

        loss.backward()
        
        # clip the grads
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)


def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
        for i, (images, captions) in tqdm(enumerate(iterator), total=len(iterator), position=0, leave=False, desc="Evaluating"):
            
            images = images.to(model.device)
            captions = captions.to(model.device)
            
            outputs = model(images, captions)
            #output = [trg len, batch size, output dim]
            
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
            )
            
            epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)



In [13]:
N_EPOCHS = 15
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_loader, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        print('Model Saved!!')
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'gru_no_tf.pth')
        
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

                                                           

Model Saved!!
Epoch: 01 | Time: 4m 42s
	Train Loss: 5.340 | Train PPL: 208.494
	 Val. Loss: 4.932 |  Val. PPL: 138.696


                                                           

Model Saved!!
Epoch: 02 | Time: 4m 23s
	Train Loss: 4.982 | Train PPL: 145.708
	 Val. Loss: 4.858 |  Val. PPL: 128.824


                                                           

Model Saved!!
Epoch: 03 | Time: 4m 22s
	Train Loss: 4.918 | Train PPL: 136.796
	 Val. Loss: 4.827 |  Val. PPL: 124.824


                                                           

Model Saved!!
Epoch: 04 | Time: 4m 21s
	Train Loss: 4.879 | Train PPL: 131.509
	 Val. Loss: 4.811 |  Val. PPL: 122.793


                                                           

Model Saved!!
Epoch: 05 | Time: 4m 21s
	Train Loss: 4.854 | Train PPL: 128.282
	 Val. Loss: 4.806 |  Val. PPL: 122.256


                                                           

Model Saved!!
Epoch: 06 | Time: 4m 21s
	Train Loss: 4.833 | Train PPL: 125.543
	 Val. Loss: 4.789 |  Val. PPL: 120.227


                                                           

Model Saved!!
Epoch: 07 | Time: 4m 22s
	Train Loss: 4.819 | Train PPL: 123.823
	 Val. Loss: 4.784 |  Val. PPL: 119.596


                                                           

Model Saved!!
Epoch: 08 | Time: 4m 20s
	Train Loss: 4.810 | Train PPL: 122.698
	 Val. Loss: 4.781 |  Val. PPL: 119.277


                                                           

Model Saved!!
Epoch: 09 | Time: 4m 22s
	Train Loss: 4.803 | Train PPL: 121.856
	 Val. Loss: 4.776 |  Val. PPL: 118.622


                                                           

Epoch: 10 | Time: 4m 23s
	Train Loss: 4.793 | Train PPL: 120.657
	 Val. Loss: 4.776 |  Val. PPL: 118.686


                                                           

Epoch: 11 | Time: 4m 21s
	Train Loss: 4.786 | Train PPL: 119.782
	 Val. Loss: 4.784 |  Val. PPL: 119.639


                                                           

Model Saved!!
Epoch: 12 | Time: 4m 22s
	Train Loss: 4.779 | Train PPL: 118.930
	 Val. Loss: 4.776 |  Val. PPL: 118.621


                                                           

Epoch: 13 | Time: 4m 23s
	Train Loss: 4.775 | Train PPL: 118.532
	 Val. Loss: 4.781 |  Val. PPL: 119.277


                                                         

KeyboardInterrupt: 

### download the model 
<a href='./gru_no_tf.pth'>download model</a>