In [113]:
from typing import Tuple, Dict
import pandas as pd
import torch
from torch.nn import functional as F
from torch import optim
from torch import nn
from torch.utils.data import DataLoader, Dataset
import sys
sys.path.insert(0, '../')

from config import *
from transformers import BertModel, BertConfig, BertForSequenceClassification
from models import load_model
from dataset import REDataset, load_data, LabelEncoder, COLUMNS
from preprocessing import preprocess_text
from utils import set_seed
from tokenization import load_tokenizer, tokenize, SpecialToken
from criterions import *
from optimizers import *

In [114]:
model = load_model(
        ModelType.SequenceClf, PreTrainedType.MultiLingual, 42, None, 0, dropout=None
    )
model.cpu()
print('MODEL')

Load Model...	Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized fro

In [115]:
dataset = REDataset()

Load raw data...	apply preprocess 'EntityMarker'...	done!
Load Tokenizer for EntityMarker...	done!


In [117]:
model.resize_token_embeddings(len(dataset.tokenizer))

Embedding(119551, 768)

In [118]:
loader = DataLoader(dataset, batch_size=4)

In [119]:
for sents, labels in loader:
    break

In [2]:
def tokenize(sentence, tokenizer, type: str=PreProcessType.Base) -> dict:
    outputs = tokenizer(
        sentence,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=128,
        add_special_tokens=True,
    )
    for key in outputs.keys():
        outputs[key] = outputs[key].squeeze()

    if type != PreProcessType.Base:
        tokenized = tokenizer.tokenize(sentence)

        if type == PreProcessType.EM:
            # Add embedding value for entity marker tokens([E1], [/E1], [E2], [/E2])
            entity_indices = find_entity_indices(tokenized)
            for open, close in entity_indices.values():
                outputs.token_type_ids[
                    OFFSET + open : OFFSET + close + 1
                ] += ENTITY_SCORE

        elif type == PreProcessType.ESP:
            # Add embedding value for separation token([SEP])
            last_sep_idx = fine_sep_indices(tokenized).pop()
            outputs.token_type_ids[OFFSET : last_sep_idx + 1] += SEP_SCORE
            return outputs

        elif type == PreProcessType.EMSP:
            entity_indices = find_entity_indices(tokenized)
            for (open, close) in entity_indices.values():
                outputs.token_type_ids[
                    OFFSET + open : OFFSET + close + 1
                ] += ENTITY_SCORE

            last_sep_idx = fine_sep_indices(tokenized).pop()
            outputs.token_type_ids[OFFSET : last_sep_idx + 1] += SEP_SCORE

    return outputs

In [124]:
from tqdm import tqdm

MAX_LENGTH = 128
OFFSET = 1
ENTITY_SCORE = 1

