In [4]:
import re
import shutil
import os
import random
import pickle
import pandas as pd
import numpy as np
import spacy
import re

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms


from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
#from torch.nn.utils.rnn import pad_sequence

from collections import Counter
import nltk


import skimage.io
import skimage.transform

## preprocess

### select_7k_images

In [17]:
def select_7k_images(c_type='humor'):
    '''8k -> 7k'''
    # open data/type/train.p
    img_lst = pickle.load(open( "data/FlickrStyle_v0.9/humor/train.p", "rb" ) )
    
    # copy imgs
    for img_name in img_lst:
        shutil.copyfile('data/Flicker8k_Dataset/' + img_name,
                        'data/Flickr7k/' + img_name)



In [None]:
# Run for first time 
select_7k_images(c_type='humor')

### select_factual_captions

In [5]:
flickr8k_filename = "data/Flickr8k_text/Flickr8k.token.txt"

In [6]:
img_id_lst = pickle.load(open( "data/FlickrStyle_v0.9/humor/train.p", "rb" ) )

In [7]:
# get filenames in flickr7k_images
filenames = os.listdir('data/Flickr7k/')
# open factual caption: Flickr8k.token.txt
with open(flickr8k_filename, 'r') as f:
    res = f.readlines()

# write out
with open('data/factual_train.txt', 'w') as f:
    r = re.compile(r'#\d*')
    for line in res:
        img_id = r.split(line)[0]
        if img_id in img_id_lst:
            f.write(line)

### random_select_test_images

In [38]:

def random_select_test_images(num=100):
    '''select test images randomly'''
    # get filenames in flickr7k, 30k_images
    filenames_7k = os.listdir('data/Flickr7k/')
    filenames_30k = os.listdir('data/Flicker8k_Dataset')

    filenames = list(set(filenames_30k) - set(filenames_7k))
    print("img_num: " + str(len(filenames)))
    random.seed(24)
    selected = random.sample(filenames, num)

    # copy images
    for img_name in selected:
        shutil.copyfile('data/Flicker8k_Dataset/' + img_name,
                        'data/test_images/' + img_name)

In [8]:
# Run for first time
random_select_test_images(num=100)

## Build vocab

In [9]:
class Vocab:
    '''vocabulary'''
    def __init__(self):
        self.w2i = {}
        self.i2w = {}
        self.ix = 0

    def add_word(self, word):
        if word not in self.w2i:
            self.w2i[word] = self.ix
            self.i2w[self.ix] = word
            self.ix += 1

    def __call__(self, word):
        if word not in self.w2i:
            return self.w2i['<unk>']
        return self.w2i[word]

    def __len__(self):
        return len(self.w2i)

In [10]:
def build_vocab(mode_list=['factual', 'humorous']):
    '''build vocabulary'''
    # define vocabulary
    vocab = Vocab()
    # add special tokens
    vocab.add_word('<pad>')
    vocab.add_word('<s>')
    vocab.add_word('</s>')
    vocab.add_word('<unk>')

    # add words
    for mode in mode_list:
        if mode == 'factual':
            captions = extract_captions(mode=mode)
            words = nltk.tokenize.word_tokenize(captions)
            counter = Counter(words)
            words = [word for word, cnt in counter.items() if cnt >= 2]
        else:
            captions = extract_captions(mode=mode)
            words = nltk.tokenize.word_tokenize(captions)

        for word in words:
            vocab.add_word(word)

    return vocab

In [11]:
with open("data/factual_train.txt", 'r') as f:
    res = f.readlines()

In [12]:
def extract_captions(mode='factual'):
    '''extract captions from data files for building vocabulary'''
    text = ''
    if mode == 'factual':
        with open("data/factual_train.txt", 'r') as f:
            res = f.readlines()

    elif mode == 'humorous':
        with open("data/FlickrStyle_v0.9/humor/funny_train.txt", 'r') as f:
            res = f.readlines()
    else:
        with open("data/FlickrStyle_v0.9/romantic/romantic_train.txt", 'r') as f:
            res = f.readlines()

    for line in res:
        line = line.replace('.', '')
        line = line.strip()
        text += line + ' '

    return text.strip().lower()

