In [1]:
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from collections import Counter 
import math
import random
from torchvision import transforms

In [20]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [2]:
max_seq_len = 33

In [3]:
df = pd.read_csv("./captions.txt", sep=',')

In [4]:
valid = df.iloc[int(0.9*len(df)):]

In [5]:
unq_valid_imgs = valid[['image']].drop_duplicates()

In [21]:
# def generate_caption(K, img_nm): 
#     img_loc = './Images/'+str(img_nm)
#     image = Image.open(img_loc).convert("RGB")
#     # plt.imshow(image)

#     model.eval() 
#     valid_img_df = valid[valid['image']==img_nm]
#     print("Actual Caption : ")
#     print(valid_img_df['caption'].tolist())
#     img_embed = valid_img_embed[img_nm].to(device)


#     img_embed = img_embed.permute(0,2,3,1)
#     img_embed = img_embed.view(img_embed.size(0), -1, img_embed.size(3))


#     input_seq = [pad_token]*max_seq_len
#     input_seq[0] = start_token

#     input_seq = torch.tensor(input_seq).unsqueeze(0).to(device)
#     predicted_sentence = []
#     with torch.no_grad():
#         for eval_iter in range(0, max_seq_len):

#             output, padding_mask = model.forward(img_embed, input_seq)

#             output = output[eval_iter, 0, :]

#             values = torch.topk(output, K).values.tolist()
#             indices = torch.topk(output, K).indices.tolist()

#             next_word_index = random.choices(indices, values, k = 1)[0]

#             next_word = index_to_word[next_word_index]

#             input_seq[:, eval_iter+1] = next_word_index


#             if next_word == '<end>' :
#                 break

#             predicted_sentence.append(next_word)
#     print("\n")
#     print("Predicted caption : ")
#     a = " ".join(predicted_sentence+['.'])
#     print(a)
#     return a


In [22]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=max_seq_len):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        

    def forward(self, x):
        if self.pe.size(0) < x.size(0):
            self.pe = self.pe.repeat(x.size(0), 1, 1).to(device)
        self.pe = self.pe[:x.size(0), : , : ]
        
        x = x + self.pe
        return self.dropout(x)

In [23]:
class ImageCaptionModel(nn.Module):
    def __init__(self, n_head, n_decoder_layer, vocab_size, embedding_size):
        super(ImageCaptionModel, self).__init__()
        self.pos_encoder = PositionalEncoding(embedding_size, 0.1)
        self.TransformerDecoderLayer = nn.TransformerDecoderLayer(d_model =  embedding_size, nhead = n_head)
        self.TransformerDecoder = nn.TransformerDecoder(decoder_layer = self.TransformerDecoderLayer, num_layers = n_decoder_layer)
        self.embedding_size = embedding_size
        self.embedding = nn.Embedding(vocab_size , embedding_size)
        self.last_linear_layer = nn.Linear(embedding_size, vocab_size)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.last_linear_layer.bias.data.zero_()
        self.last_linear_layer.weight.data.uniform_(-initrange, initrange)

    def generate_Mask(self, size, decoder_inp):
        decoder_input_mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        decoder_input_mask = decoder_input_mask.float().masked_fill(decoder_input_mask == 0, float('-inf')).masked_fill(decoder_input_mask == 1, float(0.0))

        decoder_input_pad_mask = decoder_inp.float().masked_fill(decoder_inp == 0, float(0.0)).masked_fill(decoder_inp > 0, float(1.0))
        decoder_input_pad_mask_bool = decoder_inp == 0

        return decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool

    def forward(self, encoded_image, decoder_inp):
        encoded_image = encoded_image.permute(1,0,2)
        

        decoder_inp_embed = self.embedding(decoder_inp)* math.sqrt(self.embedding_size)
        
        decoder_inp_embed = self.pos_encoder(decoder_inp_embed)
        decoder_inp_embed = decoder_inp_embed.permute(1,0,2)
        

        decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool = self.generate_Mask(decoder_inp.size(1), decoder_inp)
        decoder_input_mask = decoder_input_mask.to(device)
        decoder_input_pad_mask = decoder_input_pad_mask.to(device)
        decoder_input_pad_mask_bool = decoder_input_pad_mask_bool.to(device)
        

        decoder_output = self.TransformerDecoder(tgt = decoder_inp_embed, memory = encoded_image, tgt_mask = decoder_input_mask, tgt_key_padding_mask = decoder_input_pad_mask_bool)
        
        final_output = self.last_linear_layer(decoder_output)

        return final_output,  decoder_input_pad_mask

