In [33]:
import sys
import torch
from torch.utils.data import DataLoader
from torch import nn
from transformers import BertTokenizer, BertTokenizerFast, BertForSequenceClassification, BertModel, BertConfig
from tokenizers import BertWordPieceTokenizer


sys.path.insert(0, '../')
from models import load_model
from dataset import REDataset
from config import Config, ModelType, PreTrainedType

In [6]:
config = BertConfig.from_pretrained(PreTrainedType.MultiLingual)
config.num_labels = 42

In [7]:
model = BertForSequenceClassification.from_pretrained(PreTrainedType.MultiLingual, config=config)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model ch

In [17]:
dataset = REDataset(device='cpu')

Load raw data...	preprocessing for 'Base'...	done!
Load Tokenizer...	done!
Apply Tokenization...	done!


In [19]:
loader = DataLoader(dataset, batch_size=4)

In [20]:
for sents, labels in loader:
    break

In [24]:
output = model.bert(**sents)

In [28]:
pooler = output.last_hidden_state[:, 0, :]

In [30]:
model.classifier(pooler).size()

torch.Size([4, 42])

In [42]:
class VanillaBert_v2(nn.Module):
    def __init__(
        self,
        model_type: str = ModelType.SequenceClf,  # BertForSequenceClassification
        pretrained_type: str = PreTrainedType.MultiLingual,  # bert-base-multilingual-cased
        num_labels: int = Config.NumClasses,  # 42
        pooler_idx: int = 0
    ):
        super(VanillaBert_v2, self).__init__()
        bert = self.load_bert(
            model_type=model_type,
            pretrained_type=pretrained_type,
        )
        self.backbone = bert.bert
        self.dropout = bert.dropout
        self.clf = bert.classifier
        self.idx = 0 if pooler_idx == 0 else pooler_idx

    def forward(self, input_ids, token_type_ids, attention_mask):
        x = self.backbone(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        )
        x = x.last_hidden_state[:, self.idx, :]
        x = self.dropout(x)
        output = self.clf(x)
        return output

    @staticmethod
    def load_bert(model_type, pretrained_type):
        config = BertConfig.from_pretrained(pretrained_type)
        config.num_labels = 42
        if model_type == ModelType.SequenceClf:
            model = BertForSequenceClassification.from_pretrained(pretrained_type, config=config)
        else:
            raise NotImplementedError()

        return model

In [43]:
model = VanillaBert_v2()

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model ch

In [44]:
output = model(**sents)

In [45]:
output

tensor([[-0.3085, -0.0310,  0.1683, -0.1096,  0.0322,  0.2899, -0.2387, -0.1071,
          0.1112, -0.3434, -0.0009, -0.3300, -0.0189, -0.0739,  0.0318, -0.1613,
          0.1797, -0.2363, -0.1003, -0.0215, -0.3221, -0.0310,  0.0659, -0.0153,
          0.1841,  0.1360,  0.0297, -0.2904,  0.0027, -0.1875, -0.1578, -0.0472,
          0.2544,  0.3674,  0.0224, -0.3766,  0.2296, -0.2980,  0.4823, -0.1805,
          0.1215,  0.0884],
        [-0.3178, -0.0420,  0.5670, -0.1458,  0.3103,  0.3062, -0.4243, -0.1841,
          0.0180, -0.4984,  0.0176, -0.2405,  0.0719, -0.3171,  0.0287, -0.3817,
          0.1676, -0.0818, -0.1787,  0.0245, -0.2646,  0.0456,  0.2499,  0.0211,
          0.5569,  0.2341, -0.1474, -0.2639, -0.1208, -0.1170, -0.2079,  0.0302,
         -0.0138,  0.3647,  0.0919, -0.4525,  0.3361, -0.0224,  0.3537, -0.3153,
         -0.0371,  0.0761],
        [-0.1305, -0.0609,  0.4915, -0.0239,  0.0627,  0.1599, -0.1920, -0.2792,
          0.1001, -0.3685,  0.0419, -0.2608, -0.0087,