【前回versionとのDiff】
- [Update] LB0.715からfix_list関数の修正
- OOFの計算上、現在のLBベストよりCV改善のためサブ

In [None]:
import gc
gc.enable()

import sys
sys.path.append("../input/tez-lib/")

import os
import random
from tqdm import tqdm

import numpy as np
import pandas as pd
from scipy.special import softmax

import tez
import torch
import torch.nn as nn
from joblib import Parallel, delayed
from transformers import AutoConfig, AutoModel, AutoTokenizer
import pickle


# Config

In [None]:
class Config:
    input_dir = '../input/feedback-prize-2021'
    
    model_longformer = '../input/longformerlarge4096/longformer-large-4096'
    model_led = '../input/led-large'
    model_deberta = '../input/deberta-large-mlm-1024'
    model_deberta_x = '../input/deberta-xlarge/deberta-xlarge'
    
    max_len_test_longformer = 1600
    max_len_test_led = 1024
    
    num_jobs = 4
    seed = 1
    
    model_ckp_path = [
        #  (kind, num_labels, model_path, model_name, weight)
        # exp lf-large, lf-base, led-large, bb-large, ...
        
        # LB 0.710
        ('deberta-large', None, 15, '../input/2022022709-deberta-large-boe-bin/model_0.bin', '022709-deberta-large-boe-bin',0.8),
        #('deberta-large', None, 15, '../input/2022022709-deberta-large-boe-bin/model_1.bin', '022709-deberta-large-boe-bin',1.3),
        #('deberta-large', None, 15, '../input/2022022709-deberta-large-boe-bin/model_2.bin', '022709-deberta-large-boe-bin',1.3),
        #('deberta-large', None, 15, '../input/2022022709-deberta-large-boe-bin/model_3.bin', '022709-deberta-large-boe-bin',1.3),
        #('deberta-large', None, 15, '../input/2022022709-deberta-large-boe-bin/model_4.bin', '022709-deberta-large-boe-bin',1.3),
        
        # LB 0.702
        #('deberta-large', None, 22, '../input/fb-exp045-deberta/model_0.bin', 'yyama_exp045', 1.1),
        ('deberta-large', None, 22, '../input/fb-exp045-deberta/model_1.bin', 'yyama_exp045',1.1),
        ('deberta-large', None, 22, '../input/fb-exp045-deberta/model_2.bin', 'yyama_exp045',1.1),
        ('deberta-large', None, 22, '../input/fb-exp045-deberta/model_3.bin', 'yyama_exp045',1.1),
        #('deberta-large', None, 22,  '../input/fb-exp045-deberta/model_4.bin', 'yyama_exp045',1.1),
        
        # LB 0.709
        #('deberta-large', 'LSTM_2head', 22, '../input/feedback-101/fold-0.pt', 'makabe_feedback-101_LSTM_2head', 1.95),
        #('deberta-large', 'LSTM_2head', 22, '../input/feedback-101/fold-1.pt','makabe_feedback-101_LSTM_2head', 1.95),
        ('deberta-large', 'LSTM_2head', 22, '../input/feedback-101/fold-2.pt','makabe_feedback-101_LSTM_2head', 2.6),
        ('deberta-large', 'LSTM_2head', 22, '../input/feedback-101/fold-3.pt','makabe_feedback-101_LSTM_2head', 2.6),
        ('deberta-large', 'LSTM_2head', 22, '../input/feedback-101/fold-4.pt','makabe_feedback-101_LSTM_2head', 2.6),
        
        # LB 0.709
        ('deberta-large', 'LSTM', 22, '../input/feedback-098/fold-0.pt', 'makabe_feedback-098_LSTM',2.7),
        #('deberta-large', 'LSTM', 22, '../input/feedback-098/fold-1.pt', 'makabe_feedback-098_LSTM',2.025),
        #('deberta-large', 'LSTM', 22, '../input/feedback-098/fold-2.pt', 'makabe_feedback-098_LSTM',2.025),
        ('deberta-large', 'LSTM', 22, '../input/feedback-098/fold-3.pt', 'makabe_feedback-098_LSTM',2.7),
        ('deberta-large', 'LSTM', 22, '../input/feedback-098/fold-4.pt', 'makabe_feedback-098_LSTM',2.7),
        
        # LB 0.708
        ('deberta-large', None, 22, '../input/feedback-096/fold-0.pt', 'makabe_feedback-096', 2.1),
        ('deberta-large', None, 22, '../input/feedback-096/fold-1.pt', 'makabe_feedback-096',2.1),
        #('deberta-large', None, 22, '../input/feedback-096/fold-2.pt', 'makabe_feedback-096',2.1),
        #('deberta-large', None, 22, '../input/feedback-096/fold-3.pt', 'makabe_feedback-096',2.1),
        ('deberta-large', None, 22, '../input/feedback-096/fold-4.pt', 'makabe_feedback-096',2.1),
        
        # LB 0.707
        ('deberta-large', None, 24, '../input/fb-exp058-wo-mlm/model_0.bin', 'yyama_exp058-wo-mlm',1),
        ('deberta-large', None, 24, '../input/fb-exp058-wo-mlm/model_1.bin', 'yyama_exp058-wo-mlm',1),
        ('deberta-large', None, 24, '../input/fb-exp058-wo-mlm/model_2.bin', 'yyama_exp058-wo-mlm',1),
        #('deberta-large', None, 24, '../input/fb-exp058-wo-mlm/model_3.bin', 'yyama_exp058-wo-mlm',0.75),
        #('deberta-large', None, 24, '../input/fb-exp058-wo-mlm/model_4.bin', 'yyama_exp058-wo-mlm',0.75),
        
        # LB 707
        #('deberta-large', None, 15, '../input/2022030209-deberta-large-mnli-boe-bin/model_0.bin', 'makotu_030209-deberta-large-mnli-boe-bin', 1.2),
        #('deberta-large', None, 15, '../input/2022030209-deberta-large-mnli-boe-bin/model_1.bin', 'makotu_030209-deberta-large-mnli-boe-bin',1.2),
        #('deberta-large', None, 15, '../input/2022030209-deberta-large-mnli-boe-bin/model_2.bin', 'makotu_030209-deberta-large-mnli-boe-bin',1.2),
        #('deberta-large', None, 15, '../input/2022030209-deberta-large-mnli-boe-bin/model_3.bin', 'makotu_030209-deberta-large-mnli-boe-bin',1.0),
        #('deberta-large', None, 15, '../input/2022030209-deberta-large-mnli-boe-bin/model_4.bin', 'makotu_030209-deberta-large-mnli-boe-bin',1.0),
        
        # xlarge
        ('deberta-xlarge', None, 24, '../input/fb-exp062-xlarge/model_2.bin', 'yyama_fb_exp06x_xlarge_fold2-4', 1.4),
        ('deberta-xlarge', None, 24, '../input/fb-exp062-xlarge/model_3.bin', 'yyama_fb_exp06x_xlarge_fold2-4', 1.4),
        ('deberta-xlarge', None, 24, '../input/fb-exp062-xlarge/model_4.bin', 'yyama_fb_exp06x_xlarge_fold2-4', 1.4),
        
        
        # LB 0.697
        #('lf-large', None, 22, '../input/2022021410-lf-bie-bin/model_0.bin', 0.5),
        #('lf-large', None, 22, '../input/2022021410-lf-bie-bin/model_1.bin', 0.5),
        #('lf-large', None, 22, '../input/2022021410-lf-bie-bin/model_2.bin', 0.5),

        # LB 0.695
        ('led-large', None, 24, '../input/fb-exp037-led/model_0.bin', 'yyama_fb-exp037-led', 1.6),
        ('led-large', None, 24, '../input/fb-exp037-led/model_1.bin','yyama_fb-exp037-led',  1.6),
        #('led-large', None, 24, '../input/fb-exp037-led/model_2.bin','yyama_fb-exp037-led',  1.6),
        #('led-large', None, 24, '../input/fb-exp037-led/model_3.bin','yyama_fb-exp037-led',  1.6),
        ('led-large', None, 24, '../input/fb-exp037-led/model_4.bin','yyama_fb-exp037-led',  1.6)
    ]

    proba_thresh = {
        "Lead": 0.6, # 0.7
        "Position": 0.4, # 0.55
        "Evidence": 0.65,
        "Claim": 0.55,
        "Concluding Statement": 0.6, # 0.7
        "Counterclaim": 0.5,
        "Rebuttal": 0.55,
    }
    min_token_thresh = {
        "Lead": 5, # 9
        "Position": 4, # 5
        "Evidence": 14,
        "Claim": 2,
        "Concluding Statement": 7, # 11
        "Counterclaim": 6,
        "Rebuttal": 4,
    }
    link = {
        'Evidence': 40,
        'Counterclaim': 200,
        'Rebuttal': 200,
    }


