# データダウンロード

In [1]:
import os
import urllib.request
import tarfile

In [2]:
# フォルダ「data」が存在しない場合は作成する
data_dir = "./data"
if not os.path.exists(data_dir):
    os.mkdir(data_dir)

In [3]:
# ダウンロードをする時エラーがありました、これはそのSSLの修正
import ssl

if (not os.environ.get('PYTHONHTTPSVERIFY', '') and
    getattr(ssl, '_create_unverified_context', None)): 
    ssl._create_default_https_context = ssl._create_unverified_context

In [4]:
url = "https://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JWTD/jwtd_v2.0.tar.gz&name=JWTDv2.0.tar.gz"
save_path = "./data/JWTD.tar.gz"
if not os.path.exists(save_path):
    urllib.request.urlretrieve(url, save_path)

In [5]:
# ファイルを解凍し、カテゴリー数と内容を確認
import tarfile
import os

# 解凍
tar = tarfile.open("./data/JWTD.tar.gz")
tar.extractall("./data/JWTD/")
tar.close()


In [6]:
JWTD_TRAIN = "./data/JWTD/jwtd_v2.0/train.jsonl"
JWTD_TEST = "./data/JWTD/jwtd_v2.0/test.jsonl"

In [7]:
import itertools
import json

# 行ずつ読んで
def load_jwtd_data(path=JWTD_TRAIN):
    # ファイルを読む
    with open(path, encoding="utf-8") as file:
        # # 行ずつ読んで
        for line in file:
            # iteratorを変える
            yield json.loads(line)


In [8]:
for i, line in enumerate(load_jwtd_data(JWTD_TEST)):
    print(line)
    if i > 5:
        break

{'page': '239', 'title': 'うすた京介', 'pre_rev': '19775709', 'post_rev': '31820058', 'pre_text': 'セガのマニアであることを公言しており、作中よくセガの商品が登場する（ロボピッチャという昔のおもちゃのピッチングマシーンや、セガサターンで発売された『バーチャコップ』等）。', 'post_text': 'セガのマニアであることを公言しており、作中によくセガの商品が登場する（ロボピッチャという昔のおもちゃのピッチングマシーンや、セガサターンで発売された『バーチャコップ』等）。', 'diffs': [{'pre_str': '', 'post_str': 'に', 'pre_bart_likelihood': -32.5, 'post_bart_likelihood': -21.72, 'category': 'deletion'}], 'lstm_average_likelihood': -3.26}
{'page': '239', 'title': 'うすた京介', 'pre_rev': '56097633', 'post_rev': '56097638', 'pre_text': '一時期、の甲本ヒロトと交流があると記されていたが、２０１２年７月２２日、ｔｗｉｔｔｅｒ上での質問に「事実無根」と回答している。', 'post_text': '一時期、甲本ヒロトと交流があると記されていたが、２０１２年７月２２日、ｔｗｉｔｔｅｒ上での質問に「事実無根」と回答している。', 'diffs': [{'pre_str': 'の', 'post_str': '', 'pre_bart_likelihood': -44.55, 'post_bart_likelihood': -21.65, 'category': 'insertion_a'}], 'lstm_average_likelihood': -4.56}
{'page': '326', 'title': 'アーミッシュ', 'pre_rev': '19336024', 'post_rev': '27253461', 'pre_text': 'そのため自動車は運転しないが。', 'post_text': 'そのため自動車は運転しない。'

In [9]:
def filter_insertion_data(jwtd_data):
    categories = ["insertion_a", "insertion_b"]
    for line in jwtd_data:
        if len(line["diffs"]) == 1 and (line["diffs"][0]["category"] in categories):
            yield line

In [10]:
list(itertools.islice(filter_insertion_data(load_jwtd_data(JWTD_TEST)), 101, 103))

[{'page': '33524',
  'title': '大和 (百貨店)',
  'pre_rev': '72550844',
  'post_rev': '72550863',
  'pre_text': '２０１９年２月１５日にのパトリアの運営会社である「七尾都市開発」が破産したことにより、３月に閉店。',
  'post_text': '２０１９年２月１５日にパトリアの運営会社である「七尾都市開発」が破産したことにより、３月に閉店。',
  'diffs': [{'pre_str': 'の',
    'post_str': '',
    'pre_bart_likelihood': -31.04,
    'post_bart_likelihood': -22.51,
    'category': 'insertion_a'}],
  'lstm_average_likelihood': -3.25},
 {'page': '33524',
  'title': '大和 (百貨店)',
  'pre_rev': '70102421',
  'post_rev': '70425997',
  'pre_text': 'また２０１６年（平成２８年）には、石川県で初めてのサテライトショップとして野々市市サテライトショップがオープンした。',
  'post_text': 'また２０１６年（平成２８年）には、石川県で初めてのサテライトショップとして野々市サテライトショップがオープンした。',
  'diffs': [{'pre_str': '市',
    'post_str': '',
    'pre_bart_likelihood': -27.64,
    'post_bart_likelihood': -30.42,
    'category': 'insertion_b'}],
  'lstm_average_likelihood': -2.94}]

# Bertモデル

In [11]:
# !pip install fugashi unidic-lite transformers datasets
from transformers import BertJapaneseTokenizer, BertForTokenClassification

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
TAG_KEEP = '$KEEP'
TAG_DELETE = '$DELETE'

In [13]:
PRETRAINED_MODEL = "cl-tohoku/bert-base-japanese-v2"

tokenizer = BertJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL,
                    do_lower_case=False, word_tokenizer_type="mecab",
                    subword_tokenizer_type="wordpiece",
                    mecab_kwargs={"mecab_dic": "unidic_lite"})

