In [None]:
import sys
sys.path.insert(0, '../input/feedbackdebertav2tokenizer')
import os
import torch
import torch as t
import random
torch.autograd.set_grad_enabled(False)
from tqdm import tqdm
from glob import glob
import pandas as pd
import numpy as np
import re
import pickle
from transformers import (DebertaV2Model, DebertaV2TokenizerFast,
                          RobertaTokenizerFast, DebertaModel, DebertaV2Config, AutoModel)
from tqdm.notebook import tqdm
from torch.nn import functional as F
from time import sleep
import gc

In [None]:
average_folds_logits = True
add_models_logits = False
token_len_filters = [ 6,  0, 15,  0,  7,  7,  7]
score_filters = [3.12230473, 3.16391113, 4.88337086, 2.95137306, 2.62062365,
          1.80732187, 1.85651329]
exts = [0, 1, 0, 2, 3, 4, 5, 6, 5, 7, 8, 8, 8, 8, 3]

In [None]:
class Debertav3(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.feats = DebertaV2Model.from_pretrained('../input/deberta-v3-large/deberta-v3-large')
        self.feats.pooler = None
        self.feats.train()

        self.conv1d_layer1 = torch.nn.Conv1d(1024, 1024, kernel_size=1)
        self.conv1d_layer3 = torch.nn.Conv1d(1024, 1024, kernel_size=3, padding=1)
        self.conv1d_layer5 = torch.nn.Conv1d(1024, 1024, kernel_size=5, padding=2)

        self.class_projector = torch.nn.Sequential(
            torch.nn.LayerNorm(1024*3),
            torch.nn.Linear(1024*3, 15)
        )
        
    def forward(self, tokens, mask):
        transformer_output = self.feats(tokens, mask, return_dict=False)[0]
        conv_input = transformer_output.transpose(1, 2) # batch, hidden, seq

        conv_output1 = F.relu(self.conv1d_layer1(conv_input)) 
        conv_output3 = F.relu(self.conv1d_layer3(conv_input)) 
        conv_output5 = F.relu(self.conv1d_layer5(conv_input)) 

        concat_output = torch.cat((conv_output1, conv_output3, conv_output5), dim=1).transpose(1, 2)

        output = self.class_projector(concat_output)
        return output
    
class Debertav1Large(t.nn.Module):
    def __init__(self):
        super().__init__()
        self.feats = DebertaModel.from_pretrained(
            '../input/feedbackdebertav3dirgy/debertav1_pretrained_model')
        self.feats.pooler = None
        self.class_projector = t.nn.Sequential(
            t.nn.LayerNorm(1024),
            t.nn.Linear(1024, 15)
        )
    def forward(self, tokens, mask):
        return self.class_projector(self.feats(tokens, mask, return_dict=False)[0])
                                    
class Debertav1XLarge(t.nn.Module):
    def __init__(self):
        super().__init__()
        self.feats = DebertaModel.from_pretrained(
            '../input/debertav1xlarge/deberta_xlarge')
        self.feats.pooler = None
        self.class_projector = t.nn.Sequential(
            t.nn.LayerNorm(1024),
            t.nn.Linear(1024, 15)
        )
    def forward(self, tokens, mask):
        return self.class_projector(self.feats(tokens, mask, return_dict=False)[0])

class Debertav2(t.nn.Module):
    def __init__(self):
        super().__init__()
        config = DebertaV2Config.from_pretrained(
            '../input/feedbackdebertav2stuff/model_config/config.json')
        self.feats = AutoModel.from_config(config)
        self.feats.pooler = None
        self.class_projector = t.nn.Sequential(
            t.nn.LayerNorm(1536),
            t.nn.Linear(1536, 15)
        )
    def forward(self, tokens, mask):
        return self.class_projector(self.feats(tokens, mask, return_dict=False)[0])

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, all_tokens, all_masks, all_bounds, all_index_maps, 
                 sample_ids, sorted_index, fix_cls_token):
        self.fix_cls_token = fix_cls_token
        self.all_tokens = all_tokens
        self.all_masks = all_masks
        self.all_bounds = all_bounds
        self.all_index_maps = all_index_maps
        self.sample_ids = sample_ids
        self.sorted_index = sorted_index
    
    def __len__(self):
        return len(self.sorted_index)
    
    def __getitem__(self, ix):
        tokens = np.zeros(2048, 'i8')
        mask = np.zeros(2048, 'f4')
        offsets = np.zeros((2048, 2), 'i4')
        x = self.sorted_index[ix]
        index_map = self.all_index_maps[x]
        key = self.sample_ids[x]
        num_tokens = len(self.all_tokens[x])
        tokens[:num_tokens] = self.all_tokens[x]
        mask[:num_tokens] = self.all_masks[x]
        offsets[:num_tokens] = self.all_bounds[x]
        if self.fix_cls_token:
            tokens[0] = 1
        return tokens, mask, offsets, num_tokens, index_map, key
    