cfg = Config()    
target_id_map = {
    "B-Lead": 0,
    "I-Lead": 1,
    "B-Position": 2,
    "I-Position": 3,
    "B-Evidence": 4,
    "I-Evidence": 5,
    "B-Claim": 6,
    "I-Claim": 7,
    "B-Concluding Statement": 8,
    "I-Concluding Statement": 9,
    "B-Counterclaim": 10,
    "I-Counterclaim": 11,
    "B-Rebuttal": 12,
    "I-Rebuttal": 13,
    "O": 14,
    "PAD": -100,
}
id_target_map = {v: k for k, v in target_id_map.items()}


In [None]:
# short 
#cfg.model_ckp_path = cfg.model_ckp_path[:10]

# Modules

## Utils

In [None]:
def get_test_text(ids):
    #with open(f'{cfg.input_dir}/train/{ids}.txt', 'r') as f:
    with open(f'{cfg.input_dir}/test/{ids}.txt', 'r') as f:
        text = f.read()
    return text


def seed_everything(seed: int) -> None:
    """
    seedの固定
    """
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        
seed_everything(cfg.seed)


## Pre Process

In [None]:
def _prepare_test_data_helper(tokenizer, ids):
    test_samples = []
    for idx in ids:
        text = get_test_text(idx)

        encoded_text = tokenizer.encode_plus(
            text,
            add_special_tokens=False,
            return_offsets_mapping=True,
            max_length=1600,
            truncation=True,
        )
        input_ids = encoded_text["input_ids"]
        offset_mapping = encoded_text["offset_mapping"]

        sample = {
            "id": idx,
            "input_ids": input_ids,
            "text": text,
            "offset_mapping": offset_mapping,
        }

        test_samples.append(sample)
    return test_samples


def prepare_test_data(df, tokenizer, num_jobs):
    test_samples = []
    ids = df["id"].unique()
    ids_splits = np.array_split(ids, 4)

    results = Parallel(n_jobs=num_jobs, backend="multiprocessing")(
        delayed(_prepare_test_data_helper)(tokenizer, idx) for idx in ids_splits
    )
    for result in results:
        test_samples.extend(result)

    return test_samples

## Dataset

In [None]:
class FeedbackTestDataset:
    def __init__(self, samples, max_len, tokenizer):
        self.samples = samples
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.length = len(samples)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        input_ids = self.samples[idx]["input_ids"]
        input_ids = [self.tokenizer.cls_token_id] + input_ids

        if len(input_ids) > self.max_len - 1:
            input_ids = input_ids[: self.max_len - 1]

        # add end token id to the input_ids
        input_ids = input_ids + [self.tokenizer.sep_token_id]
        attention_mask = [1] * len(input_ids)
        return {
            "ids": input_ids,
            "mask": attention_mask,
        }
    
    
class Collate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        output = dict()
        output["ids"] = [sample["ids"] for sample in batch]
        output["mask"] = [sample["mask"] for sample in batch]

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in output["ids"]])

        # add padding
        if self.tokenizer.padding_side == "right":
            output["ids"] = [s + (batch_max - len(s)) * [self.tokenizer.pad_token_id] for s in output["ids"]]
            output["mask"] = [s + (batch_max - len(s)) * [0] for s in output["mask"]]
        else:
            output["ids"] = [(batch_max - len(s)) * [self.tokenizer.pad_token_id] + s for s in output["ids"]]
            output["mask"] = [(batch_max - len(s)) * [0] + s for s in output["mask"]]

        # convert to tensors
        output["ids"] = torch.tensor(output["ids"], dtype=torch.long)
        output["mask"] = torch.tensor(output["mask"], dtype=torch.long)

        return output

## Model

