In [1]:
import os
import re
import warnings

from tqdm import tqdm
import pandas as pd
import numpy as np
import torch

from transformers import (
    AutoConfig, AutoTokenizer, 
    T5TokenizerFast, T5ForConditionalGeneration, 
    AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, 
    AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
)

from datasets import Dataset

import evaluate

from konlpy.tag import Komoran

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings('ignore')

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
NGPU = torch.cuda.device_count()
NCPU = os.cpu_count()
NGPU, NCPU

(6, 64)

# Paths and Names

In [4]:
### paths and names

DATA_PATH = 'data/model_dev/model_dev_v3.pickle'
MODEL_CHECKPOINT = '.log/paust_pko_t5_base_v3_run_5/checkpoint-11310'

# Model & Tokenizer

In [5]:
config = AutoConfig.from_pretrained(MODEL_CHECKPOINT)

In [6]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT, config=config).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

# Inputs and Labels

In [7]:
prefix = "generate keyphrases: "

max_input_length = 1024
max_target_length = 64

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["input_text"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding="max_length")

    labels = tokenizer(examples["target_text"], max_length=max_target_length, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [8]:
data_df = pd.read_pickle(DATA_PATH)

In [9]:
dataset = Dataset.from_pandas(data_df).shuffle(seed=100).train_test_split(0.2, seed=100)
train_dataset = dataset['train']
eval_dataset = dataset['test']

In [10]:
train_dataset = train_dataset.map(preprocess_function, 
                                  batched=True, 
                                  num_proc=NCPU, 
                                  remove_columns=train_dataset.column_names)

eval_dataset = eval_dataset.map(preprocess_function, 
                                batched=True, 
                                num_proc=NCPU, 
                                remove_columns=eval_dataset.column_names)
print(train_dataset)
print(eval_dataset)

Map (num_proc=64):   0%|          | 0/9346 [00:00<?, ? examples/s]

Map (num_proc=64):   0%|          | 0/2337 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 9346
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 2337
})


In [11]:
inputs = eval_dataset[:100]
inputs = eval_dataset

input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

In [12]:
labels = tokenizer.batch_decode(inputs['labels'], skip_special_tokens=True)

In [14]:
batch_size = 64
predictions = None
with torch.no_grad():
    start = 0
    for idx in tqdm(range(batch_size, len(input_ids), batch_size)):
        ids, mask = input_ids[start:idx], attention_mask[start:idx]
        ids, mask = torch.tensor(ids).to(device), torch.tensor(mask).to(device)
        # print(start, idx)
        prediction = model.generate(input_ids=ids, attention_mask=mask, max_length=64)
        if predictions == None:
            predictions = prediction.detach().cpu().tolist()
        else:
            predictions.extend(prediction.detach().cpu().tolist())
        start = idx

100%|██████████| 36/36 [01:59<00:00,  3.31s/it]


In [16]:
len(predictions)

2304

In [None]:
predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)

### ROUGE

In [23]:
komoran = Komoran()

In [24]:
rouge = evaluate.load('rouge')

In [25]:
def rouge_for_sampale(label, prediction):
    return rouge.compute(references=[label], predictions=[prediction], tokenizer=komoran.morphs)

In [26]:
def rouge_for_batch(labels, predictions):
    rouge_scores = None
    
    for label, prediction in zip(labels, predictions):
        if rouge_scores == None:
            rouge_scores = rouge_for_sampale(label, prediction)
        else:
            rouge_score = rouge_for_sampale(label, prediction)
            for key in rouge_scores.keys():
                rouge_scores[key] = rouge_scores[key] + rouge_score[key]
    
    for key in rouge_scores.keys():
        rouge_scores[key] = rouge_scores[key] / len(labels)
    
    return rouge_scores

In [27]:
rouge_for_batch(labels, predictions)

{'rouge1': 0.6481759823542651,
 'rouge2': 0.4481389571495031,
 'rougeL': 0.5329436625976698,
 'rougeLsum': 0.5329436625976698}

### F1

In [28]:
# ### V1

# def f1_score_at_k_for_sample(label_str, prediction_str, k):
#     true_positives = 0
#     false_positives = 0
#     false_negatives = 0
    
#     # convert label and prediction strings to sets of key-phrases
#     label_lst = [key_phrase.strip() for key_phrase in label_str.split(';') if key_phrase != '']
#     label_lst = [key_phrase for key_phrase in label_lst if key_phrase != '']
#     label_set = set(label_lst)
    
#     # split the predicted key-phrases and their scores
#     prediction_lst = [key_phrase.strip() for key_phrase in prediction_str.split(';') if key_phrase != '']
#     prediction_lst = [key_phrase for key_phrase in prediction_lst if key_phrase != ''][:k]
#     prediction_set = set(prediction_lst)
    
#     # calculate true positives, false positives, and false negatives
#     for keyphrase in prediction_set:
#         if keyphrase in label_set:
#             true_positives += 1
#         else:
#             false_positives += 1
    
