In [1]:
# import nltk
# nltk.download('punkt')
# nltk.download('omw-1.4')
# !mkdir /root/nltk_data/corpora
# !cp ./wordnet.zip /root/nltk_data/corpora

In [2]:
from functools import partial
import os
from torch.utils.data import DataLoader
import  torch
from torch import nn
from transformers import get_scheduler
import losses
import data_load_utils
import random
from Configs import Config
from model import Imporved_BRIO
import SELECT_MODEL_AND_DATASET
import time
from TestDataSet import ValidDataSet
from compare_mt.rouge.rouge_scorer import RougeScorer
rouge_scorer = RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=True)
model_size  = 'large'
model_name = 'pegasus'
dataset_name = 'xsum'
is_model_base =  True if model_size == 'base' else False
args = Config(data_name=dataset_name)
torch.cuda.manual_seed(args.seed)
LOSS_DIR = f'./loss_res-{dataset_name}'
CHECK_POINTS_DIR = './check_points'

if(not os.path.exists(LOSS_DIR)):
    os.mkdir(LOSS_DIR)
if(not os.path.exists(CHECK_POINTS_DIR)):
    os.mkdir(CHECK_POINTS_DIR)

def flush_res(text):
    with open('./pred_text.txt','w') as f:
        f.write(text)
        f.close()
def flush_mle(text):
    with open(os.path.join(LOSS_DIR,'mle_loss.txt'),'a+') as f:
        f.write(text+'\n')
        f.close()
        
def flush_contraste(text):
    with open(os.path.join(LOSS_DIR,'contraste_loss.txt'),'a+') as f:
        f.write(text+ '\n')
        f.close()
def flush_loss(text):
    with open(os.path.join(LOSS_DIR,'loss.txt'),'a+') as f:
        f.write(text+'\n')
        f.close()
def flush_acc_rank(text):
    with open(os.path.join(LOSS_DIR,'acc_rank_loss.txt'),'a+') as f:
        f.write(text+'\n')
        f.close()
def flush_ce(text):
    with open(os.path.join(LOSS_DIR,'ce_loss.txt'),'a+') as f:
        f.write(text+'\n')
        f.close()
def flush_acc(text):
    with open(os.path.join(LOSS_DIR,'acc.txt'),'a+') as f:
        f.write(text+'\n')
        f.close()
def flush_lp(text):
    with open(os.path.join(LOSS_DIR,'lp_loss.txt'),'a+') as f:
        f.write(text+'\n')
        f.close()
def flush_val_rouge(text):
    with open(os.path.join(LOSS_DIR,'val_rouge.txt'),'a+') as f:
        f.write(text+' ')
        f.close()

In [3]:
model = Imporved_BRIO(model_name=model_name,dataset_name=dataset_name,args = args).to(args.device)
tokenizer = model.tokenizer

In [4]:
train_set = data_load_utils.BrioDataset(fdir=args.candidates_train_dir,tokenizer = tokenizer,max_len = args.max_len,total_len = args.total_len,is_pegasus = args.is_pegasus)

In [5]:
args.is_pegasus,args.length_penalty,args.accumulate_step,args.total_len,args.val_size,args.mle_weight,args.contraste_weight,args.margin,args.gold_margin,args.batch_size

(True, 0.6, 4, 512, 25)

In [6]:
collate_fn = partial(data_load_utils.collate_mp_brio, pad_token_id=tokenizer.pad_token_id, is_test=False)
train_loader = DataLoader(train_set,batch_size = args.batch_size,shuffle=True,drop_last=True,collate_fn=collate_fn)
validation_set = ValidDataSet('./xsum_validation_set/',l = args.val_size)
validation_loader =DataLoader(dataset=validation_set,batch_size=args.val_size,shuffle=False,drop_last=True)

In [7]:
optimizer = torch.optim.Adam(model.parameters())
def validation(model,args):
    scores = 0
    for data in validation_loader:
        if dataset_name == 'cnndam':
            x = [ i.lower().strip() for i in data[0]]
            y = [ i.lower().strip() for i in data[1]]
        else:
            x = [ i.strip() for i in data[0]]
            y = [ i.strip() for i in data[1]]
        pred_list=[]
        X = tokenizer.batch_encode_plus(x, max_length=args.total_len, truncation=True, return_tensors='pt',padding='max_length').to(args.device)
        with torch.no_grad():
            out = model.model.generate(
                input_ids=X.input_ids.to(args.device),
                attention_mask=X.attention_mask.to(args.device),
                max_length = args.gen_max_len + 2,
                min_length=args.gen_min_len + 1,
                no_repeat_ngram_size=3,
                early_stopping=True,
                num_beams=4
            )
        pred = tokenizer.batch_decode(out,skip_special_tokens=True,clean_up_tokenization_spaces=True)
        if dataset_name == 'cnndam':
            pred =[i.lower().strip() for i in pred]
        else :
            pred =[i.strip() for i in pred]
        break
    n = len(pred)
    for  i in range(n):
        rouge = rouge_scorer.score(target=y[i],prediction=pred[i])
        scores+= rouge['rouge1'].fmeasure
    scores = scores / n
    del x,y,pred_list,X,out,pred
    return scores


