In [None]:
!pip install spacy==3.2.4
!pip install ginza==5.1.0
!pip install ja-ginza==5.1.0
!pip install flair==0.11.3

ginzaをインストールした後、Kernelを再起動

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 自身の環境のパスを指定
base_folder = "drive/MyDrive/Colab\ Notebooks/cpt-hanrei-1st-refactor/src"

In [None]:
cd {base_folder}

In [None]:
# ダウンロードしたtrain.zipを解凍
!unzip -oq data/input/train.zip -d data/input/

In [None]:
import subprocess
from glob import glob
import pickle
import pandas as pd
import spacy
from tqdm.notebook import tqdm
import random
import numpy as np
random.seed(123)


def save(obj, path):
    with open(path, 'wb') as handle:
        pickle.dump(obj, handle)


def load_pickle(path):
    with open(path, 'rb') as handle:
        return pickle.load(handle)

def read_json_fold(path_list):
    dfs = []
    for path in tqdm(path_list):
        df = pd.read_json(path, orient='records', lines=True)
        dfs.append(df)
    return pd.concat(dfs)

def preprocess(input_df,label=True):
    df = input_df.copy()
    df['file_id'] = df['meta'].apply(lambda x: x['filename'].rstrip('_hanrei.txt')[1:]).map(int)
    df['category'] = df['meta'].apply(lambda x: x['category'])
    df['stratify'] = df['category'].apply(lambda x: 'その他' if x in ['労働事件裁判例', '高裁判例'] else x) # 裁判種別でtrain, valを分割。件数の少ない労働事件裁判例, 高裁判例はその他にまとめる
    df.drop(['meta', 'annotation_approver'], axis=1, inplace=True)
    df.sort_values('file_id', inplace=True)
    df = df[df["labels"].apply(len) != 0]
    return df.reset_index(drop=True)


def get_token_idx(tokens, txt):
    idx_ls = []
    begin = 0
    for token in tokens:
        begin = txt[begin:].find(token) + begin
        end = begin + len(token)
        idx_ls.append((begin, end))
        begin = end
    return idx_ls


def get_ginza_token_tag(token_idx, tags):
    tag_list = []
    for span_begin, span_end, tag in tags:
        ls = []
        for i, (first, last) in enumerate(token_idx):
            if (span_end > first >= span_begin) \
                    or (span_end >= last > span_begin) \
                    or (first < span_begin and last > span_begin):
                ls.append([i, tag])
        new_ls = [[ls[0][0], "B-" + ls[0][1]]]
        for item in ls[1:]:
            new_ls.append([item[0], "I-" + item[1]])
        tag_list.extend(new_ls)
    return tag_list


def get_ginza_token_tag_df(txt, tags, file_id):

    # ginza>=5.1では49149bytes超のテキストは受け取れないので、適当に分割する必要がある
    # https://github.com/megagonlabs/ginza/issues/242

    # tokens = [token.text for token in nlp(txt)]
    tokens = []
    for sent in txt.split('。 '):
        if sent:
            for token in nlp(sent+'。'):
                tokens.append(token.text)
    token_idx = get_token_idx(tokens, txt)
    tag_list = get_ginza_token_tag(token_idx, tags)

    tag_df = pd.DataFrame(tag_list)
    tag_df.columns = ["token_id", "tag"]
    tag_df = tag_df.merge(pd.Series(range(len(tokens)), name="token_id"), how="outer").sort_values("token_id").fillna("O")
    tag_df = tag_df.reset_index(drop=True)
    tag_df["token"] = tokens
    tag_df["file_id"] = file_id
    tag_df["token_idx"] = token_idx
    return tag_df


def generate_ginza_data(input_df):
    all_df = []
    for i in tqdm(range(len(input_df))):
        txt = input_df["text"][i]
        file_id = input_df["file_id"][i]
        tags = input_df["labels"][i]
        tags.sort(key=lambda x: x[0])
        df = get_ginza_token_tag_df(txt, tags, file_id)
        all_df.append(df)
    return pd.concat(all_df)

trainデータの判例文をまとめて、Ginzaで形態素解析したファイルを生成

In [None]:
nlp = spacy.load('ja_ginza')
train_paths = glob('data/input/train/*')

In [None]:
train_df = read_json_fold(train_paths)
preprocessed_train_df = preprocess(train_df)
ginza_train_data = generate_ginza_data(preprocessed_train_df)
ginza_train_data.to_csv("data/preprocessed/ginza_train_data.csv", index=False)

flairにより出力したword embeddingのファイルを生成

In [None]:
from flair.embeddings import FlairEmbeddings, StackedEmbeddings
from flair.data import Sentence
flair_embedding = StackedEmbeddings([FlairEmbeddings("ja-forward"), FlairEmbeddings("ja-backward")])

In [None]:
train_tokens = ginza_train_data.token.tolist()
test_tokens = pd.read_csv("data/input/test_token.csv").dropna().token.tolist()
tokens = train_tokens + test_tokens
tokens = list(set(tokens))

In [None]:
def get_word_embedding(sent):
    embedding_list = [token.embedding for token in sent]
    return sum(embedding_list)/len(embedding_list)

embedding_list = []
embedding_dict = {}
for token in tqdm(tokens):
    if token in embedding_dict:
        embedding_list.append(embedding_dict[token])
    else:
        sentence = Sentence(token)
        flair_embedding.embed(sentence)
        token_embedding = get_word_embedding(sentence)
        embedding_list.append(token_embedding)
        embedding_dict[token] = token_embedding

In [None]:
save(embedding_dict, "data/preprocessed/flair_embedding_dict.pk")