In [1]:
# an example for the inference on the essay test dataset

from fastNLP import cache_results

_first = False
bart_name = "facebook/bart-base"
dataset_name='essay'
cache_fn = f"caches/data_{bart_name}_{dataset_name}_{_first}.pt"

if 'essay' in dataset_name:
    from data.pipe import BartAMPipe_essay as BartPipe
    from model.metrics import Seq2SeqSpanMetric_essay as Seq2SeqSpanMetric
elif 'cdcp' in dataset_name:
    from data.pipe import BartAMPipe_cdcp as BartPipe
    from model.metrics import Seq2SeqSpanMetric_cdcp as Seq2SeqSpanMetric


In [2]:
import pickle

test_res = pickle.load(open("./test_batch.pkl",'rb'))


In [3]:
# prepare tokenizer and metirc
@cache_results(cache_fn, _refresh=False)
def get_data():
    pipe = BartPipe(tokenizer=bart_name, _first=_first)
    data_bundle = pipe.process_from_file(f'./data/{dataset_name}', demo=False)
    return data_bundle, pipe.tokenizer, pipe.mapping2id, pipe.mapping2targetid , pipe.relation_ids, pipe.component_ids, pipe.none_ids

data_bundle, tokenizer, mapping2id, mapping2targetid,relation_ids,component_ids,none_ids = get_data()
bos_token_id = 0  #
eos_token_id = 1  #
# tag_labels = model.seq2seq_model.decoder.mapping.tolist()

label_ids = list(mapping2id.values())
print(label_ids)
# tag_tokens = tokenizer.convert_ids_to_tokens(tag_labels)

# essay dataset setting
mapping = {'<s>': 0, '</s>': 1, '<<positive>>': 2, '<<negative>>': 3, '<<none>>': 4, '<<MC>>': 5, '<<C>>': 6, '<<P>>': 7}

metric = Seq2SeqSpanMetric(eos_token_id, num_labels=len(label_ids),label_mapping = mapping ,_first=_first,)

Read cache from caches/data_facebook/bart-base_essay_False.pt.
[50265, 50266, 50267, 50268, 50269, 50270]


In [4]:
input_batch = test_res['input_batch'][0]
output_batch = test_res['output_batch'][0]

In [7]:
print(input_batch['batch_y']['tgt_tokens'])

tensor([[  0, 241, 256,  ...,   1,   1,   1],
        [  0,  14,  22,  ...,   1,   1,   1],
        [  0, 204, 225,  ...,   1,   1,   1],
        ...,
        [  0,  23,  42,  ...,   1,   1,   1],
        [  0,  13,  30,  ...,   1,   1,   1],
        [  0, 117, 129,  ...,   1,   1,   1]])


In [8]:
for sample_id,(src_token,pred,target) in enumerate(zip(input_batch['batch_x']['src_tokens'],output_batch['batch_res']['pred'],input_batch['batch_y']['tgt_tokens'])):
    print("ID{}".format(sample_id))
    tokens = tokenizer.convert_ids_to_tokens(src_token)
    # use metric to convert sequence to spans
    ps,_ = metric.build_pair(pred.tolist())
    ts,_ = metric.build_pair(target.tolist())
    
    # preds 
    print("----------prediction results----------")
    for tup in ps:
        # sent1 target
        # sent2 src
        sent1 = tokenizer.convert_tokens_to_string(tokens[tup[0]-len(metric.id2label):tup[1]-len(metric.id2label)+1])
        lab1 = metric.id2label[tup[2]]
        sent2 = tokenizer.convert_tokens_to_string(tokens[tup[3]-len(metric.id2label):tup[4]-len(metric.id2label)+1])
        lab2 = metric.id2label[tup[5]]
        rel = metric.id2label[tup[6]]
        print('target:',sent1)
        print('source:',sent2)
        print(lab1,lab2,rel)
        
    print("----------target results----------")
    for tup in ts:
        sent1 = tokenizer.convert_tokens_to_string(tokens[tup[0]-len(metric.id2label):tup[1]-len(metric.id2label)+1])
        lab1 = metric.id2label[tup[2]]
        sent2 = tokenizer.convert_tokens_to_string(tokens[tup[3]-len(metric.id2label):tup[4]-len(metric.id2label)+1])
        lab2 = metric.id2label[tup[5]]
        rel = metric.id2label[tup[6]]
        print('target:',sent1)
        print('source:',sent2)
        print(lab1,lab2,rel)


ID0
----------prediction results----------
(241, 256, 6, 34, 50, 7, 2)
target: sporteventshelptowakeuploveandresponsibilitiesstronglyineachc
source: Seeingnationalflagsbehonoredmaybethemostemotionalmomentstoeachath
<<C>> <<P>> <<positive>>
(241, 256, 6, 53, 79, 7, 2)
target: sporteventshelptowakeuploveandresponsibilitiesstronglyineachc
source: Theathletealsoseemstobemoreawareofhisresponsibilitiesandhewantstocontributemoretohisbelovedcountry
<<C>> <<P>> <<positive>>
(241, 256, 6, 81, 107, 7, 2)
target: sporteventshelptowakeuploveandresponsibilitiesstronglyineachc
source: Whenwitnessingournationalflagsflyingproudlyamongothers,manyofuscanstophidingourprideandh
<<C>> <<P>> <<positive>>
(241, 256, 6, 113, 144, 7, 2)
target: sporteventshelptowakeuploveandresponsibilitiesstronglyineachc
source: whenVietnambecamethechampionofSEAGAMESforthefirsttimein2008,afestivalatmospherepermeatedintothestre
<<C>> <<P>> <<positive>>
(241, 256, 6, 147, 176, 7, 2)
target: sporteventshelptowakeuploveandresponsi