In [1]:
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
# import torchvision.transforms as transforms
from data.load_datasets import *
from experiment.utils import *
from nltk.translate.bleu_score import corpus_bleu
import torch.nn.functional as F
from tqdm import tqdm
from nlgeval import NLGEval
from transformers import (WEIGHTS_NAME, BertConfig,
                                  BertForSequenceClassification, BertTokenizer,
                                  )

In [None]:
# Parameters
data_folder = 'preprocessed_dataset'  # folder with data files saved by create_input_files.py
data_name = 'preprocessed_coco'  # base name shared by data files
checkpoint_file = 'BEST_34checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar'  # model checkpoint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Load model
torch.nn.Module.dump_patches = True
checkpoint = torch.load(checkpoint_file, map_location = device)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()

nlgeval = NLGEval()  # loads the evaluator

#BERT Tokenizer
model_name_or_path = "bert-base-uncased" 
tokenizer_class = BertTokenizer
tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case = True)


#token to id
#tokenizer._convert_token_to_id(token)

#id to token
#tokenizer._convert_id_to_token(index)

#word_map_file
#word_map
#rev_word_map

BERT_VOCA_SIZE = 30522
vocab_size = BERT_VOCA_SIZE #len(word_map)

In [None]:
"""
Evaluation
:param beam_size: beam size at which to generate captions for evaluation
:return: Official MSCOCO evaluator scores - bleu4, cider, rouge, meteor
"""
# DataLoader
loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'TEST'),
    batch_size=1, shuffle=True, num_workers=1, pin_memory=torch.cuda.is_available())

# Lists to store references (true captions), and hypothesis (prediction) for each image
# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
references = list()
hypotheses = list()

CLS_IDX = 101
SEP_IDX = 102
PAD_IDX = 0
start_idx = CLS_IDX
end_idx = SEP_IDX

# For each image
for i, (image_features, caps, caplens, allcaps) in enumerate(
        tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):

    k = beam_size

    # Move to GPU device, if available
    image_features = image_features.to(device)  # (1, 3, 256, 256)
    image_features_mean = image_features.mean(1)
    image_features_mean = image_features_mean.expand(k,2048)

    # Tensor to store top k previous words at each step; now they're just <start>
    CLS_IDX = 101
    start_idx = CLS_IDX
    k_prev_words = torch.LongTensor([[start_idx]] * k).to(device)  # (k, 1)

    # Tensor to store top k sequences; now they're just <start>
    seqs = k_prev_words  # (k, 1)

    # Tensor to store top k sequences' scores; now they're just 0
    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

    # Lists to store completed sequences and scores
    complete_seqs = list()
    complete_seqs_scores = list()

    # Start decoding
    step = 1
    h1, c1 = decoder.init_hidden_state(k)  # (batch_size, decoder_dim)
    h2, c2 = decoder.init_hidden_state(k)

    # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
    while True:

        embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)
        h1,c1 = decoder.top_down_attention(
            torch.cat([h2,image_features_mean,embeddings], dim=1),
            (h1,c1))  # (batch_size_t, decoder_dim)
        attention_weighted_encoding = decoder.attention(image_features,h1)
        h2,c2 = decoder.language_model(
            torch.cat([attention_weighted_encoding,h1], dim=1),(h2,c2))

        scores = decoder.fc(h2)  # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        # Add
        scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

        # For the first step, all k points will have the same scores (since same k previous words, h, c)
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

        # Convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words / vocab_size  # (s)
        next_word_inds = top_k_words % vocab_size  # (s)

        # Add new words to sequences
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

        # Which sequences are incomplete (didn't reach <end>)?
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                           next_word != end_idx]
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        # Set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
        k -= len(complete_inds)  # reduce beam length accordingly

        # Proceed with incomplete sequences
        if k == 0:
            break
        seqs = seqs[incomplete_inds]
        h1 = h1[prev_word_inds[incomplete_inds]]
        c1 = c1[prev_word_inds[incomplete_inds]]
        h2 = h2[prev_word_inds[incomplete_inds]]
        c2 = c2[prev_word_inds[incomplete_inds]]
        image_features_mean = image_features_mean[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

        # Break if things have been going on too long
        if step > 50:
            break
        step += 1

    i = complete_seqs_scores.index(max(complete_seqs_scores))
    seq = complete_seqs[i]

#     CLS_IDX = 101
# SEP_IDX = 102
# PAD_IDX = 0
# start_idx = CLS_IDX
# end_idx = SEP_IDX
#tokenizer._convert_id_to_token(102)
    # References
    #원래꺼가 [[]]이런식으로 되어있었나?, 변경 > [[],[],[],[]]
    img_caps = allcaps[0].tolist()
    img_captions = list(
        map(lambda c: [tokenizer._convert_id_to_token(w) for w in c if w not in {start_idx, end_idx, PAD_IDX}],
            img_caps))  # remove <start> and pads
    img_caps = [' '.join(c) for c in img_captions]
    #print(img_caps)
    references.append(img_caps)

    # Hypotheses
    hypothesis = ([tokenizer._convert_id_to_token(w) for w in seq if w not in {start_idx, end_idx, PAD_IDX}])
    hypothesis = ' '.join(hypothesis)
    #print(hypothesis)
    hypotheses.append(hypothesis)
    assert len(references) == len(hypotheses)

# Calculate scores
metrics_dict = nlgeval.compute_metrics(references, hypotheses)
print(metrics_dict)