In [1]:
class Args:
    dev_output_dir = '.'
    learning_rate = 5e-5
    train_data_path = '/kaggle/input/scnu-ai-challenge-5/train.json'
    dev_data_path = '/kaggle/input/scnu-ai-challenge-5/dev.json'
    test_data_path = '/kaggle/input/scnu-ai-challenge-5/test.json'
    train_epochs = 10
    save_checkpoint_dir = '.'
    test_output_dir = '.'
    best_score = 0
    max_source_length = 1024
    max_target_length = 128
    train_batch_size = 4
    predict_batch_size =8 
    seed=7
    gpu_id = 0
    default_checkpoint_path='/kaggle/input/bart-large-checkpoint2/checkpoint_best.pkl'
    scores_list=[]
args = Args()

In [2]:
import random
import os
import numpy as np
import torch
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(args.seed)

In [3]:
import json
import random
from tqdm import tqdm

pg_num=6 #文段数量最大值
def get_datas(data_path,data_type):
    with open(data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    
    d_len={i:0 for i in range(11)}## 选取的文段数量分布
    d_level={} ## 各难度文段数量
    answer_len={i:0 for i in range(15)} 
    answers=[]
    contexts=[]
    questions=[]
    context_with_ans=0
    bad_answer_type1={i:[] for i in range(1,10)}
    bad_answer_type2={i:[] for i in range(1,10)}
    bad_answer_type1_num=0
    bad_answer_type2_num=0
    support_fact_num=0
    pred_support_fact_num=0
    for i,d in tqdm(enumerate(data)):
        ## 删除训练集中的单跳问题
        if data_type=='train':
            if d['level']=='easy':
                continue
                #pass
        
        ## 验证集只保留难度为hard的文段，结果较为稳定
        if data_type=='dev':
            if d['level']!='hard':
                continue
        
        support_facts_list=[]
        fact_set=set()
        
        if data_type=='train':
            titles=[]
            for item in d['supporting_facts']:
                fact_set.add(item[0])
            support_facts_list=list(fact_set)
            for x in d['context']:
                if x[0] not in support_facts_list:
                    titles.append(x[0])
            num=pg_num-len(support_facts_list)
            support_facts_list.extend(random.choices(titles,k=min(len(titles),num)))
            ## 训练集需要打乱文段顺序避免模型利用位置信息
            random.shuffle(support_facts_list)
        else:
            for item in d['pred_support_facts']:
                if len(support_facts_list)>=pg_num:
                    break
                support_facts_list.append(item[0])
            ## 测试集和验证集不需要打乱，避免截断
            #random.shuffle(support_facts_list)
            
        d_len[len(support_facts_list)]+=1   
        
        ## 检查损失信息比例
        if data_type=='dev':
            real_support_fact=set()
            for item in d['supporting_facts']:
                real_support_fact.add(item[0])
            support_fact_num+=len(real_support_fact)
            for x in support_facts_list:
                if x in real_support_fact:
                    pred_support_fact_num+=1
        
        
        answer=d['answer']#.replace(',',' , ')
        ## 检查答案被包含在多少个文段中，如果答案出现的文段数多于两个，说明答案的质量不佳
        paragraph_with_answer=0
        ## 拼接得到context
        context=''   
        for x in d['context']:
            s=' '
            if x[0] in support_facts_list:
                for text in x[1]:
                    s+=text
                ## 段落开头和结尾用特殊字符进行标识
                context+=f'<bop> {x[0]}. {s} <eop> '
                if s.find(answer)!=-1:
                    paragraph_with_answer+=1
        #context=context.replace('.',' . ').replace(',',' , ')
        if context.find(answer)!=-1:
            context_with_ans+=1
        context=context.replace(f' {answer} ',f' <ans> {answer} </ans> ')
        answer_token_num=len(answer.split())
        answer_len[answer_token_num]=answer_len.get(answer_token_num,0)+1
        
        ## 答案出现在多个段落中认为答案质量不佳，对应问题存在多个
        if paragraph_with_answer>2:
            bad_answer_type1[answer_token_num].append(answer)
            bad_answer_type1_num+=1
            if data_type=='train':
                pass
                #continue
        ## 删除训练集中答案信息不充分的数据如yes,no
        ## 答案很短答案未出现在文段中视为信息不充分
        if answer_token_num<=3:
            if paragraph_with_answer==0:
                bad_answer_type2[answer_token_num].append(answer)
                bad_answer_type2_num+=1
                if data_type=='train':
                    pass
                    #continue
            
        answer=answer 
        contexts.append(context)
        answers.append(answer)
        
        ## 统计问题难度
        if data_type!='test':
            questions.append(d['question']) 
            level=d['level']
            d_level[level]=d_level.get(level,0)+1
        
    if data_type=='dev':
        print('check info loss:',pred_support_fact_num,support_fact_num,pred_support_fact_num/support_fact_num)
    print('chosen paragraph num:',d_len)
    print('counting question level:',d_level)
    print('counting num of answer token:',answer_len)
    print('context_with_ans:',context_with_ans)
    print('bad_answer_num_type1:',bad_answer_type1_num)
    print('bad_answer_num_type2:',bad_answer_type2_num)
    print(bad_answer_type1)
    print(bad_answer_type2)
    return contexts,questions,answers

In [4]:
!pip install pycocoevalcap
!pip install bert_score
!pip install peft

Collecting pycocoevalcap
  Downloading pycocoevalcap-1.2-py3-none-any.whl.metadata (3.2 kB)
Collecting pycocotools>=2.0.2 (from pycocoevalcap)
  Downloading pycocotools-2.0.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Downloading pycocoevalcap-1.2-py3-none-any.whl (104.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.3/104.3 MB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading pycocotools-2.0.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (426 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m426.2/426.2 kB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pycocotools, pycocoevalcap
Successfully installed pycocoevalcap-1.2 pycocotools-2.0.7
Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

In [5]:
import json
import os
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset
from datasets import load_dataset

from transformers import T5ForConditionalGeneration, T5Tokenizer, DataCollatorForSeq2Seq


from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from bert_score import BERTScorer


scorers = {
        "Bleu": Bleu(4),
        #"Meteor": Meteor(),
        "Rouge": Rouge(),
    }
import bert_score
bert_scorer = BERTScorer(lang="en",model_type='roberta-large',rescale_with_baseline=True)
# 测评问题的流畅性
def fluencyScore(preds_list, gold_list):
    
    gts = {}
    res = {}
    for i, (p, g) in enumerate(zip(preds_list, gold_list)):
        gts[i] = [p]
        res[i] = [g]
    scores = {}
    for name, scorer in scorers.items():
        score, all_scores = scorer.compute_score(gts, res)
        if isinstance(score, list):
            for i, sc in enumerate(score, 1):
                scores[name + str(i)] = sc
        else:
            scores[name] = score
    return scores

# 测评语义相似度
def SemanticScore(preds_list, gold_list):
    p,r,f1 = bert_scorer.score(preds_list, gold_list, verbose=True)
    bert_score = np.mean(f1.tolist())
    return bert_score

def getTotalScore(preds_list,gold_list):
    bert_score = SemanticScore(preds_list,gold_list)
    scores = fluencyScore(preds_list,gold_list)
    last_score = (bert_score/2+scores['Bleu4']/2)*100
   # print(scores)
    return {'TotalScore':last_score, 
            'BERTScore':bert_score,
            'Bleu1':scores['Bleu1'],
            'Bleu2':scores['Bleu2'],
            'Bleu3':scores['Bleu3'],
            'Bleu4':scores['Bleu4'],
            #'Meteor':scores['Meteor'],
            'Rouge':scores['Rouge'],
           }
    #return {'TotalScore':last_score, 'BERTScore':bert_score,'Bleu4':scores['Bleu4']}
    
def saveJsonResult(generated_questions:list[dict], data_type = 'dev', score_type = 'last'):
    '''
    保存生成结果
    data_type: 'dev'和'test'
    score_type: best（最好结果）和last（最新一次epoch的结果）
    '''
    if data_type == 'dev':
        if score_type != 'best' and score_type != 'last':
            path = os.path.join(args.dev_output_dir,'output_last.json')
        else:
            path = os.path.join(args.dev_output_dir,f'output_{score_type}.json')
    elif data_type == 'test':
        path = os.path.join(args.test_output_dir,'output.json')
    else:
        print("未写明data_type")
        return False
    with open(path, 'w', encoding='utf-8') as json_file:
        json.dump(generated_questions, json_file, ensure_ascii=False, indent=4)
print('done')

2024-06-12 09:40:08.585126: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-12 09:40:08.585227: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-12 09:40:08.710097: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


done


In [6]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self,contexts,questions,answers,tokenizer,max_len,data_type):
        self.contexts = contexts
        self.questions = questions
        self.answers=answers
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.data_type=data_type
    def __len__(self):
        return len(self.contexts)
    def __getitem__(self,item):
        context=self.contexts[item]
        answer=self.answers[item]
        source=f'{answer} [SEP] {context}'
        #source = f'Given the context and answer, please help me generate a multi-hop question.\nAnswer: {answer}\nContext: {context}\nQuestion:'
        source_encoding = self.tokenizer(text=source,
                      max_length=args.max_source_length,
                      padding=True,
                      truncation=True,
                      add_special_tokens=True,
                      return_attention_mask=True,
                      #return_tensors='pt'
                                        )
        if self.data_type=='test':
            return {
                #"text":source,
                "input_ids":source_encoding['input_ids'],
               "attention_mask":source_encoding['attention_mask'],
               }
        
        question=self.questions[item]
        target_encoding=self.tokenizer(text=question,
                      max_length=args.max_target_length,
                      padding=True,
                      truncation=True,
                      add_special_tokens=True,
                      return_attention_mask=False,
                      #return_tensors='pt'
                                      )
        #if item==0:
        #print(source)
            #print(question)
            #print(source_encoding['input_ids'].shape)
            #print(target_encoding['input_ids'])
        #print(torch.LongTensor(source_encoding['input_ids']).squeeze(0).shape)
        return {
                #"text":source,
                "input_ids":torch.LongTensor(source_encoding['input_ids']).squeeze(0),
                "attention_mask":torch.LongTensor(source_encoding['attention_mask']),
                "labels":torch.LongTensor(target_encoding['input_ids'])
               }#,[source,question,source_encoding['input_ids'].shape,target_encoding['input_ids']]

    dev_gold_question_list=[]
def create_data_loader(data_path,data_type,tokenizer,max_len,batch_size=4,shuffle=True):
    contexts,questions,answers=get_datas(data_path,data_type)
    ds = MyDataSet(
                   contexts=contexts,
                   questions=questions,
                   answers=answers,
                   tokenizer = tokenizer,
                   max_len=max_len,
                   data_type=data_type
                  )
    if data_type=='dev':
        global dev_gold_question_list
        dev_gold_question_list=questions
    collate_fn = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8,padding=True)
    return torch.utils.data.DataLoader(ds,batch_size=batch_size,collate_fn=collate_fn,shuffle=shuffle)


max_len = args.max_source_length
batch_size = 4
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [7]:
from transformers import  BartTokenizer,BartForConditionalGeneration,BertTokenizer
model_name='facebook/bart-large'
tokenizer =BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name).to(f'cuda:{args.gpu_id}')
model.config.bos_token_id=tokenizer.eos_token_id
tokenizer.add_special_tokens({'additional_special_tokens':["<ans>","</ans>","[SEP]","<eop>","<bop>"]})
print(tokenizer.SPECIAL_TOKENS_ATTRIBUTES)
model.resize_token_embeddings(len(tokenizer))