(0.1, 10, 0.001, 0, 4)

In [None]:
mle_fn = losses.label_smooth_loss(tokenizer.pad_token_id,epsilon=args.smooth)
flush_step = 50
mle_list ,ce_list,contraste_list,loss_list,acc_rank_list,ce_list,acc_list ,lp_list= [],[],[],[],[],[],[],[]
all_step_cnt = 0

for epoch in range(1,args.epoch+1):
    t1  = time.time()
    optimizer.zero_grad()
    step_cnt = 0
    z=0
    for i , batch in enumerate(train_loader):
        
        step_cnt += 1
        model.train()
        x = batch["src_input_ids"].to(args.device)
        y =  batch["candidate_ids"].to(args.device)
        output,candidate_id,cand_mask,all_prob = model(x, y, normalize=args.normalize, score_mode='log', length_penalty=args.length_penalty, require_gold=True, adding=0)
        similarity, gold_similarity = output['score'].to(args.device), output['summary_score'].to(args.device)
        similarity = similarity * args.scale
        gold_similarity = gold_similarity * args.scale
        contraste_loss = losses.RankingLoss(similarity, gold_similarity, args.margin, args.gold_margin, args.gold_weight)
        probs = output["probs"].to(args.device)  # [bz, seq_len, word_num]
        probs = output["probs"][:, :-1]  # truncate last token
        gold = batch["candidate_ids"][:, 0, 1:].to(args.device)  # shift right
        mle_loss = mle_fn(probs.transpose(1, 2), gold)
        
        acc_rank_loss ,acc,lp_loss= losses.acc_rankling_loss(all_prob,candidate_id,cand_mask,lenth_penalty=args.length_penalty,is_lpsum=args.is_lpsum)
        loss = args.contraste_weight* contraste_loss +  args.mle_weight*mle_loss  +args.lp_weight *lp_loss + args.acc_rank_weight * acc_rank_loss
        mle_list.append(mle_loss.item())
        contraste_list.append(contraste_loss.item())
        loss_list.append(loss.item())
        acc_rank_list.append(acc_rank_loss.item())
        lp_list.append(lp_loss.item())
        acc_list .append(acc.item())
        
        if(i % flush_step == 0 and i > 0):
            flush_loss(str(sum(loss_list) / flush_step))
            flush_mle(str(sum(mle_list) / flush_step))
            flush_contraste(str(sum(contraste_list) / flush_step))
            flush_acc_rank(str(sum(acc_rank_list) / flush_step))
            flush_lp(str(sum(lp_list) / flush_step))
            flush_acc(str(sum(acc_list) / flush_step))
            mle_list.clear()
            contraste_list.clear()
            loss_list.clear()
            acc_rank_list.clear()
            lp_list.clear()
            acc_list.clear()
        loss = loss / args.accumulate_step
        loss.backward()
            
        if (step_cnt == args.accumulate_step):
            if args.grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)
            step_cnt = 0
            all_step_cnt += 1
            # adjust learning rate
            lr = args.max_lr * min(all_step_cnt ** (-0.5), all_step_cnt * (args.warmup_steps ** (-1.5)))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            optimizer.step()
            optimizer.zero_grad()
                
        if(i%10 == 0):
            t2 = time.time()
            t = int(t2 - t1)
            flush_res(f'{epoch}:[{str(i)}/{len(train_loader)}]:{loss.item()}\n[10 items / {t}s]')
            t1 = time.time()

        del similarity, gold_similarity, loss, mle_loss, output, probs,acc_rank_loss,acc,contraste_loss,lp_loss
        if(i % 20 ==0):
            model.eval()
            val_rouge = validation(model,args)
            if(val_rouge) >= args.stand_rouge: 
                model_save_name = f'based_on-{model_name}-{dataset_name}-{epoch}-{i}-rouge-{val_rouge}-.bin'
                dirs_ = [ f for f in os.listdir(CHECK_POINTS_DIR) if f.endswith('.bin')]
                if(len(dirs_) >12):
                    dirs = [f.split('-') for f in dirs_]
                    rouge_number = [float(n[6]) for n in dirs]
                    v_r = float(model_save_name.split('-')[6])
                    if(v_r >= min(rouge_number)):
                        dirs.sort(key=lambda x:x[6],reverse=False)
                        t_name = dirs[0][6]
                        for i in dirs_:
                            if (i.__contains__(t_name)):
                                os.remove(os.path.join(CHECK_POINTS_DIR,i))
                                torch.save(model,os.path.join(CHECK_POINTS_DIR,model_save_name))
                                break
                else:
                    torch.save(model,os.path.join(CHECK_POINTS_DIR,model_save_name))
                                           
                
            model.train()
            if(z%40 ==0):
                flush_val_rouge(str(val_rouge))
        z+=1


  next_indices = next_tokens // vocab_size
