In [3]:
import numpy as np
from numpy import array
import pandas as pd
import string
import os
from PIL import Image
import glob
import ast
import pickle

import torch
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import word_tokenize, sent_tokenize
from itertools import chain
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, models

### Loading the data
Dataset: Flickr8k https://www.kaggle.com/shadabhussain/flickr8k

In [4]:
def load_doc(filename):
    file = open(filename, 'r')
    text = file.read()
    file.close()
    return text

filename = "Flickr8k_text/Flickr8k.token.txt"
doc = load_doc(filename)
print(doc[:300])

1000268201_693b08cb0e.jpg#0	A child in a pink dress is climbing up a set of stairs in an entry way .
1000268201_693b08cb0e.jpg#1	A girl going into a wooden building .
1000268201_693b08cb0e.jpg#2	A little girl climbing into a wooden playhouse .
1000268201_693b08cb0e.jpg#3	A little girl climbing the s


In [11]:
def load_descriptions(doc):
    image_id_lst = []
    caption_no_lst = []
    image_desc_lst = []
    # process lines
    for line in doc.split('\n'):
        # split line by white space
        tokens = line.split()
        if len(line) < 2:
            continue
        # take the first token as the image id, the rest as the description
        image_id_no, image_desc = tokens[0], tokens[1:]
        # extract filename from image id
        image_id  = image_id_no.split('#')[0]
        caption_no = image_id_no.split('#')[1]
        # convert description tokens back to string
        image_desc = ' '.join(image_desc)
        image_id_lst.append(image_id)
        caption_no_lst.append(caption_no)
        image_desc_lst.append(image_desc)
        df = pd.DataFrame(list(zip(image_id_lst, caption_no_lst, image_desc_lst)))
    return df

# parse descriptions
df = load_descriptions(doc)

In [13]:
df.head()

Unnamed: 0,0,1,2
0,1000268201_693b08cb0e.jpg,0,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,1,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,2,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,3,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,4,A little girl in a pink dress going into a woo...


### Creating PyTorch Dataset
The below code is based on https://github.com/Mdhvince/Image_Captioning and https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning

In [None]:
class ImageCaptionDataset(Dataset):
    """Image Caption dataset."""

    def __init__(self, df, root_dir, mapper_file, max_seq_length=20, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.max_seq_length = max_seq_length
        
        print("Reading data...")
        self.df = df
        self.captions_column = self.df['captions']
        self.img_name_column = self.df['img_name']
        
        print("Calculating length...")
        self.df['length'] = self.captions_column.apply(lambda x: len(x.split()))
        self.length_column = self.df['length']
        
        self.root_dir = root_dir
        self.transform = transform
        
        print("Reading Mapper file...")
        with open('mapping.pkl', 'rb') as f:
            self.mapper_file = pickle.load(f)
        
        print("Ready !")
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # take the image name contained in the csv file
        image_name = os.path.join(self.root_dir, self.img_name_column[idx])
        
        #image_name = os.path.join(self.root_dir,
                                  #self.df.iloc[idx, 0])
        

        # read the true image based on that name
        # choice: mpimg because done with 1 line
        # with cv2, I need to read the convert from BGR2RGB
        image = mpimg.imread(image_name)
        
        # read df & transform caption to tensor
        caption = self.captions_column[idx]
        #caption = self.df.iloc[idx, 1]
        caption = caption.lower()
        tokens = word_tokenize(caption)
                
        caption = []
        caption.append('<start>')
        caption.extend([token for token in tokens])
        caption.append('<end>')
        
        # Map to integer
        caption = [self.mapper_file[i] for i in caption]
        
        #pad sequence
        caption = self.pad_data(caption)
        
        sample = {'image': image, 'caption': caption}
        
        
        if self.transform:
            sample = self.transform(sample)

        return sample
    
    
    
    def pad_data(self, s):
        padded = np.ones((self.max_seq_length,), dtype=np.int64)*self.mapper_file['<PAD>']
        
        if len(s) > self.max_seq_length:
            padded[:] = s[:self.max_seq_length]
        else: 
            padded[:len(s)] = s
            
        return padded

In [None]:
def create_dataset(df, root_dir, mapper_file):
    with open(mapper_file, 'rb') as f:
        vocab = pickle.load(f)

    transform = transforms.Compose([
        Rescale(224),
        Normalize(),
        ToTensor()
    ])
    train_set = ImageCaptionDataset(df=df,
                                    root_dir=root_dir,
                                    mapper_file=mapper_file,
                                    transform=transform)
    return train_set, vocab


## Train Valid Split

In [None]:
def train_valid_split(training_set, validation_size):
    """ Function that split our dataset into train and validation
        given in parameter the training set and the % of sample for validation"""
    
    # obtain training indices that will be used for validation
    num_train = len(training_set)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(validation_size * num_train))
    train_idx, valid_idx = indices[split:], indices[:split]

    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
    return train_sampler, valid_sampler