In [13]:
vocab = build_vocab(mode_list=['factual', 'humorous'])
print(vocab.__len__())
with open('data/vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)

14889


## Data Loader

In [14]:

class Flickr7kDataset(Dataset):
    '''Flickr7k dataset'''
    def __init__(self, img_dir, caption_file, vocab, transform=None):
        '''
        Args:
            img_dir: Direcutory with all the images
            caption_file: Path to the factual caption file
            vocab: Vocab instance
            transform: Optional transform to be applied
        '''
        self.img_dir = img_dir
        self.imgname_caption_list = self._get_imgname_and_caption(caption_file)
        self.vocab = vocab
        self.transform = transform

    def _get_imgname_and_caption(self, caption_file):
        '''extract image name and caption from factual caption file'''
        with open(caption_file, 'r') as f:
            res = f.readlines()

        imgname_caption_list = []
        r = re.compile(r'#\d*')
        for line in res:
            img_and_cap = r.split(line)
            img_and_cap = [x.strip() for x in img_and_cap]
            imgname_caption_list.append(img_and_cap)

        return imgname_caption_list

    def __len__(self):
        return len(self.imgname_caption_list)

    def __getitem__(self, ix):
        '''return one data pair (image and captioin)'''
        img_name = self.imgname_caption_list[ix][0]
        img_name = os.path.join(self.img_dir, img_name)
        caption = self.imgname_caption_list[ix][1]

        image = skimage.io.imread(img_name)
        if self.transform is not None:
            image = self.transform(image)

        # convert caption to word ids
        r = re.compile("\.")
        tokens = nltk.tokenize.word_tokenize(r.sub("", caption).lower())
        caption = []
        caption.append(self.vocab('<s>'))
        caption.extend([self.vocab(token) for token in tokens])
        caption.append(self.vocab('</s>'))
        caption = torch.Tensor(caption)
        return image, caption

In [15]:
def collate_fn(data):
    '''create minibatch tensors from data(list of tuple(image, caption))'''
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # images : tuple of 3D tensor -> 4D tensor
    images = torch.stack(images, 0)

    # captions : tuple of 1D Tensor -> 2D tensor
    lengths = torch.LongTensor([len(cap) for cap in captions])
    captions = [pad_sequence_dl(cap, max(lengths)) for cap in captions]
    captions = torch.stack(captions, 0)

    return images, captions, lengths

In [16]:
def pad_sequence_dl(seq, max_len):
    seq = torch.cat((seq, torch.zeros(max_len - len(seq))))
    return seq

In [17]:
def get_data_loader(img_dir, caption_file, vocab, batch_size,
                    transform=None, shuffle=False, num_workers=0):
    '''Return data_loader'''
    if transform is None:
        transform = transforms.Compose([
            Rescale((224, 224)),
            transforms.ToTensor()
            ])

    flickr7k = Flickr7kDataset(img_dir, caption_file, vocab, transform)

    data_loader = DataLoader(dataset=flickr7k,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers,
                             collate_fn=collate_fn)
    return data_loader

In [18]:
class FlickrStyle7kDataset(Dataset):
    '''Styled caption dataset'''
    def __init__(self, caption_file, vocab):
        '''
        Args:
            caption_file: Path to styled caption file
            vocab: Vocab instance
        '''
        self.caption_list = self._get_caption(caption_file)
        self.vocab = vocab

    def _get_caption(self, caption_file):
        '''extract caption list from styled caption file'''
        with open(caption_file, 'r') as f:
            caption_list = f.readlines()

        caption_list = [x.strip() for x in caption_list]
        return caption_list

    def __len__(self):
        return len(self.caption_list)

    def __getitem__(self, ix):
        caption = self.caption_list[ix]
        # convert caption to word ids
        r = re.compile("\.")
        tokens = nltk.tokenize.word_tokenize(r.sub("", caption).lower())
        caption = []
        caption.append(self.vocab('<s>'))
        caption.extend([self.vocab(token) for token in tokens])
        caption.append(self.vocab('</s>'))
        caption = torch.Tensor(caption)
        return caption

In [19]:
def collate_fn_styled(captions):
    captions.sort(key=lambda x: len(x), reverse=True)

    # tuple of 1D Tensor -> 2D Tensor
    lengths = torch.LongTensor([len(cap) for cap in captions])
    captions = [pad_sequence_dl(cap, max(lengths)) for cap in captions]
    captions = torch.stack(captions, 0)

    return captions, lengths


In [20]:
def get_styled_data_loader(caption_file, vocab, batch_size,
                           shuffle=False, num_workers=0):
    '''Return data_loader for styled caption'''
    flickr_styled_7k = FlickrStyle7kDataset(caption_file, vocab)

    data_loader = DataLoader(dataset=flickr_styled_7k,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=num_workers,
                             collate_fn=collate_fn_styled)
    return data_loader


In [21]:
class Rescale:
    '''Rescale the image to a given size
    Args:
        output_size(int or tuple)
    '''
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, image):
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        image = skimage.transform.resize(image, (new_h, new_w))

        return image