In [None]:
class FeedbackModel(tez.Model):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.model_name = model_name
        self.num_labels = num_labels
        self.config = AutoConfig.from_pretrained(model_name)
        self.config.update(
            {
                'output_hidden_states': True,
                'add_pooling_layer': False,
                'num_labels': self.num_labels,
                'attention_probs_dropout_prob':  0.1,  # for longformer
                'hidden_dropout_prob': 0.1,            # for longformer
                'layer_norm_eps': 1e-7,                # for longformer
                'activation_dropout': 0.0,             # for LED
                'attention_dropout': 0.0,              # for LED
                'classif_dropout': 0.0,                # for LED
                'classifier_dropout': 0.0,             # for LED
                'decoder_layerdrop': 0.0,              # for LED
                'encoder_layerdrop': 0.0,              # for LED 
            }
        )
        self.transformer = AutoModel.from_config(self.config)
        self.output = nn.Linear(self.config.hidden_size, self.num_labels)

    def forward(self, ids, mask):
        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out.last_hidden_state
        logits = self.output(sequence_output)
        logits = torch.softmax(logits, dim=-1)
        return logits, 0, {}


class FeedbackModelWithLSTM(tez.Model):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.model_name = model_name
        self.num_labels = num_labels
        self.config = AutoConfig.from_pretrained(model_name)
        self.config.update(
            {
                'output_hidden_states': True,
                'add_pooling_layer': False,
                'num_labels': self.num_labels,
                'attention_probs_dropout_prob':  0.1,  # for longformer
                'hidden_dropout_prob': 0.1,            # for longformer
                'layer_norm_eps': 1e-7,                # for longformer
                'activation_dropout': 0.0,             # for LED
                'attention_dropout': 0.0,              # for LED
                'classif_dropout': 0.0,                # for LED
                'classifier_dropout': 0.0,             # for LED
                'decoder_layerdrop': 0.0,              # for LED
                'encoder_layerdrop': 0.0,              # for LED 
            }
        )
        self.transformer = AutoModel.from_config(self.config)
        self.lstm = nn.LSTM(
            self.config.hidden_size, self.config.hidden_size, batch_first=True)
        self.output = nn.Linear(self.config.hidden_size, self.num_labels)

    def forward(self, ids, mask):
        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out.last_hidden_state
        sequence_output, _ = self.lstm(sequence_output, None)
        logits = self.output(sequence_output)
        logits = torch.softmax(logits, dim=-1)
        return logits, 0, {}
    
    
class FeedbackModelWithLSTMTwoHead(tez.Model):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.model_name = model_name
        self.num_labels = num_labels
        self.config = AutoConfig.from_pretrained(model_name)
        self.config.update(
            {
                'output_hidden_states': True,
                'add_pooling_layer': False,
                'num_labels': self.num_labels,
                'attention_probs_dropout_prob':  0.1,  # for longformer
                'hidden_dropout_prob': 0.1,            # for longformer
                'layer_norm_eps': 1e-7,                # for longformer
                'activation_dropout': 0.0,             # for LED
                'attention_dropout': 0.0,              # for LED
                'classif_dropout': 0.0,                # for LED
                'classifier_dropout': 0.0,             # for LED
                'decoder_layerdrop': 0.0,              # for LED
                'encoder_layerdrop': 0.0,              # for LED 
            }
        )
        self.transformer = AutoModel.from_config(self.config)
        self.lstm = nn.LSTM(
            self.config.hidden_size, self.config.hidden_size, batch_first=True)
        self.output = nn.Linear(self.config.hidden_size, self.num_labels)
        self.output_2 = nn.Linear(self.config.hidden_size, self.num_labels - 9)

    def forward(self, ids, mask):
        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out.last_hidden_state
        sequence_output, _ = self.lstm(sequence_output, None)
        logits = self.output(sequence_output)
        logits = torch.softmax(logits, dim=-1)
        return logits, 0, {}

## Inference

In [None]:
def inference(df, cfg, bs, mode):
    #raw_preds = []
    prev_kind = None
    prev_head = None
    prev_n_labels = None
    
    long_model_w = sum([w for k, _, _, _, _, w in cfg.model_ckp_path if (k == 'lf-large') | (k == 'deberta-large')])
    short_model_w = sum([w for k, _, _, _, _, w in cfg.model_ckp_path if k == 'led-large'])
    total_w = long_model_w + short_model_w
    
    models = list(set([a[4] for a in cfg.model_ckp_path]))
    
    for model_type in models:
        models_num = len([a[4] for a in cfg.model_ckp_path if a[4] == model_type])
        
        pred_all = np.zeros([df['id'].nunique(), cfg.max_len_test_longformer, 15])
        for i, (kind, head, n_labels, model_path, model_name, w) in enumerate(cfg.model_ckp_path):
            if model_name == model_type:
                raw_preds = []
                if kind == 'lf-large':
                    model_name = cfg.model_longformer
                    max_len = cfg.max_len_test_longformer
                elif kind == 'led-large':
                    model_name = cfg.model_led
                    max_len = cfg.max_len_test_led
                elif kind == 'deberta-large':
                    model_name = cfg.model_deberta
                    max_len = cfg.max_len_test_longformer
                elif kind == 'deberta-xlarge':
                    model_name = cfg.model_deberta_x
                    max_len = cfg.max_len_test_longformer

                # 前回とモデルの構造が異なっている場合のみ初期化
                if kind != prev_kind or head != prev_head or n_labels != prev_n_labels:
                    print(f'use {model_name} tokenizer')
                    tokenizer = AutoTokenizer.from_pretrained(model_name)
                    samples = prepare_test_data(df, tokenizer, cfg.num_jobs)
                    dataset = FeedbackTestDataset(samples, max_len, tokenizer)

                    if head is None:
                        model = FeedbackModel(model_name=model_name, num_labels=n_labels)
                    elif head == 'LSTM':
                        model = FeedbackModelWithLSTM(model_name=model_name, num_labels=n_labels)
                    elif head == 'LSTM_2head':
                        model = FeedbackModelWithLSTMTwoHead(
                            model_name=model_name, num_labels=n_labels
                        )

                prev_kind, prev_head, prev_n_labels = kind, head, n_labels

                model.eval()
                model.load(model_path, weights_only=True)

                collate = Collate(tokenizer=tokenizer)
                
                
                if 'led' in model_name:
                    bs2 = min(bs, 20)
                else:
                    bs2 = bs
                
                preds_iter = model.predict(dataset, batch_size=bs2, n_jobs=-1, collate_fn=collate)

                current_idx = 0
                for preds in preds_iter:
                    preds = preds.astype(np.float16)
                    preds = np.pad(preds, [(0, 0), (0, cfg.max_len_test_longformer - preds.shape[1]), (0, 0)], 'constant')

                    if n_labels == 22:
                        preds = reshape_preds_22_to_15(preds)
                    elif n_labels == 24:
                        preds = reshape_preds_24_to_15(preds)

                        
                    
                    # デバッグ用
                    #long_model_w = 1
                    
                    
                    #preds *= w / total_w
                    # 1024(max_len_test_led)を超える分に関しては定数倍
                    preds[:, cfg.max_len_test_led:, :] *= total_w / long_model_w

                    #if i == 0:
                    raw_preds.append(preds)
                    #else:
                    #    raw_preds[current_idx] += preds
                    #    current_idx += 1
                torch.cuda.empty_cache()
                gc.collect()
                pred_all += np.vstack(raw_preds) / models_num
            
        with open(f'pred_{model_type}_{mode}.pkl', 'wb') as f:
            pickle.dump(pred_all , f)

        #preds_class, preds_prob = [], []
        #for rp in raw_preds:
        #    c_arr = np.argmax(rp, axis=2)
        #    p_arr = np.max(rp, axis=2)
        #    for c, p in zip(c_arr, p_arr):
        #        c = c.tolist()
        #        p = p.tolist()
        #        preds_class.append(c)
        #        preds_prob.append(p)

    #for i in range(len(samples)):
        # 先頭0はspecial tokenなので除外
    #    text_class = [id_target_map[c] for c in preds_class[i][1:]]
    #    text_prob = preds_prob[i][1:]
    #    samples[i]['pred_class'] = text_class
    #    samples[i]['pred_prob'] = text_prob

    #return samples