def load_checkpoint(path=None):
    if path==None:
        path=args.default_checkpoint_path
    print('load from: ',path)
    model.load_state_dict(torch.load(path))
    
mode='load'#'build'
print('mode: ',mode)
if mode=='load':
    checkpoint=None#'/kaggle/working/checkpoint_best.pkl'
    load_checkpoint(checkpoint)
elif mode=='build':
    pass

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


['bos_token', 'eos_token', 'unk_token', 'sep_token', 'pad_token', 'cls_token', 'mask_token', 'additional_special_tokens']
mode:  load
load from:  /kaggle/input/bart-large-checkpoint2/checkpoint_best.pkl


In [9]:
%%time
train_data_path='/kaggle/input/scnu-ai-challenge-5/train.json'
dev_data_path='/kaggle/input/scnu-ai-challenge-dataset-with-sorted-pred-facts/dev.json'
test_data_path='/kaggle/input/scnu-ai-challenge-dataset-with-sorted-pred-facts/test.json'
train_data_loader=create_data_loader(train_data_path,
                                     'train',
                                     tokenizer,
                                     max_len,
                                     batch_size=batch_size,
                                     shuffle=True)

88947it [00:02, 29754.37it/s]

chosen paragraph num: {0: 0, 1: 0, 2: 160, 3: 114, 4: 65, 5: 67, 6: 72220, 7: 0, 8: 0, 9: 0, 10: 0}
counting question level: {'medium': 51684, 'hard': 20942}
counting num of answer token: {0: 0, 1: 25688, 2: 25013, 3: 13801, 4: 4480, 5: 2111, 6: 670, 7: 333, 8: 192, 9: 101, 10: 66, 11: 58, 12: 23, 13: 25, 14: 18, 26: 1, 16: 14, 17: 7, 19: 3, 15: 13, 27: 1, 34: 1, 18: 4, 22: 2, 29: 1}
context_with_ans: 70148
bad_answer_num_type1: 8407
bad_answer_num_type2: 2478
{1: ['yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'no', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'no', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 




In [10]:
%%time
dev_data_loader=create_data_loader(dev_data_path,
                                     'dev',
                                     tokenizer,max_len,batch_size=args.predict_batch_size,shuffle=False)
test_data_loader=create_data_loader(test_data_path,
                                     'test',
                                     tokenizer,max_len,batch_size=args.predict_batch_size,shuffle=False)

# 读取test的ids
with open(test_data_path, 'r', encoding='utf-8') as json_file:
    data = json.load(json_file)
test_id_list = [item['_id'] for item in data]

1500it [00:00, 102081.00it/s]


check info loss: 688 720 0.9555555555555556
chosen paragraph num: {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 360, 7: 0, 8: 0, 9: 0, 10: 0}
counting question level: {'hard': 360}
counting num of answer token: {0: 0, 1: 122, 2: 124, 3: 75, 4: 24, 5: 8, 6: 0, 7: 1, 8: 3, 9: 0, 10: 1, 11: 0, 12: 0, 13: 2, 14: 0}
context_with_ans: 344
bad_answer_num_type1: 58
bad_answer_num_type2: 16
{1: ['no', 'tennis', 'no', 'no', 'director', 'Disney', 'Thames', 'no', 'no', 'no', 'Kansas', 'British', 'singer', '1995', 'German', 'no', 'no', 'Cantonese', 'Mexico', 'two', 'Irish', 'Russian', 'no', 'Detroit', 'Marvel'], 2: ['Love Actually', 'Indianapolis, Indiana', 'Supreme Court', 'Nikita Khrushchev', 'Golf Digest', 'Bob Stoops', 'North Dakota', 'Boston University', 'Chicago Outfit', 'Porfirio Rubirosa', 'Atlanta, Georgia', 'Loughborough University', 'Eugene Onegin', 'Manhattan Center', 'Alan James', 'Relient K', 'National Geographic', 'documentary film', 'Sleeping Beauty', 'professional wrestler', 'Takahiro Mo

7405it [00:00, 33153.14it/s]


chosen paragraph num: {0: 0, 1: 0, 2: 23, 3: 15, 4: 12, 5: 6, 6: 7349, 7: 0, 8: 0, 9: 0, 10: 0}
counting question level: {}
counting num of answer token: {0: 0, 1: 2607, 2: 2542, 3: 1353, 4: 453, 5: 240, 6: 83, 7: 36, 8: 26, 9: 9, 10: 11, 11: 4, 12: 9, 13: 1, 14: 4, 32: 2, 15: 5, 35: 1, 22: 1, 34: 1, 19: 4, 16: 2, 26: 1, 25: 2, 20: 2, 17: 3, 21: 1, 18: 1, 24: 1}
context_with_ans: 7184
bad_answer_num_type1: 1494
bad_answer_num_type2: 221
{1: ['Europe', 'Bundesliga', 'no', '201', 'black', '1984', 'Masterpiece', 'American', '2010', 'Cumberland', 'Garbage', 'no', 'no', 'novelist', 'cocktail', 'Drosera', 'plants', 'singer', 'Eragrostis', 'Australia', 'no', 'no', 'Brazil', 'Aerie', 'Aechmea', 'basketball', '1967', 'opera', 'Canada', 'Indian', 'no', 'Cleveland', 'Canada', 'Scientology', 'Wikstroemia', 'leader', 'rifles', 'Diplomacy', '1989', 'no', 'Loa', 'no', 'American', 'no', 'no', 'no', 'composer', 'singer', '2001', 'Puyi', 'Mexico', 'no', 'film', 'India', 'Elamite', 'American', 'rock', 'C

In [11]:
def get_infinite_data(dataloader):
        while True:
            for images in dataloader:
                yield images
train_generator=get_infinite_data(train_data_loader)
dev_generator=get_infinite_data(test_data_loader)
def get_batch():
    return train_generator.__next__()
def get_dev_batch():
    return dev_generator.__next__()
def get_answer(batch):
    bad_answers=[]
    for i in range(4):
        bad_answers.append(tokenizer.convert_ids_to_tokens(batch['input_ids'][i])[:15])
    return bad_answers

In [12]:
x=get_batch()
inputs=tokenizer.convert_ids_to_tokens(x['input_ids'][0])
#labels=tokenizer.convert_ids_to_tokens(x['labels'][0].reshape(-1))
x['input_ids'].shape,get_answer(x)

(torch.Size([4, 1024]),
 [['<s>',
   'June',
   'Ġ28',
   ',',
   'Ġ1902',
   'Ġ',
   '[SEP]',
   'Ġ',
   '<bop>',
   'ĠAmerican',
   'ĠCompos',
   'er',
   'ĠSeries',
   '.',
   'Ġ'],
  ['<s>',
   'Michael',
   'ĠC',
   'im',
   'ino',
   'Ġ',
   '[SEP]',
   'Ġ',
   '<bop>',
   'ĠThe',
   'ĠDeer',
   'ĠHunter',
   '.',
   'Ġ',
   'ĠThe'],
  ['<s>',
   'Ay',
   'man',
   'am',
   'Ġ',
   '[SEP]',
   'Ġ',
   '<bop>',
   'ĠAy',
   'man',
   'am',
   '.',
   'Ġ',
   'Ġ',
   '<ans>'],
  ['<s>',
   'yes',
   'Ġ',
   '[SEP]',
   'Ġ',
   '<bop>',
   'ĠPop',
   'stars',
   'Ġ(',
   'Germany',
   'Ġseason',
   'Ġ1',
   ').',
   'Ġ',
   'ĠIn']])

In [13]:
## 评估模型函数
def eval_model(epoch=1,testing_batch=len(dev_data_loader)):
    model.eval()
    preds_list = []
    if testing_batch<=0:
        testing_batch=len(dev_data_loader)
    with tqdm(total=testing_batch, desc=f'Validation Epoch {epoch}', unit='batch') as pbar:
        for i,batch in enumerate(dev_data_loader):
            if i>testing_batch:break
            with torch.no_grad():
                input_ids = batch['input_ids'].to(device)
                attention_mask=batch['attention_mask'].to(device)
                generated_ids = model.generate(
                    input_ids=input_ids, 
                    attention_mask=attention_mask,
                    do_sample=False,
                    max_new_tokens=256,
                    num_beams=3, 
                    #num_beam_groups=5,
                    #diversity_penalty=1.0
                )
                preds_list.extend(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
 
            pbar.update(1)
    

    for i in range(20):
        print('pred: ',preds_list[i])
        print("true: ",dev_gold_question_list[i])
        
    scores = getTotalScore(preds_list, dev_gold_question_list[:len(preds_list)])
    preds_dict = []
    if scores['TotalScore'] > args.best_score:
        torch.save(model.state_dict(), os.path.join(args.save_checkpoint_dir, "checkpoint_best.pkl"))
        print(f"Total score: {args.best_score} -> {scores['TotalScore'] }")
        print(f"checkpoint_best.pdparams已存储至{args.save_checkpoint_dir}")
        args.best_score = scores['TotalScore'] 
    args.scores_list.append(scores)
    print(f"Scores: {scores}")
    return preds_list
pred_list=eval_model(epoch=0,testing_batch=5)

Validation Epoch 0: 6batch [00:13,  2.20s/batch]                    

pred:  Are both Trichosanthes and Plantago considered flowering plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant of what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who was elected to the United States Senate seat in 2012 and was recently appointed to the seat of Nevada's open U.S. Senate seat?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did Moe Szyslak first appear in the series premiere episode of The Simpsons that originally aired on the Fox network in the United States on what date?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which America




  0%|          | 0/2 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.45 seconds, 107.84 sentences/sec
{'testlen': 705, 'reflen': 771, 'guess': [705, 657, 609, 561], 'correct': [309, 159, 92, 51]}
ratio: 0.9143968871583471
Total score: 0 -> 32.7637042243212
checkpoint_best.pdparams已存储至.
Scores: {'TotalScore': 32.7637042243212, 'BERTScore': 0.47736939913981286, 'Bleu1': 0.3991278261156405, 'Bleu2': 0.29658091645327284, 'Bleu3': 0.2295795471428749, 'Bleu4': 0.17790468534661108, 'Rouge': 0.37532466330341957}


In [None]:
optimizer = Adafactor(model.parameters(), relative_step=True, warmup_init=True,lr=None,clip_threshold=1.0)
scheduler = AdafactorSchedule(optimizer)

In [None]:
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
from torch.nn.utils import clip_grad_norm_
from transformers import  get_linear_schedule_with_warmup

import time
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" 
save_dir = args.save_checkpoint_dir
training_batch=1024
testing_batch=45
updata_batch=8
epochs=20
for epoch in range(epochs):
    # 训练部分
    model.train()
    epoch_loss = 0
    batch_loss=0
    with tqdm(total=training_batch//updata_batch, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch',mininterval=3) as pbar:
        for i in range(training_batch):
            batch=get_batch()
            input_ids= batch['input_ids'].to(device)
            attention_mask=batch['attention_mask'].to(device)
            labels=batch['labels'].to(device)
            #总是报显存不足的问题，是因为碎片没完全释放
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()

            outputs = model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            labels=labels, 
                            return_dict=True,
                           )
            loss = outputs.loss
        

            loss.backward() 
            loss_t = loss.detach().cpu()
            """if loss_t.item()>1.80:
                print(loss_t.item(),input_ids.shape)
                bad_answers=get_answer(batch)
                for x in bad_answers:
                    print(x[:x.index('[SEP]')])
                bad_answers_list.append(bad_answers)
            """
            epoch_loss += loss_t[0] if len(loss_t.shape) > 0 else loss_t.numpy()
            batch_loss +=loss_t[0] if len(loss_t.shape) > 0 else loss_t.numpy()
            if (i+1)%updata_batch==0:
                clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                learning_rate=optimizer.param_groups[0]['lr']
                pbar.set_postfix({
                    'loss': f'{ batch_loss/updata_batch:.4f}',
                    'lr':f'{learning_rate:.5f}'
                })
                batch_loss=0
                -
                pbar.update(1)
                # break
    print(f"Average Loss: {epoch_loss / training_batch:.4f}")
    # 验证部分
    eval_model(epoch,testing_batch)
    torch.save(model.state_dict(), os.path.join(args.save_checkpoint_dir, "checkpoint_last.pkl"))

Epoch 1/20: 100%|██████████| 128/128 [15:30<00:00,  7.27s/batch, loss=0.9895, lr=0.00010]


Average Loss: 0.9734


Validation Epoch 0: 100%|██████████| 45/45 [01:31<00:00,  2.03s/batch]

pred:  Are both Trichosanthes and Plantago considered flowering plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant in what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who was elected to his first full term in the 2012 United States Senate election in Nevada?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did the episode of The Simpsons in which Moe Szyslak first appeared originally air?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of honor at the Readercon annual science fiction convention?
true: 




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.15 seconds, 167.62 sentences/sec
{'testlen': 5545, 'reflen': 5774, 'guess': [5545, 5185, 4825, 4465], 'correct': [2563, 1332, 758, 430]}
ratio: 0.9603394527189192
Total score: 32.7637042243212 -> 34.2570663036352
checkpoint_best.pdparams已存储至.
Scores: {'TotalScore': 34.2570663036352, 'BERTScore': 0.4875946831886217, 'Bleu1': 0.44351811145702175, 'Bleu2': 0.3306477437452711, 'Bleu3': 0.25448124035939496, 'Bleu4': 0.19754664288408222, 'Rouge': 0.40059092892791287}


Epoch 2/20: 100%|██████████| 128/128 [15:23<00:00,  7.21s/batch, loss=0.8681, lr=0.00009]


Average Loss: 0.9548


Validation Epoch 1: 100%|██████████| 45/45 [01:32<00:00,  2.05s/batch]

pred:  Are Trichosanthes and Plantago in the same family?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant in what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term in 2012?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did Moe Szyslak first appear in the series premiere episode of The Simpsons that originally aired on the Fox network in the United States on what date?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest 




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.08 seconds, 173.42 sentences/sec
{'testlen': 5545, 'reflen': 5829, 'guess': [5545, 5185, 4825, 4465], 'correct': [2572, 1343, 768, 441]}
ratio: 0.9512780922969718
Total score: 34.2570663036352 -> 34.308401467936676
checkpoint_best.pdparams已存储至.
Scores: {'TotalScore': 34.308401467936676, 'BERTScore': 0.4881060899992008, 'Bleu1': 0.44068271600225795, 'Bleu2': 0.3293100133160657, 'Bleu3': 0.254064192048558, 'Bleu4': 0.19806193935953276, 'Rouge': 0.4035147386072598}


Epoch 3/20: 100%|██████████| 128/128 [15:26<00:00,  7.24s/batch, loss=1.0150, lr=0.00009]


Average Loss: 0.9586


Validation Epoch 2: 100%|██████████| 45/45 [01:30<00:00,  2.02s/batch]

pred:  Are both Trichosanthes and Plantago considered plant life?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant in what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term in 2012?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did the episode of The Simpsons in which Moe Szyslak first appeared air?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are The Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of honor at Readercon?
true:  Who is the science fiction writer, Sa




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.03 seconds, 177.71 sentences/sec
{'testlen': 5545, 'reflen': 5726, 'guess': [5545, 5185, 4825, 4465], 'correct': [2582, 1353, 773, 447]}
ratio: 0.9683898009079691
Total score: 34.308401467936676 -> 34.61573068062369
checkpoint_best.pdparams已存储至.
Scores: {'TotalScore': 34.61573068062369, 'BERTScore': 0.48895404816624555, 'Bleu1': 0.4506905358814383, 'Bleu2': 0.33738487759748365, 'Bleu3': 0.26036686710840834, 'Bleu4': 0.20336056544622833, 'Rouge': 0.4085983975402008}


Epoch 4/20: 100%|██████████| 128/128 [15:29<00:00,  7.26s/batch, loss=1.0448, lr=0.00008]


Average Loss: 0.9353


Validation Epoch 3: 100%|██████████| 45/45 [01:33<00:00,  2.07s/batch]

pred:  Are Trichosanthes and Plantago in the same family?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant of what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did Moe Szyslak first appear in the series premiere episode of The Simpsons that originally aired on the Fox network in the United States on what date?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are The Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of h




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.06 seconds, 174.60 sentences/sec
{'testlen': 5545, 'reflen': 5840, 'guess': [5545, 5185, 4825, 4465], 'correct': [2568, 1333, 759, 445]}
ratio: 0.9494863013697004
Scores: {'TotalScore': 34.013086115459814, 'BERTScore': 0.4831740185961179, 'Bleu1': 0.4391253738559304, 'Bleu2': 0.3271767939887265, 'Bleu3': 0.25180721451830496, 'Bleu4': 0.19708770371307835, 'Rouge': 0.4000743702322026}


Epoch 5/20: 100%|██████████| 128/128 [15:38<00:00,  7.34s/batch, loss=1.0113, lr=0.00008]


Average Loss: 0.9563


Validation Epoch 4: 100%|██████████| 45/45 [01:30<00:00,  2.02s/batch]

pred:  Are Trichosanthes and Plantago in the same family?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant in what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term in 2012?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did the episode of "The Simpsons" in which Moe Szyslak first appeared originally air?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are The Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of honor at Readercon?
true:  Who is the science fiction write




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.08 seconds, 173.36 sentences/sec
{'testlen': 5545, 'reflen': 5708, 'guess': [5545, 5185, 4825, 4465], 'correct': [2589, 1345, 765, 442]}
ratio: 0.9714435879465712
Total score: 34.61573068062369 -> 34.6316681452805
checkpoint_best.pdparams已存储至.
Scores: {'TotalScore': 34.6316681452805, 'BERTScore': 0.489876398545069, 'Bleu1': 0.453381759521069, 'Bleu2': 0.33793685298217385, 'Bleu3': 0.26002968146840155, 'Bleu4': 0.20275696436054105, 'Rouge': 0.4063051175379131}


Epoch 6/20: 100%|██████████| 128/128 [15:31<00:00,  7.28s/batch, loss=0.8707, lr=0.00007]


Average Loss: 0.9442


Validation Epoch 5: 100%|██████████| 45/45 [01:31<00:00,  2.04s/batch]

pred:  Are Trichosanthes and Plantago in the same family?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant in what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term in 2012?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did Moe Szyslak first appear in the series premiere episode of The Simpsons that originally aired on the Fox network in the United States?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are The Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of honor 




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.08 seconds, 173.07 sentences/sec
{'testlen': 5545, 'reflen': 5824, 'guess': [5545, 5185, 4825, 4465], 'correct': [2572, 1318, 741, 428]}
ratio: 0.9520947802196167
Scores: {'TotalScore': 33.804483205974456, 'BERTScore': 0.48199301957308005, 'Bleu1': 0.44108026465911804, 'Bleu2': 0.3265248543257377, 'Bleu3': 0.24970880937704582, 'Bleu4': 0.19409664454640907, 'Rouge': 0.3981381317773622}


Epoch 7/20: 100%|██████████| 128/128 [15:26<00:00,  7.24s/batch, loss=1.2639, lr=0.00007]


Average Loss: 0.9511


Validation Epoch 6: 100%|██████████| 45/45 [01:31<00:00,  2.03s/batch]

pred:  Are Trichosanthes and Plantago in the same family?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant in what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did the episode of "The Simpsons" in which Moe Szyslak first appeared originally air?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are The Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of honor at Readercon?
true:  Who is the science fiction writer, Samue




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.12 seconds, 170.16 sentences/sec
{'testlen': 5545, 'reflen': 5798, 'guess': [5545, 5185, 4825, 4465], 'correct': [2568, 1326, 740, 419]}
ratio: 0.9563642635389864
Scores: {'TotalScore': 33.99201707137644, 'BERTScore': 0.4857122528088641, 'Bleu1': 0.4424641097702878, 'Bleu2': 0.32879764338086337, 'Bleu3': 0.25114553306512677, 'Bleu4': 0.1941280886186646, 'Rouge': 0.398750013510342}


Epoch 8/20: 100%|██████████| 128/128 [15:45<00:00,  7.39s/batch, loss=1.0155, lr=0.00006]


Average Loss: 0.9479


Epoch 9/20: 100%|██████████| 128/128 [15:33<00:00,  7.30s/batch, loss=0.7143, lr=0.00006]


Average Loss: 0.9448


Validation Epoch 8: 100%|██████████| 45/45 [01:32<00:00,  2.05s/batch]

pred:  Are Trichosanthes and Plantago in the same family?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant in what season of the Indian reality tv series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did the episode of "The Simpsons" in which Moe Szyslak first appeared air?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are The Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of honor at Readercon?
true:  Who is the science fiction writer, Samuel R. Delany




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.10 seconds, 171.33 sentences/sec
{'testlen': 5545, 'reflen': 5887, 'guess': [5545, 5185, 4825, 4465], 'correct': [2584, 1337, 755, 435]}
ratio: 0.9419058943433086
Scores: {'TotalScore': 34.01519151417066, 'BERTScore': 0.48579617597990565, 'Bleu1': 0.43813192227910236, 'Bleu2': 0.3259122922874285, 'Bleu3': 0.25000864018487695, 'Bleu4': 0.1945076543035075, 'Rouge': 0.39684381048075135}


Epoch 10/20: 100%|██████████| 128/128 [15:36<00:00,  7.32s/batch, loss=1.0796, lr=0.00005]


Average Loss: 0.9496


Validation Epoch 9: 100%|██████████| 45/45 [01:32<00:00,  2.05s/batch]

pred:  Are Trichosanthes and Plantago in the same family?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant of what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did Moe Szyslak first appear in the series premiere episode of The Simpsons that originally aired on the Fox network in the United States on what date?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are The Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of h




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.10 seconds, 171.22 sentences/sec
{'testlen': 5545, 'reflen': 5841, 'guess': [5545, 5185, 4825, 4465], 'correct': [2569, 1324, 761, 446]}
ratio: 0.9493237459337528
Scores: {'TotalScore': 33.91392008070594, 'BERTScore': 0.4813005619759982, 'Bleu1': 0.4392171561041131, 'Bleu2': 0.32607509606100843, 'Bleu3': 0.2514469937000336, 'Bleu4': 0.19697783963812054, 'Rouge': 0.3970676611928179}


Epoch 11/20: 100%|██████████| 128/128 [15:29<00:00,  7.26s/batch, loss=0.9182, lr=0.00005]


Average Loss: 0.9625


Validation Epoch 10: 100%|██████████| 45/45 [01:31<00:00,  2.04s/batch]

pred:  Are Trichosanthes and Plantago in the same family?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant in what season of the Indian reality tv series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did Moe Szyslak first appear in the series premiere episode of The Simpsons that originally aired on the Fox network in the United States on what date?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are The Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of h




  0%|          | 0/12 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/6 [00:00<?, ?it/s]

done in 2.14 seconds, 168.01 sentences/sec
{'testlen': 5545, 'reflen': 5822, 'guess': [5545, 5185, 4825, 4465], 'correct': [2574, 1332, 753, 433]}
ratio: 0.95242184816198
Scores: {'TotalScore': 33.99308673962423, 'BERTScore': 0.48379157926990757, 'Bleu1': 0.4415824944632963, 'Bleu2': 0.3285005374671801, 'Bleu3': 0.25209151480865355, 'Bleu4': 0.19607015552257714, 'Rouge': 0.3978101067587732}


Epoch 12/20: 100%|██████████| 128/128 [15:34<00:00,  7.30s/batch, loss=1.0439, lr=0.00004]


Average Loss: 0.9505


Validation Epoch 11: 100%|██████████| 45/45 [01:30<00:00,  2.01s/batch]

pred:  Are Trichosanthes and Plantago in the same family?
true:  Are Trichosanthes and Plantago both forms of plant life?
pred:  Lopamudra Raut was a contestant of what season of the Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
pred:  Who won the 2012 United States Senate election in Nevada and was elected to his first full term?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
pred:  When did the episode of "The Simpsons" in which Moe Szyslak first appeared air?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
pred:  Are The Shipping News and Gene both post-rock bands?
true:  Are Shipping News and Gene both rock bands?
pred:  Which American writer of science fiction was a guest of honor at Readercon?
true:  Who is the science fiction writer, Samuel R. Delany




  0%|          | 0/12 [00:00<?, ?it/s]

In [38]:
args.scores_list

[{'TotalScore': 32.7637042243212,
  'BERTScore': 0.47736939913981286,
  'Bleu1': 0.3991278261156405,
  'Bleu2': 0.29658091645327284,
  'Bleu3': 0.2295795471428749,
  'Bleu4': 0.17790468534661108,
  'Rouge': 0.37532466330341957},
 {'TotalScore': 34.1206256316701,
  'BERTScore': 0.4848011009834914,
  'Bleu1': 0.44133567669368995,
  'Bleu2': 0.3292575376372817,
  'Bleu3': 0.2537743924964035,
  'Bleu4': 0.19761141164991072,
  'Rouge': 0.4014842854033838},
 {'TotalScore': 34.23306967849367,
  'BERTScore': 0.48562253667268024,
  'Bleu1': 0.4432651181683536,
  'Bleu2': 0.33052362052425627,
  'Bleu3': 0.25529412000355123,
  'Bleu4': 0.19903885689719322,
  'Rouge': 0.4021943254504971},
 {'TotalScore': 34.25004918130635,
  'BERTScore': 0.48733625842174255,
  'Bleu1': 0.4460437420523031,
  'Bleu2': 0.3325855944124768,
  'Bleu3': 0.2552949899567074,
  'Bleu4': 0.19766472520438452,
  'Rouge': 0.4019092878302473},
 {'TotalScore': 34.661992376937505,
  'BERTScore': 0.49149847903722227,
  'Bleu1': 0.4

In [None]:
test_generator=get_infinite_data(test_data_loader)
def get_test_batch():
    return test_generator.__next__()
# 测试
generated_questions = []
generated_questions_dict = []

with tqdm(total=len(test_data_loader), desc=f'Test epoch {1}/{1}', unit='batch',,mininterval=3) as pbar:
    for i in range(len(test_data_loader)):
        batch=get_test_batch()
        with torch.no_grad():
            input_ids = batch['input_ids'].to(device)
            attention_mask=batch['attention_mask'].to(device)
            generated_ids = model.generate(input_ids=input_ids,
                                           attention_mask=attention_mask,
                                           max_new_tokens=256,
                                           do_sample=False,
                                           num_beams=3,)
            #preds_list.extend(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
        #generated_ids = generate_question(batch)
        qs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        generated_questions.extend(qs)
        pbar.update(1)

for _,item in enumerate(generated_questions):
    generated_questions_dict.append({'_id':test_id_list[_],'question':item})

saveJsonResult(generated_questions_dict, data_type = 'test')

Test epoch 1/1:  92%|█████████▏| 849/926 [35:12<03:07,  2.43s/batch]