In [1]:
from livedoor import LivedoorNewsCorpus

corpus = LivedoorNewsCorpus(extract_dir='./lldc')
corpus.download_and_extract()
categories = corpus.categories

./ldcc-20140209.tar.gz already exists. download stopped
./lldc already exists. extract stopped


In [2]:
import pandas as pd
all_text, all_label = corpus.get_text_and_labels()
df = pd.DataFrame({'text': all_text, 'label': all_label})

In [3]:
import torch

def convert_sentence_to_ids(_sentences, tokenizer):
    input_ids = []
    attention_masks = []

    # 1文づつ処理
    for sent in _sentences:
        encoded_dict = tokenizer.encode_plus(
                            sent,                      
                            add_special_tokens = True,
                            max_length = 37,
                            pad_to_max_length = True,
                            return_attention_mask = True,
                            return_tensors = 'pt',     #  Pytorch tensorsで返す
                       )
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])

    # リストに入ったtensorを縦方向（dim=0）へ結合
    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    
    return input_ids, attention_masks


In [4]:
from factory import get_model, get_optimizers, get_tokenizer, get_dataloader

model = get_model(num_labels=len(categories))
optimizer = get_optimizers(model)
tokenizer = get_tokenizer()

In [5]:
_sentences = df.text.values
_labels = df.label.astype('category').cat.codes
# 文字列のラベルを数値に変換するところはpandasのcategory型を使ったが、pandas頼らなくても良い

id_and_masks = convert_sentence_to_ids(_sentences, tokenizer)
labels = torch.tensor(_labels, dtype=torch.int64)
ds_source = id_and_masks + (labels,)

train_dl, valid_dl = get_dataloader(ds_source, batch_size=32)

In [6]:
from fit import fit
stats = fit(model, train_dl, valid_dl, optimizer, epochs=4)


Training...
  Training took: 0:00:28

Running Validation...
  Validation took: 0:00:01

Training...
  Training took: 0:00:28

Running Validation...
  Validation took: 0:00:01

Training...
  Training took: 0:00:28

Running Validation...
  Validation took: 0:00:01

Training...
  Training took: 0:00:28

Running Validation...
  Validation took: 0:00:01

Training complete!
Total training took 0:01:58 (h:mm:ss)


In [7]:
stats

[{'epoch': 1,
  'Training Loss': 1.0470423494967132,
  'Valid. Loss': 0.46964893117547035,
  'Valid. Accur.': 0.85546875,
  'Training Time': '0:00:28',
  'Validation Time': '0:00:01'},
 {'epoch': 2,
  'Training Loss': 0.3440260239876807,
  'Valid. Loss': 0.38929920829832554,
  'Valid. Accur.': 0.8763020833333334,
  'Training Time': '0:00:28',
  'Validation Time': '0:00:01'},
 {'epoch': 3,
  'Training Loss': 0.17243757750838995,
  'Valid. Loss': 0.3836299367249012,
  'Valid. Accur.': 0.8919270833333334,
  'Training Time': '0:00:28',
  'Validation Time': '0:00:01'},
 {'epoch': 4,
  'Training Loss': 0.08538434787008625,
  'Valid. Loss': 0.4766979559014241,
  'Valid. Accur.': 0.8815104166666666,
  'Training Time': '0:00:28',
  'Validation Time': '0:00:01'}]