In [None]:
%%time
!pip install pycocoevalcap
!pip install bert_score

In [None]:
## 计算模型测试结果的bleu4和bert_score
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from bert_score import BERTScorer
import numpy as np

bleu4_scorers = {"Bleu": Bleu(4)}
bert_scorer = BERTScorer(lang="en",
                         model_type='roberta-large',
                         device='cuda:1',
                         rescale_with_baseline=True
                        )

# 测评问题的流畅性
def fluencyScore(preds_list, gold_list):
    gts = {}
    res = {}
    for i, (p, g) in enumerate(zip(preds_list, gold_list)):
        gts[i] = [p]
        res[i] = [g]
    scores = {}
    for name, scorer in bleu4_scorers.items():
        score, all_scores = scorer.compute_score(gts, res)
        if isinstance(score, list):
            for i, sc in enumerate(score, 1):
                scores[name + str(i)] = sc
        else:
            scores[name] = score
    return scores,all_scores[-1]

# 测评语义相似度
def SemanticScore(preds_list, gold_list):
    p,r,f1 = bert_scorer.score(preds_list, gold_list, verbose=True)
    bert_score = np.mean(f1.tolist())
    return bert_score

## 合并结果
def getTotalScore(preds_list,gold_list):
    bert_score = SemanticScore(preds_list,gold_list)
    scores,all_scores = fluencyScore(preds_list,gold_list)
    last_score = (bert_score/2+scores['Bleu4']/2)*100
    return {'TotalScore':last_score, 
            'BERTScore':bert_score,
            'Bleu1':scores['Bleu1'],
            'Bleu2':scores['Bleu2'],
            'Bleu3':scores['Bleu3'],
            'Bleu4':scores['Bleu4'],
           },all_scores


In [4]:
"""
model_name: 模型名称，本代码中支持不同T5型号
train_data_path,dev_data_path,test_data_path:三个数据集路径
"""
class Args:
    model_name = 'google/flan-t5-base'
    train_data_path = '/kaggle/input/scnu-ai-challenge-5/train.json'
    dev_data_path = '/kaggle/input/scnu-ai-challenge-dataset-with-sorted-pred-facts/dev.json'
    test_data_path = '/kaggle/input/scnu-ai-challenge-dataset-with-sorted-pred-facts/test.json'

## 用于记录训练过程中的loss和score
class MyLogging:
    best_score = 0
    loss_list=[]
    scores_list=[]

args = Args()
logging=MyLogging()

def plot_loss():
    import matplotlib.pyplot as plt
    loss_list=logging.loss_list
    plt.plot(range(len(loss_list)),loss_list)
    plt.show()

In [5]:
import json
import random
from tqdm import tqdm
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

import torch
from torch import nn
import torch.nn.functional as F

## 数据整理器
from transformers import DataCollatorForSeq2Seq
## 模型及分词器
from transformers import T5ForConditionalGeneration as ConditionalGeneration
from transformers import AutoTokenizer
## 优化器
from transformers.optimization import Adafactor, AdafactorSchedule


2024-06-25 12:54:48.488604: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-25 12:54:48.488697: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-25 12:54:48.619911: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
class Trainer_config:
    pg_num=1
    max_source_length=1024
    max_target_length=128
    training_batch=512
    testing_batch=10
    updata_batch=8
    epochs=2
    batch_size=16
    eval_model=True
def get_trainer_config(pg_num,max_source_length,batch_size,updata_batch,epochs):
    config=Trainer_config()
    config.pg_num=pg_num
    config.max_source_length=max_source_length
    config.batch_size=batch_size
    config.updata_batch=updata_batch
    config.epochs=epochs
    return config

## 测试
def get_eval_config():
    config=Trainer_config()
    config.pg_num=7
    config.max_source_length=1024
    config.batch_size=4
    return config
def show_config(config):
    print('pg_num:',config.pg_num)
    print('max_souce_length:',config.max_source_length)
    print('train_batch:',config.training_batch)
    print('updata_batch:',config.updata_batch)
    print('epochs:',config.epochs)
    print('batch_size:',config.batch_size)



In [7]:
## 读入数据
data_dict={}
def read_json(data_path,data_type):
    if data_dict.get(data_type,-1)!=-1:
        return data_dict[data_type]
    else:
        with open(data_path, 'r', encoding='utf-8') as f:
            data =json.load(f)
        data_dict[data_type]=data
        return data

