In [None]:
import gc
import glob
import yaml
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl

from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer, AutoModel, AutoConfig

%env TOKENIZERS_PARALLELISM = true

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
IS_DEBUG = False

if IS_DEBUG:
    TEST_ROOT = "../input/feedback-prize-2021/train/"
    MAX_NUM_MODELS = 1
    NUM_SAMPLES = 200
else:
    TEST_ROOT = "../input/feedback-prize-2021/test/"
#     MAX_NUM_MODELS = None
    MAX_NUM_MODELS = 1
    NUM_SAMPLES = None
    
print(f"TEST_ROOT: {TEST_ROOT}\nMAX_NUM_MODELS: {MAX_NUM_MODELS}\nNUM_SAMPLES: {NUM_SAMPLES}")

# Crodoc

In [None]:
import re
import pandas as pd

# =======================================================================================================#
def from_text_to_span(preds,b,i,tolb='',otol=0):
    o = '{'
    c = '}'
    
    if str(otol)!="0":
        pat = fr'({b}{o}1,{tolb}{c}{i}{o}1,{c}([^{i}{b}]{o}1,{otol}{c}{i}{o}1,{c})+)|({i}{o}1,{c}([^{i}{b}]{o}1,{otol}{c}{i}{o}1,{c})+)|({b}{o}1,{tolb}{c}{i}{o}1,{c})|({i}{o}1,{c})|({b}{o}1,{tolb}{c})'
    else:
        pat = fr'({b}{o}1,{tolb}{c}{i}{o}1,{c})|({i}{o}1,{c})|({b}{o}1,{tolb}{c})'
    
    matches = re.findall(pat,preds)
    output = []
    if len(matches)>0:
        for match in re.finditer(pat, preds):
#             s,e = match.span()
            output.append(match.span())
            
    return output


# =======================================================================================================#
def from_token_proba_to_span_df(text,idx,token_proba,dicos,offset,tolb='',otol=0):
    ID_TO_TXT = {i:j for i,j in zip(range(15),'abcdefghijklmnopqrstuvwxyz')}
    token_class = token_proba.argmax(-1)
    token_score = token_proba.max(-1)
    
    indices = np.unique(np.where(np.asarray(offset)!=(0,0))[0])
    offset = np.asarray(offset)[indices].tolist()
    token_class = token_class[indices]
    
    token_class_text = [ID_TO_TXT[i] for i in token_class]


    output = []
    for c,(b,i) in dicos.items():
        spans = from_text_to_span(''.join(token_class_text),ID_TO_TXT[b],ID_TO_TXT[i],tolb=tolb,otol=otol)
        if len(spans)>0:
            for s,e in spans:
                
                start,_ = offset[s]
                _,end = offset[e-1]
                if (start+end)>0:
                    word_start = len(text[:start].split())
                    word_end = word_start + len(text[start:end].split())
                    word_end = min(word_end, len(text.split()))
                    ps = " ".join([str(x) for x in range(word_start, word_end)])

                    output.append([idx,c,start,end,str(token_score[s:e].tolist()),ps])
    

    output = pd.DataFrame(output, columns=['id','class','discourse_start',
                                               'discourse_end','score','predictionstring'])
    
    output['discourse_type'] = output['class']
    output[["start","end"]] = output[["discourse_start","discourse_end"]]
    return output
# =======================================================================================================#
def run_token_to_span_ac(tokens_probas,ids,test_df,offsets,dicos,tolb='',otol=0):
    
    submit_df = []
    for i in range(len(ids)):
        txt = test_df[test_df.id==ids[i]].text.iloc[0]
        submit_df.append(from_token_proba_to_span_df(txt,ids[i],tokens_probas[i],dicos,offsets[i],tolb,otol))
    
    return pd.concat(submit_df)

