In [1]:
cd speaker_old

In [2]:
pwd

In [3]:
# BASED ON speaker_old/pretrained_speaker_TEST.py
# that checks the trained ReRef generation models on the test set

In [4]:
import torch
import numpy as np

from models.model_speaker_base import SpeakerModelBase
from models.model_speaker_hist_att import SpeakerModelHistAtt

from utils.SpeakerDataset import SpeakerDataset
from utils.Vocab import Vocab

# from evals import eval_beam_base, eval_beam_histatt
# !!! FOR DEMO PURPOSES, WE ARE USING THE METHOD eval_beam_histatt_DEMO given in this notebook

from nlgeval import NLGEval

import os

import datetime

In [5]:
def mask_attn(actual_num_tokens, max_num_tokens, device):

    masks = []

    for n in range(len(actual_num_tokens)):

        # items to be masked are TRUE
        mask = [False] * actual_num_tokens[n] + [True] * (max_num_tokens - actual_num_tokens[n])

        masks.append(mask)

    masks = torch.tensor(masks).unsqueeze(2).to(device)

    return masks

In [6]:
import torch
import json

import torch.nn as nn
import torch.nn.functional as F

import os

from bert_score import score

def eval_beam_histatt_DEMO(split_data_loader, model, args, best_score, print_gen, device,
                      beam_size, max_len, vocab, mask_attn, nlgeval_obj, isValidation, timestamp, isTest):
    """
        Evaluation

        :param beam_size: beam size at which to generate captions for evaluation
        :return: Official MSCOCO evaluator scores - bleu4, cider, rouge, meteor
        """

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

    references = []
    hypotheses = []

    count = 0

    empty_count = 0

    breaking = args.breaking

    sos_token = torch.tensor(vocab['<sos>']).to(device)
    eos_token = torch.tensor(vocab['<eos>']).to(device)

    if isValidation:
        split = 'val'
    elif isTest:
        split = 'test'
    else:
        split = 'train'

    file_name = args.model_type + '_' + args.metric + '_' + split + '_' + timestamp # overwrites previous versions!

    for i, data in enumerate(split_data_loader):
        # print(i)

        completed_sentences = []
        completed_scores = []

        beam_k = beam_size

        if breaking and count == 5:
            break

        count += 1

        # dataset details
        # only the parts I will use for this type of model

        utterance = data['utterance']  # to be decoded, we don't use this here in beam search!
        #target_utterance = utterance[:,1:]
        # I am using the one below as references for the calculation of metric scores
        orig_text_reference = data['orig_utterance'] # original reference without unk, eos, sos, pad
        reference_chain = data['reference_chain'][0] #batch size 1  # full set of references for a single instance
        # obtained from the whole chain

        prev_utterance = data['prev_utterance']
        prev_utt_lengths = data['prev_length']

        visual_context = data['concat_context']
        target_img_feats = data['target_img_feats']

        max_length_tensor = prev_utterance.shape[1]

        masks = mask_attn(prev_utt_lengths, max_length_tensor, device)

        visual_context_hid = model.relu(model.lin_viscontext(visual_context))
        target_img_hid = model.relu(model.linear_separate(target_img_feats))

        concat_visual_input = model.relu(model.linear_hid(torch.cat((visual_context_hid, target_img_hid), dim=1)))

        embeds_words = model.embedding(prev_utterance)  # b, l, d

        # pack sequence

        sorted_prev_utt_lens, sorted_idx = torch.sort(prev_utt_lengths, descending=True)
        embeds_words = embeds_words[sorted_idx]

        concat_visual_input = concat_visual_input[sorted_idx]

        # RuntimeError: Cannot pack empty tensors.
        packed_input = nn.utils.rnn.pack_padded_sequence(embeds_words, sorted_prev_utt_lens, batch_first=True)

        # start lstm with average visual context:
        # conditioned on the visual context

        # he, ce = self.init_hidden(batch_size, device)
        concat_visual_input = torch.stack((concat_visual_input, concat_visual_input), dim=0)

        packed_outputs, hidden = model.lstm_encoder(packed_input, hx=(concat_visual_input, concat_visual_input))

        # re-pad sequence
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        # already concat forward backward

        # un-sort
        _, reversed_idx = torch.sort(sorted_idx)
        outputs = outputs[reversed_idx]

        # ONLY THE HIDDEN AND OUTPUT ARE REVERSED
        # next_utterance is aligned (pre_utterance info is not)
        batch_out_hidden = hidden[0][:, reversed_idx]  # .squeeze(0)

        # start decoder with these

        # teacher forcing

        decoder_hid = model.linear_dec(torch.cat((batch_out_hidden[0], batch_out_hidden[1]), dim=1))

        history_att = model.lin2att_hist(outputs)

        decoder_hid = decoder_hid.expand(beam_k, -1)

        # multiple copies of the decoder
        h1, c1 = decoder_hid, decoder_hid

        # ***** beam search *****

        gen_len = 0

        decoder_input = sos_token.expand(beam_k, 1)  # beam_k sos copies

        gen_sentences_k = decoder_input  # all start off with sos now

        top_scores = torch.zeros(beam_k, 1).to(device)  # top-k generation scores

        while True:

            # EOS?

            if gen_len > max_len:
                break  # very long sentence generated

            # generate

            # sos segment eos
            # base model with visual input

            decoder_embeds = model.embedding(decoder_input).squeeze(1)

            h1, c1 = model.lstm_decoder(decoder_embeds, hx=(h1, c1))

            h1_att = model.lin2att_hid(h1)

            attention_out = model.attention(model.tanh(history_att + h1_att.unsqueeze(1)))

            attention_out = attention_out.masked_fill_(masks, float('-inf'))

            att_weights = model.softmax(attention_out)

            att_context_vector = (history_att * att_weights).sum(dim=1)
            
            word_pred = F.log_softmax(model.lin2voc(torch.cat((h1, att_context_vector), dim=1)), dim=1)

            word_pred = top_scores.expand_as(word_pred) + word_pred


            if gen_len == 0:
                # all same

                # static std::tuple<Tensor, Tensor> at::topk(const Tensor &self, int64_t k,
                # int64_t dim = -1, bool largest = true, bool sorted = true)

                top_scores, top_words = word_pred[0].topk(beam_k, 0, True, True)

            else:
                # unrolled
                top_scores, top_words = word_pred.view(-1).topk(beam_k, 0, True, True)

            # vocab - 1 to exclude <NOHS>
            sentence_index = top_words / (len(vocab)-1)  # which sentence it will be added to
            word_index = top_words % (len(vocab)-1)  # predicted word

            gen_len += 1

            # add the newly generated word to the sentences
            gen_sentences_k = torch.cat((gen_sentences_k[sentence_index], word_index.unsqueeze(1)), dim=1)

            # there could be incomplete sentences
            incomplete_sents_inds = [inc for inc in range(len(gen_sentences_k)) if
                                     eos_token not in gen_sentences_k[inc]]

            complete_sents_inds = list(set(range(len(word_index))) - set(incomplete_sents_inds))

            # save the completed sentences
            if len(complete_sents_inds) > 0:
                completed_sentences.extend(gen_sentences_k[complete_sents_inds].tolist())
                completed_scores.extend(top_scores[complete_sents_inds])

                beam_k -= len(complete_sents_inds)  # fewer, because we closed at least 1 beam

            if beam_k == 0:
                break

            # continue generation for the incomplete sentences
            gen_sentences_k = gen_sentences_k[incomplete_sents_inds]

            # use the ongoing hidden states of the incomplete sentences
            h1, c1 = h1[sentence_index[incomplete_sents_inds]], c1[sentence_index[incomplete_sents_inds]],

            top_scores = top_scores[incomplete_sents_inds].unsqueeze(1)
            decoder_input = word_index[incomplete_sents_inds]
            decoder_hid = decoder_hid[incomplete_sents_inds]

        if len(completed_scores) == 0:

            empty_count += 1
            #print('emptyseq', empty_count)

            # all incomplete here

            completed_sentences.extend((gen_sentences_k[incomplete_sents_inds].tolist()))
            completed_scores.extend(top_scores[incomplete_sents_inds])

        sorted_scores, sorted_indices = torch.sort(torch.tensor(completed_scores), descending=True)

        best_seq = completed_sentences[sorted_indices[0]]

        hypothesis = [vocab.index2word[w] for w in best_seq if w not in
                      [vocab.word2index['<sos>'], vocab.word2index['<eos>'], vocab.word2index['<pad>']]]
        # remove sos and pads # I want to check eos
        hypothesis_string = ' '.join(hypothesis)
        hypotheses.append(hypothesis_string)

        if not os.path.isfile('speaker_outputs/refs_' + file_name + '.json'):
            # Reference
            references.append(reference_chain)

        if print_gen:
            # Reference
            print('REF:', orig_text_reference) # single one
            print('REF chain:', reference_chain) 
            print('HYP:', hypothesis_string, '\n')

    if os.path.isfile('speaker_outputs/refs_' + file_name + '.json'):
        with open('speaker_outputs/refs_' + file_name + '.json', 'r') as f:
            references = json.load(f)
    else:
        with open('speaker_outputs/refs_' + file_name + '.json', 'w') as f:
            json.dump(references, f)
            
    # Calculate scores
    metrics_dict = nlgeval_obj.compute_metrics(references, hypotheses)
    print(metrics_dict)

    (P, R, Fs), hashname = score(hypotheses, references, lang='en', return_hash=True, model_type="bert-base-uncased")
    print(f'{hashname}: P={P.mean().item():.6f} R={R.mean().item():.6f} F={Fs.mean().item():.6f}')

    if args.metric == 'cider':
        selected_metric_score = metrics_dict['CIDEr']
        print(round(selected_metric_score, 5))

    elif args.metric == 'bert':
        selected_metric_score = Fs.mean().item()
        print(round(selected_metric_score, 5))

    # from https://github.com/Maluuba/nlg-eval
    # where references is a list of lists of ground truth reference text strings and hypothesis is a list of
    # hypothesis text strings. Each inner list in references is one set of references for the hypothesis
    # (a list of single reference strings for each sentence in hypothesis in the same order).

