In [10]:
import argparse
import os
import sys
import logging
import pickle
from functools import partial
import time
from tqdm import tqdm
from collections import Counter
import random
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
import json, wandb

from typing import Dict, Optional
from datasets import Dataset, load_dataset, DatasetDict

import wandb
from peft import LoraConfig

from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import LearningRateMonitor


from torch.optim import AdamW
from transformers import T5ForConditionalGeneration, T5Tokenizer, TrainingArguments
from transformers import get_linear_schedule_with_warmup, AutoModelForSeq2SeqLM, AutoTokenizer, EarlyStoppingCallback

from trl import DPOTrainer
from tqdm import tqdm

from atoss.data_utils import *
from atoss.eval_utils import *
from atoss.process import *


In [11]:
class Args:
    def __init__(self):
        # dataset parameters
        self.path = '/home/elicer/ATOSS'
        self.method = 'dpo' # task
        self.load_model = 'sft/rest15_top_12' # 읽어올 모델 폴더
        self.save_model = 'dpo/rest15_top_12' # 저장할 모델 폴더
        self.model = 'rest15_top_12' # result.txt애 찍힐 모델 이름
        self.task = 'acos' # MVP에서 평가할 데이터 셑
        self.dataset = 'rest16' # MVP에서 평가할 데이터 셑
        self.train = 'rest15' # reward 이름(데이터 셑이랑 같으면 성능 높음)
        self.dev = 'rest16' # dpo에서 eval을 하지 않음(현재는 안씀)
        self.eval_data_split = 'test' # test or dev
        self.data_path = f'{self.path}/data/{self.method}'
        self.ctrl_token = "post"
        self.data_ratio = 1.0
        self.model_name_or_path = 't5-base' # used base model
        self.load_ckpt_name = None # 사전 훈련된 모델의 체크포인트 파일로드 
        self.beam_size = 1
        self.constrained_decode = False
        self.lowercase = True
        self.load_ckpt_name = None # 사전 훈련된 모델의 체크포인트 파일로드 
        self.load_path_cache = False
        self.max_seq_length = 512 # 입력 시퀀스 최대 길이
        self.eval_batch_size = 16
        self.do_train = True # train or not
        self.do_inference = True # inference or not

        # training parameters
        self.beta = 0.1
        self.learning_rate = 1e-4
        self.gradient_accumulation_steps = 1
        self.max_length = 512
        self.max_prompt_length = 512
        self.max_target_length = 512
        self.label_pad_token_id = -100
        self.num_train_epochs = 1
        self.batch_size = 16
        self.max_steps = -1
        
        # lora parameters
        self.use_peft = False
        self.peft_lora_r = 64
        self.peft_lora_alpha = 16
        
        # instrumentation
        self.sanity_check = False
        self.report_to = 'wandb'  # 결과와 로그를 보고할 통합 목록
        
        # debug argument for distributed training
        self.ignore_bias_buffers = False
        self.gradient_checkpointing = False
        self.gradient_checkpointing_kwargs = None
        
        # wandb parameters
        self.project_name = "huggingface"

def init_args():
    args = Args()

    args.output_dir =  f'{args.path}/outputs'

    # set up output dir which looks like './outputs/rest15/'
    if not os.path.exists(f'{args.path}/outputs'):
        os.mkdir(f'{args.path}/outputs')

    if not os.path.exists(args.output_dir):
        #os.mkdir(args.output_dir)
        os.makedirs(args.output_dir, exist_ok=True)

    return args

# Args 인스턴스 생성
args = init_args()

print('method:', args.method)
print('data path:', args.data_path)
print('output path:', args.output_dir)
print('load model path:', os.path.join(args.output_dir, args.load_model))
print('save model path:', os.path.join(args.output_dir, args.save_model))
print(args.model)


method: dpo
data path: /home/elicer/ATOSS/data/dpo
output path: /home/elicer/ATOSS/outputs
load model path: /home/elicer/ATOSS/outputs/sft/rest15_top_12
save model path: /home/elicer/ATOSS/outputs/dpo/rest15_top_12
rest15_top_12