In [None]:
class TextDataset(Dataset):
    def __init__(self, df, tokenizer, cfg):

        self.tokenizer = tokenizer
        self.max_length = 4096
        
        self.texts = df['text'].values.tolist()
        self.ids = df['id'].values.tolist()

        self.x, self.offset_mappings = [], []

        for text in self.texts:
            x, offset_mapping = self.make_item(text)
            self.x.append(x)
            self.offset_mappings.append(offset_mapping)

    def get_offset_mapping(self, text):

        tokenized = self.tokenizer(
            text,
            add_special_tokens = True,
            max_length = self.max_length,
            truncation=True,
            return_offsets_mapping = True,
        )

        offset_mapping = tokenized['offset_mapping']
        skip_indices = np.where(np.array(tokenized.sequence_ids()) != 0)[0]

        return offset_mapping, skip_indices

    def make_item(self, text):

        tokenized = self.tokenizer(
            text,
            add_special_tokens = True,
            max_length = self.max_length,
            truncation=True,
            return_offsets_mapping = False,
        )

        offset_mapping, _ = self.get_offset_mapping(text)

        for k, v in tokenized.items():
            tokenized[k] = torch.tensor(v, dtype=torch.long)

        return tokenized, offset_mapping

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.x[idx]

class CustomCollator():

    def __init__(self, tokenizer):
        self.data_collator = DataCollatorWithPadding(tokenizer)

    def __call__(self, batch):
        text = []
        for item in batch:
            text.append(item)

        text = self.data_collator(text)
        return text


class TextDataModule(pl.LightningDataModule):
    def __init__(
        self,
        test_df,
        tokenizer,
        cfg,
        test_dataset
    ):
        super().__init__()
        self.test_df = test_df
        self.tokenizer = tokenizer
        self.cfg = cfg
        
        self.test_dataset = test_dataset

    def setup(self, stage):
        pass

    def predict_dataloader(self):
        custom_collator = CustomCollator(self.tokenizer)
        return DataLoader(self.test_dataset, **self.cfg["val_loader"], collate_fn=custom_collator)

In [None]:
class CutTextDataset(Dataset):
    def __init__(self, df, tokenizer, cfg):

        self.tokenizer = tokenizer
        #self.max_length = cfg['max_length']
        self.max_length = cfg['max_length_valid']

        self.texts = df['text'].values.tolist()
        self.ids = df['id'].values.tolist()
        self.stride = cfg['stride']
        
        self.x, self.x_cut, self.offset_mappings, self.text_indexes = [], [], [], []
        
        text_index = 0

        for text in self.texts:
            x, offset_mapping = self.make_item(text)

            self.x.append(x)
            self.offset_mappings.append(offset_mapping)

            start = 0
            total_tokens = len(offset_mapping)

            break_bool = False

            while start < total_tokens and not break_bool:

                if start + self.max_length > total_tokens:
                    start = max(0, total_tokens - self.max_length)
                    break_bool = True

                x_cut, _ = self.get_cut_item(x, offset_mapping, start)

                self.x_cut.append(x_cut)
                self.text_indexes.append((text_index, start))

                start += self.stride

            text_index += 1

    def get_cut_element(self, tokenized_element, start, length, is_list=False):

        new_tokenized_element = tokenized_element[start:start+length]
        if not is_list:
            new_tokenized_element = new_tokenized_element.clone()

        #new_tokenized_element[0] = tokenized_element[0]
        #new_tokenized_element[-1] = tokenized_element[-1]

        return new_tokenized_element

    def get_cut_item(self, tokenized, offset_mapping, start):

        cut_length = min(self.max_length, len(offset_mapping))

        new_tokenized = {}

        for k in tokenized:
            new_tokenized[k] = self.get_cut_element(tokenized[k], start, cut_length)

        if offset_mapping is not None:
            offset_mapping = self.get_cut_element(offset_mapping, start, cut_length, is_list=True)

        return new_tokenized, offset_mapping

    def make_item(self, text):

        tokenized = self.tokenizer(
            text,
            add_special_tokens = True,
            return_offsets_mapping = True,
        )

        offset_mapping = tokenized['offset_mapping']
        del tokenized['offset_mapping']

        for k, v in tokenized.items():
            tokenized[k] = torch.tensor(v, dtype=torch.long)

        return tokenized, offset_mapping

    def __len__(self):
        return len(self.x_cut)

    def __getitem__(self, idx):
        return self.x_cut[idx]