## 清洗数据及准备数据   
def get_datas(data_path,data_type):
    data=read_json(data_path,data_type)
    answers=[]
    questions=[]
    title_to_contexts=[]
    support_facts=[]
    for i,d in tqdm(enumerate(data)):
        ## 删除训练集中的单跳问题
        if data_type=='train':
            if d['level']=='easy':
                continue
        
        ## 验证集只保留难度为hard的文段，结果较为稳定
        if data_type=='dev':
            if d['level']!='hard':
                continue
        
        answer=d['answer']
        
        answer_token_num=len(answer.split(' '))
        ## 根据1-0.25*len的概率删除数据
        if data_type=='train':
            p_drop=max(1-0.25*answer_token_num,0)
            if random.random()<p_drop:
                continue
                
        ## support_facts
        support_facts_list=[]
        fact_set=set()
        if data_type=='train':
            titles=[]
            for item in d['supporting_facts']:
                fact_set.add(item[0])
            support_facts_list=list(fact_set)
        else:
            for item in d['pred_support_facts']:
                support_facts_list.append(item[0])
     
    
        ## title_to_context
        title_to_context={}   
        for x in d['context']:
            s=' '.join(x[1])
            context=f'{x[0]}. {s} '
            title_to_context[x[0]]=context
        
        ## question
        question=[]
        if data_type!='test':
            question=d['question']
        
        ## append
        questions.append(question)
        title_to_contexts.append(title_to_context)
        answers.append(answer)
        support_facts.append(support_facts_list)
        
    return title_to_contexts,questions,answers,support_facts

In [8]:
%%time
## 生成训练集并打乱
train_T2C,train_Q,train_A,train_SF=get_datas(args.train_data_path,'train')
def get_infinite_data(datas):
    while True:
        for value in datas:
            yield value

train_data=list(zip(train_T2C,train_Q,train_A,train_SF))
print('train_data_num:',len(train_data))
random.shuffle(train_data)
train_data_generator=get_infinite_data(train_data)
def get_example():
    return train_data_generator.__next__()

## 将数据整理成输入格式
def get_context(title_to_context,support_facts_list,data_type='train',pg_num=7):
    if data_type=='train':
        if len(support_facts_list)>pg_num:
                support_facts_list=support_facts_list[:pg_num]
        else:
            num=pg_num-len(support_facts_list)
            titles=[]
            for title in list(title_to_context.keys()):
                if title not in support_facts_list:
                    titles.append(title)
            support_facts_list.extend(random.choices(titles,k=min(len(titles),num)))
        random.shuffle(support_facts_list)
    else:
        support_facts_list=support_facts_list[:pg_num]
        
    context=''   
    for title in support_facts_list:
        context+=title_to_context[title]
    return context

title_to_context,question,answer,support_facts_list=get_example()
get_context(title_to_context,support_facts_list,data_type='train',pg_num=2)

88947it [00:00, 118844.51it/s]

train_data_num: 37382
CPU times: user 7.99 s, sys: 2.27 s, total: 10.3 s
Wall time: 19.3 s





'Albany Medical College. Albany Medical College (AMC) is a medical school located in Albany, New York, United States.  It was founded in 1839 by Alden March and James H. Armsby and is one of the oldest medical schools in the nation.  The college is part of the Albany Medical Center, which includes the Albany Medical Center Hospital. Henry Hun. Henry Hun (March 21, 1854 – March 14, 1924) an American physician, was professor of Nervous Diseases at the Albany Medical College in New York for 30 years.  He published several unique teaching volumes for his students as well as numerous journal articles on neurological disorders. '