#     for keyphrase in label_set:
#         if keyphrase not in prediction_set:
#             false_negatives += 1
    
#     # calculate precision, recall, and F1 score
#     precision = true_positives / (true_positives + false_positives)
#     recall = true_positives / (true_positives + false_negatives)
    
#     if precision == 0 or recall == 0:
#         return 0
    
#     f1_score = 2 * (precision * recall) / (precision + recall)
    
#     return f1_score

In [29]:
### V2

def f1_score_at_k_for_sample(label_str, prediction_str, k):
    true_positives = 0
    false_positives = 0
    false_negatives = 0
    
    # convert label and prediction strings to sets of key-phrases
    label_lst = [key_phrase.strip() for key_phrase in label_str.split(';') if key_phrase != '']
    label_lst = [key_phrase for key_phrase in label_lst if key_phrase != '']
    
    # split the predicted key-phrases and their scores
    prediction_lst = [key_phrase.strip() for key_phrase in prediction_str.split(';') if key_phrase != '']
    prediction_lst = [key_phrase for key_phrase in prediction_lst if key_phrase != ''][:k]
    
    # calculate true positives, false positives, and false negatives
    for keyphrase in prediction_lst:
        similarity = False
        for label in label_lst:
            if keyphrase in label or label in keyphrase:
                similarity = True
                break
        if similarity == True:
            true_positives += 1
        else:
            false_positives += 1

    for label in label_lst:
        similarity = False
        for keyphrase in prediction_lst:
            if label in keyphrase or keyphrase in label:
                similarity = True
                break
        if similarity == False:
            false_negatives += 1            

    # calculate precision, recall, and F1 score
    precision = true_positives / (true_positives + false_positives)
    recall = true_positives / (true_positives + false_negatives)
    
    if precision == 0 or recall == 0:
        return 0
    
    f1_score = 2 * (precision * recall) / (precision + recall)
    
    return f1_score

In [30]:
# labels, predictions

In [31]:
def f1_score_at_k_for_batch(labels, predictions, k):
    f1_scores =[]

    for label, prediction in zip(labels, predictions):
        f1_scores.append(f1_score_at_k_for_sample(label, prediction, k))

    # print(f1_scores)
    return sum(f1_scores) / len(f1_scores)

In [32]:
f1_score_at_k_for_batch(labels, predictions, 10)

0.597830229966394

In [33]:
# f1_score_at_k_for_sample(labels[9], prediction[9], 10)

### Jaccard

In [34]:
def jaccard_similarity_for_sample(label, prediction, k):

    # convert label and prediction strings to sets of key-phrases
    label_lst = [key_phrase.strip() for key_phrase in label.split(';') if key_phrase != '']
    label_lst = [key_phrase for key_phrase in label_lst if key_phrase != '']
    # print(label_lst)
    
    # split the predicted key-phrases and their scores
    prediction_lst = [key_phrase.strip() for key_phrase in prediction.split(';') if key_phrase != '']
    prediction_lst = [key_phrase for key_phrase in prediction_lst if key_phrase != ''][:k]
    # print(prediction_lst)

    """Define Jaccard Similarity function for two sets"""
    intersection = len(list(set(label_lst).intersection(prediction_lst)))
    union = (len(label_lst) + len(prediction_lst)) - intersection

    # print(union)
    # print(intersection)

    return float(intersection) / union

In [35]:
def jaccard_similarity_for_batch(labels, predictions, k):
    jaccard_similarities =[]

    for label, prediction in zip(labels, predictions):
        jaccard_similarities.append(jaccard_similarity_for_sample(label, prediction, k))

    print(jaccard_similarities)
    return sum(jaccard_similarities) / len(jaccard_similarities)

In [36]:
jaccard_similarity_for_batch(labels, predictions, 10)

[0.3333333333333333, 0.1111111111111111, 0.0, 0.42857142857142855, 0.25, 0.17647058823529413, 0.1111111111111111, 0.17647058823529413, 0.1111111111111111, 0.17647058823529413, 0.3333333333333333, 0.25, 0.25, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.25, 0.1111111111111111, 0.3333333333333333, 0.16666666666666666, 0.17647058823529413, 0.3333333333333333, 0.42857142857142855, 0.1111111111111111, 0.05555555555555555, 0.17647058823529413, 0.25, 0.42857142857142855, 0.25, 0.17647058823529413, 0.1111111111111111, 0.25, 0.42857142857142855, 0.3333333333333333, 0.25, 0.6666666666666666, 0.125, 0.1111111111111111, 0.17647058823529413, 0.25, 0.25, 0.42857142857142855, 0.3333333333333333, 0.17647058823529413, 0.3333333333333333, 0.25, 0.42857142857142855, 0.25, 0.1111111111111111, 0.05263157894736842, 0.17647058823529413, 0.17647058823529413, 0.17647058823529413, 0.25, 0.25, 0.17647058823529413, 0.3333333333333333, 0.5384615384615384, 0.42857142857142855, 0.3333333333333333, 0

0.2610242723256392