## tsvファイルを読み込んでBERTserver経由でテンソルに変換

後で文をトークナイズするモデルを使うため、事前に文ごとにEmbeddingをかけておく

In [3]:
#BERTを使って文ベクトルに変換
from bert_serving.client import BertClient
server_ip = '172.16.16.171' #BERTserverのipアドレス
bc_jp = BertClient(ip=server_ip,port=5555,port_out=5556,check_length=False,check_version=False)

import sentencepiece as spm
import numpy as np
import torch
import re

s = spm.SentencePieceProcessor()
s.Load("./wiki-ja.model")

def parse(text):
    text = text.lower()
    return s.EncodeAsPieces(text)

#tensorを返す関数を定義
def sentence2vec(text):
    sentence_lists = []
    sentence_lists = re.split('[。!]',text) #「。」と「！」で文を分割
    if len(sentence_lists) > 1:
        del sentence_lists[-1] #一番最後はNoneになるので消しておく
    
    parsed_texts = list(map(parse,sentence_lists))
    parsed_texts = [txt for txt in parsed_texts if (txt != [] and txt != '')]
    array = bc_jp.encode(parsed_texts,is_tokenized = True)
    return array.tolist() #すべての文の分散表現を並べたリストを返す。

先ほど作ったtsvを読み込んできて、テンソルに変換する関数を定義

In [16]:
import pandas as pd
from tqdm import tqdm

def doc2tensor(text): #引数でデータの種類を選択

    filepath = ''
    if text == 'train':
        filepath += './train_livedoor.tsv'
    elif text == 'val':
        filepath += './val_livedoor.tsv'
    elif text == 'test':
        filepath += './test_livedoor.tsv'
        
    origin_df = pd.read_csv(filepath,sep='\t')
    origin_df = origin_df.sample(frac=1).reset_index(drop=True)

    df = pd.DataFrame(columns=['main_doc','label'],index=range(len(origin_df)))

    for i in tqdm(range(len(origin_df))):
        df.iloc[i,:] = [sentence2vec(origin_df['main_doc'][i]),origin_df['label'][i]]
    
    return df

train,val,testをそれぞれEmbeddingかけた状態でpickle保存する

In [18]:
import pickle

tmp = ['train','val','test']

for i in range(len(tmp)):
    kind = tmp[i]
    with open('./embedded_{}_livedoor.pickle'.format(kind),'wb')as f:
        df = doc2tensor(kind)
        pickle.dump(df,f)


  0%|          | 0/5000 [00:00<?, ?it/s][A
  0%|          | 1/5000 [00:00<27:52,  2.99it/s][A
  0%|          | 2/5000 [00:00<27:07,  3.07it/s][A
  0%|          | 3/5000 [00:01<36:32,  2.28it/s][A
  0%|          | 4/5000 [00:02<52:28,  1.59it/s][A
  0%|          | 5/5000 [00:02<45:34,  1.83it/s][A
  0%|          | 6/5000 [00:03<46:05,  1.81it/s][A
  0%|          | 7/5000 [00:03<38:47,  2.14it/s][A
  0%|          | 8/5000 [00:04<48:44,  1.71it/s][A
  0%|          | 9/5000 [00:05<52:20,  1.59it/s][A
  0%|          | 10/5000 [00:05<47:15,  1.76it/s][A
  0%|          | 11/5000 [00:06<47:53,  1.74it/s][A
  0%|          | 12/5000 [00:07<55:19,  1.50it/s][A
  0%|          | 13/5000 [00:07<48:06,  1.73it/s][A
  0%|          | 14/5000 [00:08<52:19,  1.59it/s][A
  0%|          | 15/5000 [00:08<51:14,  1.62it/s][A
  0%|          | 16/5000 [00:09<46:49,  1.77it/s][A
  0%|          | 17/5000 [00:10<1:13:13,  1.13it/s][A
  0%|          | 18/5000 [00:12<1:26:09,  1.04s/it][A
  0%| 

中身はこんな感じ

In [19]:
df.head()

Unnamed: 0,main_doc,label
0,"[[0.22710075974464417, -0.30410313606262207, 0...",8
1,"[[0.6103750467300415, -0.5876586437225342, -0....",2
2,"[[-0.23334041237831116, 0.1732082962989807, 0....",6
3,"[[0.21237729489803314, -0.5661413669586182, 0....",1
4,"[[0.044950246810913086, 0.3979344964027405, 0....",7