def collate_fn(ins):
    max_len = (max(x[3] for x in ins) + 7) // 8 * 8
    batch = tuple(torch.from_numpy(np.concatenate(
                        [ins[z][x][None, :max_len] for z in range(len(ins))]))
                     for x in range(3))
    extras = ([x[-2] for x in ins], [x[-1] for x in ins])
    
    return batch, tuple(x[3] for x in ins), extras

In [None]:
def make_dataset(version, batch_size, fix_cls_token=False):
    filenames = glob('../input/feedback-prize-2021/test/*.txt')
    if version == 1:
        tokenizer = RobertaTokenizerFast.from_pretrained(
            '../input/feedbackdebertav3dirgy/roberta_tokenizer')
        texts = [open(x).read().strip() for x in filenames]
    else:
        if version == 2:
            tokenizer = DebertaV2TokenizerFast.from_pretrained(
                '../input/feedbackdebertav2stuff/tokenizer')
        else:
            tokenizer = DebertaV2TokenizerFast.from_pretrained(
                '../input/deberta-v3-large/deberta-v3-large')
        texts = [open(x).read().strip().replace('\n', '‽') for x in filenames]
    tokenizer.model_max_length = 2048
    all_tokenizer_outs = tokenizer(texts, return_offsets_mapping=True)
    all_tokens = [all_tokenizer_outs[ix].ids for ix in range(len(texts))]
    all_bounds = [all_tokenizer_outs[ix].offsets for ix in range(len(texts))]
    all_masks = [all_tokenizer_outs[ix].attention_mask for ix in range(len(texts))]
    del all_tokenizer_outs
    
    if version != 1:
        if version == 2:
            all_tokens = [[x if x != 126599 else 128000
                                       for x in sample_tokens] for sample_tokens in all_tokens]
        elif version == 3:
            all_tokens = [[x if x != 126861 else 128000
                                       for x in sample_tokens] for sample_tokens in all_tokens]
        inverse_vocab = {y: x for x, y in tokenizer.vocab.items()}
        for sample_ix in range(len(texts)):
            num_tokens = len(all_tokens[sample_ix])
            offset_mappings = [(0, 0)]
            default_offset_mappings = all_bounds[sample_ix]
            tokens = all_tokens[sample_ix]
            for ix in range(1, num_tokens - 1):
                a, b = default_offset_mappings[ix]
                token = inverse_vocab[tokens[ix]]
                if len(token) > 1 and token[0] == '▁' and ix != 1:
                    a += 1
                offset_mappings.append((a, b))
            offset_mappings.append((0,0))
            all_bounds[sample_ix] = offset_mappings
    all_index_maps = []
    space_regex = re.compile('[\s\n‽]')
    for text in texts:
        index_map = []
        current_word = 0
        blank = False
        for char_ix in range(len(text)):
            if space_regex.match(text[char_ix]) is not None:
                blank = True
            elif blank:
                current_word += 1
                blank = False
            index_map.append(current_word)
        all_index_maps.append(index_map)
    sorted_index = sorted(range(len(texts)), key=lambda x: len(all_tokens[x]), 
                              reverse=True)
    sample_ids = [x.split('/')[-1].split('.')[0] for x in filenames]
    return t.utils.data.DataLoader(Dataset(all_tokens, all_masks, all_bounds, all_index_maps,
                                           sample_ids, sorted_index, fix_cls_token), 
                                   batch_size=batch_size, collate_fn=collate_fn)
        