In [17]:
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self,
                 title_to_contexts,
                 support_facts_lists,
                 questions,
                 answers,
                 tokenizer,
                 config,
                 data_type):
        self.title_to_contexts=title_to_contexts
        self.support_facts_lists=support_facts_lists
        self.questions = questions
        self.answers=answers
        self.tokenizer = tokenizer
        self.config=config
        self.data_type=data_type

    def __len__(self):
        if self.data_type=='train':
            return self.config.epochs*self.config.training_batch*self.config.batch_size
        return len(self.answers)
    
    def __getitem__(self,item):
        if self.data_type=='train':
            title_to_context,question,answer,support_facts_list=get_example()
        else:
            title_to_context= self.title_to_contexts[item]
            support_facts_list=self.support_facts_lists[item]
            answer=self.answers[item]
            question=''
            if self.data_type=='dev':
                question=self.questions[item]

        context=get_context(title_to_context,
                            support_facts_list,
                            self.data_type,
                            self.config.pg_num)

        source=f'Please generate only one question according to the answer and the paragraphs.\n\
                   Answer: {answer} \n Content: {context}'
        target=question
            
        source_encoding = self.tokenizer(text=source,
                      max_length=self.config.max_source_length,
                      padding=True,
                      truncation=True,
                      add_special_tokens=True,
                      return_attention_mask=True,
                      )
        if self.data_type=='test':
            return {
                "input_ids":source_encoding['input_ids'],
               "attention_mask":source_encoding['attention_mask'],
               }
        
        target_encoding=self.tokenizer(text=target,
                      max_length=self.config.max_target_length,
                      padding=True,
                      truncation=True,
                      add_special_tokens=True,
                      return_attention_mask=True,
                    )
   

        return {
                "input_ids":torch.LongTensor(source_encoding['input_ids']),
                "attention_mask":torch.LongTensor(source_encoding['attention_mask']),
                "labels":torch.LongTensor(target_encoding['input_ids']),
               }

    
def create_data_loader(data_path,data_type,tokenizer,config,shuffle=True):
    pg_num=config.pg_num
    max_len=config.max_source_length
    if data_type=='train':
        batch_size=config.batch_size
        title_to_contexts=None
        questions=None
        answers=None
        support_facts=None
    else:
        title_to_contexts,questions,answers,support_facts=get_datas(data_path,data_type)
        batch_size=config.batch_size*2

    ds = MyDataSet(
                   title_to_contexts=title_to_contexts,
                   support_facts_lists=support_facts,
                   questions=questions,
                   answers=answers,
                   tokenizer = tokenizer,
                   config=config,
                   data_type=data_type
                  )
    
    collate_fn = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8,padding=True)
    
    return torch.utils.data.DataLoader(ds,batch_size=batch_size,collate_fn=collate_fn,shuffle=shuffle)



In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
def load_model(model_name):
    print('model_name:',model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)
    model = ConditionalGeneration.from_pretrained(model_name,device_map="auto")#.to(device)
    print('Successfully loaded model')
    return model,tokenizer
def load_checkpoint(checkpoint):
    if checkpoint:
        try:
            model.load_state_dict(torch.load(checkpoint))
            print('Successfully loaded checkpoint!')
        except:
            print('Failed to load checkpoint!')
    


## 加载模型和检查点
model_name='t5-base'
checkpoint='/kaggle/input/t5-base-checkpoint5/checkpoint_best (34.90).pkl'
if model_name==None:
    model_name=args.defalut_model_name
model,tokenizer=load_model(model_name)
load_checkpoint(checkpoint)

optimizer = Adafactor(model.parameters(), relative_step=True, warmup_init=True,lr=None,clip_threshold=1.0)
scheduler = AdafactorSchedule(optimizer)

model_name: t5-base




config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Successfully loaded model
Failed to load checkpoint!


In [11]:
%%time
## about 23s
def prepare_data(tokenizer,config):
    train_data_path=args.train_data_path
    dev_data_path=args.dev_data_path
    test_data_path=args.test_data_path
    print('preparing data...')
    train_data_loader=create_data_loader(train_data_path,'train',tokenizer,config,shuffle=True,)
    dev_data_loader=create_data_loader(dev_data_path,'dev',tokenizer,config,shuffle=False)
    test_data_loader=create_data_loader(test_data_path,'test',tokenizer,config,shuffle=False)
    print('prepare data success!')
    return train_data_loader,dev_data_loader,test_data_loader
train_data_loader,dev_data_loader,test_data_loader=prepare_data(tokenizer,config=get_eval_config())


preparing data...


1500it [00:00, 200735.63it/s]
7405it [00:00, 59629.88it/s]

prepare data success!
CPU times: user 1.03 s, sys: 593 ms, total: 1.62 s
Wall time: 2.93 s





In [None]:
train_generator=get_infinite_data(train_data_loader)
def get_batch():
    return train_generator.__next__()
def get_answer(batch):
    bad_answers=[]
    for i in range(batch['input_ids'].shape[0]):
        bad_answers.append(tokenizer.convert_ids_to_tokens(batch['input_ids'][i])[10:20])
    return bad_answers
