In [None]:
!pip install transformers==4.20.1
!pip install seqeval==1.2.2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 自身の環境のパスを指定
base_folder = "drive/MyDrive/Colab\ Notebooks/cpt-hanrei-1st-refactor/src"

In [None]:
cd {base_folder}

In [None]:
import random
from sklearn.model_selection import KFold
from utils import load_pickle
from seqeval.metrics import classification_report,f1_score
from collections import Counter
import pandas as pd
import torch
import numpy as np
from utils import load_pickle, save
import os
pd.options.display.max_rows = 10000



seed_map = {
    "cl-wom": 71,
    "cl-charwom": 271,
    "cl": 4306,
    "cl-char": 1545,
    "NICT-100k": 8155,
    "NICT-32k": 1250,
}


def correct_idx(input_tags, tokens):
    tags = [i for i in input_tags]
    valid_idxs = get_seq_idx(tags, True, return_type="idx", flatten=False)
    for idxs in valid_idxs:
        string = "".join([tokens[i] for i in idxs])
        if len(string) <= 2:
            for i in idxs:
                tags[i] = "O"
    invalid_idxs = get_seq_idx(tags, False, return_type="idx", flatten=True)
    for i in invalid_idxs:
        tags[i] = "O"
    return tags


def is_valid_seq(seq):
    return seq[0].split("-")[0] == "B"


def get_seq_idx(tags, valid_flag=True, return_type="idx", flatten=True):
    begin = False
    category = None
    all_idx_ls = []
    all_tag_ls = []
    idx_ls = []
    tag_ls = []
    for i, tag in enumerate(tags):
        if isinstance(tag, str):
            if tag != "O":
                pos = tag.split("-")[0]
                category = tag.split("-")[1]
                if not begin:
                    begin = True
                else:
                    if category != tag_ls[0].split("-")[1]:
                        begin = False
                        if is_valid_seq(tag_ls) == valid_flag:
                            all_idx_ls.append(idx_ls)
                            all_tag_ls.append(tag_ls)
                        idx_ls = []
                        tag_ls = []
                idx_ls.append(i)
                tag_ls.append(tag)
            else:
                if begin:
                    begin = False
                    if is_valid_seq(tag_ls) == valid_flag:
                        all_idx_ls.append(idx_ls)
                        all_tag_ls.append(tag_ls)
                    idx_ls = []
                    tag_ls = []
    if return_type == "idx":
        result = all_idx_ls
    elif return_type == "tag":
        result = all_tag_ls
    if flatten:
        return [i for ls in result for i in ls]
    else:
        return result


def get_voted_result(result_list):
    voted_list = []
    for idx in range(len(result_list[0])):
        ls = [ls[idx] for ls in result_list]
        if len(set(ls)) == 1:
            voted_list.append(ls[0])
        else:
            counter = Counter(ls)
            voted_list.append(counter.most_common(1)[0][0])
    return voted_list

seed_map = {
    "cl-wom": 71,
    "cl-charwom": 271,
    "cl": 4306,
    "cl-char": 1545,
    "NICT-100k": 8155,
    "NICT-32k": 1250,
}
tag2id = {'B-LOCATION': 7,
          'B-MISC': 9,
          'B-ORGFACPOS': 5,
          'B-PERSON': 3,
          'B-TIMEX': 1,
          'I-LOCATION': 8,
          'I-MISC': 10,
          'I-ORGFACPOS': 6,
          'I-PERSON': 4,
          'I-TIMEX': 2,
          'O': 0,
          'mask': -100}

In [None]:
GINZA_DATA_PATH = f"data/preprocessed/ginza_train_data.csv"
MODEL_NAME_LIST = ["cl-charwom", "cl-wom", "cl", "NICT-100k", "NICT-32k", "cl-char"]
GINZA_TRAIN_DF = pd.read_csv(GINZA_DATA_PATH)
TEST_TOKENS = pd.read_csv("data/input/test_token.csv").dropna().token.tolist()
FILE_IDS = GINZA_TRAIN_DF.file_id.unique()