In [None]:
class TextModel(pl.LightningModule):

    def __init__(self, cfg, config_path=None):
        super().__init__()

        self.cfg = cfg
        model_cfg = cfg['model']
        self.num_labels = model_cfg['num_labels']

        self.config = torch.load(config_path)
        self.backbone = AutoModel.from_config(self.config)

        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.2)
        self.dropout3 = nn.Dropout(0.3)
        self.dropout4 = nn.Dropout(0.4)
        self.dropout5 = nn.Dropout(0.5)

        self.fc = nn.Linear(self.config.hidden_size, self.num_labels)

    def forward(self, x):

        x = self.backbone(**x).last_hidden_state
        x = self.dropout(x)

        x1 = self.dropout1(x)
        x2 = self.dropout2(x)
        x3 = self.dropout3(x)
        x4 = self.dropout4(x)
        x5 = self.dropout5(x)

        x = (x1+x2+x3+x4+x5) / 5.0

        x = self.fc(x)

        return x

    def predict_step(self, batch, batch_idx):

            output = self(batch)
            pred = output.softmax(dim=-1).detach().cpu()

            return pred

    def configure_optimizers(self):
        return None

In [None]:
def get_test_df():
#     test_files = glob.glob('../input/feedback-prize-2021/test/*.txt')
    test_files = glob.glob(str(TEST_ROOT) + '*.txt')[:NUM_SAMPLES]
    test_ids = [test_file.split('/')[-1][:-4] for test_file in test_files]
    
    test_texts = []
    
    for test_file, test_id in zip(test_files, test_ids):
        
        with open(test_file, 'r') as f:
            text = f.read()

        # no-break space
        text = text.replace(u'\xa0', u' ')
        # next line
        text = text.replace(u'\x85', u'\n')
        
        test_texts.append(text)
    
    values = list(zip(test_ids, test_texts))
    values.sort(key=lambda x: -len(x[1]))
    
    test_df = pd.DataFrame(values, columns=['id','text'])
    return test_df

In [None]:
def text_to_words(text):
    word = text.split()
    word_offset = []

    start = 0
    for w in word:
        r = text[start:].find(w)

        if r==-1:
            raise NotImplementedError
        else:
            start = start+r
            end   = start+len(w)
            word_offset.append((start,end))
        start = end

    return word, word_offset

def word_probability_to_predict_df(text_to_word_probability, id):
    
    len_word = len(text_to_word_probability)
    word_predict = text_to_word_probability.argmax(-1)
    word_score   = text_to_word_probability.max(-1)
    predict_df = []

    t = 0
    while 1:
        if word_predict[t] not in [
            discourse_marker_to_label['O'],
        ]:
            start = t
            b_marker_label = word_predict[t]
        else:
            t = t+1
            if t== len_word-1:break
            continue

        t = t+1
        if t== len_word-1: break

        if   label_to_discourse_marker[b_marker_label][0]=='B':
            i_marker_label = b_marker_label+1
        else:
            i_marker_label = b_marker_label

        while 1:
            if (word_predict[t] != i_marker_label) or (t ==len_word-1):
                end = t
                prediction_string = ' '.join([str(i) for i in range(start,end)])
                discourse_type = label_to_discourse_marker[b_marker_label][2:]
                discourse_score = word_score[start:end].tolist()
                predict_df.append((id, discourse_type, prediction_string, str(discourse_score)))
                break
            else:
                t = t+1
                continue
        if t== len_word-1: break

    predict_df = pd.DataFrame(predict_df, columns=['id', 'class', 'predictionstring', 'score'])
    return predict_df


# def do_threshold(submit_df, use=['length','probability']):
#     df = submit_df.copy()
#     df = df.fillna('')

#     if 'length' in use:
#         df['l'] = df.predictionstring.apply(lambda x: len(x.split()))
#         for key, value in min_thresh.items():
#             #value=3
#             index = df.loc[df['class'] == key].query('l<%d'%value).index
#             df.drop(index, inplace=True)

#     if 'probability' in use:
#         df['s'] = df.score.apply(lambda x: np.mean(eval(x)))
#         for key, value in proba_thresh.items():
#             index = df.loc[df['class'] == key].query('s<%f'%value).index
#             df.drop(index, inplace=True)