In [None]:
def collect_ps(model_class, checkpoints, dataset, num_files):
    all_outs = np.zeros((num_files, 2048, 15), 'f4')
    all_bounds = np.zeros((num_files, 2048, 2), 'i4')
    all_token_nums = np.zeros((num_files,), 'i4')
    all_word_indices = []
    all_sample_ids = []
    
    output_fn = t.log_softmax if average_folds_logits else t.softmax
    model = model_class().eval().cuda()
    for fold_ix in range(5):
        ix = 0
        model.load_state_dict(t.load(checkpoints[fold_ix], map_location='cuda:0'),
                              strict=False)
        for (tokens, mask, offsets), num_tokens, (word_indices, sample_ids) in tqdm(dataset):
            bs = len(sample_ids)
            outs = output_fn(model(tokens.cuda(), mask.cuda()), -1).cpu().numpy()
            if fold_ix == 0:
                all_word_indices.extend(word_indices)
                all_sample_ids.extend(sample_ids)
            for x in range(bs):
                sample_num_tokens = num_tokens[x]
                if fold_ix == 0:
                    all_token_nums[ix] = sample_num_tokens - 2
                    all_bounds[ix, :sample_num_tokens - 2] = offsets[x, 1: sample_num_tokens - 1]
                all_outs[ix, :sample_num_tokens - 2] += .2 * outs[x, 1: sample_num_tokens - 1]
                ix += 1
    del model
    return all_outs, all_bounds, all_token_nums, all_word_indices, all_sample_ids
    

In [None]:
def calc_entity_score(span, ps, c):
    s, e = span
    score = (ps[s, c * 2 - 1] + ps[s + 1: e + 1, c * 2].sum())/(e - s + 1)
    return score

In [None]:
def first_token_merge(gather_a, gather_b, bounds_a, bounds_b, logits_a, logits_b, new_bounds):
    new_gather_a = [0]
    new_gather_b = [0]
    prev_same_a = False
    prev_same_b = False
    for ix in range(1, len(gather_a)):
        same_a = gather_a[ix] == gather_a[ix - 1]
        same_b = gather_b[ix] == gather_b[ix - 1]
        if same_a:
            assert not same_b
            if prev_same_b and not prev_same_a and new_gather_a[-1] != gather_a[ix]:
                new_gather_a.append(gather_a[ix])
                new_gather_b.append(gather_b[ix])
            else:
                prev_same_a = same_a
                prev_same_b = same_b
                continue
        elif same_b:
            if prev_same_a and not prev_same_b and new_gather_b[-1] != gather_b[ix]:
                new_gather_a.append(gather_a[ix])
                new_gather_b.append(gather_b[ix])
            else:
                prev_same_a = same_a
                prev_same_b = same_b
                continue
        else:
            new_gather_a.append(gather_a[ix])
            new_gather_b.append(gather_b[ix])
        prev_same_a = same_a
        prev_same_b = same_b
    new_bounds = []
    for index_tokens_a, index_tokens_b in zip(new_gather_a, new_gather_b):
        new_bounds.append((min(bounds_a[index_tokens_a, 0],
                               bounds_b[index_tokens_b, 0]), 
                           max(bounds_a[index_tokens_a, 1],
                               bounds_b[index_tokens_b, 1])))
    new_logits = logits_a[new_gather_a] + logits_b[new_gather_b]
    return new_logits, np.array(new_bounds)