In [12]:
class T5FineTuner(pl.LightningModule):
    """
    Fine tune a pre-trained T5 model
    """

    def __init__(self, config, tfm_model, tokenizer):
        super().__init__()
        self.save_hyperparameters(ignore=['tfm_model'])
        self.config = config
        self.model = tfm_model
        self.tokenizer = tokenizer

    def forward(self,
                input_ids,
                attention_mask=None,
                decoder_input_ids=None,
                decoder_attention_mask=None,
                labels=None):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
        )

    def _step(self, batch):
        lm_labels = batch["target_ids"]
        lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100

        outputs = self(input_ids=batch["source_ids"],
                       attention_mask=batch["source_mask"],
                       labels=lm_labels,
                       decoder_attention_mask=batch['target_mask'])

        loss = outputs[0]
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._step(batch)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        # get f1
        outs = self.model.generate(input_ids=batch['source_ids'],
                                   attention_mask=batch['source_mask'],
                                   max_length=self.config.max_seq_length,
                                   return_dict_in_generate=True,
                                   output_scores=True,
                                   num_beams=1)

        dec = [
            self.tokenizer.decode(ids, skip_special_tokens=True)
            for ids in outs.sequences
        ]
        target = [
            self.tokenizer.decode(ids, skip_special_tokens=True)
            for ids in batch["target_ids"]
        ]
        scores, _, _ = compute_scores(dec, target, verbose=False)
        f1 = torch.tensor(scores['f1'], dtype=torch.float64)

        # get loss
        loss = self._step(batch)

        if stage:
            self.log(f"{stage}_loss",
                     loss,
                     prog_bar=True,
                     on_step=False,
                     on_epoch=True)
            self.log(f"{stage}_f1",
                     f1,
                     prog_bar=True,
                     on_step=False,
                     on_epoch=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        """ Prepare optimizer and schedule (linear warmup and decay) """
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.config.weight_decay,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.config.learning_rate,
                          eps=self.config.adam_epsilon)
        scheduler = {
            "scheduler":
            get_linear_schedule_with_warmup(optimizer,
                                            **self.config.lr_scheduler_init),
            "interval":
            "step",
        }
        return [optimizer], [scheduler]

    def train_dataloader(self):
        print("load training data.")
        train_dataset = ABSADataset(args=args,
                                    tokenizer=tokenizer,
                                    task_name=args.task,
                                    data_type=args.train,
                                    max_len=args.max_seq_length)
        dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.train_batch_size,
            drop_last=True
            if args.data_ratio > 0.3 else False, # don't drop on few-shot
            shuffle=True,
            num_workers=2)

        return dataloader

    def val_dataloader(self):
        val_dataset = ABSADataset(args=args,
                                    tokenizer=tokenizer,
                                    task_name=args.task,
                                    data_type=args.dev,
                                    max_len=args.max_seq_length)
        return DataLoader(val_dataset,
                          batch_size=self.config.eval_batch_size,
                          num_workers=2)

    @staticmethod
    def rindex(_list, _value):
        return len(_list) - _list[::-1].index(_value) - 1


In [15]:
# do train
if args.do_train:

    file_path1 = f'{args.data_path}/{args.train}.txt'  # Replace with your file path
    file_path2 = f'{args.data_path}/{args.dev}.txt'  # Replace with your file path
    dataset_dict = create_dataset(file_path1, file_path2)
    train_dataset = dataset_dict['train'] #get_hh("train", sanity_check=args.sanity_check)
    eval_dataset = dataset_dict['dev'] #get_hh("train", sanity_check=args.sanity_check)
    
    print('train data load : total num = ', train_dataset)
    print('dev: data load : total num = ', eval_dataset)
    
    
    model = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(args.output_dir, args.load_model))
    
    if args.ignore_bias_buffers:
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]
    model_ref = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(args.output_dir, args.load_model))
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.output_dir, args.load_model), lecacy=False, use_fast=False)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    
    training_args = TrainingArguments(
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.eval_batch_size,
        load_best_model_at_end = True,
        max_steps=args.max_steps,
        remove_unused_columns=False,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_first_step=True,
        logging_steps=1000,  # match results in blog post
        eval_steps=500,
        num_train_epochs=args.num_train_epochs,
        output_dir=os.path.join(args.output_dir, args.save_model),
        optim="adamw_hf",
        adam_epsilon=1e-8,
        warmup_steps=150,
        report_to=args.report_to,
        bf16=True,
        gradient_checkpointing=args.gradient_checkpointing,
    )
    if args.use_peft:
        peft_config = LoraConfig(
            r=args.peft_lora_r,
            lora_alpha=args.peft_lora_alpha,
            bias="none",
            task_type="CAUSAL_LM",
        )
    else:
        peft_config = None

    dpo_trainer = DPOTrainer(
        model,
        model_ref,
        args=training_args,
        beta=args.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_length=args.max_length,
        max_target_length=args.max_target_length,
        max_prompt_length=args.max_prompt_length,
        generate_during_eval=True,
        peft_config=peft_config,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )
    
    dpo_trainer.train()

    dpo_trainer.save_model(os.path.join(args.output_dir, args.save_model))
    tokenizer.save_pretrained(os.path.join(args.output_dir, args.save_model))
    print("Finish training and saving the model!")

