<a href="https://colab.research.google.com/github/Denev6/practice/blob/main/transformer/emoberta_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!python --version

Python 3.7.15


In [None]:
!pip install transformers

import nltk

nltk.download("punkt")

In [None]:
import os
import re
import gc
import warnings

import numpy as np
import pandas as pd
from tqdm.auto import tqdm, trange
import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, RobertaTokenizerFast, RobertaForSequenceClassification
from nltk.tokenize import TweetTokenizer
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from google.colab import drive

drive.mount("/content/drive")
warnings.filterwarnings("ignore")

In [None]:
def join_path(*args):
    return os.path.join("/content/drive/MyDrive/DACON", *args)


DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
TRAIN_CSV = join_path("data", "train.csv")
TEST_CSV = join_path("data", "test.csv")

ARGS = {
    "model": "tae898/emoberta-large",
    "model_path": join_path("emoberta", "model.pth"),
    "batch_size": 8,
    "grad_step": 8,
    "epochs": 10,
    "max_len": 256,
    "lr": 1e-6,
    "patience": 2,
}

In [None]:
class EarlyStoppingCallback(object):
    def __init__(self, patience=2):
        self._min_eval_loss = np.inf
        self._patience = patience
        self.__counter = 0

    def should_stop(self, eval_loss, model, save_path):
        if eval_loss < self._min_eval_loss:
            self._min_eval_loss = eval_loss
            self.__counter = 0
            torch.save(model.state_dict(), save_path)
        elif eval_loss > self._min_eval_loss:
            self.__counter += 1
            if self.__counter >= self._patience:
                return True
        return False

In [None]:
class LabelEncoder(object):
    def __init__(self):
        self._targets = [
            "neutral",
            "joy",
            "surprise",
            "anger",
            "sadness",
            "disgust",
            "fear",
        ]
        self.num_classes = len(self._targets)

    def encode(self, label):
        return self._targets.index(label)

    def decode(self, label):
        return self._targets[label]