#     df = df[['id', 'class', 'predictionstring']]
#     return df


def do_threshold(submit_df, use=['length','probability']):
    df = submit_df.copy()
    df = df.fillna('')

    if 'length' in use:
        df['l'] = df.predictionstring.apply(lambda x: len(x.split()))
        for key, value in min_thresh.items():
            #value=3
            index = df.loc[df['class'] == key].query('l<%d'%value).index
            df.drop(index, inplace=True)

    if 'probability' in use:
        df['score'] = df.score.apply(lambda x: np.mean(eval(x)))
        for key, value in proba_thresh.items():
            index = df.loc[df['class'] == key].query('score<%f'%value).index
            df.drop(index, inplace=True)
    
    df.rename(columns={"l": "num_tokens"}, inplace=True)
    df["start"] = df.predictionstring.apply(lambda x: int(x.split()[0]))
    df["end"] = df.predictionstring.apply(lambda x: int(x.split()[-1]) + 1)
    df = df[['id', 'class', 'predictionstring', "num_tokens", "score", "start", "end"]]
    return df

In [None]:
def reset_crodoc():
    global text_words, text_word_offsets, text_word_preds, text_ids, text_lenghts, test_df, num_labels
    
    test_df = get_test_df()

    num_labels = 10

    text_words, text_word_offsets, text_word_preds, text_ids, text_lenghts = [], [], [], [], []

    for idx in range(len(test_df)):
        row = test_df.iloc[idx]
        text_ids.append(row.id)
        text_lenghts.append(len(row.text))

        row_words, row_word_offsets = text_to_words(row.text)
        text_words.append(row_words)
        text_word_offsets.append(row_word_offsets)

        word_preds = np.full((len(row_words),num_labels),0, np.float32)
        text_word_preds.append(word_preds)

In [None]:
def update_word_preds(model_preds, offset_mappings, coef):
    idx = 0
    
    for idx, row_preds in enumerate(model_preds):
            
        character_preds = np.full((text_lenghts[idx],num_labels),0, np.float32)

        for pos,(start,end) in enumerate(offset_mappings[idx]):
            character_preds[start:end] = row_preds[pos] * coef
            
        for pos,(start,end) in enumerate(text_word_offsets[idx]):
            text_word_preds[idx][pos] += character_preds[start:end].mean(0)

In [None]:
def merge_cut_preds(model_preds, dataset):

    dataset_length = len(set(dataset.texts))
    
    index = 0
    preds_tmp = []
    text_indexes = dataset.text_indexes

    overlap = dataset.stride // 2

    while index < len(model_preds):

        text_index, _ = text_indexes[index]
        offset_mapping = dataset.offset_mappings[text_index]

        preds = np.zeros((len(offset_mapping), 10))

        while index < len(model_preds):
            curr_text_index, start = text_indexes[index]

            if curr_text_index != text_index:
                break

            curr_preds = model_preds[index]

            if start == 0:
                length = min(len(preds), len(curr_preds))
                preds[:length] = curr_preds[:length]
            elif start + len(curr_preds) > len(offset_mapping):
                preds[-len(curr_preds)+overlap:] = curr_preds[overlap:]
            else:
                preds[start+overlap:start+len(curr_preds)] = curr_preds[overlap:]

            index += 1

        preds_tmp.append(preds)

    return preds_tmp

In [None]:
discourse_marker_to_label = {
    'O': 0,
    'B-Claim': 1,
    'I-Claim': 2,
    'B-Evidence': 3,
    'I-Evidence': 4,
    'X-Lead': 5,
    'X-Position': 6,
    'X-Counterclaim': 7,
    'X-Rebuttal': 8,
    'X-Concluding Statement': 9,
}

# min_thresh = {
#     'Lead': 6,
#     'Position': 4,
#     'Evidence': 16,
#     'Claim': 2,
#     'Concluding Statement': 11,
#     'Counterclaim': 7,
#     'Rebuttal': 6,
# }