train data load : total num =  Dataset({
    features: ['chosen', 'rejected', 'prompt'],
    num_rows: 834
})
dev: data load : total num =  Dataset({
    features: ['chosen', 'rejected', 'prompt'],
    num_rows: 606
})


Map:   0%|          | 0/834 [00:00<?, ? examples/s]

Map:   0%|          | 0/606 [00:00<?, ? examples/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/rejected,Logps/chosen,Logits/rejected,Logits/chosen
1,0.7238,0.136239,0.094817,-3.549674,0.945254,3.644491,-49.901299,-10.408525,-34.307175,-35.319221


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


Finish training and saving the model!


In [16]:
# do inference
if args.do_inference:
        
    tfm_model = T5ForConditionalGeneration.from_pretrained(os.path.join(args.output_dir, args.save_model))
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.output_dir, args.save_model), lecacy=False, use_fast=False)
    model = T5FineTuner(args, tfm_model, tokenizer)
    # inference
    print("\n****** Conduct inference on trained checkpoint ******")
    
    args.output_dir = f'{args.output_dir}/{args.method}'
    
    if args.load_ckpt_name:
        ckpt_path = os.path.join(args.output_dir, args.load_ckpt_name)
        print("Loading ckpt:", ckpt_path)
        checkpoint = torch.load(ckpt_path)
        model.load_state_dict(checkpoint["state_dict"])
    
    log_file_path = os.path.join(args.output_dir, "result.txt")
    
    # compute the performance scores
    with open(log_file_path, "a+") as f:
        config_str = f" model: {args.model}, beam: {args.beam_size}, constrained: {args.constrained_decode}\n"
        print(config_str)
        f.write(config_str)
        scores = evaluate(args,
                          model,
                          args.task,
                          data_type=args.eval_data_split)
    
        exp_results = "model: {} data: {} precision: {:.2f} recall: {:.2f} F1 = {:.2f}".format(
            args.model, args.eval_data_split, scores['precision'], scores['recall'], scores['f1'])
        print(exp_results)
        f.write(exp_results + "\n")
        f.flush()


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.



****** Conduct inference on trained checkpoint ******
 model: rest15_top_12, beam: 1, constrained: False

Total examples = 583
Total examples = 583


100%|██████████| 37/37 [01:12<00:00,  1.97s/it]

pred labels count Counter({1: 583})
gold  yum!
pred  yum!

gold  serves really good sushi.
pred  serves really good sushi.

gold  not the biggest portions but adequate.
pred  not the biggest portions but adequate.

gold  green tea creme brulee is a must!
pred  green tea creme brulee is a must!

gold  it has great sushi and even better service.
pred  it has great sushi and even better service.

gold  the entire staff was extremely accomodating and tended to my every need.
pred  the entire staff was extremely accomodating and tended to my every need.

gold  i've been to this restaurant over a dozen times with no complaints to date.
pred  i've been to this restaurant over a dozen times with no complaints to date.

gold  the owner is belligerent to guests that have a complaint.
pred  the owner is belligerent to guests that have a complaint.

gold  good food!
pred  good food!

gold  this is a great place to get a delicious meal.
pred  this is a great place to get a delicious meal.

number o




In [17]:
file_path = os.path.join(
        args.output_dir, "rst_{}{}_{}{}_{}_{}{}beam{}.pickle".format(
            args.method,
            args.model,
            args.task,
            args.dataset,
            args.eval_data_split,
            "best_" if args.load_ckpt_name else "",
            "cd_" if args.constrained_decode else "",
            args.beam_size))
with open(file_path, 'rb') as f:
    loaded_object = pd.read_pickle(f)
loaded_object[0][0]

'yum!'

In [18]:
file = f'/home/elicer/ABSA/data/{args.task}/{args.dataset}/test.txt'
_, targets = split_sharp(file)
sft = merge_sharp_n(loaded_object[0],targets)
file_name = f'/home/elicer/ABSA/data/{args.task}/{args.dataset}/{args.method}_{args.model}_test.txt'
print(file_name)
with open(file_name, 'w', encoding='UTF-8') as file:
    for line in sft:
        file.write(line + '\n')

Data read. Total count:  583
Input count: 583
Expanded target count: 583
Merged data count: 583
Data return. Total count: 583
/home/elicer/ABSA/data/acos/rest16/dpo_rest15_top_12_test.txt


In [55]:
args.model

'zero1440'