In [None]:
class EmoDataset(Dataset):
    def __init__(
        self,
        data,
        twt_tokenizer,
        roberta_tokenizer,
        label_encoder,
        max_length=256,
        mode=None,
    ):
        self._label_encoder = label_encoder
        self._twt_tokenizer = twt_tokenizer
        self._roberta_tokenizer = roberta_tokenizer
        self._max_length = max_length
        self._mode = mode
        self._dataset = self._init_dataset(data)

    def _init_dataset(self, data):
        data["Utterance"] = data["Utterance"].map(self._preprocess)

        if self._mode == "train":
            data["Target"] = data["Target"].map(self._label_encoder.encode)
            data = data.loc[:, ["Utterance", "Target"]]
        else:
            data = data.loc[:, "Utterance"]
            data = data.to_frame()
        return data

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, idx):
        text = self._dataset.loc[idx, "Utterance"]
        inputs = self._roberta_tokenizer(
            text,
            max_length=self._max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = inputs["input_ids"][0]
        attention_mask = inputs["attention_mask"][0]

        if self._mode == "train":
            y = self._dataset.loc[idx, "Target"]
            return input_ids, attention_mask, y
        else:
            return input_ids, attention_mask

    def _preprocess(self, sentence):
        twt_tokens = self._twt_tokenizer.tokenize(sentence)
        twt_tokens = self._shorten_repeated_words(twt_tokens)
        twt_sentence = self._decode_tokens(twt_tokens)
        return twt_sentence

    def _shorten_repeated_words(self, tokens):
        for i, token in enumerate(tokens):
            if "-" in token:
                token = token.split("-")
                token = "-".join(dict.fromkeys(token))
                tokens[i] = token
        return tokens

    def _decode_tokens(self, tokens):
        sentence = " ".join(tokens)
        marks = re.findall(r"\s\W\s*", sentence)
        for mark in marks:
            if mark.strip() in ["'", "’"]:
                sentence = sentence.replace(mark, mark.strip())
            else:
                sentence = sentence.replace(mark, mark.lstrip())
        return sentence

# Dataset

In [None]:
train_csv = pd.read_csv(TRAIN_CSV)

df_train, df_val = train_test_split(
    train_csv, test_size=0.2, shuffle=True, random_state=42
)
df_train.head()

Unnamed: 0,ID,Utterance,Speaker,Dialogue_ID,Target
8349,TRAIN_8349,"Well yeah, sure, what’s up?",Rachel,879,neutral
9518,TRAIN_9518,"Y'know, I-I don't even feel like I know you an...",Joey,994,anger
9042,TRAIN_9042,Do it!,Phoebe,946,joy
9858,TRAIN_9858,Come on!,Rachel,1025,anger
2012,TRAIN_2012,Action!,The Director,209,neutral


In [None]:
label_encoder = LabelEncoder()
roberta_tokenizer = RobertaTokenizerFast.from_pretrained(ARGS["model"], truncation=True)
twt_tokenizer = nltk.tokenize.TweetTokenizer(
    preserve_case=False, strip_handles=True, reduce_len=True
)

train_set = EmoDataset(
    df_train.reset_index(drop=True),
    twt_tokenizer,
    roberta_tokenizer,
    label_encoder,
    max_length=ARGS["max_len"],
    mode="train",
)
val_set = EmoDataset(
    df_val.reset_index(drop=True),
    twt_tokenizer,
    roberta_tokenizer,
    label_encoder,
    max_length=ARGS["max_len"],
    mode="train",
)

train_dataloader = DataLoader(train_set, batch_size=ARGS["batch_size"])
val_dataloader = DataLoader(val_set, batch_size=ARGS["batch_size"])

# Model

In [None]:
def evaluate(model, criterion, val_loader, device, mode=None):
    model.eval()

    val_loss = list()
    model_preds = list()
    true_labels = list()

    with torch.no_grad():
        for input_ids, attention_mask, label in val_loader:
            label = label.to(device)
            input_id = input_ids.to(device)
            mask = attention_mask.to(device)

            output = model(input_id, mask)

            batch_loss = criterion(output.logits, label.long())
            val_loss.append(batch_loss.item())

            if mode != "train":
                model_preds += output.logits.argmax(1).detach().cpu().numpy().tolist()
                true_labels += label.detach().cpu().numpy().tolist()

        if mode != "train":
            val_acc = accuracy_score(true_labels, model_preds)
            val_f1 = f1_score(true_labels, model_preds, average="macro")
            return val_acc, val_f1

        return val_loss


def train(model, optimizer, criterion, train_loader, val_loader, device):
    torch.cuda.empty_cache()
    gc.collect()

    epoch_progress = trange(1, ARGS["epochs"] + 1)
    early_stopper = EarlyStoppingCallback(patience=ARGS["patience"])
    criterion.to(device)

    grad_step = ARGS["grad_step"]
    model_path = ARGS["model_path"]
    model.to(device)
    model.zero_grad()

    for epoch in epoch_progress:

        model.train()
        train_loss = list()
        for batch_id, data in enumerate(train_loader, start=1):

            input_ids, attention_mask, train_label = data
            train_label = train_label.to(device)
            input_id = input_ids.to(device)
            mask = attention_mask.to(device)

            output = model(input_id, mask)

            batch_loss = criterion(output.logits, train_label.long())
            train_loss.append(batch_loss.item())

            batch_loss /= grad_step
            batch_loss.backward()

            if batch_id % grad_step == 0:
                optimizer.step()
                model.zero_grad()

        val_loss = evaluate(model, criterion, val_loader, device, mode="train")
        train_loss = np.mean(train_loss)
        val_loss = np.mean(val_loss)
        tqdm.write(
            f"Epoch {epoch}, Train-Loss: {train_loss:.5f},  Val-Loss: {val_loss:.5f}"
        )

        if early_stopper.should_stop(val_loss, model, model_path):
            model.load_state_dict(torch.load(model_path))
            tqdm.write(f"\n\n -- EarlyStoppingCallback: [Epoch: {epoch - ARGS['patience']}]")
            tqdm.write(f"Model saved at '{model_path}'.")
            break

    return model

In [None]:
model = RobertaForSequenceClassification.from_pretrained(
    ARGS["model"], num_labels=label_encoder.num_classes
)
optimizer = AdamW(
    model.parameters(), lr=ARGS["lr"], weight_decay=1e-3, correct_bias=False
)
criterion = CrossEntropyLoss()

In [None]:
# 모델 학습
best_model = train(
    model, optimizer, criterion, train_dataloader, val_dataloader, DEVICE
)

In [None]:
torch.cuda.empty_cache()
gc.collect()

val_acc, val_f1 = evaluate(best_model, criterion, val_dataloader, DEVICE)
print(f"Accuracy: {val_acc:.5f}")
print(f"F1-macro: {val_f1:.5f}")

# Prediction

In [None]:
# Test 데이터
df_test = pd.read_csv(TEST_CSV)
test_set = EmoDataset(
    df_test.reset_index(drop=True),
    twt_tokenizer,
    roberta_tokenizer,
    label_encoder,
    max_length=ARGS["max_len"],
)
test_dataloader = DataLoader(test_set, batch_size=ARGS["batch_size"], shuffle=False)

In [None]:
def predict(model, test_loader, device):
    model.eval()
    model_preds = list()

    with torch.no_grad():
        for input_ids, attention_mask in test_loader:
            input_id = input_ids.to(device)
            mask = attention_mask.to(device)

            output = model(input_id, mask)
            model_preds += output.logits.argmax(1).detach().cpu().numpy().tolist()
        return model_preds

In [None]:
# 레이블 예측
preds = predict(best_model, test_dataloader, DEVICE)

In [None]:
df_test["Target"] = preds
df_test["Target"] = df_test["Target"].map(label_encoder.decode)
submit = df_test.loc[:, ["ID", "Target"]]
submit.head()

In [None]:
# 예측 값 저장
submit.to_csv(join_path("emoberta", "submit.csv"), index=False)