# proba_thresh = {
#     "Lead": 0.7,
#     "Position": 0.6,
#     "Evidence": 0.65,
#     "Claim": 0.55,
#     "Concluding Statement": 0.7,
#     "Counterclaim": 0.6,
#     "Rebuttal": 0.6,
# }
min_thresh = {
    'Lead': 5,
    'Position': 3,
    'Evidence': 10,
    'Claim': 2,
    'Concluding Statement': 5,
    'Counterclaim': 6,
    'Rebuttal': 5,
}

proba_thresh = {
    'Lead': 0.55,
    'Position': 0.55,
    'Evidence': 0.55,
    'Claim': 0.56,
    'Concluding Statement': 0.55,
    'Counterclaim': 0.56,
    'Rebuttal': 0.57,
}
# B / I
DICOS = {
        "Lead":[5,5],
        "Position":[6,6],
        'Claim':[1,2],
        'Counterclaim':[7,7],
        'Rebuttal':[8,8],
        'Evidence':[3,4],
        'Concluding Statement':[9,9]
  
        }

label_to_discourse_marker = {v: k for k, v in discourse_marker_to_label.items()}

def get_sub_crodoc(token_preds,ids,test_df,offset_mappings):
#     sub_crodoc = []

#     for idx, row_word_preds in enumerate(text_word_preds):
#         sub_crodoc.append(word_probability_to_predict_df(row_word_preds, text_ids[idx]))
        
    sub_crodoc = run_token_to_span_ac(token_preds,ids,test_df,offset_mappings,DICOS,tolb='',otol=2).reset_index(drop=True) 
    
#     sub_crodoc = pd.concat(sub_crodoc).reset_index(drop=True) 
    sub_crodoc = do_threshold(sub_crodoc, use=['length', 'probability'])
    
    return sub_crodoc

In [None]:
model_names = ['cp-deberta-xlarge-v2', 'deberta-bs2']
# model_coeffs = [1.0, 1.0]
model_weights_crodoc = [0.60, 0.40]

model_start_ends = [(0, 3), (2, 5)]

subs_crodoc = []


for model_name, (start, end) in zip(model_names, model_start_ends):
    
    reset_crodoc()
    
    with open('../input/' + model_name + '/hparams.yml', 'r') as f:
        cfg = yaml.safe_load(f)

    cfg['val_loader']['num_workers'] = 2
    cfg['val_loader']['batch_size'] = 8
    
    #if 'deberta-large' in model_name:
        #cfg['val_loader']['batch_size'] *= 3
    
    config_path = '../input/' + model_name + '/config.pth'    
#     model_paths = glob.glob('../input/' + model_name + '/*.ckpt')[:MAX_NUM_MODELS]
    model_paths = glob.glob('../input/' + model_name + '/*.ckpt')[start:end]
    
    folds = len(model_paths)
    model_preds = []
    
    print(f"{model_name}: {folds}")
    
    tokenizer = AutoTokenizer.from_pretrained('../input/' + model_name + '/tokenizer/tokenizer')
    
    if 'stride' in cfg and cfg['stride'] > 0:
        test_dataset = CutTextDataset(test_df, tokenizer, cfg)
    else:
        test_dataset = TextDataset(test_df, tokenizer, cfg)
    
    for model_path in model_paths:
        
        datamodule = TextDataModule(test_df, tokenizer, cfg, test_dataset)
        trainer = pl.Trainer(logger=False, **cfg['trainer'])
        model = TextModel.load_from_checkpoint(checkpoint_path=model_path, cfg=cfg, config_path=config_path)
        
        fold_preds = trainer.predict(model, datamodule)
        
        if not model_preds:
            for pred_batch in fold_preds:            
                for pred in pred_batch:
                    model_preds.append(pred.numpy().copy()/folds)
        else:
            idx = 0
            for pred_batch in fold_preds:            
                for pred in pred_batch:
                    model_preds[idx] += pred.numpy().copy()/folds
                    idx += 1
        
        del fold_preds
        del trainer
        del model
        del datamodule
    
        gc.collect()
        torch.cuda.empty_cache()
    
    if 'stride' in cfg and cfg['stride'] > 0:
        model_preds = merge_cut_preds(model_preds, test_dataset)
            
    subs_crodoc.append(get_sub_crodoc(model_preds,text_ids,test_df,test_dataset.offset_mappings))
        