In [22]:
with open("data/vocab.pkl", 'rb') as f:
    vocab = pickle.load(f)

img_path = "data/Flickr7k"
cap_path = "data/factual_train.txt"
cap_path_styled = "data/FlickrStyle_v0.9/humor/funny_train.txt"
data_loader = get_data_loader(img_path, cap_path, vocab, 3)
styled_data_loader = get_styled_data_loader(cap_path_styled, vocab, 3)


In [None]:
for i, (captions, lengths) in enumerate(styled_data_loader):
    print(i)
    # print(images.shape)
    print(captions[:, 1:])
    print(lengths - 1)
    print()
    if i == 3:
        break

## Loss

In [27]:
def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    batch_size = sequence_length.size(0)
    seq_range = torch.range(0, max_len - 1).long()
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (sequence_length.unsqueeze(1)
                         .expand_as(seq_range_expand))
    return seq_range_expand < seq_length_expand


def masked_cross_entropy(logits, target, length):
    length = Variable(length)
    if torch.cuda.is_available():
        length = length.cuda()

    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = F.log_softmax(logits_flat)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss


In [28]:

length = torch.LongTensor([23, 21, 17])

print(sequence_mask(length))

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False, False, False]])


  """


## Model

In [29]:
import sys
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torch.autograd import Variable
#from constant import get_symbol_id

In [30]:
class EncoderCNN(nn.Module):
    def __init__(self, emb_dim):
        '''
        Load the pretrained ResNet152 and replace fc
        '''
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.A = nn.Linear(resnet.fc.in_features, emb_dim)

    def forward(self, images):
        '''Extract the image feature vectors'''
        features = self.resnet(images)
        features = Variable(features.data)
        if torch.cuda.is_available():
            features = features.cuda()
        features = features.view(features.size(0), -1)
        features = self.A(features)
        return features

In [31]:
class FactoredLSTM(nn.Module):
    def __init__(self, emb_dim, hidden_dim, factored_dim,  vocab_size):
        super(FactoredLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size

        # embedding
        self.B = nn.Embedding(vocab_size, emb_dim)

        # factored lstm weights
        self.U_i = nn.Linear(factored_dim, hidden_dim)
        self.S_fi = nn.Linear(factored_dim, factored_dim)
        self.V_i = nn.Linear(emb_dim, factored_dim)
        self.W_i = nn.Linear(hidden_dim, hidden_dim)

        self.U_f = nn.Linear(factored_dim, hidden_dim)
        self.S_ff = nn.Linear(factored_dim, factored_dim)
        self.V_f = nn.Linear(emb_dim, factored_dim)
        self.W_f = nn.Linear(hidden_dim, hidden_dim)

        self.U_o = nn.Linear(factored_dim, hidden_dim)
        self.S_fo = nn.Linear(factored_dim, factored_dim)
        self.V_o = nn.Linear(emb_dim, factored_dim)
        self.W_o = nn.Linear(hidden_dim, hidden_dim)

        self.U_c = nn.Linear(factored_dim, hidden_dim)
        self.S_fc = nn.Linear(factored_dim, factored_dim)
        self.V_c = nn.Linear(emb_dim, factored_dim)
        self.W_c = nn.Linear(hidden_dim, hidden_dim)

        # h - humorous
        self.S_hi = nn.Linear(factored_dim, factored_dim)
        self.S_hf = nn.Linear(factored_dim, factored_dim)
        self.S_ho = nn.Linear(factored_dim, factored_dim)
        self.S_hc = nn.Linear(factored_dim, factored_dim)

        # r - romantic
        # self.S_ri = nn.Linear(factored_dim, factored_dim)
        # self.S_rf = nn.Linear(factored_dim, factored_dim)
        # self.S_ro = nn.Linear(factored_dim, factored_dim)
        # self.S_rc = nn.Linear(factored_dim, factored_dim)

        # weight for output
        self.C = nn.Linear(hidden_dim, vocab_size)

    def forward_step(self, embedded, h_0, c_0, mode):
        # transform embedded from emb_dim --> factored_dim
        i = self.V_i(embedded)
        f = self.V_f(embedded)
        o = self.V_o(embedded)
        c = self.V_c(embedded)
        
        # factored_dim --> factored_dim
        if mode == "factual":
            i = self.S_fi(i)
            f = self.S_ff(f)
            o = self.S_fo(o)
            c = self.S_fc(c)
        elif mode == "humorous":
            i = self.S_hi(i)
            f = self.S_hf(f)
            o = self.S_ho(o)
            c = self.S_hc(c)
        # elif mode == "romantic":
        #     i = self.S_ri(i)
        #     f = self.S_rf(f)
        #     o = self.S_ro(o)
        #     c = self.S_rc(c)
        else:
            sys.stderr.write("mode name wrong!")

        i_t = F.sigmoid(self.U_i(i.double()) + self.W_i(h_0.double()))
        f_t = F.sigmoid(self.U_f(f.double()) + self.W_f(h_0.double()))
        o_t = F.sigmoid(self.U_o(o.double()) + self.W_o(h_0.double()))
        c_tilda = F.tanh(self.U_c(c.double()) + self.W_c(h_0.double()))

        c_t = f_t * c_0 + i_t * c_tilda
        h_t = o_t * c_t

        outputs = self.C(h_t)

        return outputs, h_t, c_t

    def forward(self, captions, features=None, mode="factual"):
        '''
        Args:
            features: fixed vectors from images, [batch, emb_dim]
            captions: [batch, max_len]
            mode: type of caption to generate
        '''
        batch_size = captions.size(0)
        embedded = self.B(captions)  # [batch, max_len, emb_dim]
        # concat image features and captions
        if mode == "factual":
            if features is None:
                sys.stderr.write("features is None!")
            embedded = torch.cat((features.unsqueeze(1), embedded), 1)

        # initialize hidden state
        h_t = Variable(torch.Tensor(batch_size, self.hidden_dim))
        c_t = Variable(torch.Tensor(batch_size, self.hidden_dim))
        nn.init.uniform(h_t)
        nn.init.uniform(c_t)

        if torch.cuda.is_available():
            h_t = h_t.cuda()
            c_t = c_t.cuda()

        all_outputs = []
        # iterate
        for ix in range(embedded.size(1) - 1):
            emb = embedded[:, ix, :]
            outputs, h_t, c_t = self.forward_step(emb, h_t, c_t, mode=mode)
            all_outputs.append(outputs)

        all_outputs = torch.stack(all_outputs, 1)

        return all_outputs


In [32]:
import os
import pickle
import argparse
import torch


In [33]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

def eval_outputs(outputs, vocab):
    # outputs: [batch, max_len - 1, vocab_size]
    indices = torch.topk(outputs, 1)[1]
    indices = indices.squeeze(2)
    indices = indices.data
    for i in range(len(indices)):
        caption = [vocab.i2w[x] for x in indices[i]]
        print(caption)

In [None]:
batch_size = 50
data_loader = get_data_loader(img_path, cap_path, vocab, batch_size)
styled_data_loader = get_styled_data_loader(cap_path_styled, vocab, batch_size)

In [34]:
emb_dim = 300
hidden_dim = 512
factored_dim = 512
vocab_size = len(vocab)

In [35]:
vocab_size

14889

In [36]:
encoder = EncoderCNN(emb_dim)
decoder = FactoredLSTM(emb_dim, hidden_dim, factored_dim, vocab_size)

In [37]:
if torch.cuda.is_available():
    encoder = encoder.cuda()
    decoder = decoder.cuda()

In [38]:
lr_caption = 0.0002
lr_language = 0.0005

In [39]:
# loss and optimizer
criterion = masked_cross_entropy
cap_params = list(decoder.parameters()) + list(encoder.A.parameters())
lang_params = list(decoder.parameters())
optimizer_cap = torch.optim.Adam(cap_params, lr=lr_caption)
optimizer_lang = torch.optim.Adam(lang_params, lr=lr_language)

In [40]:
# train
total_cap_step = len(data_loader)
total_lang_step = len(styled_data_loader)
epoch_num = 1

In [41]:
model_path = 'pretrained_models'
if not os.path.exists(model_path):
    os.makedirs(model_path)

In [42]:
log_step_caption = 50
log_step_language = 10

In [43]:
encoder = encoder.double()
decoder = decoder.double()

In [44]:
def eval_outputs(outputs, vocab):
    # outputs: [batch, max_len - 1, vocab_size]
    indices = torch.topk(outputs, 1)[1]
    indices = indices.squeeze(2)
    indices = indices.data
    for i in range(len(indices)):
        caption = [vocab.i2w[x.item()] for x in indices[i]]
        print(caption)

In [45]:
for epoch in range(epoch_num):
    # caption
    for i, (images, captions, lengths) in enumerate(data_loader):
#         images = to_var(images, volatile=True)
#         captions = to_var(captions.long())
        if torch.cuda.is_available():
            images = images.cuda()
            captions = captions.cuda()
            
        # forward, backward and optimize
        decoder.zero_grad()
        encoder.zero_grad()
        features = encoder(images)
        outputs = decoder(captions.long(), features.double(), mode="factual")
        loss = criterion(outputs[:, 1:, :].contiguous(),
                         captions[:, 1:].contiguous().long(), lengths - 1)
        loss.backward()
        optimizer_cap.step()
        
        print("Epoch [%d/%d], CAP, Step [%d/%d], Loss: %.4f"
                  % (epoch+1, epoch_num, i, total_cap_step,
                      loss.data.mean()))
        
        #if i == 3: break

        # print log
        if i % log_step_caption == 0:
            print("Epoch [%d/%d], CAP, Step [%d/%d], Loss: %.4f"
                  % (epoch+1, epoch_num, i, total_cap_step,
                      loss.data.mean()))

    eval_outputs(outputs, vocab)

    # language
    for i, (captions, lengths) in enumerate(styled_data_loader):
        #captions = to_var(captions.long())
        if torch.cuda.is_available():
            captions = captions.cuda()

        # forward, backward and optimize
        decoder.zero_grad()
        outputs = decoder(captions.long(), mode='humorous')
        loss = criterion(outputs, captions[:, 1:].contiguous().long(), lengths-1)
        loss.backward()
        optimizer_lang.step()

        # print log
        if i % log_step_language == 0:
            print("Epoch [%d/%d], LANG, Step [%d/%d], Loss: %.4f"
                  % (epoch+1, epoch_num, i, total_lang_step,
                      loss.data.mean()))

#         print("Epoch [%d/%d], LANG, Step [%d/%d], Loss: %.4f"
#                   % (epoch+1, epoch_num, i, total_lang_step,
#                       loss.data.mean()))
        #if i == 3: break

    # save models
    torch.save(decoder.state_dict(),
               os.path.join(model_path, 'decoder-%d.pkl' % (epoch + 1,)))

    torch.save(encoder.state_dict(),
               os.path.join(model_path, 'encoder-%d.pkl' % (epoch + 1,)))



  """


Epoch [1/1], CAP, Step [0/700], Loss: 9.6081
Epoch [1/1], CAP, Step [0/700], Loss: 9.6081
Epoch [1/1], CAP, Step [1/700], Loss: 9.5723
Epoch [1/1], CAP, Step [2/700], Loss: 9.5490


KeyboardInterrupt: 

In [None]:
parser.add_argument('--caption_batch_size', type=int, default=64,
                        help='mini batch size for caption model training')
    parser.add_argument('--language_batch_size', type=int, default=96,
                        help='mini batch size for language model training')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding size of word, image')
    parser.add_argument('--hidden_dim', type=int, default=512,
                        help='hidden state size of factored LSTM')
    parser.add_argument('--factored_dim', type=int, default=512,
                        help='size of factored matrix')
    parser.add_argument('--lr_caption', type=int, default=0.0002,
                        help='learning rate for caption model training')
    parser.add_argument('--lr_language', type=int, default=0.0005,
                        help='learning rate for language model training')
    parser.add_argument('--epoch_num', type=int, default=30)
    parser.add_argument('--log_step_caption', type=int, default=50,
                        help='steps for print log while train caption model')
    parser.add_argument('--log_step_language', type=int, default=10,
                        help='steps for print log while train language model')