In [None]:
train_sampler, valid_sampler = train_valid_split(train_set, valid_size)
train_loader = DataLoader(train_set,
                            batch_size=batch_size,
                            sampler=train_sampler,
                            num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_set,
                                            batch_size=batch_size,
                                            sampler=valid_sampler,
                                            num_workers=num_workers)

### CNN Image Encoder using ResNet

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        
        #import pre trained model
        resnet = models.resnet50(pretrained=True)
        
        for param in resnet.parameters():
            param.requires_grad_(False)
        
        # remove last fully connected layer
        modules = list(resnet.children())[:-1]
        
        # build the new resnet
        self.resnet = nn.Sequential(*modules)
        
        # our additional Fully connected layer with an output = the embbed size
        # to feed the rnn
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        #call resnet on our images
        features = self.resnet(images)
        
        #flatten for our additional fc layer
        features = features.view(features.size(0), -1)
        
        features = self.embed(features)
        
        return features #here is our spacial information extracted from the image with the right output size

### LSTM Decoder to generate captions

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super().__init__()
        
        self.hidden_dim = hidden_size
        
        # Our embedding layer
        self.word_embeddings = nn.Embedding(vocab_size, embed_size)
        
        self.lstm = nn.LSTM(embed_size, self.hidden_dim, num_layers, batch_first=True)
        
        # The linear layer maps the hidden state output of the LSTM to the number of words we want:
        # vocab_size
        self.linear = nn.Linear(self.hidden_dim, vocab_size)
        
    
    def init_hidden(self, batch_size):
        """ 
        Here we need to define h0, c0 with all zeroes in order to initialize our LSTM
        Architecture
        """
        return torch.zeros((1, batch_size, self.hidden_dim)), torch.zeros((1, batch_size, self.hidden_dim))
    
    
    def forward(self, features, captions):
        
        # Make sure that features shape are :batch_size, embed_size
        batch_size = features.shape[0]
        
        # Initialize the hidden state
        self.hidden = self.init_hidden(batch_size)
        
        # Create embedded word vectors for each word in the captions
        embeddings = self.word_embeddings(captions)
        
        # Stack the features and captions
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1) 
        
        lstm_out, self.hidden = self.lstm(embeddings, self.hidden) 
        
        out = self.linear(lstm_out)
        
        out = out[:, :-1]
        
        return out

### Training and validation

In [None]:
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Move to GPU, if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = encoder.to(device)
decoder = decoder.to(device)

criterion = nn.CrossEntropyLoss().to(device)
params = list(decoder.parameters()) + list(encoder.embed.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)

# This is to make sure that the 1st loss is  lower than sth and
# Save the model according to this comparison
valid_loss_min = np.Inf

In [None]:

for epoch in range(1, n_epochs+1):

    # Keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0

    encoder.train()
    decoder.train()
    for data in train_loader:
        images, captions = data['image'], data['caption']
        images = images.type(torch.FloatTensor)
        images.to(device)
        captions.to(device)

        decoder.zero_grad()
        encoder.zero_grad()

        features = encoder(images)
        outputs = decoder(features, captions)

        loss = criterion(outputs.contiguous().view(-1, vocab_size), captions.view(-1))
        loss.backward()  
        optimizer.step()

        train_loss += loss.item()*images.size(0)


    encoder.eval()
    decoder.eval()
    for data in valid_loader:
        images, captions = data['image'], data['caption']
        images = images.type(torch.FloatTensor)
        images.to(device)
        captions.to(device)

        features = encoder(images)
        outputs = decoder(features, captions)

        loss = criterion(outputs.contiguous().view(-1, vocab_size), captions.view(-1))

        valid_loss += loss.item()*images.size(0)

        # Average losses
        train_loss = train_loss/len(train_loader)
        valid_loss = valid_loss/len(valid_loader)

        print(f"Epoch: {epoch} \tTraining Loss: {train_loss} \tValidation Loss: {valid_loss}")

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print(f"Validation loss decreased ({valid_loss_min} --> {valid_loss}).  Saving model ...")
            torch.save(encoder.state_dict(), save_location_path+'/encoder{n_epochs}.pt')
            torch.save(decoder.state_dict(), save_location_path+'/decoder{n_epochs}.pt')
            valid_loss_min = valid_loss