def reshape_preds_24_to_15(preds):
    new_preds = np.zeros((preds.shape[0], preds.shape[1], 15))
    new_preds[:, :, 0] = preds[:, :, 0]
    new_preds[:, :, 1] += preds[:, :, 1] + preds[:, :, 2]
    new_preds[:, :, 2] += preds[:, :, 3]
    new_preds[:, :, 3] += preds[:, :, 4] + preds[:, :, 5]
    new_preds[:, :, 4] += preds[:, :, 6]
    new_preds[:, :, 5] += preds[:, :, 7] + preds[:, :, 8]
    new_preds[:, :, 6] += preds[:, :, 9]
    new_preds[:, :, 7] += preds[:, :, 10] + preds[:, :, 11]
    new_preds[:, :, 8] += preds[:, :, 12]
    new_preds[:, :, 9] += preds[:, :, 13] + preds[:, :, 14]
    new_preds[:, :, 10] += preds[:, :, 15]
    new_preds[:, :, 11] += preds[:, :, 16] + preds[:, :, 17]
    new_preds[:, :, 12] += preds[:, :, 18]
    new_preds[:, :, 13] += preds[:, :, 19] + preds[:, :, 20]
    new_preds[:, :, 14] += preds[:, :, 21] + preds[:, :, 22] + preds[:, :, 23]
    return new_preds

def reshape_preds_22_to_15(preds):
    new_preds = np.zeros((preds.shape[0], preds.shape[1], 15))
    new_preds[:, :, 0] = preds[:, :, 0]
    new_preds[:, :, 1] += preds[:, :, 1] + preds[:, :, 2]
    new_preds[:, :, 2] += preds[:, :, 3]
    new_preds[:, :, 3] += preds[:, :, 4] + preds[:, :, 5]
    new_preds[:, :, 4] += preds[:, :, 6]
    new_preds[:, :, 5] += preds[:, :, 7] + preds[:, :, 8]
    new_preds[:, :, 6] += preds[:, :, 9]
    new_preds[:, :, 7] += preds[:, :, 10] + preds[:, :, 11]
    new_preds[:, :, 8] += preds[:, :, 12]
    new_preds[:, :, 9] += preds[:, :, 13] + preds[:, :, 14]
    new_preds[:, :, 10] += preds[:, :, 15]
    new_preds[:, :, 11] += preds[:, :, 16] + preds[:, :, 17]
    new_preds[:, :, 12] += preds[:, :, 18]
    new_preds[:, :, 13] += preds[:, :, 19] + preds[:, :, 20]
    new_preds[:, :, 14] += preds[:, :, 21]
    return new_preds

## Post Process

In [None]:
def post_process(samples):
    sub = create_submission(samples)
    sub = thresh_prob(sub, cfg)
    sub = thresh_min_token(sub, cfg)
    sub = get_max_prob(sub)
    sub = link_class(sub, 'Evidence', cfg.link['Evidence'])
    sub = link_class(sub, 'Counterclaim', cfg.link['Counterclaim'])
    sub = link_class(sub, 'Rebuttal', cfg.link['Rebuttal'])
    sub = sub.reset_index(drop=True)
    sub = sub[['id', 'class', 'predictionstring']]
    return sub

def post_process(samples):
    sub = create_submission(samples)
    sub = thresh_prob(sub, cfg)
    sub = thresh_min_token(sub, cfg)
    sub = get_max_prob(sub)
    sub = link_class(sub, 'Evidence', cfg.link['Evidence'])
    sub = link_class(sub, 'Counterclaim', cfg.link['Counterclaim'])
    sub = link_class(sub, 'Rebuttal', cfg.link['Rebuttal'])
    sub = sub.reset_index(drop=True)
    sub = sub[['id', 'class', 'predictionstring']]
    return sub