class Result:

    def __init__(self, trial_name, model_name,  fold, data_type):
        self.trail_name = trial_name
        self.model_name = model_name
        self.fold = fold
        self.data_type = data_type

    @staticmethod
    def _extract_f1_from_report(report):
        return float(report.split()[report.split().index("micro") + 4])

    def retrieve_results(self, path=None):
        if not path:
            path = f"save/train/{self.trail_name}_{self.model_name}/output/{self.data_type}_pred_{self.fold}.pk"
        self.tags, self.labels, self.logits, report = load_pickle(path)
        self.f1 = self._extract_f1_from_report(report)


class ModelResult:

    def __init__(self, trail_name, model_name,data_type):
        self.trial_name = trail_name
        self.model_name = model_name
        self.data_type = data_type
        self.tokens = self.get_tokens()
        self.fold_results = self.retrieve_results_for_each_fold()
        self.index = self.get_index()
        self.tags = self.get_tags()
        self.tags_corrected = correct_idx(self.tags, self.tokens)
        self.logits = self.get_logits()

        if data_type=="valid":
            self.labels = self.get_labels()
            self.f1 = f1_score([self.labels], [self.tags])
            self.f1_corrected = f1_score([self.labels], [self.tags_corrected])

    def get_tokens(self):
        if self.data_type == "valid":
            return GINZA_TRAIN_DF.token.tolist()
        return TEST_TOKENS

    def retrieve_results_for_each_fold(self, folds=5):
        fold_results = []
        for fold in range(folds):
            result = Result(self.trial_name,
                            self.model_name,
                            fold,
                            self.data_type)
            result.retrieve_results()
            fold_results.append(result)
        return fold_results

    def get_file_ids(self):
        file_ids = FILE_IDS.copy()
        seed = seed_map[self.model_name]
        random.Random(seed).shuffle(file_ids)
        return file_ids

    def get_index(self):
        file_ids = self.get_file_ids()
        df_list = []
        for _, valid_idx in KFold(n_splits=5).split(file_ids):
            valid_file_ids = file_ids[valid_idx]
            for ids in valid_file_ids:
                df_list.append(GINZA_TRAIN_DF[GINZA_TRAIN_DF.file_id == ids])
        concated = pd.concat(df_list)
        index = concated.index
        return index

    def get_from_all_fold(self, attr):
        item_list = []
        for fold in range(5):
            result = self.fold_results[fold]
            item = getattr(result, attr)
            if self.data_type == "valid":
                item_list.extend(item)
            else:
                item_list.append(item)
        return item_list

    def get_tags(self):
        tags = self.get_from_all_fold("tags")
        if self.data_type=="valid":
            return self.sort_according_idx(tags)
        else:
            return get_voted_result(tags)

    def get_labels(self):
        labels = self.get_from_all_fold("labels")
        return self.sort_according_idx(labels)
    
    def get_logits(self):
        logits = self.get_from_all_fold("logits")
        logits = self.sort_according_idx(logits)
        return  torch.stack(logits, 0)
    
    def sort_according_idx(self, X):
        return [x for _,x in sorted(zip(self.index,X))]


class TrailResult:

    def __init__(self, trail_name, data_type):
        self.trail_name = trail_name
        self.data_type = data_type
        self.tokens = self.get_tokens()

        print("loding result files")
        self.model_result_dict = {
            model_name: ModelResult(trail_name,model_name,data_type)
            for model_name in MODEL_NAME_LIST
        }
        self.tag_list = [
            model_result.tags for model_result in self.model_result_dict.values()
        ]
        self.tag_list_corrected = [
            model_result.tags_corrected for model_result in self.model_result_dict.values()
        ]

    def generate_voted_result(self, tag_list, correct=False):
        voted = get_voted_result(tag_list)
        if correct:
            voted = correct_idx(voted, self.tokens)
        return voted

    def get_tokens(self):
        if self.data_type == "valid":
            return GINZA_TRAIN_DF.token.tolist()
        return TEST_TOKENS

    def generate_voted(self):
        print("generating voted result")
        self.voted = self.generate_voted_result(self.tag_list)
        self.voted_before_corrected = self.generate_voted_result(self.tag_list, correct=True)
        self.voted_after_corrected = self.generate_voted_result(self.tag_list_corrected, correct=True)

    def generate_report(self):
        label_list = GINZA_TRAIN_DF.tag.tolist()
        print("generating report")
        self.f1_report_orig = \
            classification_report([label_list], [self.voted], digits=4)
        self.f1_report_correct = \
            classification_report([label_list], [self.voted_before_corrected], digits=4)
        self.f1_report_vote_after_correct = \
            classification_report([label_list], [self.voted_after_corrected], digits=4)

    def print_all(self):
        print("#"*50)
        print("f1_orig".ljust(50, '-'))
        for model, result in self.model_result_dict.items():
            print(model, round(result.f1, 4))
        print("f1_dict_corrected".ljust(50, '-'))
        for model, result in self.model_result_dict.items():
            print(model, round(result.f1_corrected, 4))
        print("f1_report_orig".ljust(50, '-'))
        print(self.f1_report_orig)
        print("f1_report_correct".ljust(50, '-'))
        print(self.f1_report_correct)
        print("f1_report_vote_after_correct".ljust(50, '-'))
        print(self.f1_report_vote_after_correct)
        print("#" * 50)

