In [None]:
import os
import re
import warnings

import pandas as pd
import numpy as np
import torch

from transformers import (
    AutoConfig, AutoTokenizer, 
    T5TokenizerFast, T5ForConditionalGeneration, 
    AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, 
    AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
)

from datasets import load_metric, Dataset

import wandb
import nltk

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings('ignore')
nltk.download('punkt')

In [None]:
NGPU = torch.cuda.device_count()
NCPU = os.cpu_count()
NGPU, NCPU

In [None]:
torch.cuda.is_available()

# Paths and Names

In [None]:
### paths and names

PROJECT_NAME = 'news-topic-keyphrase-generation-model-dev'
RUN_ID = 'v3_run_5'

DATA_PATH = 'data/model_dev/model_dev_v3.pickle'

MODEL_CHECKPOINT = 'paust/pko-t5-base'
model_name = re.sub(r'[/-]', r'_', MODEL_CHECKPOINT).lower()

METRIC_NAME = 'rouge'

NOTEBOOK_NAME = './train_seq2seq_plm.ipynb'

ROOT_PATH = './'
SAVE_PATH = os.path.join(ROOT_PATH, '.log')

run_name = f'{model_name}_{RUN_ID}'
output_dir = os.path.join(SAVE_PATH, run_name)

print(run_name)
print(output_dir)

!mkdir -p {SAVE_PATH}

In [None]:
os.environ['WANDB_PROJECT'] = PROJECT_NAME
os.environ['WANDB_NOTEBOOK_NAME'] = NOTEBOOK_NAME
os.environ['WANDB_LOG_MODEL'] = 'true'
os.environ['WANDB_WATCH'] = 'all'

wandb.login()

# Training Args

In [None]:
report_to="wandb"

num_train_epochs = 30
per_device_train_batch_size = 4
per_device_eval_batch_size = 4
gradient_accumulation_steps = 1

optim = 'adamw_torch' # 'adamw_torch' or 'adamw_hf'

learning_rate = 3e-6 # 3e-6 * (per_device_train_batch_size * NGPU) / 8
weight_decay = 0.01
adam_epsilon = 1e-8

lr_scheduler_type = 'linear' # 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup'
warmup_ratio = 0

save_total_limit = 2

load_best_model_at_end = True
metric_for_best_model = 'eval_loss'

save_strategy = "epoch"
evaluation_strategy = "epoch"

logging_strategy = "steps"
logging_first_step = True 
logging_steps = int(500 / NGPU)

predict_with_generate=False
generation_max_length=64
# generation_num_beams=5

fp16 = False

# Model & Tokenizer & Metric

In [None]:
config = AutoConfig.from_pretrained(MODEL_CHECKPOINT)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT, config=config)
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
metric = load_metric(METRIC_NAME)

# Functions

In [None]:
prefix = "generate keyphrases: "

max_input_length = 1024
max_target_length = 64

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["input_text"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    labels = tokenizer(examples["target_text"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
#     decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
#     # Replace -100 in the labels as we can't decode them.
#     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
#     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
#     # Rouge expects a newline after each sentence
#     decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
#     decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
#     result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
#     # Extract a few results
#     result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
#     # Add mean generated length
#     prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
#     result["gen_len"] = np.mean(prediction_lens)
    
#     return {k: round(v, 4) for k, v in result.items()}

# Inputs and Labels

In [None]:
data_df = pd.read_pickle(DATA_PATH)

In [None]:
dataset = Dataset.from_pandas(data_df).shuffle(seed=100).train_test_split(0.2, seed=100)
train_dataset = dataset['train']
eval_dataset = dataset['test']

In [None]:
train_dataset = train_dataset.map(preprocess_function, 
                                  batched=True, 
                                  num_proc=NCPU, 
                                  remove_columns=train_dataset.column_names)

eval_dataset = eval_dataset.map(preprocess_function, 
                                batched=True, 
                                num_proc=NCPU, 
                                remove_columns=eval_dataset.column_names)
print(train_dataset)
print(eval_dataset)

In [None]:
tokenizer.decode(train_dataset['input_ids'][0])

In [None]:
tokenizer.decode(train_dataset['labels'][0])

# Train

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    run_name=run_name,
    report_to=report_to,

    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,

    optim=optim,

    learning_rate=learning_rate,
    weight_decay=weight_decay,
    adam_epsilon=adam_epsilon,

    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,

    save_total_limit=save_total_limit,

    load_best_model_at_end=load_best_model_at_end,
    metric_for_best_model=metric_for_best_model,

    save_strategy=save_strategy,
    evaluation_strategy=evaluation_strategy,

    logging_strategy=logging_strategy,
    logging_first_step=logging_first_step, 
    logging_steps=logging_steps,

    predict_with_generate=predict_with_generate,
    generation_max_length=generation_max_length,
    # generation_num_beams=generation_num_beams,

    fp16=fp16,
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

trainer = Seq2SeqTrainer(
    model=model,
    
    args=training_args,
    
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    
    tokenizer=tokenizer,
    data_collator=data_collator,
    
    # compute_metrics=compute_metrics,
)

In [None]:
trainer.train()
wandb.finish()

In [1]:
keep = [
    'added_tokens.json',
    'config.json',
    'pytorch_model.bin',
    'special_tokens_map.json',
    'tokenizer.json',
    'tokenizer_config.json',
    'vocab.txt'
]

ckpts = os.listdir(output_dir)
for ckpt in ckpts:
    ckpt = os.path.join(output_dir, ckpt)
    for item in os.listdir(ckpt):
        if item not in keep:
            os.remove(os.path.join(ckpt, item))

# Generate

In [None]:
# preds = trainer.predict(eval_dataset)

In [None]:
# preds.metrics

In [None]:
# for data, pred in zip(eval_dataset, preds.predictions):
#     pred = np.where(pred != -100, pred, tokenizer.pad_token_id)
#     # context = tokenizer.decode(data['input_ids'], skip_special_tokens=True)
#     summary = tokenizer.decode(data['labels'], skip_special_tokens=True)
#     pred = tokenizer.decode(pred, skip_special_tokens=True)
#     # print(f'입력: {context}')
#     print(f'정답: {summary}')
#     print(f'예측: {pred}', end='\n\n')