PyTorch version 1.6.0 available.


AttributeError: type object 'BertConfig' has no attribute 'pretrained_config_archive_map'

In [None]:
if __name__ == '__main__':
    
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    nlge = NLGEval(no_skipthoughts=True, no_glove=True)

    speaker_files = ['saved_models/model_speaker_hist_att_42_bert_2020-05-21-15-13-22.pkl']
#     ['saved_models/model_speaker_hist_att_1_bert_2020-05-22-16-40-11.pkl',
# 'saved_models/model_speaker_hist_att_2_bert_2020-05-22-16-41-12.pkl',
# 'saved_models/model_speaker_hist_att_3_bert_2020-05-22-16-42-13.pkl',
# 'saved_models/model_speaker_hist_att_4_bert_2020-05-22-16-43-13.pkl']

    for speaker_file in speaker_files:

        seed = 28

        # for reproducibility
        print(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

        print(speaker_file)

        checkpoint = torch.load(speaker_file, map_location=device)

        args = checkpoint['args']

        model_type = args.model_type

        print("Loading the vocab...")
        vocab = Vocab(os.path.join(args.data_path, args.vocab_file))
        vocab.index2word[len(vocab)] = '<nohs>'  # special token placeholder for no prev utt
        vocab.word2index['<nohs>'] = len(vocab)  # len(vocab) updated (depends on w2i)

        testset = SpeakerDataset(
            data_dir=args.data_path,
            utterances_file='test_' + args.utterances_file,
            vectors_file=args.vectors_file,
            chain_file='test_' + args.chains_file,
            orig_ref_file='test_' + args.orig_ref_file,
            split='test',
            subset_size=1 #args.subset_size
        )

        print('vocab len', len(vocab))
        print('test len', len(testset), 'longest sentence', testset.max_len)

        max_len = 30  # for beam search

        img_dim = 2048

        embedding_dim = args.embedding_dim
        hidden_dim = args.hidden_dim
        att_dim = args.attention_dim

        dropout_prob = args.dropout_prob
        beam_size = args.beam_size

        metric = args.metric

        shuffle = args.shuffle
        normalize = args.normalize
        breaking = args.breaking

        print_gen = args.print

        # depending on the selected model type, we will have a different architecture

        model = SpeakerModelHistAtt(len(vocab), embedding_dim, hidden_dim, img_dim, dropout_prob, att_dim).to(device)

        batch_size = 1

        load_params_test = {'batch_size': 1, 'shuffle': False,
                            'collate_fn': SpeakerDataset.get_collate_fn(device, vocab['<sos>'], vocab['<eos>'],
                                                                        vocab['<nohs>'])}

        test_loader = torch.utils.data.DataLoader(testset, **load_params_test)

        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(device)

        with torch.no_grad():
            model.eval()

            isValidation = False
            isTest = True
            print('\nTest Eval')

            # THIS IS test EVAL_BEAM
            print('beam')

            # best_score and timestamp not so necessary here
            best_score = checkpoint['accuracy']  # cider or bert
            t = datetime.datetime.now()
            timestamp = str(t.date()) + '-' + str(t.hour) + '-' + str(t.minute) + '-' + str(t.second)

            print_gen = True
            
            # USING THE DEMO METHOD PROVIDED ABOVE
            eval_beam_histatt_DEMO(test_loader, model, args, best_score, print_gen, device,
                                      beam_size, max_len, vocab, mask_attn, nlge, isValidation, timestamp, isTest)