In [None]:
!nvidia-smi

In [None]:
!pip install transformers==4.20.1
!pip install seqeval==1.2.2
!pip install wandb==0.12.21
!pip install pytorch-crf==0.7.2
!pip install torchcontrib==0.0.2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 自身の環境のパスを指定
base_folder = "drive/MyDrive/Colab\ Notebooks/cpt-hanrei-1st-refactor/src"

In [None]:
cd {base_folder}

In [None]:
from transformers import BertForTokenClassification
from torch.optim import AdamW
from utils import save, get_save_dir, save_model, model_path_map
from model import NERModel
from dataloader import create_dataloader
import torch
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from tqdm.notebook import tqdm
from seqeval.metrics import f1_score, classification_report
from collections import Counter


def correct_idx(input_tags, tokens):
    tags = [i for i in input_tags]
    valid_idxs = get_seq_idx(tags, True, return_type="idx", flatten=False)
    for idxs in valid_idxs:
        string = "".join([tokens[i] for i in idxs])
        if len(string) <= 2:
            for i in idxs:
                tags[i] = "O"
    invalid_idxs = get_seq_idx(tags, False, return_type="idx", flatten=True)
    for i in invalid_idxs:
        tags[i] = "O"
    return tags


def is_valid_seq(seq):
    return seq[0].split("-")[0] == "B"


def get_seq_idx(tags, valid_flag=True, return_type="idx", flatten=True):
    begin = False
    category = None
    all_idx_ls = []
    all_tag_ls = []
    idx_ls = []
    tag_ls = []
    for i, tag in enumerate(tags):
        if isinstance(tag, str):
            if tag != "O":
                pos = tag.split("-")[0]
                category = tag.split("-")[1]
                if not begin:
                    begin = True
                else:
                    if category != tag_ls[0].split("-")[1]:
                        begin = False
                        if is_valid_seq(tag_ls) == valid_flag:
                            all_idx_ls.append(idx_ls)
                            all_tag_ls.append(tag_ls)
                        idx_ls = []
                        tag_ls = []
                idx_ls.append(i)
                tag_ls.append(tag)
            else:
                if begin:
                    begin = False
                    if is_valid_seq(tag_ls) == valid_flag:
                        all_idx_ls.append(idx_ls)
                        all_tag_ls.append(tag_ls)
                    idx_ls = []
                    tag_ls = []
    if return_type == "idx":
        result = all_idx_ls
    elif return_type == "tag":
        result = all_tag_ls
    if flatten:
        return [i for ls in result for i in ls]
    else:
        return result


def get_voted_result(result_list):
    voted_list = []
    for idx in range(len(result_list[0])):
        ls = [ls[idx] for ls in result_list]
        if len(set(ls)) == 1:
            voted_list.append(ls[0])
        else:
            counter = Counter(ls)
            voted_list.append(counter.most_common(1)[0][0])
    return voted_list


seed_map = {
    "cl-wom": 71,
    "cl-charwom": 271,
    "cl": 4306,
    "cl-char": 1545,
    "NICT-100k": 8155,
    "NICT-32k": 1250,
}

In [None]:
import pickle


def load_pickle(path):
    with open(path, 'rb') as handle:
        return pickle.load(handle)


class NERStackingDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, token_dict, labels=True):
        self.data_list = data_list
        self.labels = labels
        self.token_dict = token_dict
    
    def get_embedding(self, tokens):
        return np.stack([self.token_dict[token].cpu().numpy() for token in tokens])

    def __getitem__(self, idx):
        data = self.data_list[idx]
        item = {"input_logits": data["logits"],
                "embeddings": torch.tensor(self.get_embedding(data["tokens"]))
                }
        if self.labels:
            item['labels'] = torch.tensor(data["labels"])
        return item

    def __len__(self):
        return len(self.data_list)


loss_fct = CrossEntropyLoss()


class NERStackingModel(nn.Module):
    def __init__(self, hidden_dim, model_num):
        super().__init__()
        self.embedding_lstm = nn.LSTM(4096 + 11 * model_num, hidden_dim, batch_first=True, bidirectional=True)
        self.lstm = nn.LSTM(11 * model_num, hidden_dim, batch_first=True, bidirectional=True)
        self.logits = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 11),
        )
        self.high_dropout = nn.Dropout(p=0.5)

    def forward(self, input_logits, embeddings, labels=None):
        concated = torch.cat([input_logits, embeddings], -1)
        embedding_out, _ = self.embedding_lstm(concated)
        logits_out, _ = self.lstm(input_logits)
        features = torch.cat([logits_out, embedding_out], -1)
        logits = torch.mean(
            torch.stack(
                [self.logits(self.high_dropout(features)) for _ in range(5)],
                dim=0,
            ),
            dim=0,
        )
        if labels is None:
            return (logits,)
        loss = loss_fct(logits.view(-1, 11), labels.view(-1))
        return loss, logits


def save_model(model, path):
    torch.save(model.state_dict(), path)


class MAMeter:
    def __init__(self, windows=10):
        self.windows = windows
        self.ls = []

    def add(self, num):
        if len(self.ls) < 10:
            self.ls = self.ls + [num]
        else:
            self.ls = self.ls[1:] + [num]

    def avg(self):
        return sum(self.ls) / len(self.ls)


def train_fn(data_loader, model, optimizer, device, scheduler=None):
    model.train()
    model.to(device)
    loss_meter = MAMeter(10)
    bar = tqdm(data_loader, total=len(data_loader))

    for _, batch in enumerate(bar):
        input_logits = batch["input_logits"].to(device)
        labels = batch["labels"].to(device)
        embeddings = batch["embeddings"].to(device)
        model.zero_grad()
        loss, _ = model(input_logits=input_logits, embeddings=embeddings, labels=labels)
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()

        loss_meter.add(loss.item())
        bar.set_postfix(loss=loss_meter.avg())