def merge_ps(logits_a, logits_b, bounds_a, bounds_b):
    a = bounds_a
    b = bounds_b
    mapping_a = []
    mapping_b = []
    new_bounds = []
    apos = 0
    bpos = 0
    a_s, a_e = a[apos]
    b_s, b_e = b[bpos]
    current_start = 0
    while True:
        if a_e == b_e:
            new_bounds.append((current_start, b_e))
            mapping_a.append(apos)
            mapping_b.append(bpos)
            if apos == len(a) - 1 or bpos == len(b) - 1:
                break
            next_a_s, next_b_s = a[apos + 1][0], b[bpos + 1][0]
            if next_a_s < next_b_s:
                apos += 1
                a_s, a_e = a[apos]
                current_start = a_s
            elif next_b_s < next_a_s:
                bpos += 1
                b_s, b_e = b[bpos]
                current_start = b_s
            else:
                apos += 1
                bpos += 1
                a_s, a_e = a[apos]
                b_s, b_e = b[bpos]
                current_start = a_s
        elif a_e < b_e:
            new_bounds.append((current_start, a_e))
            mapping_a.append(apos)
            mapping_b.append(bpos)
            apos += 1
            a_s, a_e = a[apos]
            current_start = a_s
        else:
            new_bounds.append((current_start, b_e))
            mapping_a.append(apos)
            mapping_b.append(bpos)
            bpos += 1
            b_s, b_e = b[bpos]
            current_start = b_s
    
    return first_token_merge(mapping_a, mapping_b, bounds_a, bounds_b, logits_a, logits_b, new_bounds)
    


In [None]:
num_files = len(glob('../input/feedback-prize-2021/test/*.txt'))
all_outs = None
all_bounds = None
all_token_nums = None
all_sample_ids = None
all_word_indices = None
v1xl_checkpoints = sorted(glob('../input/debertav1xlarge/clean*'), 
                        key=lambda x: int(x.split('/')[-1].split('_')[1][-1]))
v1l_checkpoints = sorted(glob('../input/feedbackv2/debertav1/*'), 
                        key=lambda x: int(x.split('/')[-1].split('_')[0][-1]))
v2_checkpoints = sorted(glob('../input/feedbackdebertav2xlargeweights/*'), 
                        key=lambda x: int(x.split('/')[-1].split('_')[0][-1]))
v3_checkpoints = sorted(glob('../input/feedback-result/clean_nomalv3_scheduler_1dcnn_unk100_0.8_fold*'), 
                        key=lambda x: int(x.split('/')[-1].split('_')[6][-1]))
for dataset_version, checkpoints, model, batch_size, fix_cls_token in zip((1, 2, 3, 1),
                                                           (v1xl_checkpoints, v2_checkpoints,
                                                            v3_checkpoints, v1l_checkpoints),
                                                           (Debertav1XLarge, Debertav2,
                                                            Debertav3, Debertav1Large),
                                                           ( 4, 4, 8, 8),
                                                           (True, False, False, False)):
    dataset = make_dataset(version=dataset_version, batch_size=batch_size,
                           fix_cls_token=fix_cls_token)
    (new_outs, new_bounds, new_token_nums,
    new_word_indices, new_sample_ids) = collect_ps(model, checkpoints, 
                                                   dataset, num_files)
    if average_folds_logits and not add_models_logits:
        new_outs = np.exp(new_outs)
    elif add_models_logits and not average_folds_logits:
        new_outs = np.log(new_outs)
    
        
    if all_outs is None:
        all_outs = new_outs
        all_bounds = new_bounds
        all_token_nums = new_token_nums
        all_word_indices = new_word_indices
        all_sample_ids = new_sample_ids
    else:
        sample_id_to_ix = {x: ix for ix, x in enumerate(new_sample_ids)}
        alignment_index = [sample_id_to_ix[x] for x in all_sample_ids] 
        merged_outs = []
        merged_bounds = []
        merged_token_nums = []
        for sample_ix in range(len(all_outs)):
            aligned_ix = alignment_index[sample_ix]
            logits, bounds = merge_ps(all_outs[sample_ix][:all_token_nums[sample_ix]],
                                      new_outs[aligned_ix][:new_token_nums[aligned_ix]],
                                      all_bounds[sample_ix][:all_token_nums[sample_ix]],
                                      new_bounds[aligned_ix][:new_token_nums[aligned_ix]])
            merged_outs.append(logits)
            merged_bounds.append(bounds)
            merged_token_nums.append(len(logits))
        all_outs = merged_outs
        all_bounds = merged_bounds
        all_token_nums = merged_token_nums
    
        del new_outs
        del new_bounds
        del new_token_nums
        del new_word_indices
        del new_sample_ids
        
    del dataset
    gc.collect()
    t.cuda.empty_cache()
    sleep(10)

