In [1]:
from livedoor import LivedoorNewsCorpus

corpus = LivedoorNewsCorpus(extract_dir='./lldc')
corpus.download_and_extract()
categories = corpus.categories

./ldcc-20140209.tar.gz already exists. download stopped
./lldc already exists. extract stopped


In [2]:
import pandas as pd
all_text, all_label = corpus.get_text_and_labels()
df = pd.DataFrame({'text': all_text, 'label': all_label})

In [3]:
from transformers import BertConfig, BertForPreTraining, BertForSequenceClassification

# configの用意 (語彙数は30522 -> 32000に修正しておく)
bertconfig = BertConfig.from_pretrained('bert-base-uncased',
                                        num_labels=len(categories),
                                        output_attentions = False,
                                        output_hidden_states = False,
                                       )
bertconfig.vocab_size = 32000

# BERTモデルの"ガワ"の用意 (全パラメーターはランダムに初期化されている)
pretrained = BertForPreTraining(bertconfig)

In [4]:
DIR_BERT_KIKUTA = '/home/miyamonz/2019-12-12-Bert-example/model/bert-baseline/'
BASE_CKPT = 'model.ckpt-1400000'    # 拡張子は含めない

# TensorFlowモデルの重み行列を読み込む (数分程度かかる場合がある)
pretrained.load_tf_weights(bertconfig, DIR_BERT_KIKUTA + BASE_CKPT)
pretrained.save_pretrained("./tmp/")
model = BertForSequenceClassification.from_pretrained("./tmp/")

In [5]:
import sentencepiece as sp
BASE_SPM = 'wiki-ja.model'
BASE_VOCAB = 'wiki-ja.vocab'

spm = sp.SentencePieceProcessor()
spm.Load(DIR_BERT_KIKUTA + BASE_SPM)

#bert tokenizerのencode_plusと似たような出力になるようにする
def spm_encode(example, max_length = 512):
    raw_pieces  = spm.EncodeAsPieces(example)

    # if input size is over max_length, truncate them
    # Account for [CLS], [SEP] with `- 2`
    if len(raw_pieces) > max_length-2:
        raw_pieces = raw_pieces[:max_length-2]


    pieces = []

    # first token must be CLS
    pieces.append("[CLS]")

    for piece in raw_pieces:
        pieces.append(piece)

    # last token must be SEP
    pieces.append('[SEP]')

    # convert pieces to ids
    input_ids = [ spm.PieceToId(p) for p in pieces ]
    attention_mask = [1] * len(input_ids)

    #fill 0 in the rest list space
    while len(input_ids) < max_length:
        input_ids.append(0)
        attention_mask.append(0)

    return {
            "input_ids":[input_ids],
            "attention_mask":[attention_mask],
            }

In [6]:
import torch
from torch.utils.data import TensorDataset, random_split

def _convert_sentence_to_ids(_sentences):
    input_ids = []
    attention_masks = []

    for sent in _sentences:
        encoded_dict = spm_encode(sent)
        _ = torch.tensor(encoded_dict['input_ids'])
        input_ids.append(_)
        _ = torch.tensor(encoded_dict['attention_mask'])
        attention_masks.append(_)

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return (input_ids, attention_masks)

In [7]:
from factory import get_optimizers, get_dataloader

#model = get_model(num_labels=len(categories))
optimizer = get_optimizers(model)

In [8]:
_sentences = df.text.values
_labels = df.label.astype('category').cat.codes
# 文字列のラベルを数値に変換するところはpandasのcategory型を使ったが、pandas頼らなくても良い

id_and_masks = _convert_sentence_to_ids(_sentences)
labels = torch.tensor(_labels, dtype=torch.int64)
ds_source = id_and_masks + (labels,)

train_dl, valid_dl = get_dataloader(ds_source, batch_size=4)

In [9]:
from fit import fit
stats = fit(model, train_dl, valid_dl, optimizer, epochs=4)


Training...
  Training took: 0:06:20

Running Validation...
  Validation took: 0:00:13

Training...
  Training took: 0:06:21

Running Validation...
  Validation took: 0:00:13

Training...
  Training took: 0:06:22

Running Validation...
  Validation took: 0:00:13

Training...
  Training took: 0:06:22

Running Validation...
  Validation took: 0:00:13

Training complete!
Total training took 0:26:18 (h:mm:ss)


In [10]:
stats

[{'epoch': 1,
  'Training Loss': 0.40225998480299846,
  'Valid. Loss': 0.1971782452351338,
  'Valid. Accur.': 0.9594594594594594,
  'Training Time': '0:06:20',
  'Validation Time': '0:00:13'},
 {'epoch': 2,
  'Training Loss': 0.1434100202865566,
  'Valid. Loss': 0.21691012028101328,
  'Valid. Accur.': 0.9608108108108108,
  'Training Time': '0:06:21',
  'Validation Time': '0:00:13'},
 {'epoch': 3,
  'Training Loss': 0.08883859060785596,
  'Valid. Loss': 0.19121961851377745,
  'Valid. Accur.': 0.9689189189189189,
  'Training Time': '0:06:22',
  'Validation Time': '0:00:13'},
 {'epoch': 4,
  'Training Loss': 0.06136125998416482,
  'Valid. Loss': 0.21733691982320838,
  'Valid. Accur.': 0.9675675675675676,
  'Training Time': '0:06:22',
  'Validation Time': '0:00:13'}]