In [14]:

    
def create_tags(wrong_text:str=None, correct_text:str=None, tokenizer=None):
    """
    誤ってる文章と正解の文章から、いらないトーケンを見つける。
    return:
        tags: list
    """
    
    wrong_tokens = tokenizer.tokenize(wrong_text)
    correct_tokens = tokenizer.tokenize(correct_text)

    # 必要の誤りだけ
    if len(wrong_tokens) != (len(correct_tokens) + 1):
        raise ValueError("誤り文字は一つ以上")

    tags = []
    keep_correct_token = None
    for wrong_token in wrong_tokens:
        if keep_correct_token is None:
            # アンパック（Unpacking）a_token, *remaining = a_doc
            correct_token, *correct_tokens = correct_tokens
        
        # 同時にリストをロープしてるからトーケンが違うならそれは誤り
        if wrong_token != correct_token:
            tags.append(TAG_DELETE)
            keep_correct_token = correct_token
        else:
            tags.append(TAG_KEEP)
            keep_correct_token = None

    return tags
    
    

In [15]:
tokenizer.tokenize("これはテストでです"), tokenizer("これはテストでです").input_ids, create_tags("これはテストでです", "これはテストです",tokenizer)

(['これ', 'は', 'テスト', 'で', 'です'],
 [2, 11190, 897, 13744, 889, 12461, 3],
 ['$KEEP', '$KEEP', '$KEEP', '$DELETE', '$KEEP'])

In [16]:
for line in itertools.islice(filter_insertion_data(load_jwtd_data(JWTD_TEST)), 110, 112):
    wrong = line["pre_text"]
    correct = line["post_text"]
    
    try:
        tags = create_tags(wrong, correct, tokenizer)
    
    except ValueError:
        continue
        
    print(tokenizer.tokenize(wrong))
    print(tags)
    print(wrong)
    print(correct)
    print()

