In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [50]:
import torch
import numpy as np

from datasets import Dataset
from datasets import load_dataset
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, EvalPrediction

In [3]:
# CONFIG

SEED = 42
MODEL_CHECKPOINT = 'results/checkpoint-202500'
TOKENIZER_CHECKPOINT = 'bert-large-cased'
NUM_OF_OPTIONS = 2605
PROBLEM_TYPE = 'multi_label_classification'

DATA_FILES = {
    # 'train': '../preprocess/option/train.json',
    # 'valid': '../preprocess/option/valid.json',
    'test': '../preprocess/option/test.json'
}

In [4]:
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_CHECKPOINT)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_CHECKPOINT, num_labels=NUM_OF_OPTIONS, problem_type=PROBLEM_TYPE)

In [5]:
# Load the dataset and apply the tokenizer
dataset = load_dataset("json", data_files=DATA_FILES)

Found cached dataset json (/home/nlplab11/.cache/huggingface/datasets/json/default-4f4ce6c5e70910e9/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)
100%|██████████| 1/1 [00:00<00:00, 557.68it/s]


In [6]:
# Define a function to tokenize the data
def tokenize(batch):
    return tokenizer(batch["prompt"], padding=True, truncation=True, max_length=256)

In [7]:
import json

with open('option_list.json', 'r', encoding='utf-8') as f:
    option_list = json.load(f)

In [8]:
mlb = MultiLabelBinarizer()
mlb.fit_transform([option_list])

array([[1, 1, 1, ..., 1, 1, 1]])

In [9]:
mlb.classes_

array(['#wow', '2d', '2d animation', ..., 'zine', 'zoom', 'zoom lens'],
      dtype=object)

In [10]:
dataset = dataset.map(tokenize, batched=True)

Loading cached processed dataset at /home/nlplab11/.cache/huggingface/datasets/json/default-4f4ce6c5e70910e9/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96/cache-97cd2f9fbdd8beab.arrow


In [11]:
def list_to_numpy(batch):
    batch['labels'] = np.array(mlb.transform(batch['option'])[0], dtype=np.float32)
    return batch

In [12]:
dataset = dataset.map(list_to_numpy, batched=False)

Loading cached processed dataset at /home/nlplab11/.cache/huggingface/datasets/json/default-4f4ce6c5e70910e9/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96/cache-3153ac2839d7a966.arrow


In [13]:
dataset = dataset.remove_columns(["prompt", "option"])
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

In [14]:
    
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))

    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1

    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)

    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    
    return metrics

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    result = multi_label_metrics(
        predictions=logits, 
        labels=labels
        )
    
    return result

In [15]:
# Set up the training arguments and trainer
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    fp16=True
)

In [16]:
trainer = Trainer(
    model=model,
    args=training_args,
    # train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    compute_metrics=compute_metrics,   
)

In [17]:
trainer.evaluate()

{'eval_loss': 0.001053415471687913,
 'eval_f1': 0.9738370384727603,
 'eval_roc_auc': 0.9771094167790283,
 'eval_accuracy': 0.7872,
 'eval_runtime': 634.1005,
 'eval_samples_per_second': 157.704,
 'eval_steps_per_second': 19.713}

In [51]:
def infer_keywords_from_prompt(s):
    example = Dataset.from_dict({'prompt': [s]})
    example = example.map(tokenize)
    pred = trainer.predict(example)
    
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(pred.predictions))
    
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= 0.5)] = 1

    print(mlb.inverse_transform(y_pred))    

In [52]:
infer_keywords_from_prompt('Butterfly Garden, filled with flowers and abundance of pink. An illustration radiating warmth and happiness.')

[('butterfly', 'fill', 'illustration', 'ink')]