class REDataset_v0(Dataset):
    def __init__(
        self,
        root: str = Config.Train,
        preprocess_type: str = PreProcessType.EM,
        device: str = Config.Device,
    ):
        self.data = self._load_data(root, preprocess_type=preprocess_type)
        self.labels = self.data["label"].tolist()
        self.tokenizer = load_tokenizer(type=preprocess_type)
        self.inputs = self._tokenize(self.data)
        self.device = device

    def __getitem__(self, idx) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
        """모델에 입력할 데이터 생성시, device 상황에 따라 CPU 또는 GPU에 할당한 채로 return"""        
        sentence = {
            key: torch.as_tensor(val[idx]).to(self.device) # device 할당
            for key, val in self.inputs.items()
        }
        label = torch.as_tensor(self.labels[idx]).to(self.device) # device 할당
        return sentence, label

    def __len__(self):
        return len(self.labels)

    def _tokenize(self, data):
        print("Apply Tokenization...", end="\t")
        tokenized_decoded = self.data['input'].apply(lambda x: self.tokenizer.tokenize(x))
        entity_intervals = tokenized_decoded.apply(lambda x: find_entity_intervals(x))
        entity_interval_tensor = self.make_entity_interval_tensor(entity_intervals, max_length=MAX_LENGTH)

        data_tokenized = self.tokenizer(
            data["input"].tolist(),
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH,
            add_special_tokens=True,
        )
        data_tokenized['token_type_ids'] += entity_interval_tensor.long()

        print("done!")
        return data_tokenized

    def _load_data(self, root: str, preprocess_type: str) -> pd.DataFrame:
        enc = LabelEncoder()
        print("Load raw data...", end="\t")
        raw = pd.read_csv(root, sep="\t", header=None)
        raw.columns = COLUMNS
        raw = raw.drop("id", axis=1)
        raw["label"] = raw["label"].apply(lambda x: enc.transform(x))
        print(f"preprocessing for '{preprocess_type}'...", end="\t")
        data = preprocess_text(raw, method=preprocess_type)
        print("done!")
        return data
    
    def make_entity_interval_tensor(self, entity_intervals: list, max_length: int=128):
        n_rows = self.__len__()
        n_cols = max_length
        entity_interval_tensor = torch.zeros(n_rows, n_cols)

        for idx, (e1, e2) in tqdm(enumerate(entity_intervals), desc="Update token_type_ids"):
            entity_interval_tensor[idx][OFFSET+e1[0]: OFFSET+e1[1]+1] += ENTITY_SCORE
            entity_interval_tensor[idx][OFFSET+e2[0]: OFFSET+e2[1]+1] += ENTITY_SCORE

        return entity_interval_tensor
    
    @staticmethod
    def _find_entity_intervals(tokenized: list) -> dict:
        entity_intervals = [
            (tokenized.index(SpecialToken.E1Open), tokenized.index(SpecialToken.E1Close)),
            (tokenized.index(SpecialToken.E2Open), tokenized.index(SpecialToken.E2Close))
        ]
        return entity_intervals
    
    

In [125]:
# dataset = REDataset(device='cpu')
dataset_stable = REDataset_v0(device='cpu')

Load raw data...	preprocessing for 'EntityMarker'...	done!
Load Tokenizer for EntityMarker...	done!
Update token_type_ids: 9000it [00:00, 31860.58it/s]
done!


In [126]:
dataset_stable.tokenizer

PreTrainedTokenizer(name_or_path='bert-base-multilingual-cased', vocab_size=119547, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]', 'additional_special_tokens': ['[E1]', '[/E1]', '[E2]', '[/E2]']})

In [128]:
loader = DataLoader(dataset_stable, batch_size=4)

In [129]:
for sents, labels in loader:
    break

In [127]:
model.resize_token_embeddings(len(dataset.tokenizer))

Embedding(119551, 768)

In [130]:
model(**sents)

SequenceClassifierOutput(loss=None, logits=tensor([[ 0.1994, -0.0248, -0.0062, -0.0171, -0.0209, -0.1810, -0.0397,  0.2497,
         -0.0708,  0.0445,  0.1316, -0.0034, -0.0707, -0.0775, -0.0678, -0.0381,
          0.0455,  0.0807,  0.0078, -0.0194, -0.0158,  0.0933, -0.1661, -0.1252,
          0.0518,  0.1195, -0.0281,  0.0389, -0.0122,  0.1866, -0.0236,  0.0199,
         -0.0128, -0.1093,  0.0093, -0.1239, -0.0390, -0.2052, -0.0179, -0.0561,
         -0.0232,  0.1765],
        [ 0.2068, -0.0411,  0.0759,  0.0373, -0.0166, -0.0974, -0.0823,  0.1185,
         -0.0871, -0.0335,  0.1325, -0.0942, -0.0721, -0.0381, -0.0757,  0.0282,
         -0.0473,  0.1172,  0.0055, -0.0144, -0.0729,  0.1922, -0.0428, -0.1354,
          0.0109,  0.0942,  0.0104,  0.0799,  0.0140,  0.1804, -0.0604,  0.1042,
         -0.0479, -0.1226, -0.0580, -0.1510, -0.0386, -0.2189, -0.0229, -0.0726,
          0.0257,  0.1316],
        [ 0.1037, -0.1658,  0.2718,  0.0931, -0.1524, -0.0138, -0.0878,  0.0802,
         -