def create_submission(samples):
    sub = []
    for _, sample in enumerate(samples):
        pred_class = sample['pred_class']
        offset_mapping = sample['offset_mapping']
        sample_id = sample['id']
        sample_text = sample['text']
        pred_prob = sample['pred_prob']
        
        second_class = sample['second_class']
        second_prob = sample['second_prob']
        
        third_class = sample['third_class']
        third_prob = sample['third_prob']
        

        pred_class = fix_list(pred_class)

        sample_preds = []
        # テキストが4096より長い場合
        if len(pred_class) < len(offset_mapping)-1:
            pred_class += ['O'] * (len(offset_mapping) - len(pred_class))
            second_class += ['O'] * (len(offset_mapping) - len(second_class))
            third_class += ['O'] * (len(offset_mapping) - len(third_class))
            
            pred_prob += [0] * (len(offset_mapping) - len(pred_prob))
            second_prob += [0] * (len(offset_mapping) - len(second_prob))
            third_prob += [0] * (len(offset_mapping) - len(third_prob))

        idx = 0
        phrase_preds = []
        while idx < len(offset_mapping)-1:
            start, _ = offset_mapping[idx]
            if pred_class[idx] != 'O':
                label = pred_class[idx][2:]
            else:
                label = 'O'
                
            if second_class[idx] != 'O':
                second_label = second_class[idx][2:]
            else:
                second_label = 'O'
                
            if pred_class[idx] != 'O':
                third_label = third_class[idx][2:]
            else:
                third_label = 'O'
            
            
            phrase_probs = []
            phrase_probs.append(pred_prob[idx])
            
            phrase_second_probs = []
            phrase_second_probs.append(second_prob[idx])
            
            phrase_third_probs = []
            phrase_third_probs.append(third_prob[idx])
            
            idx += 1
            
            while idx < len(offset_mapping)-1:
                if label != 'O':
                    matching_label = f'I-{label}'
                else:
                    matching_label = 'O'
                if pred_class[idx] == matching_label:
                    _, end = offset_mapping[idx]
                    phrase_probs.append(pred_prob[idx])
                    phrase_second_probs.append(second_prob[idx])
                    phrase_third_probs.append(third_prob[idx])
                    
                    idx += 1
                else:
                    break
            if 'end' in locals():
                phrase = sample_text[start:end]
                phrase_preds.append((phrase, start, end, label, second_label, third_label, phrase_probs, phrase_second_probs, phrase_third_probs))

        temp_df = []
        for phrase_idx, (phrase, start, end, label, second_label, third_label, phrase_probs, phrase_second_probs, phrase_third_probs) in enumerate(phrase_preds):
            word_start = len(sample_text[:start].split())
            word_end = word_start + len(sample_text[start:end].split())
            word_end = min(word_end, len(sample_text.split()))
            ps = " ".join([str(x) for x in range(word_start, word_end)])
            if label != 'O':
                phrase_probs_mean = sum(phrase_probs) / len(phrase_probs)
                phrase_second_probs_mean = sum(phrase_second_probs) / len(phrase_probs)
                phrase_third_probs_mean = sum(phrase_third_probs) / len(phrase_probs)
                
                temp_df.append((sample_id, label, second_label, third_label, ps, phrase_probs_mean, phrase_second_probs_mean, phrase_third_probs_mean))
        temp_df = pd.DataFrame(temp_df, columns=['id', 'class', 'second_class', 'third_class', 'predictionstring', 'prob', 'second_prob', 'third_prob'])
        sub.append(temp_df)
    
    sub = pd.concat(sub).reset_index(drop=True)
    sub['len'] = sub['predictionstring'].apply(lambda x: len(x.split()))
    sub = sub[sub['len'] > 0]
    return sub

def fix_list(pred_list):

    class_list = ["I-Lead", "I-Position", "I-Evidence", "I-Claim", 
                  "I-Concluding Statement", "I-Counterclaim", "I-Rebuttal", "O"]
    
    fix_threholds = {
        "I-Lead":2, 
        "I-Concluding Statement":2, 
        "I-Evidence":1,
        "I-Position":2,
        "I-Claim":1,
        "I-Counterclaim":5, 
        "I-Rebuttal":7,
        "O":1
    }
    
    for class_ in class_list:

        flg_index = []
        out_class = [col for col in class_list if col not in class_]
        counter = 0

        for token_id, token in enumerate(pred_list):
            
            # 連続2回以上続いた後の別classにはflgを立てる
            if counter > 2 and token in out_class:
                flg_index.append(token_id)
                counter = 0

            if token == class_:
                counter += 1
            else:
                counter = 0
                
        for ind in flg_index:
            if ind + fix_threholds[class_] + 1 < len(pred_list):
                counter_2 = fix_threholds[class_]
                while counter_2 != 0:
                    if pred_list[ind + counter_2] == class_ and pred_list[ind + counter_2 + 1] == class_:
                        for i in range(counter_2):
                            pred_list[ind + i] = class_
                        counter_2 = 0
                    else:
                        counter_2 -= 1
                        
    return pred_list

def fix_list_(pred_list):
    class_list = ["I-Lead", "I-Position", "I-Evidence", "I-Claim", 
                  "I-Concluding Statement", "I-Counterclaim", "I-Rebuttal"]

    for class_ in class_list:
        flg_index = []
        out_class = set(class_list) - {class_}
        counter = 0

        for token_id, token in enumerate(pred_list):
            if counter > 2 and token in out_class:
                flg_index.append(token_id)
                counter = 0

            if token == class_:
                counter += 1
            else:
                counter = 0

        for ind in flg_index:
            if ind + 2 < len(pred_list):
                if pred_list[ind + 1] == class_ and pred_list[ind + 2] == class_:
                    pred_list[ind] = class_

    return pred_list


def jn(pst, start, end):
    return " ".join([str(x) for x in pst[start:end] if x != -1])


def link_class(oof, discourse_type, thresh2):
    id_list = oof['id'].unique().tolist()
    if not len(oof):
        return oof
    thresh = 1
    idu = oof['id'].unique()
    eoof = oof[oof['class'] == f"{discourse_type}"]
    neoof = oof[oof['class'] != f"{discourse_type}"]
    eoof.index = eoof[['id', 'class']]
    
    retval = []
    for idv in idu:
        q = eoof[eoof['id'] == idv]
        if not len(q):
            continue
        pst = []
        for r in q.itertuples():
            pst = [*pst, -1,  *[int(x) for x in r.predictionstring.split()]]
        start, end = 1, 1
        for i in range(2, len(pst)):
            cur = pst[i]
            end = i
            if  (
                (cur == -1) and
                ((pst[i + 1] > pst[end - 1] + thresh) or (pst[i + 1] - pst[start] > thresh2))
            ):
                retval.append((idv, discourse_type, jn(pst, start, end)))
                start = i + 1
        v = (idv, discourse_type, jn(pst, start, end + 1))
        retval.append(v)

    roof = pd.DataFrame(retval, columns=['id', 'class', 'predictionstring'])
    roof = roof.merge(neoof, how='outer')
    
    dfs = []
    for doc_id in id_list:
        r_df_tmp = roof.query(f'id == "{doc_id}"')
        r_df_tmp['start'] = r_df_tmp['predictionstring'].apply(lambda x: int(x.split(' ')[0]))
        r_df_tmp = r_df_tmp.sort_values('start').drop('start', axis=1)
        dfs.append(r_df_tmp)
    return pd.concat(dfs, axis=0)


def thresh_prob(df, cfg):
    df_other = df[(df['class'] != 'Claim') | (df['len'] != 2)]
    df_target = df[(df['class'] == 'Claim') & (df['len'] == 2)]
    df_target['prob'] -= 0.1
    df = pd.concat([df_other, df_target])
    df = df.sort_index()
    for k, v in cfg.proba_thresh.items():
        idx = df.loc[df['class'] == k].query(f'prob < {v}').index
        df = df.drop(idx)
    return df