#     update_word_preds(model_preds, test_dataset.offset_mappings, model_coef / folds)
#     update_word_preds(model_preds, test_dataset.offset_mappings, 1 / folds)

    del test_dataset
    del tokenizer
    del model_preds
    
    gc.collect()
    torch.cuda.empty_cache()
    
    
    
    
    print(subs_crodoc[-1].shape)

In [None]:
if not IS_DEBUG:
    del text_words, text_word_offsets, text_word_preds, text_ids, text_lenghts 
    gc.collect()
    torch.cuda.empty_cache()

# Kkiller

In [None]:
import sys, os
sys.path.insert(0, "../input/fprize-kkiller-tools/fprize")
sys.path.insert(0, "../input/weighted-boxes-fusion")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
import pandas as pd, numpy as np
import torch

from tqdm.auto import tqdm
from pathlib import Path
from  datetime import datetime

import mtask_v2.src.inference as inference
import mtask_v2.src.dataset as dataset
import mtask_v2.src.configs as configs

from mtask_v2.src.dataset import read_from_id, read_train_df
from mtask_v2.src.post_processing import get_seg_from_ner
from mtask_v2.src.wbf import fusion_boxes_for_subs

In [None]:
# q_crodoc = 0.50
# iou_thr = 0.333
# skip_box_thr = 0.1
# # out_skip_box_thr = 0.25

In [None]:
for sub_crodoc in subs_crodoc:
    sub_crodoc["class_id"] = sub_crodoc["class"].map(configs.Discourse2ID)

# sub_crodoc = fusion_boxes_for_subs(subs_crodoc, model_weights_crodoc, iou_thr=iou_thr, skip_box_thr=skip_box_thr)

# if not IS_DEBUG:
#     del subs_crodoc
#     gc.collect()
    
# print(sub_crodoc.shape)
# sub_crodoc = sub_crodoc.query(f"score >= {out_skip_box_thr}")
# print(sub_crodoc.shape)
# sub_crodoc.head()

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# TAKE_SAMPLE = (datetime.utcnow() <  datetime(year=2022, month=2, day=7, hour=18, minute=33))
# print("TAKE_SAMPLE:", TAKE_SAMPLE)
print(DEVICE)

In [None]:
# TEST_ROOT = Path("../input/feedback-prize-2021/test")
TEST_ROOT = Path(TEST_ROOT)

configs.TRAIN_ROOT = TEST_ROOT

# uuids = [file.stem for file in TEST_ROOT.glob("*.txt")]
uuids = sub_crodoc["id"].unique()
               
# uuids = sorted(set(df["id"][df["fold_abishek"] == 0].values))

uuids = sorted(uuids, key=lambda uuid: -len(read_from_id(uuid, root=TEST_ROOT).split()))

# uuids = uuids[:200]

print(len(uuids))
uuids[:10]

In [None]:
params = [
    
    inference.get_params(
        model_name="microsoft/deberta-xlarge",
        batch_size=14,
        maxlen=1024,
        stride=1024,
        num_workers=2,
        weight=.60,
        config_path="../input/fprize-kkiller-tools/microsoft_deberta-xlarge",
        tokenizer_path="../input/fprize-kkiller-tools/microsoft_deberta-xlarge",
        is_pickle=False,
        device=DEVICE,
#         model_paths=sorted(Path("../input/cp-deberta-xlarge-kkiller").glob("*.pth"))[:MAX_NUM_MODELS],
        model_paths=sorted(Path("../input/cp-deberta-xlarge-kkiller").glob("*.pth"))[:5],
        root=TEST_ROOT,
        use_position_embeddings=False,
    ),
    
    
    inference.get_params(
        model_name="microsoft/deberta-large",
        batch_size=24,
        maxlen=1024,
        stride=1024,
        num_workers=2,
        weight=.40,
        config_path="../input/fprize-kkiller-tools/microsoft_deberta-large",
        tokenizer_path="../input/fprize-kkiller-tools/microsoft_deberta-large",
        is_pickle=False,
        device=DEVICE,
#         model_paths=sorted(Path("../input/gdrive-db1l-1024-v2-v11-no-pe-weights/microsoft_deberta-large_maxlen1024_clb_mtask_msd_v2_v11_no_pe/"
#                                ).glob("*.pth"))[:MAX_NUM_MODELS],
        model_paths=sorted(Path("../input/gdrive-db1l-1024-v2-v11-no-pe-weights/microsoft_deberta-large_maxlen1024_clb_mtask_msd_v2_v11_no_pe/"
                               ).glob("*.pth"))[3:],
        root=TEST_ROOT,
        use_position_embeddings=False,
    ),
    
]