x=get_batch()
inputs=tokenizer.convert_ids_to_tokens(x['input_ids'][0])
#labels=tokenizer.convert_ids_to_tokens(x['labels'][0].reshape(-1))
x['input_ids'].shape,inputs

In [None]:
## 评估模型函数
def eval_model(epoch=1,testing_batch=-1):
    model.eval()
    preds_list = []
    real_list=[]
    if testing_batch<=0:
        testing_batch=len(dev_data_loader)
    with tqdm(total=testing_batch, desc=f'Validation Epoch {epoch}', unit='batch') as pbar:
        for i,batch in enumerate(dev_data_loader):
            if i>testing_batch:break
            with torch.no_grad():
                input_ids = batch['input_ids'].to(device)
                attention_mask=batch['attention_mask'].to(device)
                labels=batch['labels'].to(device)
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
                generated_ids = model.generate(
                    input_ids=input_ids, 
                    attention_mask=attention_mask,
                    do_sample=False,
                    max_new_tokens=128,
                    num_beams=3, 
                )
                preds_list.extend(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
                labels=torch.where(labels!=-100,labels,tokenizer.pad_token_id)
                real_list.extend(tokenizer.batch_decode(labels, skip_special_tokens=True))
                
            pbar.update(1)
    
    scores,all_score = getTotalScore(preds_list, real_list)
    for i in range(20):
        print('pred: ',preds_list[i])
        print("true: ",real_list[i])
        print('bleu4:',all_score[i])

    if scores['TotalScore'] > logging.best_score:
        torch.save(model.state_dict(),  "/kaggle/working/checkpoint_best.pkl")
        print(f"Total score: {logging.best_score} -> {scores['TotalScore'] }")
        print(f"checkpoint_best.pkl已存储至/kaggle/working/")
        logging.best_score = scores['TotalScore'] 

    logging.scores_list.append(scores)
    print(f"Scores: {scores}")
    return preds_list
pred_list=eval_model(epoch=0,testing_batch=10)

In [14]:
def train(model,config,show=False):
    if show:
        show_config(config)
        
    pg_num=config.pg_num
    training_batch=config.training_batch
    testing_batch=config.testing_batch
    updata_batch=config.updata_batch
    epochs=config.epochs
    global dev_data_loader
    train_data_loader,dev_data_loader,_=prepare_data(tokenizer,config)
    train_generator=get_infinite_data(train_data_loader)
    def get_batch():
        return train_generator.__next__()
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        batch_loss=0
        with tqdm(total=training_batch//updata_batch, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch',mininterval=3) as pbar:
            for i in range(training_batch):
                batch=get_batch()
                input_ids= batch['input_ids'].to(device)
                attention_mask=batch['attention_mask'].to(device)
                labels=batch['labels'].to(device)
 
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()

                outputs = model(
                                input_ids=input_ids,
                                attention_mask=attention_mask,
                                labels=labels, 
                                return_dict=True,
                               )
                
               
                loss = outputs.loss
                loss.backward() 

                loss_t = loss.detach().cpu()
                epoch_loss += loss_t[0] if len(loss_t.shape) > 0 else loss_t.numpy()
                batch_loss +=loss_t[0] if len(loss_t.shape) > 0 else loss_t.numpy()
                
                if (i+1)%updata_batch==0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    batch_loss=batch_loss/updata_batch
                    #lr=optimizer.param_groups[0]['lr']
                    logging.loss_list.append(batch_loss)
                    pbar.set_postfix({
                        'loss': f'{ batch_loss:.4f}',
                        #'lr':f'{lr:.6f}'
                    })
                    batch_loss=0
                    pbar.update(1)
                    
        print(f"Average Loss: {epoch_loss / training_batch:.4f}")
        # 验证部分
        if config.eval_model:
            eval_model(epoch,testing_batch)
        torch.save(model.state_dict(), "/kaggle/working/checkpoint_last.pkl")

In [14]:
## stage1
config1={
    'pg_num':2,
    'max_source_length':384,
    'batch_size':12,
    'updata_batch':2,
    'epochs':10,
}
trainer_config=get_trainer_config(**config1)
train(model,trainer_config,True)

pg_num: 2
max_souce_length: 384
train_batch: 512
updata_batch: 2
epochs: 10
batch_size: 12
preparing data...


1500it [00:00, 137913.06it/s]
7405it [00:00, 56261.29it/s]


prepare data success!


Epoch 1/10: 100%|██████████| 256/256 [11:04<00:00,  2.60s/batch, loss=1.6687]


Average Loss: 2.0888


Validation Epoch 0: 11batch [01:21,  7.37s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 0.82 seconds, 107.81 sentences/sec
{'testlen': 1329, 'reflen': 1534, 'guess': [1329, 1241, 1153, 1065], 'correct': [295, 84, 29, 11]}
ratio: 0.8663624511076491
pred:  Plantago is a genus of about 200 species of small, inconspicuous plants commonly called plantains or fleaworts. Plantago princeps is a rare species of flowering plant in the plantain family known by the common name ale.
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 4.9687989121574917e-14
pred:  What season of the Indian reality TV series "Bigg Boss" was hosted by Salman Khan?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 0.2928298013524784
pred:  Which United States Senator was elected to the open seat left vacant by resigning U.S. Senator John Ensign?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 2.284411468590456e-09
pred:  The episod

Epoch 2/10: 100%|██████████| 256/256 [11:08<00:00,  2.61s/batch, loss=1.5812]


Average Loss: 1.4212


Validation Epoch 1: 11batch [01:15,  6.90s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 0.81 seconds, 109.21 sentences/sec
{'testlen': 1329, 'reflen': 1496, 'guess': [1329, 1241, 1153, 1065], 'correct': [360, 110, 47, 20]}
ratio: 0.8883689839566254
pred:  Are Plantago and Plantago hawaiensis both species of plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 6.376715692219053e-09
pred:  What season of the Indian reality TV series "Bigg Boss" was hosted by Salman Khan?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 0.2928298013524784
pred:  Which Nevada senator was elected to the open seat left vacant by resigning U.S. Senator John Ensign?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 3.4159919055542935e-13
pred:  The homer they fall was the first episode to air on what date?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
bleu4: 3.6

Epoch 3/10: 100%|██████████| 256/256 [11:08<00:00,  2.61s/batch, loss=1.2998]


Average Loss: 1.3044


Validation Epoch 2: 11batch [01:09,  6.34s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 0.74 seconds, 118.60 sentences/sec
{'testlen': 1329, 'reflen': 1387, 'guess': [1329, 1241, 1153, 1065], 'correct': [394, 128, 53, 24]}
ratio: 0.9581831290548247
pred:  Are Plantago and Plantago hawaiensis both endemic to Hawaii?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 6.030725358915047e-09
pred:  What season of the Indian reality TV series "Bigg Boss" was hosted by Salman Khan?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 0.2928298013524784
pred:  Which Republican U.S. Senator from Nevada was elected to the open seat left vacant by resigning U.S. Senator John Ensign?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 3.6351253345722926e-13
pred:  When did the episode that was the first episode to air on the Fox network in the United States?
true:  When was the Simpson's episode broadcasted that int

Epoch 4/10: 100%|██████████| 256/256 [11:04<00:00,  2.60s/batch, loss=1.2828]


Average Loss: 1.2477


Validation Epoch 3: 11batch [01:08,  6.21s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 0.79 seconds, 111.58 sentences/sec
{'testlen': 1329, 'reflen': 1488, 'guess': [1329, 1241, 1153, 1065], 'correct': [388, 122, 44, 20]}
ratio: 0.8931451612897223
pred:  Are both Plantago and Plantago hawaiensis native to the same country?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 4.82902736160217e-09
pred:  What season of the Indian reality TV series "Bigg Boss" did Lopamudra Raut win the third "Best National Costume" award for India?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 0.2496907208867028
pred:  Who was the senior United States Senator from Nevada who was elected to the open U.S. Senate seat in 2012?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 2.570421774539086e-09
pred:  When did the episode in which the character who manages the bar in Fort Greene, Brooklyn first air?
true:  When was

Epoch 5/10: 100%|██████████| 256/256 [11:05<00:00,  2.60s/batch, loss=1.1920]


Average Loss: 1.2218


Validation Epoch 4: 11batch [01:28,  8.01s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 1.13 seconds, 77.61 sentences/sec
{'testlen': 1329, 'reflen': 1496, 'guess': [1329, 1241, 1153, 1065], 'correct': [406, 134, 52, 20]}
ratio: 0.8883689839566254
pred:  Are both Plantago and Plantago hawaiensis genuses of plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 6.376715692219053e-09
pred:  What season of Indian reality TV series did Lopamudra Raut compete in?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 2.4623953023675646e-05
pred:  Which Republican Senator was elected to the open seat in the United States Senate in Nevada in 2012?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 2.225025328784018e-05
pred:  When did the Simpsons episode that aired in the United States air?
true:  When was the Simpson's episode broadcasted that introduced the character Morris "Moe" Szyslak?
bleu4: 7.3470531

Epoch 6/10: 100%|██████████| 256/256 [11:07<00:00,  2.61s/batch, loss=1.4339]


Average Loss: 1.2004


Validation Epoch 5: 11batch [01:33,  8.50s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 1.18 seconds, 74.64 sentences/sec
{'testlen': 1329, 'reflen': 1615, 'guess': [1329, 1241, 1153, 1065], 'correct': [439, 164, 66, 26]}
ratio: 0.8229102167177568
pred:  Are Plantago and Plantago both types of plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 4.264366797236323e-05
pred:  What season of the Indian reality TV series "Bigg Boss" did Lopamudra Raut win the third Best National Costume Award for India?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 0.2496907208867028
pred:  United States Senate election in Wyoming, 2012 and the 2012 United States Senate election in Utah, 2012 were held alongside which other election to the United States Senate and House of Representatives?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 0.1263410722664223
pred:  When was the first episode of the Simpsons tha

Epoch 7/10: 100%|██████████| 256/256 [11:08<00:00,  2.61s/batch, loss=1.2209]


Average Loss: 1.1775


Validation Epoch 6: 11batch [01:18,  7.17s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 0.87 seconds, 101.25 sentences/sec
{'testlen': 1329, 'reflen': 1416, 'guess': [1329, 1241, 1153, 1065], 'correct': [403, 139, 58, 23]}
ratio: 0.9385593220332354
pred:  Are both Plantago and Plantago a rare species of flowering plant?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 5.10607476252634e-09
pred:  What season of the Indian reality TV series "Bigg Boss" did Lopamudra Raut win the title of Miss United Continents 2016?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 0.26481899767115935
pred:  What senior United States Senator from Nevada was recently appointed to the seat left vacant by resigning U.S. Senator John Ensign?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 2.1400047283729616e-09
pred:  What date did the episode that starred in the episode "Flaming Moe's" air?
true:  When was the Simpso

Epoch 8/10: 100%|██████████| 256/256 [11:08<00:00,  2.61s/batch, loss=1.3926]


Average Loss: 1.1588


Validation Epoch 7: 11batch [01:07,  6.13s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 0.72 seconds, 121.54 sentences/sec
{'testlen': 1329, 'reflen': 1339, 'guess': [1329, 1241, 1153, 1065], 'correct': [442, 165, 78, 40]}
ratio: 0.9925317401038144
pred:  Are Plantago asiatica and Plantago hawaiensis both flowering plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 6.030725358915047e-09
pred:  What season of the Indian reality TV series "Bigg Boss" did Lopamudra Raut win the "Best National Costume" award for India?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 0.26481899767115935
pred:  Which United States Senator from Nevada was elected to Nevada's open U.S. Senate seat in 2012?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 2.8724953479163113e-09
pred:  When did the Simpsons episode that aired "The Homer They Fall" originally air?
true:  When was the Simpson's episode broadcasted th

Epoch 9/10: 100%|██████████| 256/256 [11:12<00:00,  2.63s/batch, loss=1.1154]


Average Loss: 1.1353


Validation Epoch 8: 11batch [01:28,  8.02s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 1.23 seconds, 71.51 sentences/sec
{'testlen': 1329, 'reflen': 1565, 'guess': [1329, 1241, 1153, 1065], 'correct': [432, 146, 54, 20]}
ratio: 0.849201277954729
pred:  Are Plantago hawaiensis and Plantago a genus of plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 6.030725358915047e-09
pred:  What season of Bigg Boss was hosted by Salman Khan?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 2.4325425696936665e-09
pred:  Which United States Senator from Nevada was elected to the Senate in 2012, United States Senate election in Utah, 2012, or United States Senate election in Nevada, 2012, alongside 33 other elections to the United States Senate in other states, as well as other elections to the United States Senate in other states, as well as other elections to the United States Senate in other states, as well as other elections to the United States Senate in other states, as well as oth

Epoch 10/10: 100%|██████████| 256/256 [11:05<00:00,  2.60s/batch, loss=1.2723]


Average Loss: 1.1160


Validation Epoch 9: 11batch [01:14,  6.77s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 0.82 seconds, 107.69 sentences/sec
{'testlen': 1329, 'reflen': 1391, 'guess': [1329, 1241, 1153, 1065], 'correct': [428, 148, 62, 25]}
ratio: 0.9554277498195863
pred:  Are both plantago and plantago asiatica plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 9.980099402512292e-13
pred:  What season of Indian reality TV series "Bigg Boss" did Lopamudra Raut contest?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 2.735838424446947e-05
pred:  Which senior United States Senator from Nevada was elected to his first full term in the United States Senate election in Wyoming, 2012?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 0.15939331383660452
pred:  When was the episode of "The Simpsons" that the character who manages Moe Goes from Rags to Riches first appeared?
true:  When was the Simpson's episode br

In [15]:
## stage2
config2={
    'pg_num':7,
    'max_source_length':1024,
    'batch_size':4,
    'updata_batch':8,
    'epochs':20,
}
trainer_config=get_trainer_config(**config2)
train(model,trainer_config,True)

pg_num: 7
max_souce_length: 1024
train_batch: 512
updata_batch: 8
epochs: 20
batch_size: 4
preparing data...


1500it [00:00, 157046.90it/s]
7405it [00:00, 61402.31it/s]


prepare data success!


Epoch 1/20: 100%|██████████| 64/64 [11:57<00:00, 11.22s/batch, loss=1.0425]


Average Loss: 1.0679


Validation Epoch 0: 11batch [01:47,  9.81s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/2 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.52 seconds, 83.97 sentences/sec
{'testlen': 635, 'reflen': 729, 'guess': [635, 591, 547, 503], 'correct': [284, 132, 64, 29]}
ratio: 0.8710562414254169
pred:  Are Plantago and Trichosanthes in the same genus?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 1.0724314734549807e-12
pred:  Lopamudra Raut was a contestant of what season of Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 2.3693055762206496e-05
pred:  Which Republican U.S. Senator from Nevada was narrowly elected to his first full term in the 2012 United States Senate election?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 0.13090213860020217
pred:  Moe Szyslak first appeared in an episode that originally aired on the Fox network in the United States on what date?
true:  When was the Simpson's episode bro

Epoch 2/20: 100%|██████████| 64/64 [11:59<00:00, 11.24s/batch, loss=1.0497]


Average Loss: 1.0414


Validation Epoch 1: 11batch [01:50, 10.03s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/2 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.51 seconds, 85.87 sentences/sec
{'testlen': 635, 'reflen': 749, 'guess': [635, 591, 547, 503], 'correct': [286, 130, 66, 31]}
ratio: 0.8477970627492019
pred:  Are Plantago and Trichosanthes both types of plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 1.1868405217902023e-12
pred:  Lopamudra Raut was a contestant of what season of Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 2.3693055762206496e-05
pred:  Which Republican U.S. Senator was elected to the open U.S. Senate seat in 2012 and was narrowly elected to his first full term?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 3.6592701640592855e-13
pred:  Moe Szyslak first appeared in an episode that originally aired on the Fox network in the United States on what date?
true:  When was the Simpson's episod

Epoch 3/20: 100%|██████████| 64/64 [11:57<00:00, 11.21s/batch, loss=1.0393]


Average Loss: 1.0248


Validation Epoch 2: 11batch [01:55, 10.51s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/2 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.52 seconds, 84.83 sentences/sec
{'testlen': 635, 'reflen': 793, 'guess': [635, 591, 547, 503], 'correct': [301, 145, 81, 44]}
ratio: 0.8007566204277418
pred:  Are Plantago and Trichosanthes both types of plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 1.1868405217902023e-12
pred:  Lopamudra Raut was a contestant of what season of Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 2.3693055762206496e-05
pred:  Which Republican U.S. Senator from Nevada was elected to the open U.S. Senate seat in 2012 and was narrowly elected to his first full term over Shelley Berkley?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 2.9307894654014826e-13
pred:  Moe Szyslak first appeared in an episode that originally aired on the Fox network in the United States on what date?
tru

Epoch 4/20: 100%|██████████| 64/64 [11:56<00:00, 11.20s/batch, loss=1.0217]


Average Loss: 1.0169


Validation Epoch 3: 11batch [01:44,  9.46s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/2 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.47 seconds, 94.04 sentences/sec
{'testlen': 635, 'reflen': 731, 'guess': [635, 591, 547, 503], 'correct': [291, 142, 80, 44]}
ratio: 0.8686730506144068
pred:  Are Plantago and Trichosanthes both types of plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 1.1868405217902023e-12
pred:  Lopamudra Raut was a contestant of what season of Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 2.3693055762206496e-05
pred:  Which Republican U.S. Senator was elected to the seat that was held on November 6, 2012?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 3.4159919055542935e-13
pred:  Moe Szyslak first appeared in an episode that originally aired on the Fox network in the United States on what date?
true:  When was the Simpson's episode broadcasted that introduced the chara

Epoch 5/20: 100%|██████████| 64/64 [11:55<00:00, 11.19s/batch, loss=1.0550]


Average Loss: 1.0180


Validation Epoch 4: 11batch [01:46,  9.73s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/2 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.44 seconds, 101.09 sentences/sec
{'testlen': 635, 'reflen': 765, 'guess': [635, 591, 547, 503], 'correct': [300, 149, 86, 48]}
ratio: 0.8300653594760391
pred:  Are Plantago and Trichosanthes both types of plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 1.1868405217902023e-12
pred:  Lopamudra Raut was a contestant of what season of Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 2.3693055762206496e-05
pred:  Which Republican U.S. Senator from Nevada was elected to the Senate seat in 2012 and was narrowly elected to his first full term over Shelley Berkley?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 3.275213114053928e-13
pred:  Moe Szyslak first appeared in an episode that originally aired in the United States on what date?
true:  When was the Simpson's ep

Epoch 6/20: 100%|██████████| 64/64 [11:56<00:00, 11.20s/batch, loss=0.9460]


Average Loss: 1.0004


Validation Epoch 5: 11batch [02:01, 11.01s/batch]                     

calculating scores...
computing bert embedding.





  0%|          | 0/2 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.49 seconds, 89.06 sentences/sec
{'testlen': 635, 'reflen': 790, 'guess': [635, 591, 547, 503], 'correct': [291, 143, 82, 46]}
ratio: 0.8037974683534129
pred:  Are Plantago and Trichosanthes both types of plants?
true:  Are Trichosanthes and Plantago both forms of plant life?
bleu4: 1.1868405217902023e-12
pred:  Lopamudra Raut was a contestant of what season of Indian reality TV series "Bigg Boss"?
true:  In which season of the Indian reality TV show "Big Boss" did the model Lopamundra Raut participate?
bleu4: 2.3693055762206496e-05
pred:  Which Republican U.S. Senator was elected to the seat that was held on November 6, 2012 and was narrowly elected to his first full term over Shelley Berkley?
true:  The 2012 United States Senate election in Nevada concluded with a close victory for which current Republican incumbent?
bleu4: 2.3307710740126315e-13
pred:  Moe Szyslak first appeared in an episode that originally aired in the United States on what date?
true:  When was the Simps

Epoch 7/20:   0%|          | 0/64 [00:04<?, ?batch/s]


KeyboardInterrupt: 

## 在测试集上测试

In [None]:
with open(args.test_data_path, 'r', encoding='utf-8') as json_file:
    data = json.load(json_file)
test_id_list = [item['_id'] for item in data]

train_data_loader,dev_data_loader,test_data_loader=prepare_data(tokenizer,config=get_eval_config())
test_generator=get_infinite_data(test_data_loader)
def get_test_batch():
    return test_generator.__next__()

# 测试
generated_questions = []
generated_questions_dict = []
model.eval()
with tqdm(total=len(test_data_loader), desc=f'Test epoch {1}/{1}', unit='batch') as pbar:
    for i in range(len(test_data_loader)):
        batch=get_test_batch()
        with torch.no_grad():
            input_ids = batch['input_ids'].to(device)
            attention_mask=batch['attention_mask'].to(device)
            generated_ids = model.generate(input_ids=input_ids, 
                                           max_new_tokens=128,
                                           do_sample=False)
            #preds_list.extend(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
        #generated_ids = generate_question(batch)
        qs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        generated_questions.extend(qs)
        pbar.update(1)

for _,item in enumerate(generated_questions):
    generated_questions_dict.append({'_id':test_id_list[_],'question':item})

with open('output.json', 'w', encoding='utf-8') as json_file:
        json.dump(generated_questions_dict, json_file, ensure_ascii=False, indent=4)