# add
def thresh_second_prob(df, cfg):
    df = df.sort_index()
    for k, v in cfg.second_proba_thresh.items():
        idx = df.loc[df['class'] == k].query(f'second_prob > {v}').index
        df = df.drop(idx)
    return df

# add
def thresh_third_prob(df, cfg):
    df = df.sort_index()
    for k, v in cfg.second_proba_thresh.items():
        idx = df.loc[df['class'] == k].query(f'third_prob > {v}').index
        df = df.drop(idx)
    return df


def thresh_min_token(df, cfg):
    df['len'] = df['predictionstring'].apply(lambda x: len(x.split(' ')))
    for k, v in cfg.min_token_thresh.items():
        idx = df.loc[df['class'] == k].query(f'len < {v}').index
        df = df.drop(idx)
    return df


def get_max_prob(sub):
    sub['prob'] = sub['prob'].astype(float)
    id_list = sub['id'].unique().tolist()
    unique_class = ['Lead', 'Position', 'Concluding Statement']
    sub_in_unique = sub[sub['class'].isin(unique_class) == True]
    sub_not_in_unique = sub[sub['class'].isin(unique_class) == False]
    sub_in_unique = sub_in_unique.loc[sub_in_unique.groupby(['id', 'class'])['prob'].idxmax(), :]
    sub = pd.concat([sub_in_unique, sub_not_in_unique])
    return sub


def post_process_sub(sub):
    sub = thresh_prob(sub, cfg)
    sub = thresh_min_token(sub, cfg)
    sub = get_max_prob(sub)
    sub = link_class(sub, 'Evidence', cfg.link['Evidence'])
    sub = link_class(sub, 'Counterclaim', cfg.link['Counterclaim'])
    sub = link_class(sub, 'Rebuttal', cfg.link['Rebuttal'])
    sub = sub.reset_index(drop=True)
    sub = sub[['id', 'class', 'predictionstring', 'prob', 'len']]
    return sub


# main

In [None]:
test_df = pd.read_csv(os.path.join("../input/feedback-prize-2021/", "sample_submission.csv"))

if len(test_df) < 100:
    test_df['class'] = 'Claim'
    test_df.to_csv('submission.csv', index=False)

#test_df = pd.read_csv(os.path.join("../input/feedback-prize-2021/", "train.csv"))
#test_df = test_df[test_df['id'].isin(test_df['id'].unique()[:1000])].drop_duplicates(subset=['id'])
#test_df = test_df.drop_duplicates(subset=['id'])

test_ids = test_df['id'].unique()
test_ids = test_ids[~(test_ids=='AD005493F9BF')]
# for debug
# test_ids = test_ids[:100]
test_df = test_df[test_df['id'].isin(test_ids)]

tokenizer = AutoTokenizer.from_pretrained('../input/longformerlarge4096/longformer-large-4096')
test_samples = prepare_test_data(test_df, tokenizer, cfg.num_jobs)

# get test token_len
token_len_list = [len(test_samples[i]['input_ids']) for i in range(len(test_samples))]
test_df['token_len'] = token_len_list

del test_samples

# sort test data and rebuild test_df
test_df = test_df.sort_values('token_len')
test_df.reset_index(drop=True, inplace = True)

#test_df_A = test_df.query('token_len < 600')
#test_df_B = test_df.query('token_len >= 600')

#test_ids_A = test_df_A['id'].unique()
#test_ids_B = test_df_B['id'].unique()

In [None]:
test_df

In [None]:
#inference(test_df_A, cfg, 32, 'short')
#inference(test_df_B, cfg, 8, 'long')

In [None]:
def create_id_df(samples):
    
    id_df = pd.DataFrame()
    
    offset_list = []
    id_list = []
    num_list = []
    len_list = []
        
    for i in tqdm(range(len(samples))):
        offset_list += (samples[i]['offset_mapping'] + [(9999,9999)] * (cfg.max_len_test_longformer - len(samples[i]['offset_mapping'])))
        id_list += [samples[i]['id']] * cfg.max_len_test_longformer
        num_list += list(range(cfg.max_len_test_longformer))
        len_list += [len(samples[i]['offset_mapping'])] * cfg.max_len_test_longformer
        
    id_df['offset'] = offset_list
    id_df['id'] = id_list
    id_df['token_num'] = num_list
    id_df['token_len'] = len_list
    
    return id_df

In [None]:
def add_max_class(oof, samples, tokenizer, cfg, fold_num=0):
    
    samples_ = samples.copy()
    preds_class, preds_prob = [], []
    arg_srt = np.argsort(oof, axis=-1)
    srt = np.sort(oof, axis=-1)
    
    #argm = np.argmax(oof, axis=2)
    #m = np.max(oof, axis=2)
    
    argm = arg_srt[:, :, -1]
    m = srt[:, :, -1]
    
    argsecond = arg_srt[:, :, -2]
    second = srt[:, :, -2]
    
    argthird = arg_srt[:, :, -3]
    third = srt[:, :, -3]

    for i in tqdm(range(len(samples))):
        # 先頭0はspecial tokenなので除外
        text_class = [id_target_map[c] for c in list(argm[i, 1:])]
        text_prob = list(m[i, 1:])
        samples_[i]['pred_class'] = text_class
        samples_[i]['pred_prob'] = text_prob
        
        second_class = [id_target_map[c] for c in list(argsecond[i, 1:])]
        second_prob = list(second[i, 1:])
        samples_[i]['second_class'] = second_class
        samples_[i]['second_prob'] = second_prob
        
        third_class = [id_target_map[c] for c in list(argthird[i, 1:])]
        third_prob = list(third[i, 1:])
        samples_[i]['third_class'] = third_class
        samples_[i]['third_prob'] = third_prob
        
        
        
    #samples_all.append(samples_)

    return samples

In [None]:
import re