In [None]:
START_WITH_I = True
LOOK_AHEAD = True


def extract_entities(ps, n):
    max_ps = ps.max(-1)
    
    ps = ps.argsort(-1)[...,::-1]
    # argmax
    cat_ps = ps[:, 0]
    # argmax2
    cat_ps2 = ps[:, 1]
    
    all_entities = {}
    new_entity = True
    current_cat = current_start = current_end = None
    
    # except for special tokens
    for ix in range(n):

        # logic on new entity
        if new_entity:
            # Background - ignore
            if cat_ps[ix] == 0:
                pass

            # B-LABEL(1,3,5,7,...) - start entity
            elif cat_ps[ix] % 2 == 1:
                current_cat = (cat_ps[ix] + 1) // 2
                current_start = current_end = ix
                new_entity = False
                
                if current_cat in [6, 7]:
                    LOOK_AHEAD = False
                else:
                    LOOK_AHEAD = True

            # I-LABEL(2,4,6,8,...) - conditional start
            elif cat_ps[ix] % 2 == 0:
                if START_WITH_I:
                    # Condition: I-LABEL in argmax with B-LABEL in argmax2
                    if cat_ps[ix] == (cat_ps2[ix]+1):
                        current_cat = cat_ps[ix] // 2
                        current_start = current_end = ix
                        new_entity = False
                        
                        if current_cat in [6, 7]:
                            LOOK_AHEAD = False
                        else:
                            LOOK_AHEAD = True
        
        # logic on ongoing entity
        else:
            # Background - save current entity and init current
            if cat_ps[ix] == 0:
                if LOOK_AHEAD:
                    if ix < n - 1 and (cat_ps[ix+1] == current_cat*2) and (cat_ps2[ix] == current_cat*2):
                        current_end = ix
                    else:
                        # update current
                        if current_cat not in all_entities:
                            all_entities[current_cat] = []
                        all_entities[current_cat].append((current_start, current_end))

                        # init current for new start
                        new_entity = True
                        current_cat = current_start = current_end = None
                
                else:
                    # update current
                    if current_cat not in all_entities:
                        all_entities[current_cat] = []
                    all_entities[current_cat].append((current_start, current_end))

                    # init current for new start
                    new_entity = True
                    current_cat = current_start = current_end = None

            # B-LABEL(1,3,5,7,...) - save current entity and start new
            elif cat_ps[ix] % 2 == 1:
                if cat_ps[ix] == (current_cat*2-1):
                    # update current
                    if current_cat not in all_entities:
                        all_entities[current_cat] = []
                    all_entities[current_cat].append((current_start, current_end))

                    # start new current
                    current_cat = (cat_ps[ix] + 1) // 2
                    current_start = current_end = ix
                    new_entity = False
                    
                    if current_cat in [6, 7]:
                        LOOK_AHEAD = False
                    else:
                        LOOK_AHEAD = True
                
                else:
                    if LOOK_AHEAD:
                        if ix < n - 1 and (cat_ps[ix+1] == current_cat*2) and (cat_ps2[ix] == current_cat*2):
                            current_end = ix
                        else:
                            # update current
                            if current_cat not in all_entities:
                                all_entities[current_cat] = []
                            all_entities[current_cat].append((current_start, current_end))

                            # start new current
                            current_cat = (cat_ps[ix] + 1) // 2
                            current_start = current_end = ix
                            new_entity = False
                        
                            if current_cat in [6, 7]:
                                LOOK_AHEAD = False
                            else:
                                LOOK_AHEAD = True
                            
                    else:
                        # update current
                        if current_cat not in all_entities:
                            all_entities[current_cat] = []
                        all_entities[current_cat].append((current_start, current_end))

                        # start new current
                        current_cat = (cat_ps[ix] + 1) // 2
                        current_start = current_end = ix
                        new_entity = False
                        
                        if current_cat in [6, 7]:
                            LOOK_AHEAD = False
                        else:
                            LOOK_AHEAD = True
                
            # I-LABEL(2,4,6,8,...) - conditional continue
            elif cat_ps[ix] % 2 == 0:
                # B-LABEL0, I-LABEL0 - continue
                if cat_ps[ix] == current_cat*2:
                    current_end = ix
                # B-LBAEL0, I-LABEL1 - conditional finish current entity
                else:
                    if LOOK_AHEAD:
                        if ix < n - 1 and (cat_ps[ix+1] == current_cat*2) and (cat_ps2[ix] == current_cat*2):
                            current_end = ix
                        else:
                            # update current
                            if current_cat not in all_entities:
                                all_entities[current_cat] = []
                            all_entities[current_cat].append((current_start, current_end))

                            # init current
                            new_entity = True
                            current_cat = current_start = current_end = None
                    else:
                        # update current
                        if current_cat not in all_entities:
                            all_entities[current_cat] = []
                        all_entities[current_cat].append((current_start, current_end))

                        # init current
                        new_entity = True
                        current_cat = current_start = current_end = None
    
    # last entity
    if not new_entity:
        # update current
        if current_cat not in all_entities:
            all_entities[current_cat] = []
        all_entities[current_cat].append((current_start, current_end))
    
    return all_entities