In [None]:
result = TrailResult("seed_data_for_each_model", "valid")

In [None]:
logits_list = [result.model_result_dict[model].logits for model in MODEL_NAME_LIST]
logits = torch.cat(logits_list,axis=-1)
labels = pd.Series(result.model_result_dict[MODEL_NAME_LIST[0]].labels)

In [None]:
from sklearn.model_selection import KFold
from tqdm.notebook import tqdm

file_ids = GINZA_TRAIN_DF.file_id.unique()
seed = 4306
random.Random(seed).shuffle(file_ids)

for fold, (train_idx, valid_idx) in enumerate(KFold(n_splits=5).split(file_ids)):
    train_file_ids = file_ids[train_idx]
    valid_file_ids = file_ids[valid_idx]
    file_id = train_file_ids[0]
    train_data_list = []
    valid_data_list = []
    for file_id in train_file_ids:
        single_data = {}
        sub_df = GINZA_TRAIN_DF[GINZA_TRAIN_DF.file_id == file_id]
        index = sub_df.index.tolist()
        single_data["logits"]  = logits[index]
        single_data["tokens"] = sub_df.token.tolist()
        single_data["labels"]= [tag2id[tag] for tag in labels[index]]
        train_data_list.append(single_data)
    path = f"data/preprocessed/train_stacking_data_seed_{seed}_fold_{fold}.pk"
    save(train_data_list, path)
    for file_id in valid_file_ids:
        single_data = {}
        sub_df = GINZA_TRAIN_DF[GINZA_TRAIN_DF.file_id == file_id]
        index = sub_df.index.tolist()
        single_data["logits"]  = logits[index]
        single_data["tokens"] = sub_df.token.tolist()
        single_data["labels"]= [tag2id[tag] for tag in labels[index]]
        valid_data_list.append(single_data)
    path = f"data/preprocessed/valid_stacking_data_seed_{seed}_fold_{fold}.pk"
    save(valid_data_list, path)

In [None]:
test_df = pd.read_csv("data/input/test_token.csv").dropna().reset_index(drop=True)
file_ids = test_df.file_id.unique()

In [None]:
result = TrailResult("seed_data_for_each_model", "test")

In [None]:
test_logits = []
for model in MODEL_NAME_LIST:
    logits_list = [result.model_result_dict[model].fold_results[fold].logits
                for fold in range(5)
                ]
    logits = sum(logits_list)/len(logits_list)
    test_logits.append(logits)
logits = torch.cat(test_logits,axis=-1)

In [None]:
test_data_list = []

for file_id in file_ids:
    sub_df = test_df[test_df.file_id == file_id]
    idx = sub_df.index.tolist()
    
    dic = {"logits":logits[idx],
           "tokens":sub_df.token.tolist()
           }
    test_data_list.append(dic)

In [None]:
path = f"data/preprocessed/test_stacking_data_new_aug.pk"
save(test_data_list, path)

In [None]:
for fold in range(5):
    test_logits = []
    for model in MODEL_NAME_LIST:
        logits_list = result.model_result_dict[model].fold_results[fold].logits
        test_logits.append(logits_list)
    logits = torch.cat(test_logits,axis=-1)
    test_data_list = []
    for file_id in file_ids:
        sub_df = test_df[test_df.file_id == file_id]
        idx = sub_df.index.tolist()
        
        dic = {"logits":logits[idx],
            "tokens":sub_df.token.tolist()
            }
        test_data_list.append(dic)
    path = f"data/preprocessed/test_stacking_data_{fold}_new_aug.pk"
    save(test_data_list, path)