def word_list(text):
    a = re.split('(\s+)', text)
    return [a[i*2] + a[i*2+1] for i in range(len(a)//2)]

def discourse_offset(text, lst, start, end):
    words = ''.join(lst[start:end+1])
    return (text.find(words), text.find(words)+len(words)-1)

def add_offset(sub):
    sub['start'] = sub['predictionstring'].map(lambda x: int(x.split()[0]))
    sub['end'] = sub['predictionstring'].map(lambda x: int(x.split()[-1]))    

    sub['text'] = sub['id'].map(txt_dict)
    sub['offset'] = sub.apply(lambda x: discourse_offset(x['text'], word_list(x['text']), x['start'], x['end']), axis=1)

    sub['start'] = sub['offset'].map(lambda x: x[0])
    sub['end'] = sub['offset'].map(lambda x: x[1])
    sub['discourse_num'] = sub.groupby('id')['class'].cumcount()
    
    return sub

In [None]:
def create_sub_and_id_df(oof, id_df, test_samples):
    sample1 = add_max_class(oof, test_samples, tokenizer, cfg, fold_num=0)
    sub = create_submission(sample1)
    sub = add_offset(sub)
    
    id_df_2 = pd.DataFrame()
    for i in tqdm(range(len(test_ids) // 300 + 1)):
        a =id_df.iloc[i*300*cfg.max_len_test_longformer:(i+1)*300*cfg.max_len_test_longformer, :].merge(sub[['id', 'discourse_num', 'start', 'end']], on=['id'], how='left').reset_index(drop=True)
        a = a[(a['offset_start'] >= a['start']) & (a['offset_start'] <= a['end'])]
        id_df_2 = id_df_2.append(a)
        
    id_df_2['discourse_key'] = id_df_2['id'].astype(str) + '_' + id_df_2['discourse_num'].astype(str)
    sub['discourse_key'] = sub['id'].astype(str) + '_' + sub['discourse_num'].astype(str)
    
    return sub, id_df_2

In [None]:
def create_prob_agg_df(oof):
    value_list = [k for k in list(target_id_map.keys()) if k != 'PAD']

    dfg = oof.groupby('discourse_key')[value_list]

    dfg_stats = dfg.agg([np.mean, np.max, np.min]).stack()
    dfg_quantiles = dfg.quantile([0.2, 0.8])

    dfg_stats = dfg_stats.append(dfg_quantiles).sort_index()
    prob_agg_df = dfg_stats.unstack().astype(np.float16)

    prob_agg_df.columns = [col[0] + '_' + str(col[1]) for col in prob_agg_df.columns.values]
    return prob_agg_df

In [None]:
def create_oof(model_type):
    #short = pickle.load(open(f'pred_{model_type}_short.pkl', 'rb'))
    #long = pickle.load(open(f'pred_{model_type}_long.pkl', 'rb'))
    
    return pickle.load(open(f'pred_{model_type}_{mode}.pkl', 'rb'))

In [None]:
def create_feats(model_type, id_df, id_df_2):
    #path = '../input/fb-oof-dict/oof_dict_022709-deberta-large-boe-bin.pickle'
    oof = create_oof(model_type)

    oof = pd.DataFrame(oof.reshape(oof.shape[0] * oof.shape[1], oof.shape[2]))
    #model_name = path.split('/')[-1].replace(".pickle", "").replace("oof_dict_", "")

    oof.columns = oof.columns.map(id_target_map)

    oof['token_key'] = id_df['token_key']
    oof = oof[oof['token_key'].isin(id_df_2['token_key'])]
    oof['discourse_key'] = oof['token_key'].map(dict(zip(id_df_2['token_key'], id_df_2['discourse_key'])))

    prob_agg_df = create_prob_agg_df(oof)
    prob_agg_df.columns = [model_type + '_' + col for col in prob_agg_df.columns]
    return prob_agg_df

In [None]:
def create_sub(model_type, test_samples, id_df, model_paths):
    oof = create_oof(model_type)
    sub, id_df_2 = create_sub_and_id_df(oof, id_df, test_samples)
    sub = sub[['id', 'class', 'predictionstring', 'start', 'end']]

    #feats_list = []
    #for model_path in model_paths:
    #    feats  = create_feats(model_path, id_df_2)
    #    feats_list.append(feats)
    #feats = pd.concat(feats_list, axis=1)
    #del id_df_2
    
    #sub['text_len'] = sub['text'].map(lambda x: len(x))
    #sub = sub.merge(feats.reset_index(), on='discourse_key', how='left')
    
    return sub

In [None]:
def create_ensemble_sub_and_feats(models, test_samples, id_df, weight):
    
    oof = np.zeros([len(test_samples), cfg.max_len_test_longformer , 15])
    for model_type, w in zip(models, weight):
        oof += create_oof(model_type) * w / np.sum(weight)
    
    sub, id_df_2 = create_sub_and_id_df(oof, id_df, test_samples)

    feats_list = []
    for model_type in models:
        feats  = create_feats(model_type, id_df, id_df_2)
        feats_list.append(feats)
    feats = pd.concat(feats_list, axis=1)
    del id_df_2
    
    sub['text_len'] = sub['text'].map(lambda x: len(x))
    
    sub = sub.merge(feats.reset_index(), on='discourse_key', how='left')
    
    return sub

In [None]:
def create_result(df, bs, mode):
    test_samples = prepare_test_data(df, tokenizer, cfg.num_jobs)
    inference(df, cfg, bs, mode)
    id_df = create_id_df(test_samples)
    id_df['offset_start'] = id_df['offset'].map(lambda x: x[0])

    id_df['token_key'] = id_df['id'].astype(str) + '_' + id_df['token_num'].astype(str)
    
    all_sub = pd.DataFrame()
    for model_type in models:
        sub = create_sub(model_type, test_samples, id_df, models)
        sub['model_name'] = model_type
        all_sub = all_sub.append(sub.reset_index(drop=True))
        
    ensemble_sub = create_ensemble_sub_and_feats(models, test_samples, id_df, weight)
    
    return all_sub, ensemble_sub

In [None]:
test_df.query('token_len >= 450 and token_len < 600')

In [None]:
import gc

txt_dict = {}
for ids in tqdm(test_ids):
    txt_dict[ids] = get_test_text(ids)

models = list(set([a[4] for a in cfg.model_ckp_path]))
weight = [np.sum([w for k, _, _, _, m, w in cfg.model_ckp_path if m == model_type]) for model_type in models]    

test_df_A = test_df.query('token_len < 460')
mode = 'short'
all_sub_A, ensemble_sub_A = create_result(test_df_A, 40, 'short')
gc.collect()

test_df_B = test_df.query('token_len >= 460 and token_len < 600')
mode = 'medium'
all_sub_B, ensemble_sub_B = create_result(test_df_B, 32, 'medium')
gc.collect()

test_df_C = test_df.query('token_len >= 600')
mode = 'long'
all_sub_C, ensemble_sub_C = create_result(test_df_C, 8, 'long')
gc.collect()

all_sub = pd.concat([all_sub_A, all_sub_B, all_sub_C], axis=0)
del all_sub_A, all_sub_B, all_sub_C 
ensemble_sub = pd.concat([ensemble_sub_A, ensemble_sub_B, ensemble_sub_C], axis=0)
del ensemble_sub_A, ensemble_sub_B, ensemble_sub_C
gc.collect()

In [None]:
for col in ['predictionstring', 'start', 'end']:
    dic = (all_sub.groupby(['id', 'class', col])['model_name'].nunique() / all_sub['model_name'].nunique()).to_dict()
    all_sub[f'dupli_{col}'] = all_sub.set_index(['id', 'class', col]).index.map(dic)
    ensemble_sub[f'dupli_{col}'] = ensemble_sub.set_index(['id', 'class', col]).index.map(dic).fillna(0)

In [None]:
feats = [col for col in ensemble_sub.columns if not col in ['id', 'class', 'predictionstring', 'text', 'offset', 'fold', 'is_tp', 'discourse_key', 'model_name']]

In [None]:
all_sub

# lgbm postprocess

In [None]:
# write lgbm

model = pickle.load(open('../input/lgbm-for-postprocess/lgb_fold0.pkl', 'rb'))
feats = model.feature_name_



In [None]:
class_dict = dict(zip(['Claim', 'Concluding Statement', 'Counterclaim', 'Evidence',
       'Lead', 'Position', 'Rebuttal'], range(7)))

class_dict_inv = {v: k for k, v in class_dict.items()}

disc_prob = np.zeros(len(ensemble_sub))

ensemble_sub.columns = [col.replace('Concluding Statement','Concluding_Statement') for col in ensemble_sub.columns]

sub = ensemble_sub[['id', 'class', 'predictionstring', 'prob', 'second_prob', 'len']].reset_index(drop=True)
ensemble_sub = ensemble_sub[feats]

if 'class' in feats:
    ensemble_sub['class'] = ensemble_sub['class'].map(class_dict).fillna(7)
    
if 'second_class' in feats:
    ensemble_sub['second_class'] = ensemble_sub['second_class'].map(class_dict).fillna(7)
    
if 'third_class' in feats:
    ensemble_sub['third_class'] = ensemble_sub['third_class'].map(class_dict).fillna(7)

In [None]:
for fold in range(5):
    model = pickle.load(open(f'../input/lgbm-for-postprocess/lgb_fold{fold}.pkl', 'rb'))
    disc_prob += model.predict_proba(ensemble_sub)[:, 1] / 5
    #sub[f'disc_prob_{fold}'] = model.predict_proba(ensemble_sub)[:, 1]
    
sub['disc_prob'] = disc_prob

In [None]:
sub

In [None]:
# postprocess

cfg.proba_thresh = {
        "Lead": 0.49, # 0.7
        "Position": 0.29, # 0.55
        "Evidence": 0.54,
        "Claim": 0.44,
        "Concluding Statement": 0.49, # 0.7
        "Counterclaim": 0.48,
        "Rebuttal": 0.44,
    }
cfg.min_token_thresh = {
        "Lead": 4, # 5
        "Position": 3, # 4
        "Evidence": 11, # 14
        "Claim": 1, # 2
        "Concluding Statement": 7, # 7
        "Counterclaim": 4, # 6
        "Rebuttal": 3, # 4
    }
cfg.link = {
        'Evidence': 40,
        'Counterclaim': 200,
        'Rebuttal': 200,
    }

sub['proba_thresh'] = sub['class'].map(target_id_map)
sub = sub[((sub['disc_prob'] > 0.19) & ((sub['prob'] - sub['proba_thresh'] > 0.1) |(sub['disc_prob'] > 0.28) | ((sub['class'].isin(['Counterclaim'])) & (sub['prob'] - sub['second_prob'] > 0.25)) | (sub['class'].isin(['Rebuttal'])) ))][['id', 'class', 'predictionstring', 'prob', 'len']]

In [None]:
sub = post_process_sub(sub[['id', 'class', 'predictionstring', 'prob', 'len']]).drop_duplicates()

In [None]:
sub

In [None]:
len(feats)

# submission

In [None]:
%%time
sub = sub[['id', 'class', 'predictionstring']]
sub.to_csv('submission.csv', index=False)

# Visualize

In [None]:
class Color:
    BLACK = '\033[30m'
    Lead = '\033[31m'
    Position = '\033[32m'
    Claim = '\033[33m'
    Counterclaim = '\033[34m'
    Rebuttal = '\033[35m'
    Evidence = '\033[36m'
    ConcludingStatement = '\033[37m'
    END = '\033[0m'
    BOLD = '\038[1m'
    UNDERLINE = '\033[4m'
    INVISIBLE = '\033[08m'
    REVERCE = '\033[07m'


def visualize(sub, doc_id):
    word_type = [None] * 100000
    df = sub.query(f' id == "{doc_id}" ')
    for i in range(df.shape[0]):
        row = df.iloc[i, :]
        _cls, _pred = row['class'], row['predictionstring']
        _pred_list = _pred.split(' ')
        for j in _pred_list:
            word_type[int(j)] = _cls
    
    text = get_test_text(doc_id)
    word_list = text.split()
    
    for i, word in enumerate(word_list):
        if word_type[i] == 'Lead':
            print(f'{Color.Lead}{word}{Color.END}', end=' ')
        elif word_type[i] == 'Position':
            print(f'{Color.Position}{word}{Color.END}', end=' ')
        elif word_type[i] == 'Claim':
            print(f'{Color.Claim}{word}{Color.END}', end=' ')
        elif word_type[i] == 'Counterclaim':
            print(f'{Color.Counterclaim}{word}{Color.END}', end=' ')
        elif word_type[i] == 'Rebuttal':
            print(f'{Color.Rebuttal}{word}{Color.END}', end=' ')
        elif word_type[i] == 'Evidence':
            print(f'{Color.Evidence}{word}{Color.END}', end=' ')
        elif word_type[i] == 'Concluding Statement':
            print(f'{Color.ConcludingStatement}{word}{Color.END}', end=' ')
        else:
            print(word, end=' ')

In [None]:
visualize(sub, '18409261F5C2')

In [None]:
visualize(sub, 'D46BCB48440A')

In [None]:
visualize(sub, '0FB0700DAF44')

In [None]:
visualize(sub, 'D72CB1C11673')

In [None]:
visualize(sub, 'DF920E0A7337')