def filter_ps(all_entities, ps):
    
    for cat_ix, min_num_tokens, min_score in zip(range(1, 8), token_len_filters, score_filters):
        
        if cat_ix in all_entities:
            possible_entities = [x for x 
                                 in all_entities[cat_ix] 
                                 if x[1] - x[0] + 1 >= min_num_tokens
                        and calc_entity_score(x, ps, cat_ix) * (x[1] - x[0] + 1) ** .2 > min_score]
    
            if cat_ix in (1, 2, 5):
                if len(possible_entities) > 1:
                    max_score = -9999
                    for x in possible_entities:
                        entity_score = calc_entity_score(x, ps, cat_ix)
                        if entity_score > max_score:
                            max_score = entity_score
                            biggest_entity = x
                    possible_entities = [biggest_entity]
            
            all_entities[cat_ix] = possible_entities
    
    return all_entities.items()

#clean
def extend_tokens(ent, n):
    ent_size = ent[1] - ent[0] + 1
    if ent_size > 15:
        return [ent[0], min(n-1, ent[1] + exts[-1])]
    return [ent[0], min(n-1, ent[1] + exts[max(0, ent_size - 2)])]


In [None]:
def map_span_to_word_indices(span, index_map, bounds):
    return (index_map[bounds[span[0], 0]], index_map[bounds[span[1], 1] - 1])

label_names = ['None', 'Lead', 'Position', 'Evidence', 'Claim',
               'Concluding Statement', 'Counterclaim', 'Rebuttal']


sub_sample_ids = []
sub_cat_names = []
sub_spans = []
sub_scores = []
for sample_ix in tqdm(range(len(all_token_nums)), leave=False):
    predicted_spans = \
        {x: {
            'entity': [
                map_span_to_word_indices(extend_tokens(span, all_token_nums[sample_ix]),
                                         all_word_indices[sample_ix], all_bounds[sample_ix]) 
                for span 
                in y],
            'scores': [
                calc_entity_score(span, all_outs[sample_ix], x) 
                for span 
                in y],
            }
         for x, y in filter_ps(extract_entities(all_outs[sample_ix],  all_token_nums[sample_ix]), all_outs[sample_ix])}

    for cat_ix in predicted_spans:
        for entity in predicted_spans[cat_ix]['entity']:
            sub_sample_ids.append(all_sample_ids[sample_ix])
            sub_cat_names.append(label_names[cat_ix])
            sub_spans.append(' '.join(str(x) for x in range(entity[0], entity[1] + 1)))
        for scores in predicted_spans[cat_ix]['scores']:
            sub_scores.append(scores)

sub = pd.DataFrame({'id': sub_sample_ids, 
              'class': sub_cat_names,
              'predictionstring': sub_spans,
              'scores': sub_scores,})


sub.drop('scores', axis=1).to_csv('submission.csv', index=False)

In [None]:
sub