S = sum([param["weight"] for param in params])
assert abs(S- 1.0) < 1e-6
params[0]

In [None]:
subs_kkiller = []
model_weights_kkiller = []

for param in params:
    print("{}: {}".format(Path(param["model_paths"][0]).stem,  len(param["model_paths"])))
    preds, preds_seg  = inference.predict_from_param(uuids=uuids, param=param, make_sub=False, oof=False, model_bar=False)
    
    preds_seg = 0.60 * preds_seg + 0.40 * get_seg_from_ner(preds)
    
    subs_kkiller.append(
        inference.make_sub_from_res(uuids=uuids, res=preds, res_seg=preds_seg, q=0.015, threshs=None)
    )
    
    
    model_weights_kkiller.append(param["weight"])
    
    print(subs_kkiller[-1].shape)
    
print("model_weights_kkiller:", model_weights_kkiller)

# sub_kkiller = fusion_boxes_for_subs(subs_kkiller, model_weights_kkiller, iou_thr=iou_thr, skip_box_thr=skip_box_thr)

if not IS_DEBUG:
#     del subs_kkiller, preds, preds_seg
    del preds, preds_seg
    gc.collect()
    
# print(sub_kkiller.shape)
# sub_kkiller = sub_kkiller.query(f"score >= {out_skip_box_thr}")
# print(sub_kkiller.shape)
# sub_kkiller.head()

In [None]:
# sub_kkiller.shape, sub_crodoc.shape

In [None]:
# sub_crodoc["class_id"] = sub_crodoc["class"].map(configs.Discourse2ID)
# sub_crodoc.head()

# Box Fusion

In [None]:
min_thresh_for_wbf = {
    'Lead': 3,
    'Position': 4,
    'Evidence': 4,
    'Claim': 2,
    'Concluding Statement': 9,
    'Counterclaim': 5,
    'Rebuttal': 2,
}


proba_thresh_for_wbf = {
    "Lead": 0.27,
    "Position": 0.28,
    "Evidence": 0.39,
    "Claim": 0.30,
    "Concluding Statement": 0.36,
    "Counterclaim": 0.21,
    "Rebuttal": 0.20,
}

In [None]:
def do_threshold(submit_df,min_thresh,proba_thresh,use=['length','probability']):
    df = submit_df.copy()
    df = df.fillna('')

    if 'length' in use:
        df['l'] = df.predictionstring.apply(lambda x: len(x.split()))
        for key, value in min_thresh.items():
            index = df.loc[df['class'] == key].query('l<%d'%value).index
            df.drop(index, inplace=True)

    if 'probability' in use:
        for key, value in proba_thresh.items():
            index = df.loc[df['class'] == key].query('score<%f'%value).index
            df.drop(index, inplace=True)
    return df

In [None]:
# subs = [sub_crodoc, sub_kkiller]
subs = subs_kkiller+subs_crodoc
weights = [1/4]*4
iou_thr = 0.3333
skip_box_thr = 0.01
sub = fusion_boxes_for_subs(subs, weights, iou_thr=iou_thr, skip_box_thr=skip_box_thr)

sub = do_threshold(sub.reset_index(drop=True),min_thresh_for_wbf,proba_thresh_for_wbf,use=['probability'])

In [None]:
# print(sub.shape)
# sub = sub.query(f"score >= {out_skip_box_thr}")
# print(sub.shape)

sub[["id", "class", "predictionstring"]].to_csv("submission.csv", index=False)

sub.head(30)