def get_result(labels_ids, logits):
    labels_ids = labels_ids.flatten()
    tag_ids = logits.argmax(axis=-1).flatten()
    tag_ids = tag_ids[labels_ids >= 0].tolist()
    labels_ids = labels_ids[labels_ids >= 0].tolist()
    tags = [id2tag[id] for id in tag_ids]
    labels = [id2tag[id] for id in labels_ids]
    return tags, labels


id2tag = {-100: 'mask',
          0: 'O',
          1: 'B-TIMEX',
          2: 'I-TIMEX',
          3: 'B-PERSON',
          4: 'I-PERSON',
          5: 'B-ORGFACPOS',
          6: 'I-ORGFACPOS',
          7: 'B-LOCATION',
          8: 'I-LOCATION',
          9: 'B-MISC',
          10: 'I-MISC'}


def valid_fn(data_loader, model, device):
    model.eval()
    model.to(device)
    tag_list = []
    label_list = []
    logits_list = []
    with torch.no_grad():
        bar = tqdm(data_loader, total=len(data_loader))
        for _, batch in enumerate(bar):
            input_logits = batch["input_logits"].to(device)
            labels_ids = batch["labels"].to(device)
            embeddings = batch["embeddings"].to(device)
            outputs = model(input_logits=input_logits, embeddings=embeddings)
            logits = outputs[0]
            logits_list.append(logits)
            tags, labels = get_result(labels_ids, logits)
            tag_list.extend(tags)
            label_list.extend(labels)
    report = classification_report([label_list], [tag_list], digits=4)
    f1 = _extract_f1_from_report(report)
    return tag_list, label_list, logits_list, f1, report


def test_fn(data_loader, model, device):
    model.eval()
    model.to(device)
    logits_list = []
    with torch.no_grad():
        bar = tqdm(data_loader, total=len(data_loader))
        for _, batch in enumerate(bar):
            input_logits = batch["input_logits"].to(device)
            embeddings = batch["embeddings"].to(device)
            outputs = model(input_logits=input_logits, embeddings=embeddings)
            logits = outputs[0]
            logits_list.append(logits)
    pred = torch.cat(logits_list,axis=1).argmax(-1).detach().cpu().numpy()
    return pred

def _extract_f1_from_report(report):
    return float(report.split()[report.split().index("micro") + 4])

flair_embedding_dict = load_pickle("data/preprocessed/flair_embedding_dict.pk")

In [None]:
hidden_dim = 64
model_num = 6
lr = 0.001
seed = 4306
fold = 2

for fold in range(5):
    train_path = f"data/preprocessed/train_stacking_data_seed_{seed}_fold_{fold}.pk"
    valid_path = f"data/preprocessed/valid_stacking_data_seed_{seed}_fold_{fold}.pk"
    train_data = load_pickle(train_path)
    valid_data = load_pickle(valid_path)
    train_data = [data for data in train_data if data['logits'].shape[0] < 80000] # GPUのOut of Memoryのためサイズを制限。GPUのメモリを16GB以上に増やせる環境なら制限を取り払える
    valid_data = [data for data in valid_data if data['logits'].shape[0] < 80000] # GPUのOut of Memoryのためサイズを制限。GPUのメモリを16GB以上に増やせる環境なら制限を取り払える
    train_dataset = NERStackingDataset(train_data, flair_embedding_dict)
    valid_dataset = NERStackingDataset(valid_data, flair_embedding_dict)
    train_loader = DataLoader(train_dataset, batch_size=1)
    valid_loader = DataLoader(valid_dataset, batch_size=1)

    stacking_model = NERStackingModel(hidden_dim, model_num)
    optimizer = AdamW(stacking_model.parameters(), lr=lr)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    best_score = 0
    for epoch in range(10):
        train_fn(train_loader, stacking_model, optimizer, device)
        tag_list, label_list, logits_list, f1, report = valid_fn(valid_loader, stacking_model, device)
        if f1 > best_score:
            best_score = f1
            print("found better model, saving")
            save_model(stacking_model, f"save/stacking/stacking_model_{fold}.pt")
        print(report)

In [None]:
tag_list = []

for fold in range(5):
    test_path = f"data/preprocessed/test_stacking_data_{fold}_new_aug.pk"
    test_data = load_pickle(test_path)
    test_dataset = NERStackingDataset(test_data, flair_embedding_dict, labels=False)
    test_loader = DataLoader(test_dataset, batch_size=1)
    stacking_model.load_state_dict(torch.load(f"save/stacking/stacking_model_{fold}.pt"))
    tags = test_fn(test_loader, stacking_model, device)
    tag_list.append(tags)

In [None]:
tags = [[id2tag[i] for i in fold_id[0]] for fold_id in tag_list ]

In [None]:
import pandas as pd

tokens = pd.read_csv("data/input/test_token.csv").dropna().token.tolist()
voted = get_voted_result([correct_idx(t, tokens) for t in tags])
corrected = correct_idx(voted, tokens)

In [None]:
sub = pd.read_csv("data/input/sample_submission.csv")
tokens = pd.read_csv("data/input/test_token.csv")
tokens = tokens.dropna()
tokens["tag"]= corrected
tokens = tokens.sort_values(["file_id","token_id"])
submission = sub.drop("tag",axis=1).merge(tokens[["file_id","token_id","tag"]],how="outer")
submission.to_csv("data/output/stacking_from_train.csv",index=False)