['この', '為', '、', '影', '##武', '##者', '説', 'や', '「', '任期', '中', 'に', '別人', 'と', 'の', '入れ替わっ', 'て', 'い', 'た', '」', '等', 'の', '説', 'が', '流れ', 'た', 'こと', 'が', 'ある', '。']
['$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$DELETE', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP']
この為、影武者説や「任期中に別人との入れ替わっていた」等の説が流れたことがある。
この為、影武者説や「任期中に別人と入れ替わっていた」等の説が流れたことがある。

['出演', '料', 'は', '自身', 'が', '設立', 'し', 'た', '環境', '保護', '団体', 'と', '、', 'の', '元', 'アメリカ', '副', '大統領', 'アル', '・', 'ゴア', 'の', '地球', '温暖', '化', '防止', '事業', 'に', '寄付', 'さ', 'れ', 'た', 'と', 'いう', '。']
['$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$DELETE', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$KEEP', '$K

In [17]:
def load_jwtd_tags_data(path=JWTD_TRAIN):
    for line in filter_insertion_data(load_jwtd_data(path)):
        wrong = line["pre_text"]
        correct = line["post_text"]
        
        try:
            tags = create_tags(wrong, correct, tokenizer)

        except ValueError:
            continue
        
        yield wrong, tags

## データセット

In [18]:
import torch
import torch.optim as optim
import torch.nn as nn

TAG_LABEL_MAP = {
    TAG_KEEP: 0,
    TAG_DELETE: 1,
}
LABEL_TAG_MAP = {
    0: TAG_KEEP,
    1: TAG_DELETE,
}

In [19]:
list_tags = []
list_pre_text = []

for text, tags in load_jwtd_tags_data(JWTD_TRAIN):
    list_pre_text.append(text)
    list_tags.append(tags)
    
len(list_pre_text), len(list_tags), list_pre_text[0:3], list_tags[0:3]

(111465,
 111465,
 ['このような言語はし死語と呼ばれ、死語が再び母語として使用される例はほとんどない。',
  '待遇表現の面では、文法的・語彙的に発達した敬語体系がありり、叙述される人物同士の微妙な関係を表現する（「待遇表現」の節参照）。',
  'という短歌は、冒頭から「ひとひらの」までが「雲」に係る長い修飾語でありる。'],
 [['$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$DELETE',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP'],
  ['$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$DELETE',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
  

In [20]:
import pandas as pd

df = pd.DataFrame({'pre_text': list_pre_text, 'tags': list_tags})

# 大きさを確認しておく
print(df.shape)

df.head()

(111465, 2)


Unnamed: 0,pre_text,tags
0,このような言語はし死語と呼ばれ、死語が再び母語として使用される例はほとんどない。,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $DELETE, $..."
1,待遇表現の面では、文法的・語彙的に発達した敬語体系がありり、叙述される人物同士の微妙な関係を...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
2,という短歌は、冒頭から「ひとひらの」までが「雲」に係る長い修飾語でありる。,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
3,２０１７年４月月現在、インターネット上の言語使用者数は、英語、中国語、スペイン語、アラビア語...,"[$KEEP, $KEEP, $KEEP, $KEEP, $DELETE, $KEEP, $..."
4,また、関西で「う」を唇を丸めてで発音する（円唇母音）のに対し、関東では唇を丸めずに発音するの...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."


In [21]:
# 順番をシャッフルする
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
df.head()


Unnamed: 0,pre_text,tags
0,詳細は、＃プレイヤー側の用語・設定のデデデの項をを参照。,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
1,通常、フォースの構造と媒体は、超束積高エネルギー生命体（バイドの切れ端）を用いて生成し製作さ...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
2,ほとんどの魔物が攻撃系の魔法が主体であるのに、強力な盾や回復等、戦闘補助がほとんどというう極...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
3,「商標登録」がないとする認識のの論拠は何なのでしょう？,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
4,出場から再出場までの空白年数期間がを長かった歌手,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."


In [22]:
# 訓練と検証データを分割
# 全体の2割の文章数
len_0_2 = len(df) // 5

val_df_tags = df[:len_0_2]
train_df_tags = df[len_0_2:]

len_0_2, train_df_tags.shape, val_df_tags.shape, val_df_tags[0:3]

(22293,
 (89172, 2),
 (22293, 2),
                                             pre_text  \
 0                       詳細は、＃プレイヤー側の用語・設定のデデデの項をを参照。   
 1  通常、フォースの構造と媒体は、超束積高エネルギー生命体（バイドの切れ端）を用いて生成し製作さ...   
 2  ほとんどの魔物が攻撃系の魔法が主体であるのに、強力な盾や回復等、戦闘補助がほとんどというう極...   
 
                                                 tags  
 0  [$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE...  
 1  [$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE...  
 2  [$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE...  )

In [23]:
train_df_tags = train_df_tags.sample(frac=1, random_state=123).reset_index(drop=True)
train_df_tags.head()

Unnamed: 0,pre_text,tags
0,非戦闘員であるためアンナの手助けをしており、アノー号ではシスターミッチェルとともに生活班をに...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
1,１９４７年（昭和２２年年）５月１日　−　土浦市神立町に新治郡上大津村立上大津中学校として開校...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $DE..."
2,袁紫衣の母親をとの因縁がある。,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $DELETE, $..."
3,五稜郭駅前停留所は、函館市交通局が１９５５年（昭和３０年）１１月２７日の鉄道工場前～五稜郭駅...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
4,１９９４年　新潟国際情報大学設立（情報文化学部学部）,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."


In [24]:
val_df_tags.head()

Unnamed: 0,pre_text,tags
0,詳細は、＃プレイヤー側の用語・設定のデデデの項をを参照。,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
1,通常、フォースの構造と媒体は、超束積高エネルギー生命体（バイドの切れ端）を用いて生成し製作さ...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
2,ほとんどの魔物が攻撃系の魔法が主体であるのに、強力な盾や回復等、戦闘補助がほとんどというう極...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
3,「商標登録」がないとする認識のの論拠は何なのでしょう？,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
4,出場から再出場までの空白年数期間がを長かった歌手,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."


In [25]:
# トーケン数の長さを見ましょう
def map_to_length(x, tokenizer):

    x["pre_text_len"] = len(tokenizer(x["pre_text"]).input_ids)
    x["pre_text_longer_32"] = int(x["pre_text_len"] > 32)
    x["pre_text_longer_64"] = int(x["pre_text_len"] > 64)
    x["pre_text_longer_128"] = int(x["pre_text_len"] > 128)
    x["pre_text_longer_256"] = int(x["pre_text_len"] > 256)

    return x

In [26]:
from datasets import Dataset

tokenizer = BertJapaneseTokenizer.from_pretrained(
    PRETRAINED_MODEL,
    do_lower_case=False,
    word_tokenizer_type="mecab",
    subword_tokenizer_type="wordpiece",
    mecab_kwargs={"mecab_dic": "unidic_lite"}
)

dataset_from_df = Dataset.from_pandas(val_df_tags)
data_stats = dataset_from_df.select(range(len(dataset_from_df))).map(lambda x: map_to_length(x, tokenizer), num_proc=1)

Map: 100%|██████████| 22293/22293 [00:05<00:00, 4248.56 examples/s]


In [27]:
data_stats

Dataset({
    features: ['pre_text', 'tags', 'pre_text_len', 'pre_text_longer_32', 'pre_text_longer_64', 'pre_text_longer_128', 'pre_text_longer_256'],
    num_rows: 22293
})

In [28]:
def compute_and_print_stats(x):
    sample_size = len(x["pre_text_len"])
    print(
        "Pre text Mean: {}, %-Pre text > 32:{}, %-Pre text > 64:{}, %-Pre text > 128:{}, %-Pre text > 256:{}".format(
            sum(x["pre_text_len"]) / sample_size,
            sum(x["pre_text_longer_32"]) / sample_size,
            sum(x["pre_text_longer_64"]) / sample_size,
            sum(x["pre_text_longer_128"]) / sample_size,
            sum(x["pre_text_longer_256"]) / sample_size,
        )
    )

In [29]:
output = data_stats.map(
    compute_and_print_stats,
    batched=True,
    batch_size=-1)

Map: 100%|██████████| 22293/22293 [00:00<00:00, 495296.82 examples/s]

Pre text Mean: 43.88776746063787, %-Pre text > 32:0.6384963890010317, %-Pre text > 64:0.1634145247387072, %-Pre text > 128:0.004171713093796259, %-Pre text > 256:0.0





In [30]:
text = "朝太の名は、橘家圓喬が三遊亭圓朝に入門した折にに名乗ったのが最初である。"

encoding = tokenizer(text, max_length=16, padding="max_length",
                              truncation=True)

encoding.input_ids, tokenizer.convert_ids_to_tokens(encoding.input_ids)

([2,
  2821,
  1848,
  896,
  1564,
  897,
  828,
  3049,
  2000,
  1727,
  1665,
  862,
  26452,
  1727,
  2821,
  3],
 ['[CLS]',
  '朝',
  '太',
  'の',
  '名',
  'は',
  '、',
  '橘',
  '家',
  '圓',
  '喬',
  'が',
  '三遊亭',
  '圓',
  '朝',
  '[SEP]'])

In [19]:
from torch.utils.data import Dataset

class Typo_Dataset(Dataset):
    def __init__(self, df, tokenizer):
        self.tokenizer = tokenizer
        self.df = df
        self.max_length = 64
        
    def __len__(self):
        return len(self.df)
    
    def truncate_labels(self, labels):
        # tokenizerと同じように文章の長さから二個を排除する
        truncated_labels = labels[:(self.max_length) - 2]
        
        return truncated_labels
    
    def __getitem__(self, idx):
        pre_text, tags = self.df.loc[idx]
        
        encoding = self.tokenizer(pre_text, max_length=self.max_length, padding="max_length",
                                 truncation=True)
        
        # もし文書は切り捨てるとラベルは切り捨てないなら、長さが異なるからエラーです
        labels = [TAG_LABEL_MAP[t] for t in tags]
        labels = self.truncate_labels(labels)
        
        # tokenizerは[CLS]と[SEP]を追加するからラベルも含めて
        encoding["labels"] = [0] + labels + [0] + [0] * (
            len(encoding.input_ids) - len(labels) - 2)
        
        # Pytorchテンソルへ, return_tensors="pt"は新しい次元をつけるからめんどくさい
        encoding = {k: torch.tensor(v) for k, v in encoding.items()}

        return encoding
    

In [32]:
train_dataset = Typo_Dataset(train_df_tags, tokenizer)
val_dataset = Typo_Dataset(val_df_tags, tokenizer)

len(train_dataset), len(val_dataset), val_dataset[0], val_dataset[0]["input_ids"].shape

(89172,
 22293,
 {'input_ids': tensor([    2, 12916,   897,   828,    17, 12730,  1244,   896, 14120,  1025,
          11671,   896,   977,   977,   977,   896,  5664,   932,   932, 11854,
            829,     3,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0]),
  'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
  'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0,

In [33]:
from torch.utils.data import DataLoader

batch_size = 1024
# 128

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# 辞書オブジェクトにまとめる
dataloaders_dict = {"train": train_dataloader, "val": val_loader}


In [34]:
batch = next(iter(val_loader))
print(batch)
print(batch["input_ids"].shape)
print(batch["labels"])

{'input_ids': tensor([[    2, 12916,   897,  ...,     0,     0,     0],
        [    2, 11699,   828,  ..., 11878, 14467,     3],
        [    2, 11834,   896,  ...,     0,     0,     0],
        ...,
        [    2, 11574,   828,  ...,     0,     0,     0],
        [    2, 25573,  1026,  ...,     0,     0,     0],
        [    2, 11156, 16327,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 1, 1, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 

In [35]:
model = BertForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=PRETRAINED_MODEL, num_labels=2)
model.config.id2label= LABEL_TAG_MAP
model.config.label2id = TAG_LABEL_MAP

model.config

Some weights of BertForTokenClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertConfig {
  "_name_or_path": "cl-tohoku/bert-base-japanese-v2",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "$KEEP",
    "1": "$DELETE"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "$DELETE": 1,
    "$KEEP": 0
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "tokenizer_class": "BertJapaneseTokenizer",
  "transformers_version": "4.36.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 32768
}

In [36]:
# 訓練モードに設定
model.train()

print('ネットワーク設定完了')
model

ネットワーク設定完了


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32768, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

# 損失について

In [37]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
with  torch.set_grad_enabled(False):
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**batch)

    loss = outputs["loss"]
    logits = outputs["logits"]

In [38]:
outputs.keys(), logits.shape, loss, logits[0][0:5], batch["labels"][0][0:5]

(odict_keys(['loss', 'logits']),
 torch.Size([1024, 64, 2]),
 tensor(0.7426, device='cuda:0'),
 tensor([[ 0.4149, -0.0088],
         [ 0.2802, -0.4258],
         [ 0.8826,  0.5264],
         [ 0.7072, -0.4151],
         [ 0.6948,  0.2499]], device='cuda:0'),
 tensor([0, 0, 0, 0, 0], device='cuda:0'))

In [39]:
# 損失はロジットと確率の場合は違う
torch.nn.CrossEntropyLoss()(logits[0][0:5], batch["labels"][0][0:5])

tensor(0.4425, device='cuda:0')

In [40]:
logits[0][0:5].softmax(dim=1), torch.nn.CrossEntropyLoss()(logits[0][0:5].softmax(dim=1), batch["labels"][0][0:5])

(tensor([[0.6044, 0.3956],
         [0.6695, 0.3305],
         [0.5881, 0.4119],
         [0.7544, 0.2456],
         [0.6094, 0.3906]], device='cuda:0'),
 tensor(0.5603, device='cuda:0'))

In [41]:
nn.CrossEntropyLoss()(logits, batch["labels"])

RuntimeError: Expected target size [1024, 2], got [1024, 64]

In [42]:
logits.view(-1, 2).shape, batch["labels"].view(-1).shape, logits.shape, batch["labels"].shape

(torch.Size([65536, 2]),
 torch.Size([65536]),
 torch.Size([1024, 64, 2]),
 torch.Size([1024, 64]))

In [43]:
nn.CrossEntropyLoss()(logits.view(-1, 2), batch["labels"].view(-1))

tensor(0.7426, device='cuda:0')

In [44]:
loss_fct = nn.CrossEntropyLoss()

active_loss = batch["attention_mask"].view(-1) == 1
active_logits = logits.view(-1, 2)
active_labels = torch.where(
    active_loss, batch["labels"].view(-1), torch.tensor(loss_fct.ignore_index).type_as(batch["labels"])
)
active_loss.shape, active_logits.shape, active_labels.shape

(torch.Size([65536]), torch.Size([65536, 2]), torch.Size([65536]))

In [45]:
active_loss, active_labels, active_loss.dtype, active_labels.dtype, batch["labels"].dtype

(tensor([ True,  True,  True,  ..., False, False, False], device='cuda:0'),
 tensor([   0,    0,    0,  ..., -100, -100, -100], device='cuda:0'),
 torch.bool,
 torch.int64,
 torch.int64)

In [46]:
loss_fct(active_logits, active_labels)

tensor(0.7692, device='cuda:0')

In [47]:
weights = torch.tensor([0.2, 1]).to(device)
loss_fct = nn.CrossEntropyLoss(weight=weights)

loss_fct(active_logits, active_labels)

tensor(0.7553, device='cuda:0')

In [48]:
_, preds = torch.max(outputs["logits"], 2)
preds, preds.shape

(tensor([[0, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 1, 1, 0],
         [1, 1, 0,  ..., 1, 1, 1],
         ...,
         [0, 1, 0,  ..., 1, 0, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [0, 1, 0,  ..., 0, 0, 0]], device='cuda:0'),
 torch.Size([1024, 64]))

In [49]:
torch.mean(torch.sum(preds == batch["labels"] , dim=1) / batch["labels"].size(1))

tensor(0.4761, device='cuda:0')

# 学習

In [50]:
# 勾配計算を最後のBertLayerモジュールとClassifierモジュールだけ

# 1.まず全部を、勾配計算Falseにしてしまう
for param in model.parameters():
    param.requires_grad = False

# 2. BertLayerモジュールの最後を勾配計算ありに変更
for param in model.bert.encoder.layer[-1].parameters():
    param.requires_grad = True

# 3. 識別器を勾配計算ありに変更
for param in model.classifier.parameters():
    param.requires_grad = True
    

In [51]:
# 最適化手法の設定

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Bertの元の部分はファインチューニング
optimizer = optim.Adam([
    {'params': model.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
    {'params': model.classifier.parameters(), 'lr': 1e-4},
])

# 損失関数の設定
weights = torch.tensor([0.35, 1]).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)
# nn.LogSoftmax()を計算してからnn.NLLLoss(negative log likelihood loss)を計算

In [52]:
# モデルを学習させる関数を作成
from tqdm import tqdm

def train_one_epoch(model, dataloaders_dict, criterion, optimizer,
                   device, epoch, custom_loss=False):
    
    # モデルのモードを変わる
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()
        else:
            model.eval()
         
        # ミニバッチのサイズ、確率とロスのため
        batch_size = dataloaders_dict[phase].batch_size
        
        epoch_loss = 0.0
        epoch_corrects = 0

        pbar = tqdm(dataloaders_dict[phase])

        for batch in pbar:
            # GPUに移動する
            batch = {k: v.to(device) for k, v in batch.items()}
            
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                if custom_loss:
                    # レベルを入れない
                    outputs = model(batch["input_ids"], batch["attention_mask"],
                                    batch["token_type_ids"])
                    
                    # ロスの計算
                    if batch["attention_mask"] is not None:
                        active_loss = batch["attention_mask"].view(-1) == 1
                        active_logits = outputs["logits"].view(-1, 2)
                        active_labels = torch.where(
                            active_loss, batch["labels"].view(-1), 
                            torch.tensor(criterion.ignore_index).type_as(batch["labels"]))
                        
                        loss = criterion(active_logits, active_labels)
                    
                    else:
                        # マスクがない場合
                        loss = criterion(logits.view(-1, 2), batch["labels"].view(-1))
                
                # 普通のロス
                else:
                    outputs = model(**batch)
                    loss = outputs["loss"]
                    
                _, preds = torch.max(outputs["logits"], 2)
                
                if phase == "train":
                    loss.backward()
                    optimizer.step()
                
                # 損失と正解数の合計を更新
                epoch_loss += loss.item() * batch_size
                # batch acc
                epoch_corrects += torch.mean(torch.sum(preds == batch["labels"] , dim=1) / batch["labels"].size(1))

                pbar.set_description(f"loss: {loss}")
                
    epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
    # epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)
    epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase])

    print('Epoch {} | {:^5} |  Loss: {:.4f} Acc: {:.4f}'.format(epoch+1, phase,
                                                                epoch_loss, epoch_acc))

In [53]:
torch.cuda.empty_cache()

In [54]:
num_epochs = 2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = model.to(device)

# ネットワークがある程度固定であれば、高速化させる
torch.backends.cudnn.benchmark = True

for epoch in range(num_epochs):
    train_one_epoch(model, dataloaders_dict, criterion, optimizer,
                    device, epoch, custom_loss=False)


loss: 0.597691535949707:   2%|▏         | 2/88 [00:18<13:08,  9.17s/it] 


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "./token_class_weighted.pth")

# テスト

In [20]:
from transformers.models.bert.modeling_bert import BertForTokenClassification
from transformers.models.bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
import torch

MODEL_NAME = "cl-tohoku/bert-base-japanese-v2"

tokenizer = BertJapaneseTokenizer.from_pretrained(
    MODEL_NAME,
    do_lower_case=False,
    word_tokenizer_type="mecab",
    subword_tokenizer_type="wordpiece",
    mecab_kwargs={"mecab_dic": "unidic_lite"}
)
TAG_KEEP = '$KEEP'
TAG_DELETE = '$DELETE'

TAG_LABEL_MAP = {
    TAG_KEEP: 0,
    TAG_DELETE: 1,
}
LABEL_TAG_MAP = {
    0: TAG_KEEP,
    1: TAG_DELETE,
}

model = BertForTokenClassification.from_pretrained(pretrained_model_name_or_path=MODEL_NAME, num_labels=2)
model.config.id2label= LABEL_TAG_MAP
model.config.label2id = TAG_LABEL_MAP
model.config

Some weights of BertForTokenClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertConfig {
  "_name_or_path": "cl-tohoku/bert-base-japanese-v2",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "$KEEP",
    "1": "$DELETE"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "$DELETE": 1,
    "$KEEP": 0
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "tokenizer_class": "BertJapaneseTokenizer",
  "transformers_version": "4.36.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 32768
}

In [21]:
load_weights = torch.load("./token_class_weighted.pth")

model.load_state_dict(load_weights)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.eval()   # モデルを検証モードに

model.to(device)

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32768, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

## テストデータをロードする

In [23]:
list_pre_text = []
list_tags = []

# change to JWTD_TEST
for text, tags in load_jwtd_tags_data(JWTD_TEST):
    list_pre_text.append(text)
    list_tags.append(tags)
    
len(list_pre_text), len(list_tags), list_pre_text[0:3], list_tags[0:3]

(1118,
 1118,
 ['一時期、の甲本ヒロトと交流があると記されていたが、２０１２年７月２２日、ｔｗｉｔｔｅｒ上での質問に「事実無根」と回答している。',
  'そのため自動車は運転しないが。',
  '随筆の中で自らの息子に対して１２歳で去勢させる計画を綴っていたり（実際に彼女が息子の性器切断を行ったという話はない）、彼に対するフフェラチオ体験談などにも触れたりしており、ルポライターの谷口玲が纏めた少年への性的虐待についての著書の中で批判を受けている。'],
 [['$KEEP',
   '$KEEP',
   '$KEEP',
   '$DELETE',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP'],
  ['$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$DELETE',
   '$KEEP'],
  ['$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',
   '$KEEP',


In [24]:
import pandas as pd

df = pd.DataFrame({'pre_text': list_pre_text, 'tags': list_tags})

# 大きさを確認してお
print(df.shape)

# 順番をシャッフルする
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
df.head()


(1118, 2)


Unnamed: 0,pre_text,tags
0,２００８年４月２６日、第１回福福島競馬５日目第６レースのサラブレット３歳未勝利戦にてデビュー。,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
1,航空作戦国家センター（ＣＮＯＡ）が置かれている他、第０５．９４２探知管制センターは南部フラン...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
2,しかし直親も、「遠州錯乱」で小野政直の息子・小野道好の讒言により、主君の今川氏真から松平元康...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
3,イメージは６ｒｄ方式とは逆に、途中のＩＰｖ６空間にＩＰｖ４の信号を流すためのトンネルを設定し...,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."
4,プルートはドアの鍵穴を覗くと、そこには誘拐されたろロニーがいた。,"[$KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KEEP, $KE..."


In [26]:
test_dataset = Typo_Dataset(df, tokenizer)

len(test_dataset), test_dataset[0], test_dataset[0]["input_ids"].shape

(1118,
 {'input_ids': tensor([    2, 11431,  2181,    34,  2812, 11436,  2719,   828,  4036,    31,
           1708,  3933, 12745, 12385,    35,  2719,  3803,  4036,    36, 11995,
            896, 13423, 14731, 11160,    33,  3099,  2826, 11763,  2470, 11323,
          11617,   829,     3,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0]),
  'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
  'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0,

In [27]:
from torch.utils.data import DataLoader

batch_size = 128

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [28]:
batch = next(iter(test_dataloader))
print(batch)
print(batch["input_ids"].shape)
print(batch["labels"])

{'input_ids': tensor([[    2, 11431,  2181,  ...,     0,     0,     0],
        [    2, 11601, 12051,  ...,     0,     0,     0],
        [    2, 11258,  3805,  ..., 11158,   867,     3],
        ...,
        [    2, 31391,  2713,  ...,  1015, 19907,     3],
        [    2, 11145,  2293,  ...,     0,     0,     0],
        [    2, 12809,  1025,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 

In [29]:
# 損失関数の設定
weights = torch.tensor([0.35, 1]).to(device)
criterion = torch.nn.CrossEntropyLoss(weight=weights)

from tqdm import tqdm

# ミニバッチのサイズ
batch_size = test_dataloader.batch_size

epoch_loss = 0.0
epoch_corrects = 0

pbar = tqdm(test_dataloader)
custom_loss = True
model.eval()

for batch in pbar:
    # GPUに移動する
    batch = {k: v.to(device) for k, v in batch.items()}

    #   optimizer.zero_grad()
    with torch.set_grad_enabled(False):
        if custom_loss:
            # to define a custom loss for the weights
            # input_ids=batch["input_ids"]
            # attention_mask=batch["attention_mask"]
            # token_type_ids=batch["token_type_ids"]

            # labels = batch["labels"]

            outputs = model(batch["input_ids"], batch["attention_mask"], batch["token_type_ids"])

            # print(outputs.keys(), outputs["logits"].shape, batch["labels"].shape)
            #dict_keys(['logits']) torch.Size([128, 128, 2]) torch.Size([128, 128])
            if batch["attention_mask"] is not None:
                active_loss = batch["attention_mask"].view(-1) == 1
                active_logits = outputs["logits"].view(-1, 2)
                active_labels = torch.where(
                    active_loss, batch["labels"].view(-1), torch.tensor(criterion.ignore_index).type_as(batch["labels"])
                )
                # print(active_logits.shape, active_labels.shape)
                loss = criterion(active_logits, active_labels)
            else:
                loss = criterion(logits.view(-1, 2), batch["labels"].view(-1))

            # loss = criterion(outputs["logits"], batch["labels"])

        else:
            outputs = model(**batch)

            loss = outputs["loss"]

        _, preds = torch.max(outputs["logits"], 2)


        # 損失と正解数の合計を更新
        epoch_loss += loss.item() * batch_size
        # epoch_corrects += torch.sum(preds == batch["labels"])
        # batch acc
        epoch_corrects += torch.mean(torch.sum(preds == batch["labels"] , dim=1) / batch["labels"].size(1))

        pbar.set_description(f"loss: {loss}")

epoch_loss = epoch_loss / len(test_dataloader.dataset)
# epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)
epoch_acc = epoch_corrects.double() / len(test_dataloader)

print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))

loss: 0.24980299174785614: 100%|██████████| 9/9 [00:07<00:00,  1.25it/s]

Loss: 0.2239 Acc: 0.9137





In [32]:
def prediction_remove(model, sentence):
    model.eval()
    with torch.no_grad():
        # 今回はreturn_tensors="pt"を使って、バッチのように扱うことが出来ます
        encoding = tokenizer(sentence, max_length=64, padding="max_length",
                            truncation=True, return_tensors="pt")
        
        encoding = {k: v.to(device) for k, v in encoding.items()}
        
        outputs = model(**encoding)
        logits = outputs["logits"]
        preds = logits[0].argmax(-1).detach().cpu().numpy().tolist()
        
        # print(preds[1:-1])は[CLS]と[SEP]のトーケンを排除する
        return preds[1:-1]


In [35]:
for text in [
    "これははあんまり面白くないテスト",
    "直ぐにエラーがが見つけるだろう",
    "今日はちょっと雲雲が多い。",
    "今日はちょっと雲光が多い。",
    "最近のパソコンは結構性能能がいいじゃない？",
    "ああんまりいいことを書けないけど",
    "これは問題ないと思う。"
      ]:
    tokenized_text = tokenizer.tokenize(text)
    predictions = prediction_remove(model, text)
    
    for token, pred in zip(tokenized_text, predictions):
        if pred == 1:
            # 赤いいるで表示
            print(f" \033[31m{token}\033[0m ", end="")
            
        else:
            print(f"{token}", end="")
    print()

これは [31mは[0m あん##まり面白##くないテスト
直##ぐにエラー [31mが[0m  [31mが[0m 見つけるだろう
今日はちょっと雲雲が多い。
今日はちょっと雲光が多い。
最近のパソコンは結##構性能 [31m能[0m がいいじゃない?
ああん##まりいいことを書##けないけど
これは問題ないと思う。
