In [1]:
import argparse
import logging
import math
import os
import random

import datasets
import nltk
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm

import transformers
from accelerate import Accelerator
from filelock import FileLock
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AdamW,
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    SchedulerType,
    get_scheduler,
    set_seed,
    XLMProphetNetTokenizer, XLMProphetNetForConditionalGeneration, XLMProphetNetConfig
)
from transformers.file_utils import is_offline_mode
from transformers.utils.versions import require_version
from preprocess_data import load_xglue
import pickle
import sacrebleu

In [2]:
os.environ ["CUDA_VISIBLE_DEVICES"] = '0, 1'
logger = logging.getLogger(__name__)

In [3]:
class Args():
    model_name_or_path = 'microsoft/xprophetnet-large-wiki100-cased-xglue-ntg'
    cache_dir = "/home/work/xiaoyu/ckpt/xprophtnet_ntg"
    use_fast_tokenizer = True
    data_folder = "/home/work/xiaoyu/datasets/xglue_full_dataset/sampled_NTG"
    pad_to_max_length = False
    ignore_pad_token_for_loss = True
    per_device_eval_batch_size = 2
    val_max_target_length = None
    max_target_length = 128
    num_beams = 10
    
args = Args()

In [4]:
accelerator = Accelerator()
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
logger.info(accelerator.state)

08/18/2021 22:08:50 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Use FP16 precision: False



In [5]:
config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=args.use_fast_tokenizer, cache_dir=args.cache_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, cache_dir=args.cache_dir)
model.resize_token_embeddings(len(tokenizer))

Embedding(250012, 1024, padding_idx=0)

In [6]:
model = accelerator.prepare(model)
print(next(model.parameters()).device) 

cuda:0


In [7]:
label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id    

In [8]:
processed_datasets_path = os.path.join(args.data_folder, "processed_datasets.pkl")
tmp_file = open(processed_datasets_path, "rb")
processed_datasets = pickle.load(tmp_file)
tmp_file.close()

In [9]:
test_dataset = {}
lg = "fr"
test_dataset[lg] = processed_datasets["test." + lg]
index_list = [0, 1, 2, -1]
for index in index_list:
    input_tokens = test_dataset[lg][index]
    print(f"Sample {index} of the test_dataset[{lg}] set: {input_tokens['labels']}.")
    #input_sent = tokenizer.batch_decode(input_tokens["labels"], skip_special_tokens=True)
    #print("input_sent:", " ".join(input_sent))

Sample 0 of the test_dataset[fr] set: [1745, 30719, 18, 3435, 92872, 10851, 115, 36, 1822, 213677, 24, 350, 83863, 203322, 32, 27834, 2197, 264, 107, 36, 192142, 40578, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100].
Sample 1 of the test_dataset[fr] set: [35188, 22797, 18, 126, 8943, 580, 2075, 16888, 2528, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 

In [10]:
padding = "max_length" if args.pad_to_max_length else False
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8 if accelerator.use_fp16 else None,
)

In [11]:
test_dataloader = {}
test_dataloader[lg] = DataLoader(test_dataset[lg], collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
test_dataloader[lg] = accelerator.prepare(test_dataloader[lg])

In [12]:
model.eval()
if args.val_max_target_length is None:
    args.val_max_target_length = args.max_target_length
gen_kwargs = {
    "max_length": args.val_max_target_length if args is not None else config.max_length,
    "num_beams": args.num_beams,
}

In [17]:
def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [[label.strip()] for label in labels]

        # rougeLSum expects newline after each sentence
        #preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        #labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

In [14]:
metric = load_metric("sacrebleu")

Downloading:   0%|          | 0.00/2.38k [00:00<?, ?B/s]

In [21]:
results = {}
for step, batch in enumerate(test_dataloader[lg]):
    if step > 5:
        break
    with torch.no_grad():
        generated_tokens = accelerator.unwrap_model(model).generate(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
            **gen_kwargs,
        )
        #print("generated_tokens", generated_tokens)

        generated_tokens = accelerator.pad_across_processes(
            generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
        )
        #print("generated_tokens", generated_tokens)
        
        labels = batch["labels"]
        if not args.pad_to_max_length:
            # If we did not pad to max length, we need to pad the labels too
            labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

        generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
        labels = accelerator.gather(labels).cpu().numpy()

        if args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]
        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        
        input_seq = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)
        print("\ninput_seq", input_seq[0][:200])
        print("decoded_preds", decoded_preds)
        print("decoded_labels", decoded_labels)
        
        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        
        metric.add_batch(predictions=decoded_preds, references=decoded_labels)
        
res = metric.compute()
print(res)
results[lg] = round(res["score"], 2)
logger.info(f"language {lg} results:")
logger.info(results[lg])



input_seq Vice-présidente de l'Assemblée nationale, la macroniste Carole Bureau-Bonnard était chargée mardi après-midi d'animer la séance d'examen du projet de loi «confiance dans l'action publique». C'était sa
decoded_preds ["Carole Bureau-Bonnard, vice-présidente de l'Assemblée nationale, a connu une séance éprouvante", "Les plus grands fauteuils de l'île d'Antiparos"]
decoded_labels ["Les débuts balbutiants d'une députée LREM provoque la pagaille à l'Assemblée nationale", 'Ces maisons du sud qui nous inspirent']

input_seq Le procès d'un Turc de 17 ans qui avait agressé en janvier 2016 à la machette un enseignant d'une école juive de Marseille portant une kippa, s'ouvre mercredi devant le tribunal pour enfants (TPE) de 
decoded_preds [',,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,', 'The S.O.A.A.D.:,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,']
decoded_labels ['Un jeune djihadiste de 17 ans en procès à Paris', 'Canada : la forme de ce nuage est invra

08/18/2021 22:15:38 - INFO - __main__ - language fr results:
08/18/2021 22:15:38 - INFO - __main__ - 3.24



input_seq Le kaki fait son grand come-back dans notre dressing. Par petites touches ou en total look, voici 20 tenues repérées sur Pinterest pour être stylée en kaki.. Un blouson satiné kaki avec une jupe fleur
decoded_preds ['20 tenues pour être stylée en kaki', 'La tuerie de Las Vegas relance le débat sur le contrôle des armes à feu aux Etats-Unis']
decoded_labels ['Pinterest : 20 façons de porter du kaki ce printemps', 'Fusillades: Les Etats-Unis pays développé le plus meurtrier au monde']
{'score': 3.23696458177316, 'counts': [23, 9, 5, 3], 'totals': [249, 237, 225, 213], 'precisions': [9.236947791164658, 3.7974683544303796, 2.2222222222222223, 1.408450704225352], 'bp': 1.0, 'sys_len': 249, 'ref_len': 125}
