In [None]:
# load MDS test dataset, get summaries for clusters, evaluate

In [None]:
import json
from pathlib import Path
import tqdm

from transformers import (modeling_utils,
                          BartTokenizer,
                          BartForConditionalGeneration,
                          BartConfig)

from transformer_decoding import decoding_utils
from transformer_decoding.test_decoding import get_start_state


In [None]:
bart_cnndm_model = BartForConditionalGeneration.from_pretrained('bart-large-cnn'),
bart_tokenizer = BartTokenizer.from_pretrained('bart-large-cnn'),

In [None]:
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example

def article_to_text(article):
    return f'{article["title"]} {article["text"]}'


def evaluate(args):

    model = args['model']
    tokenizer = args['tokenizer']
    decoding_hyperparams = {
        'max_length': args['max_length'],
        'num_beams': args['num_beams']
    }
    dataset = [json.loads(l) for l in open(args['evaluation_dataset'])]

    summaries = []
    # get summary for each cluster
    for cluster in tqdm.tqdm(dataset[:2]):
        articles = [article_to_text(a) for a in cluster['articles'][:args['max_articles']]]
        
        component_states = [get_start_state(a, model, tokenizer, decoding_hyperparams)
                            for a in articles]
        ensemble_state = get_start_state(articles[0], model, tokenizer, decoding_hyperparams)

        component_states, ensemble_state = \
            decoding_utils.generate(component_states, decoding_hyperparams['max_length'],
                                    ensemble_state=ensemble_state)
        
        #assert len(ensemble_state['input_ids']) == 1, 'We currently have batch size=1 (we decode one cluster at a time)'
        print(f'input_ids shape: {ensemble_state["input_ids"].shape}')
        print(f'Reference Summary:\n{cluster["summary"]}')
        predictions = [tokenizer.decode(input_ids,
                                        skip_special_tokens=True,
                                        clean_up_tokenization_spaces=False)
                       for input_ids in ensemble_state['input_ids']]
        print(f'Predictions: \n{predictions}')
        print()
        
        summaries.append(predictions)
    return summaries


In [None]:
evaluation_args = {
    'evaluation_dataset': '/home/chris/projects/aylien/dynamic-ensembles/data/WCEP/test.jsonl',
    'model': bart_cnndm_model,
    'tokenizer': bart_tokenizer,
    'max_length': 40,
    'num_beams': 2,
    'max_articles': 10
}

cluster_summaries = evaluate(evaluation_args)

In [None]:
cluster_summaries[0][0] == cluster_summaries[0][1]

In [None]:
predictions[0] == preditions[1]

In [None]:
cls.test_news_article_1 = 'New Zealand says it has stopped community transmission of Covid-19, ' \
                    'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \
                    '- Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned ' \
                    'against complacency, saying it does not mean a total end to new coronavirus cases. ' \
                    'The news comes hours before New Zealand is set to move out of its toughest level of social restrictions. ' \
                    'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \
                    'Most people will still be required to remain at home at all times and avoid all social interactions.'

cls.test_news_article_2 = \
    'But officials have warned against complacency, saying it does not mean a total end to new HIV cases. ' \
    'Most people will still be required to remain at home at all times and avoid all social interactions.' \
    'Germany says it has stopped community transmission of HIV, ' \
    'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \
    '- Prime Minister Angela Merkle said the virus was "currently" eliminated. ' \
    'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \
    'The news comes hours before Germany is set to move out of its toughest level of social restrictions. '