In [24]:
#删除掉长度小于1的单词
def remove_single_char_word(word_list):
    lst = []
    for word in word_list:
        if len(word)>1:
            lst.append(word)

    return lst
df = pd.read_csv("./captions.txt", sep=',')
print(len(df))
display(df.head(3))
#数据清洗，增加起始符和结束符
df['cleaned_caption'] = df['caption'].apply(lambda caption : ['<start>'] + [word.lower() if word.isalpha() else '' for word in caption.split(" ")] + ['<end>'])
df['cleaned_caption']  = df['cleaned_caption'].apply(lambda x : remove_single_char_word(x))
df['seq_len'] = df['cleaned_caption'].apply(lambda x : len(x))
max_seq_len = df['seq_len'].max()
print(max_seq_len)
#序列填充
df.drop(['seq_len'], axis = 1, inplace = True)
df['cleaned_caption'] = df['cleaned_caption'].apply(lambda caption : caption + ['<pad>']*(max_seq_len-len(caption)) )

40455


Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .


33


In [25]:
word_list = df['cleaned_caption'].apply(lambda x : " ".join(x)).str.cat(sep = ' ').split(' ')
word_dict = Counter(word_list)
word_dict =  sorted(word_dict, key=word_dict.get, reverse=True)

In [26]:
word_to_index = {word: index for index, word in enumerate(word_dict)}

In [27]:
print(word_dict[:5])

['<pad>', '<start>', '<end>', 'in', 'the']


In [35]:
model = torch.load('./BestModel')
model = model.to(device)
start_token = word_to_index['<start>']
end_token = word_to_index['<end>']
pad_token = word_to_index['<pad>']
max_seq_len = 33
print(start_token, end_token, pad_token)

  model = torch.load('./BestModel')


1 2 0


In [36]:
valid_img_embed = pd.read_pickle('EncodedImageValidResNet.pkl')

In [37]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [38]:
index_to_word = {index: word for index, word in enumerate(word_dict)}

In [39]:
unq_valid_imgs.iloc[50]['image']

'449287870_f17fb825d7.jpg'

In [40]:
def generate_caption(K, img_nm): 
    img_loc = './Images/'+str(img_nm)
    image = Image.open(img_loc).convert("RGB")
    # plt.imshow(image)

    model.eval() 
    img_embed = valid_img_embed[img_nm].to(device)


    img_embed = img_embed.permute(0,2,3,1)
    img_embed = img_embed.view(img_embed.size(0), -1, img_embed.size(3))


    input_seq = [pad_token]*max_seq_len
    input_seq[0] = start_token

    input_seq = torch.tensor(input_seq).unsqueeze(0).to(device)
    predicted_sentence = []
    with torch.no_grad():
        for eval_iter in range(0, max_seq_len):

            output, padding_mask = model.forward(img_embed, input_seq)

            output = output[eval_iter, 0, :]

            values = torch.topk(output, K).values.tolist()
            indices = torch.topk(output, K).indices.tolist()

            next_word_index = random.choices(indices, values, k = 1)[0]

            next_word = index_to_word[next_word_index]

            input_seq[:, eval_iter+1] = next_word_index


            if next_word == '<end>' :
                break

            predicted_sentence.append(next_word)
    print("\n")
    print("Predicted caption : ")
    a = " ".join(predicted_sentence+['.'])
    print(a)
    return a


In [41]:
a = generate_caption(1, unq_valid_imgs.iloc[50]['image'])



Predicted caption : 
little girl in pink shirt is jumping on playground .




'little girl in pink shirt is jumping on playground .'