# Packages

In [None]:
# use transformer version 4.19.2
!pip install -qq ../input/transformers-4-19-2/transformers-4.19.2-py3-none-any.whl

# Imports

In [None]:
# basics
import os
import gc
import sys
import json
from copy import deepcopy
from dataclasses import dataclass
from itertools import chain

# Processing
import numpy as np
import pandas as pd

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# huggingface
from datasets import Dataset
from accelerate import Accelerator
from transformers import AutoConfig, AutoModel, AutoTokenizer, DataCollatorWithPadding
from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout

# misc
from joblib import Parallel, delayed
from tqdm.auto import tqdm

# ipython
from IPython.display import display
from IPython.core.debugger import set_trace
from tokenizers import AddedToken

# Enable other models

In [None]:
# debv3-l 8 fold
use_exp1 = True

# debv3-l multihead lstm
use_exp3 = False

# debv3-l resolved data
use_exp4 = False

# dexl 
use_exp6 = False

# debl kd from dexl
use_exp8 = False

# debv3-l revist
use_exp10 = False

# debv3-l uda 
use_exp11 = False

# debv3-l 10 fold
use_exp16 = True

# v3L+b 
use_exp102 = False

# debv3-l prompt 8 fold
use_exp205 = False

# debv3-l prompt 10 fold LB 0.565
use_exp209 = True

# longformer
use_exp212 = True

use_exp214 = False

use_full_data_models = True

# enable for running sampled (3k) train vs test
debug = False

In [None]:
import pickle
from textblob import TextBlob

# functions for separating the POS Tags
def adjectives(text):
    blob = TextBlob(text)
    return len([word for (word,tag) in blob.tags if tag == 'JJ'])
def verbs(text):
    blob = TextBlob(text)
    return len([word for (word,tag) in blob.tags if tag.startswith('VB')])
def adverbs(text):
    blob = TextBlob(text)
    return len([word for (word,tag) in blob.tags if tag.startswith('RB')])
def nouns(text):
    blob = TextBlob(text)
    return len([word for (word,tag) in blob.tags if tag.startswith('NN')])

# Load Data

In [None]:
# Read in test data and assign uid for tracking discourse elements
if debug:
    test_df = pd.read_csv("../input/feedback-prize-effectiveness/train.csv")
    test_df = test_df.sample(n=3000).reset_index(drop=True)
else:
    test_df = pd.read_csv("../input/feedback-prize-effectiveness/test.csv")


all_ids = test_df["discourse_id"].unique().tolist()
discourse2idx = {discourse: pos for pos, discourse in enumerate(all_ids)}
idx2discourse = {v:k for k, v in discourse2idx.items()}
test_df["uid"] = test_df["discourse_id"].map(discourse2idx)

# Load test essays
def _load_essay(essay_id):
    if debug:
        filename = os.path.join("../input/feedback-prize-effectiveness/train", f"{essay_id}.txt")
    else:
        filename = os.path.join("../input/feedback-prize-effectiveness/test", f"{essay_id}.txt")
    with open(filename, "r") as f:
        text = f.read()
    return [essay_id, text]

def read_essays(essay_ids, num_jobs=12):
    train_essays = []
    results = Parallel(n_jobs=num_jobs, verbose=1)(delayed(_load_essay)(essay_id) for essay_id in essay_ids)
    for result in results:
        train_essays.append(result)

    result_dict = dict()
    for e in train_essays:
        result_dict[e[0]] = e[1]

    essay_df = pd.Series(result_dict).reset_index()
    essay_df.columns = ["essay_id", "essay_text"]
    return essay_df

essay_ids = test_df["essay_id"].unique().tolist()
essay_df = read_essays(essay_ids)

# Display sample test data
display(test_df.sample())

# Topics

In [None]:
import sys
#sys.path.append('../input/k/trushk/feedback-topics-identification-with-bertopic/site-packages/site-packages/')
sys.path.append('../input/k/lextoumbourou/feedback-topics-identification-with-bertopic/site-packages/')
from bertopic import BERTopic
import glob, pandas as pd, numpy as np, re
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

from tqdm import tqdm

topic_model = BERTopic.load("../input/feedback-topics-identification-with-bertopic/feedback_2021_topic_model")
topic_meta_df = pd.read_csv('../input/feedback-topics-identification-with-bertopic/topic_model_metadata.csv')

#topic_model = BERTopic.load("../input/fdbk-topic-model/feedback_2021_topic_model")
#topic_meta_df = pd.read_csv('../input/fdbk-topic-model/topic_model_metadata.csv')


topic_meta_df = topic_meta_df.rename(columns={'Topic': 'topic', 'Name': 'topic_name'}).drop(columns=['Count'])
topic_meta_df.topic_name = topic_meta_df.topic_name.apply(lambda n: ' '.join(n.split('_')[1:]))

sws = stopwords.words("english") + ["n't",  "'s", "'ve"]
fls = glob.glob("../input/feedback-prize-effectiveness/test/*.txt")
docs = []
for fl in tqdm(fls):
    with open(fl) as f:
        txt = f.read()
        word_tokens = word_tokenize(txt)
        txt = " ".join([w for w in word_tokens if not w.lower() in sws])
    docs.append(txt)

topics, probs = topic_model.transform(docs)

pred_topics = pd.DataFrame()
dids = list(map(lambda fl: fl.split("/")[-1].split(".")[0], fls))
pred_topics["id"] = dids
pred_topics["topic"] = topics
pred_topics['prob'] = probs
pred_topics = pred_topics.drop(columns={'prob'})
pred_topics = pred_topics.rename(columns={'id': 'essay_id'})

pred_topics = pred_topics.merge(topic_meta_df, left_on='topic', right_on='topic', how='left')
pred_topics.rename(columns={'topic': 'topic_num', 'topic_name': 'topic'}, inplace=True)
pred_topics

In [None]:
topic_map = {
	'seagoing luke animals cowboys': 'Should you join the Seagoing Cowboys program?',
	'driving phone phones cell' :  'Should drivers be allowed to use cell phones while driving?',
	 'phones cell cell phones school': 'Should students be allowed to use cell phones in school?',
	 'straights state welfare wa' : ' State welfare' ,
	 'summer students project projects': 'Should school summer projects be designed by students or teachers?',
	 'students online school classes': 'Is distance learning or online schooling beneficial to students?',
	 'car cars usage pollution': 'Should car usage be limited to help reduce pollution?',
	 'cars driverless car driverless cars': 'Are driverless cars going to be helpful?',
	 'emotions technology facial computer' : 'Should computers read the emotional expressions of students in a classroom?',
	 'community service community service help': 'Should community service be mandatory for all students?',
	 'sports average school students' : 'Should students be allowed to participate in sports  unless they have at least a grade B average?',
	 'advice people ask multiple': 'Should you ask multiple people for advice?',
	 'extracurricular activities activity students': 'Should all students participate in at least one extracurricular activity?',
	 'electoral college electoral college vote':  'Should the electoral college be abolished in favor of popular vote?' ,
	 'electoral vote college electoral college' : 'Should the electoral college be abolished in favor of popular vote?' ,
	 'face mars landform aliens' : 'Is the face on Mars  a natural landform or made by Aliens?',
     'venus planet author earth': 'Is Studying Venus a worthy pursuit?',
}
essay_df = essay_df.merge(pred_topics, on='essay_id', how='left')
essay_df['prompt'] = essay_df['topic'].map(topic_map)

essay_df.head()

# EXP 214 - Debv2-XL + Aug data

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/deberta-v2-xlarge/",
    "add_new_tokens": true,
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 3,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

In [None]:
#--------------- Tokenizer ---------------------------------------------#
def get_tokenizer(config):
    """load the tokenizer"""

    print("using auto tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])

    print("=="*40)
    print(f"tokenizer len: {len(tokenizer)}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [SOE] [LEAD] [CLAIM] [POSITION] [COUNTER_CLAIM]')}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [EOE] [LEAD_END] [POSITION_END] [CLAIM_END]')}")

    print("=="*40)
    return tokenizer


#--------------- Processing ---------------------------------------------#


DISCOURSE_START_TOKENS = [
    "[LEAD]",
    "[POSITION]",
    "[CLAIM]",
    "[COUNTER_CLAIM]",
    "[REBUTTAL]",
    "[EVIDENCE]",
    "[CONCLUDING_STATEMENT]"
]

TOKEN_MAP = {
    "topic": ["Topic [TOPIC]", "[TOPIC END]"],
    "Lead": ["Lead [LEAD]", "[LEAD END]"],
    "Position": ["Position [POSITION]", "[POSITION END]"],
    "Claim": ["Claim [CLAIM]", "[CLAIM END]"],
    "Counterclaim": ["Counterclaim [COUNTER_CLAIM]", "[COUNTER_CLAIM END]"],
    "Rebuttal": ["Rebuttal [REBUTTAL]", "[REBUTTAL END]"],
    "Evidence": ["Evidence [EVIDENCE]", "[EVIDENCE END]"],
    "Concluding Statement": ["Concluding Statement [CONCLUDING_STATEMENT]", "[CONCLUDING_STATEMENT END]"]
}


DISCOURSE_END_TOKENS = [
    "[LEAD END]",
    "[POSITION END]",
    "[CLAIM END]",
    "[COUNTER_CLAIM END]",
    "[REBUTTAL END]",
    "[EVIDENCE END]",
    "[CONCLUDING_STATEMENT END]",
]



def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    for cur_discourse in discourse_list:
        if cur_discourse not in to_return:
            to_return[cur_discourse] = []

        matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
        for match in matches:
            span_start, span_end = match.span()
            if span_end <= reading_head:
                continue
            to_return[cur_discourse].append(match.span())
            reading_head = span_end
            break

    # post process
    for cur_discourse in discourse_list:
        if not to_return[cur_discourse]:
            print("resorting to relaxed search...")
            to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_essay(essay_id, essay_text, topic, prompt, anno_df):
    """insert newly added tokens in the essay text
    """
    tmp_df = anno_df[anno_df["essay_id"] == essay_id].copy()
    tmp_df = tmp_df.sort_values(by="discourse_start")
    buffer = 0

    for _, row in tmp_df.iterrows():
        s, e, d_type = int(row.discourse_start) + buffer, int(row.discourse_end) + buffer, row.discourse_type
        s_tok, e_tok = TOKEN_MAP[d_type]
        essay_text = " ".join([essay_text[:s], s_tok, essay_text[s:e], e_tok, essay_text[e:]])
        buffer += len(s_tok) + len(e_tok) + 4

    essay_text = "[SOE]" + " [TOPIC] " + prompt + " [TOPIC END] " +  essay_text + "[EOE]"
    return essay_text


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness", "uid"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()

    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")
    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    # anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    print("=="*40)
    print("processing essay text and inserting new tokens at span boundaries")
    notes_df["essay_text"] = notes_df[["essay_id", "essay_text", "topic", "prompt"]].apply(
        lambda x: process_essay(x[0], x[1], x[2], x[3], anno_df), axis=1
    )
    print("=="*40)

    anno_df = anno_df.drop(columns=["discourse_start", "discourse_end"])
    notes_df = notes_df.drop_duplicates(subset=["essay_id"])[["essay_id", "essay_text"]].copy()

    anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    if "discourse_effectiveness" in anno_df.columns:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_effectiveness", "discourse_type"]].agg(list).reset_index()
    else:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_type"]].agg(list).reset_index()

    grouped_df = pd.merge(grouped_df, notes_df, on="essay_id", how="left")
    grouped_df = grouped_df.rename(columns={"uid": "uids"})

    return grouped_df


#--------------- Dataset ----------------------------------------------#
NEW_TOKENS = [
        "[LEAD]",
        "[POSITION]",
        "[CLAIM]",
        "[COUNTER_CLAIM]",
        "[REBUTTAL]",
        "[EVIDENCE]",
        "[CONCLUDING_STATEMENT]",
        "[TOPIC]",  # 12808
        "[SOE]",  # 12809
        "[EOE]",  # 12810
        "[LEAD END]",
        "[POSITION END]",
        "[CLAIM END]",
        "[COUNTER_CLAIM END]",
        "[REBUTTAL END]",
        "[EVIDENCE END]",
        "[CONCLUDING_STATEMENT END]",
        "[TOPIC END]",  # 128018
    ]
    

class AuxFeedbackDataset:
    """Dataset class for feedback prize effectiveness task
    """

    def __init__(self, config):
        self.config = config

        self.label2id = {
            "Ineffective": 0,
            "Adequate": 1,
            "Effective": 2,
        }

        self.discourse_type2id = {
            "Lead": 1,
            "Position": 2,
            "Claim": 3,
            "Counterclaim": 4,
            "Rebuttal": 5,
            "Evidence": 6,
            "Concluding Statement": 7,
        }

        self.id2label = {v: k for k, v in self.label2id.items()}
        self.load_tokenizer()

    def load_tokenizer(self):
        """load tokenizer as per config 
        """
        self.tokenizer = get_tokenizer(self.config)
        print("=="*40)
        print("token maps...")
        print(TOKEN_MAP)
        print("=="*40)

        if self.config["add_new_tokens"]:
            print("adding new tokens...")
            tokens_to_add = []
            for this_tok in NEW_TOKENS:
                 tokens_to_add.append(AddedToken(this_tok, lstrip=True, rstrip=False))
            self.tokenizer.add_tokens(tokens_to_add)
        print(f"tokenizer len: {len(self.tokenizer)}")

        self.discourse_token_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_START_TOKENS))
        self.discourse_end_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_END_TOKENS))
        self.global_tokens = self.discourse_token_ids.union(self.discourse_end_ids)

    def tokenize_function(self, examples):
        tz = self.tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=False,  # no truncation at first
            add_special_tokens=True,
            return_offsets_mapping=True,
        )
        return tz

    def process_spans(self, examples):

        span_head_char_start_idxs, span_tail_char_end_idxs = [], []
        span_head_idxs, span_tail_idxs = [], []

        for example_input_ids, example_offset_mapping, example_uids in zip(examples["input_ids"], examples["offset_mapping"], examples["uids"]):
            example_span_head_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_token_ids]
            example_span_tail_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_end_ids]

            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in example_span_head_idxs]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in example_span_tail_idxs]

            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

            span_head_idxs.append(example_span_head_idxs)
            span_tail_idxs.append(example_span_tail_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def generate_labels(self, examples):
        labels = []
        for example_labels, example_uids in zip(examples["discourse_effectiveness"], examples["uids"]):
            labels.append([self.label2id[l] for l in example_labels])
        return {"labels": labels}

    def generate_discourse_type_ids(self, examples):
        discourse_type_ids = []
        for example_discourse_types in examples["discourse_type"]:
            discourse_type_ids.append([self.discourse_type2id[dt] for dt in example_discourse_types])
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(self, examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(self, examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1

    def sanity_check_head_labels(self, examples):
        for head_idxs, head_labels in zip(examples["span_head_idxs"], examples["labels"]):
            assert len(head_idxs) == len(head_labels)

    def get_dataset(self, df, essay_df, mode='train'):
        """main api for creating the Feedback dataset

        :param df: input annotation dataframe
        :type df: pd.DataFrame
        :param essay_df: dataframe with essay texts
        :type essay_df: pd.DataFrame
        :param mode: check if required for train or infer, defaults to 'train'
        :type mode: str, optional
        :return: the created dataset
        :rtype: Dataset
        """
        df = process_input_df(df, essay_df)

        # save a sample for sanity checks
        sample_df = df.sample(min(16, len(df)))
        sample_df.to_csv(os.path.join(self.config["model_dir"], f"{mode}_df_processed.csv"), index=False)

        task_dataset = Dataset.from_pandas(df)
        task_dataset = task_dataset.map(self.tokenize_function, batched=True)
        task_dataset = task_dataset.map(self.compute_input_length, batched=True)
        task_dataset = task_dataset.map(self.process_spans, batched=True)
        print(task_dataset)
        # todo check edge cases
        task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) == len(
            example['span_tail_idxs']))  # no need to run on empty set
        print(task_dataset)
        task_dataset = task_dataset.map(self.generate_discourse_type_ids, batched=True)
        task_dataset = task_dataset.map(self.sanity_check_head_tail, batched=True)

        if mode != "infer":
            task_dataset = task_dataset.map(self.generate_labels, batched=True)
            task_dataset = task_dataset.map(self.sanity_check_head_labels, batched=True)

        try:
            task_dataset = task_dataset.remove_columns(column_names=["__index_level_0__"])
        except Exception as e:
            pass
        return df, task_dataset

#--------------- dataset with truncation ---------------------------------------------#


def get_fast_dataset(config, df, essay_df, mode="train"):
    """Function to get fast approach dataset with truncation & sliding window
    """
    dataset_creator = AuxFeedbackDataset(config)
    _, task_dataset = dataset_creator.get_dataset(df, essay_df, mode=mode)

    original_dataset = deepcopy(task_dataset)
    tokenizer = dataset_creator.tokenizer
    START_IDS = dataset_creator.discourse_token_ids
    END_IDS = dataset_creator.discourse_end_ids
    GLOBAL_IDS = dataset_creator.global_tokens

    def tokenize_with_truncation(examples):
        tz = tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=True,
            add_special_tokens=True,
            return_offsets_mapping=True,
            max_length=config["max_length"],
            stride=config["stride"],
            return_overflowing_tokens=True,
            return_token_type_ids=True,
        )
        return tz

    def process_span(examples):
        span_head_idxs, span_tail_idxs = [], []
        span_head_char_start_idxs, span_tail_char_end_idxs = [], []

        buffer = 25  # do not include a head if it is within buffer distance away from last token

        for example_input_ids, example_offset_mapping in zip(examples["input_ids"], examples["offset_mapping"]):
            # ------------------- Span Heads -----------------------------------------#
            if len(example_input_ids) < config["max_length"]:  # no truncation
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in START_IDS]
            else:
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if (
                    (this_id in START_IDS) & (pos <= config["max_length"]-buffer))]

            n_heads = len(head_candidate)

            # ------------------- Span Tails -----------------------------------------#
            tail_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in END_IDS]

            # ------------------- Edge Cases -----------------------------------------#
            # 1. A tail occurs before the first head in the sequence due to truncation
            if (len(tail_candidate) > 0) & (len(head_candidate) > 0):
                if tail_candidate[0] < head_candidate[0]:  # truncation effect
                    # print(f"check: heads: {head_candidate}, tails {tail_candidate}")
                    tail_candidate = tail_candidate[1:]  # shift by one

            # 2. Tail got chopped off due to truncation but the corresponding head is still there
            if len(tail_candidate) < n_heads:
                assert len(tail_candidate) + 1 == n_heads
                assert len(example_input_ids) == config["max_length"]  # should only happen if input text is truncated
                tail_candidate.append(config["max_length"]-2)  # the token before [SEP] token

            # 3. Additional tails remain in the buffer region
            if len(tail_candidate) > len(head_candidate):
                tail_candidate = tail_candidate[:len(head_candidate)]

            # ------------------- Create the fields ------------------------------------#
            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in head_candidate]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in tail_candidate]

            span_head_idxs.append(head_candidate)
            span_tail_idxs.append(tail_candidate)
            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def enforce_alignment(examples):
        uids = []

        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_uids = original_example["uids"]
            char2uid = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_uids)}
            current_example_uids = [char2uid[char_idx] for char_idx in example_span_head_char_start_idxs]
            uids.append(current_example_uids)
        return {"uids": uids}

    def recompute_labels(examples):
        labels = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_labels = original_example["labels"]
            char2label = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_labels)}
            current_example_labels = [char2label[char_idx] for char_idx in example_span_head_char_start_idxs]
            labels.append(current_example_labels)
        return {"labels": labels}

    def recompute_discourse_type_ids(examples):
        discourse_type_ids = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_discourse_type_ids = original_example["discourse_type_ids"]
            char2discourse_id = {k: v for k, v in zip(
                original_example_span_head_char_start_idxs, original_example_discourse_type_ids)}
            current_example_discourse_type_ids = [char2discourse_id[char_idx]
                                                  for char_idx in example_span_head_char_start_idxs]
            discourse_type_ids.append(current_example_discourse_type_ids)
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1, f"head idxs: {head_idxs}, tail idxs {tail_idxs}"

    task_dataset = task_dataset.map(
        tokenize_with_truncation,
        batched=True,
        remove_columns=task_dataset.column_names,
        batch_size=len(task_dataset)
    )

    task_dataset = task_dataset.map(process_span, batched=True)
    task_dataset = task_dataset.map(enforce_alignment, batched=True)
    task_dataset = task_dataset.map(recompute_discourse_type_ids, batched=True)
    task_dataset = task_dataset.map(sanity_check_head_tail, batched=True)

    # no need to run on empty set
    task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) != 0)
    task_dataset = task_dataset.map(compute_input_length, batched=True)

    if mode != "infer":
        task_dataset = task_dataset.map(recompute_labels, batched=True)

    to_return = dict()
    to_return["dataset"] = task_dataset
    to_return["original_dataset"] = original_dataset
    to_return["tokenizer"] = tokenizer
    return to_return

if use_exp214:
    os.makedirs(config["model_dir"], exist_ok=True)

    print("creating the inference datasets...")
    infer_ds_dict = get_fast_dataset(config, test_df, essay_df, mode="infer")
    tokenizer = infer_ds_dict["tokenizer"]
    infer_dataset = infer_ds_dict["dataset"]
    print(infer_dataset)

    config["len_tokenizer"] = len(tokenizer)

    infer_dataset = infer_dataset.sort("input_length")

    infer_dataset.set_format(
        type=None,
        columns=['input_ids', 'attention_mask', 'token_type_ids', 'span_head_idxs',
                 'span_tail_idxs', 'discourse_type_ids', 'uids']
    )

# %% [markdown]
# ## Data Loader

from copy import deepcopy
from dataclasses import dataclass

import torch
from transformers import DataCollatorWithPadding


@dataclass
class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    """
    data collector for seq classification
    """

    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = 512
    return_tensors = "pt"

    def __call__(self, features):
        uids = [feature["uids"] for feature in features]
        discourse_type_ids = [feature["discourse_type_ids"] for feature in features]
        span_head_idxs = [feature["span_head_idxs"] for feature in features]
        span_tail_idxs = [feature["span_tail_idxs"] for feature in features]
        span_attention_mask = [[1]*len(feature["span_head_idxs"]) for feature in features]

        labels = None
        if "labels" in features[0].keys():
            labels = [feature["labels"] for feature in features]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None,
        )

        b_max = max([len(l) for l in span_head_idxs])
        max_len = len(batch["input_ids"][0])

        default_head_idx = max(max_len - 10, 1)  # for padding
        default_tail_idx = max(max_len - 4, 1)  # for padding

        batch["span_head_idxs"] = [
            ex_span_head_idxs + [default_head_idx] * (b_max - len(ex_span_head_idxs)) for ex_span_head_idxs in span_head_idxs
        ]

        batch["uids"] = [ex_uids + [-1] * (b_max - len(ex_uids)) for ex_uids in uids]
        batch["discourse_type_ids"] = [ex_discourse_type_ids + [0] *
                                       (b_max - len(ex_discourse_type_ids)) for ex_discourse_type_ids in discourse_type_ids]

        batch["span_tail_idxs"] = [
            ex_span_tail_idxs + [default_tail_idx] * (b_max - len(ex_span_tail_idxs)) for ex_span_tail_idxs in span_tail_idxs
        ]

        batch["span_attention_mask"] = [
            ex_discourse_masks + [0] * (b_max - len(ex_discourse_masks)) for ex_discourse_masks in span_attention_mask
        ]

        if labels is not None:
            batch["labels"] = [ex_labels + [-1] * (b_max - len(ex_labels)) for ex_labels in labels]

        # multitask labels
        def _get_additional_labels(label_id):
            if label_id == 0:
                vec = [0, 0]
            elif label_id == 1:
                vec = [1, 0]
            elif label_id == 2:
                vec = [1, 1]
            elif label_id == -1:
                vec = [-1, -1]
            else:
                raise
            return vec

        if labels is not None:
            additional_labels = []
            for ex_labels in batch["labels"]:
                ex_additional_labels = [_get_additional_labels(el) for el in ex_labels]
                additional_labels.append(ex_additional_labels)
            batch["multitask_labels"] = additional_labels
        # pdb.set_trace()

        batch = {k: (torch.tensor(v, dtype=torch.int64) if k != "multitask_labels" else torch.tensor(
            v, dtype=torch.float32)) for k, v in batch.items()}
        return batch


if use_exp214:
    data_collector = CustomDataCollatorWithPadding(tokenizer=tokenizer)

    infer_dl = DataLoader(
        infer_dataset,
        batch_size=config["infer_bs"],
        shuffle=False,
        collate_fn=data_collector
    )

# %% [markdown]
# ## Model

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention

from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

         # resize model embeddings
        if config["add_new_tokens"]:
            print("resizing model embeddings...")
            print(f"tokenizer length = {config['len_tokenizer']}")
            self.base_model.resize_token_embeddings(config["len_tokenizer"])
        
        # dropouts
        self.dropout = StableDropout(self.config["dropout"])
        
        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]
        
        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        feature_vector = self.dropout(feature_vector)

        logits = self.classifier(feature_vector)

        return logits

# %% [markdown]
# ## Inference

checkpoints = [
    "../input/exp214-debv2-xl-prompt/fpe_model_fold_0_best.pth.tar",
    "../input/exp214-debv2-xl-prompt/fpe_model_fold_1_best.pth.tar",
    "../input/exp214-debv2-xl-prompt/fpe_model_fold_2_best.pth.tar",
    "../input/exp214-debv2-xl-prompt/fpe_model_fold_3_best.pth.tar",
]

def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp214_model_preds_{model_id}.csv", index=False)
    
if use_exp214:
    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        ckpt = torch.load(checkpoint)
        print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)


    del model
    del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    gc.collect()
    torch.cuda.empty_cache()

    import glob
    import pandas as pd

    csvs = glob.glob("exp214_model_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp214_df = pd.DataFrame()
    exp214_df["discourse_id"] = idx
    exp214_df["Ineffective"]  = preds[:, 0]
    exp214_df["Adequate"]     = preds[:, 1]
    exp214_df["Effective"]    = preds[:, 2]

    exp214_df = exp214_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()



In [None]:
if use_exp214:
    print(exp214_df.head())

# EXP 19 Model - DEXL

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/tapt-fpe-dexl",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 5,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

In [None]:
import os
import re
from copy import deepcopy
from itertools import chain

import pandas as pd
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer


#--------------- Tokenizer ---------------------------------------------#
def get_tokenizer(config):
    """load the tokenizer"""

    print("using auto tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])

    print("=="*40)
    print(f"tokenizer len: {len(tokenizer)}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [SOE] [LEAD] [CLAIM] [POSITION] [COUNTER_CLAIM]')}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [EOE] [LEAD_END] [POSITION_END] [CLAIM_END]')}")

    print("=="*40)
    return tokenizer


#--------------- Processing ---------------------------------------------#

TOKEN_MAP = {
    "Lead": ["Lead [LEAD]", "[LEAD_END]"],
    "Position": ["Position [POSITION]", "[POSITION_END]"],
    "Claim": ["Claim [CLAIM]", "[CLAIM_END]"],
    "Counterclaim": ["Counterclaim [COUNTER_CLAIM]", "[COUNTER_CLAIM_END]"],
    "Rebuttal": ["Rebuttal [REBUTTAL]", "[REBUTTAL_END]"],
    "Evidence": ["Evidence [EVIDENCE]", "[EVIDENCE_END]"],
    "Concluding Statement": ["Concluding Statement [CONCLUDING_STATEMENT]", "[CONCLUDING_STATEMENT_END]"]
}

DISCOURSE_START_TOKENS = [
    "[LEAD]",
    "[POSITION]",
    "[CLAIM]",
    "[COUNTER_CLAIM]",
    "[REBUTTAL]",
    "[EVIDENCE]",
    "[CONCLUDING_STATEMENT]"
]

DISCOURSE_END_TOKENS = [
    "[LEAD_END]",
    "[POSITION_END]",
    "[CLAIM_END]",
    "[COUNTER_CLAIM_END]",
    "[REBUTTAL_END]",
    "[EVIDENCE_END]",
    "[CONCLUDING_STATEMENT_END]"
]

# NEW_TOKENS = [
#     "[LEAD]",
#     "[POSITION]",
#     "[CLAIM]",
#     "[COUNTER_CLAIM]",
#     "[REBUTTAL]",
#     "[EVIDENCE]",
#     "[CONCLUDING_STATEMENT]",
#     "[LEAD_END]",
#     "[POSITION_END]",
#     "[CLAIM_END]",
#     "[COUNTER_CLAIM_END]",
#     "[REBUTTAL_END]",
#     "[EVIDENCE_END]",
#     "[CONCLUDING_STATEMENT_END]",
#     "[SOE]",
#     "[EOE]",
# ]


def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    for cur_discourse in discourse_list:
        if cur_discourse not in to_return:
            to_return[cur_discourse] = []

        matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
        for match in matches:
            span_start, span_end = match.span()
            if span_end <= reading_head:
                continue
            to_return[cur_discourse].append(match.span())
            reading_head = span_end
            break

    # post process
    for cur_discourse in discourse_list:
        if not to_return[cur_discourse]:
            print("resorting to relaxed search...")
            to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_essay(essay_id, essay_text, anno_df):
    """insert newly added tokens in the essay text
    """
    tmp_df = anno_df[anno_df["essay_id"] == essay_id].copy()
    tmp_df = tmp_df.sort_values(by="discourse_start")
    buffer = 0

    for _, row in tmp_df.iterrows():
        s, e, d_type = int(row.discourse_start) + buffer, int(row.discourse_end) + buffer, row.discourse_type
        s_tok, e_tok = TOKEN_MAP[d_type]
        essay_text = " ".join([essay_text[:s], s_tok, essay_text[s:e], e_tok, essay_text[e:]])
        buffer += len(s_tok) + len(e_tok) + 4

    essay_text = "[SOE]" + essay_text + "[EOE]"
    return essay_text


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness", "uid"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()

    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")
    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    # anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    print("=="*40)
    print("processing essay text and inserting new tokens at span boundaries")
    notes_df["essay_text"] = notes_df[["essay_id", "essay_text"]].apply(
        lambda x: process_essay(x[0], x[1], anno_df), axis=1
    )
    print("=="*40)

    anno_df = anno_df.drop(columns=["discourse_start", "discourse_end"])
    notes_df = notes_df.drop_duplicates(subset=["essay_id"])[["essay_id", "essay_text"]].copy()

    anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    if "discourse_effectiveness" in anno_df.columns:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_effectiveness", "discourse_type"]].agg(list).reset_index()
    else:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_type"]].agg(list).reset_index()

    grouped_df = pd.merge(grouped_df, notes_df, on="essay_id", how="left")
    grouped_df = grouped_df.rename(columns={"uid": "uids"})

    return grouped_df


#--------------- Dataset ----------------------------------------------#


class AuxFeedbackDataset:
    """Dataset class for feedback prize effectiveness task
    """

    def __init__(self, config):
        self.config = config

        self.label2id = {
            "Ineffective": 0,
            "Adequate": 1,
            "Effective": 2,
        }

        self.discourse_type2id = {
            "Lead": 1,
            "Position": 2,
            "Claim": 3,
            "Counterclaim": 4,
            "Rebuttal": 5,
            "Evidence": 6,
            "Concluding Statement": 7,
        }

        self.id2label = {v: k for k, v in self.label2id.items()}
        self.load_tokenizer()

    def load_tokenizer(self):
        """load tokenizer as per config 
        """
        self.tokenizer = get_tokenizer(self.config)
        print("=="*40)
        print("token maps...")
        print(TOKEN_MAP)
        print("=="*40)

        # print("adding new tokens...")
        # tokens_to_add = []
        # for this_tok in NEW_TOKENS:
        #     tokens_to_add.append(AddedToken(this_tok, lstrip=True, rstrip=False))
        # self.tokenizer.add_tokens(tokens_to_add)
        print(f"tokenizer len: {len(self.tokenizer)}")

        self.discourse_token_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_START_TOKENS))
        self.discourse_end_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_END_TOKENS))
        self.global_tokens = self.discourse_token_ids.union(self.discourse_end_ids)

    def tokenize_function(self, examples):
        tz = self.tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=False,  # no truncation at first
            add_special_tokens=True,
            return_offsets_mapping=True,
        )
        return tz

    def process_spans(self, examples):

        span_head_char_start_idxs, span_tail_char_end_idxs = [], []
        span_head_idxs, span_tail_idxs = [], []

        for example_input_ids, example_offset_mapping, example_uids in zip(examples["input_ids"], examples["offset_mapping"], examples["uids"]):
            example_span_head_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_token_ids]
            example_span_tail_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_end_ids]

            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in example_span_head_idxs]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in example_span_tail_idxs]

            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

            span_head_idxs.append(example_span_head_idxs)
            span_tail_idxs.append(example_span_tail_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def generate_labels(self, examples):
        labels = []
        for example_labels, example_uids in zip(examples["discourse_effectiveness"], examples["uids"]):
            labels.append([self.label2id[l] for l in example_labels])
        return {"labels": labels}

    def generate_discourse_type_ids(self, examples):
        discourse_type_ids = []
        for example_discourse_types in examples["discourse_type"]:
            discourse_type_ids.append([self.discourse_type2id[dt] for dt in example_discourse_types])
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(self, examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(self, examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1

    def sanity_check_head_labels(self, examples):
        for head_idxs, head_labels in zip(examples["span_head_idxs"], examples["labels"]):
            assert len(head_idxs) == len(head_labels)

    def get_dataset(self, df, essay_df, mode='train'):
        """main api for creating the Feedback dataset

        :param df: input annotation dataframe
        :type df: pd.DataFrame
        :param essay_df: dataframe with essay texts
        :type essay_df: pd.DataFrame
        :param mode: check if required for train or infer, defaults to 'train'
        :type mode: str, optional
        :return: the created dataset
        :rtype: Dataset
        """
        df = process_input_df(df, essay_df)

        # save a sample for sanity checks
        sample_df = df.sample(min(16, len(df)))
        sample_df.to_csv(os.path.join(self.config["model_dir"], f"{mode}_df_processed.csv"), index=False)

        task_dataset = Dataset.from_pandas(df)
        task_dataset = task_dataset.map(self.tokenize_function, batched=True)
        task_dataset = task_dataset.map(self.compute_input_length, batched=True)
        task_dataset = task_dataset.map(self.process_spans, batched=True)
        print(task_dataset)
        # todo check edge cases
        task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) == len(
            example['span_tail_idxs']))  # no need to run on empty set
        print(task_dataset)
        task_dataset = task_dataset.map(self.generate_discourse_type_ids, batched=True)
        task_dataset = task_dataset.map(self.sanity_check_head_tail, batched=True)

        if mode != "infer":
            task_dataset = task_dataset.map(self.generate_labels, batched=True)
            task_dataset = task_dataset.map(self.sanity_check_head_labels, batched=True)

        try:
            task_dataset = task_dataset.remove_columns(column_names=["__index_level_0__"])
        except Exception as e:
            pass
        return df, task_dataset

#--------------- dataset with truncation ---------------------------------------------#


def get_fast_dataset(config, df, essay_df, mode="train"):
    """Function to get fast approach dataset with truncation & sliding window
    """
    dataset_creator = AuxFeedbackDataset(config)
    _, task_dataset = dataset_creator.get_dataset(df, essay_df, mode=mode)

    original_dataset = deepcopy(task_dataset)
    tokenizer = dataset_creator.tokenizer
    START_IDS = dataset_creator.discourse_token_ids
    END_IDS = dataset_creator.discourse_end_ids
    GLOBAL_IDS = dataset_creator.global_tokens

    def tokenize_with_truncation(examples):
        tz = tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=True,
            add_special_tokens=True,
            return_offsets_mapping=True,
            max_length=config["max_length"],
            stride=config["stride"],
            return_overflowing_tokens=True,
            return_token_type_ids=True,
        )
        return tz

    def process_span(examples):
        span_head_idxs, span_tail_idxs = [], []
        span_head_char_start_idxs, span_tail_char_end_idxs = [], []

        buffer = 25  # do not include a head if it is within buffer distance away from last token

        for example_input_ids, example_offset_mapping in zip(examples["input_ids"], examples["offset_mapping"]):
            # ------------------- Span Heads -----------------------------------------#
            if len(example_input_ids) < config["max_length"]:  # no truncation
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in START_IDS]
            else:
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if (
                    (this_id in START_IDS) & (pos <= config["max_length"]-buffer))]

            n_heads = len(head_candidate)

            # ------------------- Span Tails -----------------------------------------#
            tail_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in END_IDS]

            # ------------------- Edge Cases -----------------------------------------#
            # 1. A tail occurs before the first head in the sequence due to truncation
            if (len(tail_candidate) > 0) & (len(head_candidate) > 0):
                if tail_candidate[0] < head_candidate[0]:  # truncation effect
                    # print(f"check: heads: {head_candidate}, tails {tail_candidate}")
                    tail_candidate = tail_candidate[1:]  # shift by one

            # 2. Tail got chopped off due to truncation but the corresponding head is still there
            if len(tail_candidate) < n_heads:
                assert len(tail_candidate) + 1 == n_heads
                assert len(example_input_ids) == config["max_length"]  # should only happen if input text is truncated
                tail_candidate.append(config["max_length"]-2)  # the token before [SEP] token

            # 3. Additional tails remain in the buffer region
            if len(tail_candidate) > len(head_candidate):
                tail_candidate = tail_candidate[:len(head_candidate)]

            # ------------------- Create the fields ------------------------------------#
            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in head_candidate]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in tail_candidate]

            span_head_idxs.append(head_candidate)
            span_tail_idxs.append(tail_candidate)
            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def enforce_alignment(examples):
        uids = []

        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_uids = original_example["uids"]
            char2uid = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_uids)}
            current_example_uids = [char2uid[char_idx] for char_idx in example_span_head_char_start_idxs]
            uids.append(current_example_uids)
        return {"uids": uids}

    def recompute_labels(examples):
        labels = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_labels = original_example["labels"]
            char2label = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_labels)}
            current_example_labels = [char2label[char_idx] for char_idx in example_span_head_char_start_idxs]
            labels.append(current_example_labels)
        return {"labels": labels}

    def recompute_discourse_type_ids(examples):
        discourse_type_ids = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_discourse_type_ids = original_example["discourse_type_ids"]
            char2discourse_id = {k: v for k, v in zip(
                original_example_span_head_char_start_idxs, original_example_discourse_type_ids)}
            current_example_discourse_type_ids = [char2discourse_id[char_idx]
                                                  for char_idx in example_span_head_char_start_idxs]
            discourse_type_ids.append(current_example_discourse_type_ids)
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1, f"head idxs: {head_idxs}, tail idxs {tail_idxs}"

    task_dataset = task_dataset.map(
        tokenize_with_truncation,
        batched=True,
        remove_columns=task_dataset.column_names,
        batch_size=len(task_dataset)
    )

    task_dataset = task_dataset.map(process_span, batched=True)
    task_dataset = task_dataset.map(enforce_alignment, batched=True)
    task_dataset = task_dataset.map(recompute_discourse_type_ids, batched=True)
    task_dataset = task_dataset.map(sanity_check_head_tail, batched=True)

    # no need to run on empty set
    task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) != 0)
    task_dataset = task_dataset.map(compute_input_length, batched=True)

    if mode != "infer":
        task_dataset = task_dataset.map(recompute_labels, batched=True)

    to_return = dict()
    to_return["dataset"] = task_dataset
    to_return["original_dataset"] = original_dataset
    to_return["tokenizer"] = tokenizer
    return to_return

In [None]:
# Reuse for exp7,6,8
os.makedirs(config["model_dir"], exist_ok=True)

print("creating the inference datasets...")
infer_ds_dict = get_fast_dataset(config, test_df, essay_df, mode="infer")
tokenizer = infer_ds_dict["tokenizer"]
infer_dataset = infer_ds_dict["dataset"]
print(infer_dataset)

In [None]:
config["len_tokenizer"] = len(tokenizer)

infer_dataset = infer_dataset.sort("input_length")

infer_dataset.set_format(
    type=None,
    columns=['input_ids', 'attention_mask', 'token_type_ids', 'span_head_idxs',
             'span_tail_idxs', 'discourse_type_ids', 'uids']
)

In [None]:
from copy import deepcopy
from dataclasses import dataclass

import torch
from transformers import DataCollatorWithPadding


@dataclass
class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    """
    data collector for seq classification
    """

    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = 512
    return_tensors = "pt"

    def __call__(self, features):
        uids = [feature["uids"] for feature in features]
        discourse_type_ids = [feature["discourse_type_ids"] for feature in features]
        span_head_idxs = [feature["span_head_idxs"] for feature in features]
        span_tail_idxs = [feature["span_tail_idxs"] for feature in features]
        span_attention_mask = [[1]*len(feature["span_head_idxs"]) for feature in features]

        labels = None
        if "labels" in features[0].keys():
            labels = [feature["labels"] for feature in features]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None,
        )

        b_max = max([len(l) for l in span_head_idxs])
        max_len = len(batch["input_ids"][0])

        default_head_idx = max(max_len - 10, 1)  # for padding
        default_tail_idx = max(max_len - 4, 1)  # for padding

        batch["span_head_idxs"] = [
            ex_span_head_idxs + [default_head_idx] * (b_max - len(ex_span_head_idxs)) for ex_span_head_idxs in span_head_idxs
        ]

        batch["uids"] = [ex_uids + [-1] * (b_max - len(ex_uids)) for ex_uids in uids]
        batch["discourse_type_ids"] = [ex_discourse_type_ids + [0] *
                                       (b_max - len(ex_discourse_type_ids)) for ex_discourse_type_ids in discourse_type_ids]

        batch["span_tail_idxs"] = [
            ex_span_tail_idxs + [default_tail_idx] * (b_max - len(ex_span_tail_idxs)) for ex_span_tail_idxs in span_tail_idxs
        ]

        batch["span_attention_mask"] = [
            ex_discourse_masks + [0] * (b_max - len(ex_discourse_masks)) for ex_discourse_masks in span_attention_mask
        ]

        if labels is not None:
            batch["labels"] = [ex_labels + [-1] * (b_max - len(ex_labels)) for ex_labels in labels]

        # multitask labels
        def _get_additional_labels(label_id):
            if label_id == 0:
                vec = [0, 0]
            elif label_id == 1:
                vec = [1, 0]
            elif label_id == 2:
                vec = [1, 1]
            elif label_id == -1:
                vec = [-1, -1]
            else:
                raise
            return vec

        if labels is not None:
            additional_labels = []
            for ex_labels in batch["labels"]:
                ex_additional_labels = [_get_additional_labels(el) for el in ex_labels]
                additional_labels.append(ex_additional_labels)
            batch["multitask_labels"] = additional_labels
        # pdb.set_trace()

        batch = {k: (torch.tensor(v, dtype=torch.int64) if k != "multitask_labels" else torch.tensor(
            v, dtype=torch.float32)) for k, v in batch.items()}
        return batch

In [None]:
data_collector = CustomDataCollatorWithPadding(tokenizer=tokenizer)

infer_dl = DataLoader(
    infer_dataset,
    batch_size=config["infer_bs"],
    shuffle=False,
    collate_fn=data_collector
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = nn.Dropout(self.config["dropout"])
        
        # multi-head attention over span representations
        attention_config = BertConfig()
        attention_config.update(
            {
                "num_attention_heads": self.base_model.config.num_attention_heads,
                "hidden_size": self.base_model.config.hidden_size,
                "attention_probs_dropout_prob": self.base_model.config.attention_probs_dropout_prob,
                "is_decoder": False,

            }
        )
        self.fpe_span_attention = BertAttention(attention_config, position_embedding_type="relative_key")
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )


        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, token_type_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        bs = input_ids.shape[0]  # batch size

        outputs = self.base_model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        encoder_layer = outputs[0]

        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]  # LSTM layer outputs

        mean_feature_vector = []

        for i in range(bs):
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attention mechanism
        extended_span_attention_mask = span_attention_mask[:, None, None, :]
        # extended_span_attention_mask = extended_span_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_span_attention_mask = (1.0 - extended_span_attention_mask) * -10000.0
        feature_vector = self.fpe_span_attention(mean_feature_vector, extended_span_attention_mask)[0]

        feature_vector = self.dropout(feature_vector) # span-atten
        logits = self.classifier(feature_vector)
        
        ######
        
        logits = logits[:,:, :3] # main logits
        return logits

In [None]:
checkpoints = [
    "../input/exp19-dexl-dataset-part-1/exp-19-dexl-revisit-part-1/fpe_model_fold_0_best.pth.tar",
    "../input/exp19-dexl-dataset-part-1/exp-19-dexl-revisit-part-1/fpe_model_fold_1_best.pth.tar",
    "../input/exp19-dexl-dataset-part-1/exp-19-dexl-revisit-part-1/fpe_model_fold_2_best.pth.tar",
    "../input/exp19-dexl-dataset-part-1/exp-19-dexl-revisit-part-1/fpe_model_fold_3_best.pth.tar",
    "../input/exp19-dexl-revisit-dataset-part-2/exp-19-dexl-revisit-part-2/fpe_model_fold_4_best.pth.tar",
    "../input/exp19-dexl-revisit-dataset-part-2/exp-19-dexl-revisit-part-2/fpe_model_fold_5_best.pth.tar",
    "../input/exp19-dexl-revisit-dataset-part-2/exp-19-dexl-revisit-part-2/fpe_model_fold_6_best.pth.tar",
    "../input/exp19-dexl-revisit-dataset-part-2/exp-19-dexl-revisit-part-2/fpe_model_fold_7_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp19_dexl_model_preds_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackModel(config)
    ckpt = torch.load(checkpoint)
    print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)
    
del model
# del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
gc.collect()
torch.cuda.empty_cache()

In [None]:
import glob
import pandas as pd

csvs = glob.glob("exp19_dexl_model_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp19_df = pd.DataFrame()
exp19_df["discourse_id"] = idx
exp19_df["Ineffective"]  = preds[:, 0]
exp19_df["Adequate"]     = preds[:, 1]
exp19_df["Effective"]    = preds[:, 2]

exp19_df = exp19_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()



In [None]:
exp19_df.head()

### DEXL all data

In [None]:
checkpoints = [
    "../input/exp-19f-dexl-revisit-all-data/fpe_model_all_data_seed_464.pth.tar",
    "../input/exp-19f-dexl-revisit-all-data/fpe_model_all_data_seed_446.pth.tar",
]

def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"19f_dexl_model_preds_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackModel(config)
    ckpt = torch.load(checkpoint)
    print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)
    
del model
gc.collect()
torch.cuda.empty_cache()


import glob
import pandas as pd

csvs = glob.glob("19f_dexl_model_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp19f_df = pd.DataFrame()
exp19f_df["discourse_id"] = idx
exp19f_df["Ineffective"]  = preds[:, 0]
exp19f_df["Adequate"]     = preds[:, 1]
exp19f_df["Effective"]    = preds[:, 2]

exp19f_df = exp19f_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

exp19f_df.head()

### DEL KD All Data

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/tapt-fpe-del-wiki",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 3,
    "dropout": 0.1,
    "infer_bs": 8,
    
    "use_multitask": true,
    "num_additional_labels": 2
}
"""
config = json.loads(config)
config["len_tokenizer"] = len(tokenizer)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention


#-------- Model ------------------------------------------------------------------#

class FeedbackModel(nn.Module):
    """
    The feedback prize effectiveness model for fast approach
    """

    def __init__(self, config):
        print("=="*40)
        print("initializing the feedback model...")

        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # resize model embeddings
        print("resizing model embeddings...")
        print(f"tokenizer length = {config['len_tokenizer']}")
        self.base_model.resize_token_embeddings(config["len_tokenizer"])


        # dropouts
        self.dropout = nn.Dropout(self.config["dropout"])
        self.num_labels = self.num_original_labels = self.config["num_labels"]

        if self.config["use_multitask"]:
            print("using multi-task approach...")
            self.num_labels += self.config["num_additional_labels"]

        # multi-head attention over span representations
        attention_config = BertConfig()
        attention_config.update(
            {
                "num_attention_heads": self.base_model.config.num_attention_heads,
                "hidden_size": self.base_model.config.hidden_size,
                "attention_probs_dropout_prob": self.base_model.config.attention_probs_dropout_prob,
                "is_decoder": False,

            }
        )
        self.fpe_span_attention = BertAttention(attention_config, position_embedding_type="relative_key")

        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)


        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.classifier = nn.Linear(feature_size, self.num_labels)

    def forward(
        self,
        input_ids,
        token_type_ids,
        attention_mask,
        span_head_idxs,
        span_tail_idxs,
        span_attention_mask,
        labels=None,
        multitask_labels=None,
        **kwargs
    ):

        bs = input_ids.shape[0]  # batch size

        outputs = self.base_model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        encoder_layer = outputs[0]


        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]  # LSTM layer outputs

        mean_feature_vector = []

        for i in range(bs):
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attention mechanism
        extended_span_attention_mask = span_attention_mask[:, None, None, :]
        # extended_span_attention_mask = extended_span_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_span_attention_mask = (1.0 - extended_span_attention_mask) * -10000.0
        feature_vector = self.fpe_span_attention(mean_feature_vector, extended_span_attention_mask)[0]

        feature_vector = self.dropout(feature_vector)
        logits = self.classifier(feature_vector)
        
        logits = logits[:,:,:3]

        return logits
    
checkpoints = [
    "../input/exp-20-del-kd-all-data-train/fpe_model_kd_seed_1.pth.tar",
#     "../input/exp-20-del-kd-all-data-train/fpe_model_kd_seed_2.pth.tar",
]

def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp20_del_kd_model_preds_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackModel(config)
    ckpt = torch.load(checkpoint)
    print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)
    
del model
gc.collect()
torch.cuda.empty_cache()

import glob
import pandas as pd

csvs = glob.glob("exp20_del_kd_model_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp20f_df = pd.DataFrame()
exp20f_df["discourse_id"] = idx
exp20f_df["Ineffective"]  = preds[:, 0]
exp20f_df["Adequate"]     = preds[:, 1]
exp20f_df["Effective"]    = preds[:, 2]

exp20f_df = exp20f_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

exp20f_df.head()

# EXP 6 - Model - DEXL

## Config

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/tapt-fpe-dexl",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 5,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

## Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = nn.Dropout(self.config["dropout"])
        
        # multi-head attention over span representations
        attention_config = BertConfig()
        attention_config.update(
            {
                "num_attention_heads": self.base_model.config.num_attention_heads,
                "hidden_size": self.base_model.config.hidden_size,
                "attention_probs_dropout_prob": self.base_model.config.attention_probs_dropout_prob,
                "is_decoder": False,

            }
        )
        self.fpe_span_attention = BertAttention(attention_config, position_embedding_type="relative_key")
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )


        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, token_type_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        bs = input_ids.shape[0]  # batch size

        outputs = self.base_model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        encoder_layer = outputs[0]

        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]  # LSTM layer outputs

        mean_feature_vector = []

        for i in range(bs):
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attention mechanism
        extended_span_attention_mask = span_attention_mask[:, None, None, :]
        # extended_span_attention_mask = extended_span_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_span_attention_mask = (1.0 - extended_span_attention_mask) * -10000.0
        feature_vector = self.fpe_span_attention(mean_feature_vector, extended_span_attention_mask)[0]

        feature_vector = self.dropout(mean_feature_vector)
        logits = self.classifier(feature_vector)
        
        ######
        
        logits = logits[:,:, :3] # main logits
        return logits

## Inference

In [None]:
checkpoints = [
    "../input/01-a-prod-fpe-dexl-4fold-dataset/a-prod-fpe-dexl/fpe_model_fold_0_best.pth.tar",
    "../input/01-a-prod-fpe-dexl-4fold-dataset/a-prod-fpe-dexl/fpe_model_fold_1_best.pth.tar",
    "../input/01-a-prod-fpe-dexl-4fold-dataset/a-prod-fpe-dexl/fpe_model_fold_2_best.pth.tar",
    "../input/01-a-prod-fpe-dexl-4fold-dataset/a-prod-fpe-dexl/fpe_model_fold_3_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp06_dexl_model_preds_{model_id}.csv", index=False)
    

if use_exp6:
    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        ckpt = torch.load(checkpoint)
        print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)
    
    del model
    # del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
if use_exp6:
    import glob
    import pandas as pd

    csvs = glob.glob("exp06_dexl_model_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp06_df = pd.DataFrame()
    exp06_df["discourse_id"] = idx
    exp06_df["Ineffective"]  = preds[:, 0]
    exp06_df["Adequate"]     = preds[:, 1]
    exp06_df["Effective"]    = preds[:, 2]

    exp06_df = exp06_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
if use_exp6:
    exp06_df.head()

# EXP8 - Fast Model - Distilled Deberta Large

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/tapt-fpe-del-wiki",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 5,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

## Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = nn.Dropout(self.config["dropout"])
        
        # multi-head attention over span representations
        attention_config = BertConfig()
        attention_config.update(
            {
                "num_attention_heads": self.base_model.config.num_attention_heads,
                "hidden_size": self.base_model.config.hidden_size,
                "attention_probs_dropout_prob": self.base_model.config.attention_probs_dropout_prob,
                "is_decoder": False,

            }
        )
        self.fpe_span_attention = BertAttention(attention_config, position_embedding_type="relative_key")
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )


        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, token_type_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        bs = input_ids.shape[0]  # batch size

        outputs = self.base_model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        encoder_layer = outputs[0]

        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]  # LSTM layer outputs

        mean_feature_vector = []

        for i in range(bs):
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attention mechanism
        extended_span_attention_mask = span_attention_mask[:, None, None, :]
        # extended_span_attention_mask = extended_span_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_span_attention_mask = (1.0 - extended_span_attention_mask) * -10000.0
        feature_vector = self.fpe_span_attention(mean_feature_vector, extended_span_attention_mask)[0]

        feature_vector = self.dropout(feature_vector) # span-atten
        logits = self.classifier(feature_vector)
        
        ######
        
        logits = logits[:,:, :3] # main logits
        return logits

## Inference

In [None]:
checkpoints = [
    "../input/exp08-del-dataset-part-1/exp-08-del-8folds-kd-part-1/fpe_model_fold_0_best.pth.tar",
    "../input/exp08-del-dataset-part-1/exp-08-del-8folds-kd-part-1/fpe_model_fold_1_best.pth.tar",
    "../input/exp08-del-dataset-part-1/exp-08-del-8folds-kd-part-1/fpe_model_fold_2_best.pth.tar",
    "../input/exp08-del-dataset-part-1/exp-08-del-8folds-kd-part-1/fpe_model_fold_3_best.pth.tar",
    "../input/exp08-del-dataset-part-2/exp-08-del-8folds-kd-part-2/fpe_model_fold_4_best.pth.tar",
    "../input/exp08-del-dataset-part-2/exp-08-del-8folds-kd-part-2/fpe_model_fold_5_best.pth.tar",
    "../input/exp08-del-dataset-part-2/exp-08-del-8folds-kd-part-2/fpe_model_fold_6_best.pth.tar",
    "../input/exp08-del-dataset-part-2/exp-08-del-8folds-kd-part-2/fpe_model_fold_7_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp08_del_kd_model_preds_{model_id}.csv", index=False)

if use_exp8:

    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        ckpt = torch.load(checkpoint)
        print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

    del model
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
if use_exp8:

    import glob
    import pandas as pd

    csvs = glob.glob("exp08_del_kd_model_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp_08_df = pd.DataFrame()
    exp_08_df["discourse_id"] = idx
    exp_08_df["Ineffective"]  = preds[:, 0]
    exp_08_df["Adequate"]     = preds[:, 1]
    exp_08_df["Effective"]    = preds[:, 2]

    exp_08_df = exp_08_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
if use_exp8:
    display(exp_08_df.head())

# EXP212 - Fast Model - Longformer

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/exp212-longformer-l-prompt-mlm50/mlm_model",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 3,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

## Dataset

In [None]:
import os
import re
from copy import deepcopy
from itertools import chain

import pandas as pd
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer


#--------------- Tokenizer ---------------------------------------------#
def get_tokenizer(config):
    """load the tokenizer"""

    print("using auto tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])

    print("=="*40)
    print(f"tokenizer len: {len(tokenizer)}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [SOE] [LEAD] [CLAIM] [POSITION] [COUNTER_CLAIM]')}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [EOE] [LEAD_END] [POSITION_END] [CLAIM_END]')}")

    print("=="*40)
    return tokenizer


#--------------- Processing ---------------------------------------------#


DISCOURSE_START_TOKENS = [
    "[LEAD]",
    "[POSITION]",
    "[CLAIM]",
    "[COUNTER_CLAIM]",
    "[REBUTTAL]",
    "[EVIDENCE]",
    "[CONCLUDING_STATEMENT]"
]

TOKEN_MAP = {
    "topic": ["Topic [TOPIC]", "[TOPIC END]"],
    "Lead": ["Lead [LEAD]", "[LEAD END]"],
    "Position": ["Position [POSITION]", "[POSITION END]"],
    "Claim": ["Claim [CLAIM]", "[CLAIM END]"],
    "Counterclaim": ["Counterclaim [COUNTER_CLAIM]", "[COUNTER_CLAIM END]"],
    "Rebuttal": ["Rebuttal [REBUTTAL]", "[REBUTTAL END]"],
    "Evidence": ["Evidence [EVIDENCE]", "[EVIDENCE END]"],
    "Concluding Statement": ["Concluding Statement [CONCLUDING_STATEMENT]", "[CONCLUDING_STATEMENT END]"]
}


DISCOURSE_END_TOKENS = [
    "[LEAD END]",
    "[POSITION END]",
    "[CLAIM END]",
    "[COUNTER_CLAIM END]",
    "[REBUTTAL END]",
    "[EVIDENCE END]",
    "[CONCLUDING_STATEMENT END]",
]






def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    for cur_discourse in discourse_list:
        if cur_discourse not in to_return:
            to_return[cur_discourse] = []

        matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
        for match in matches:
            span_start, span_end = match.span()
            if span_end <= reading_head:
                continue
            to_return[cur_discourse].append(match.span())
            reading_head = span_end
            break

    # post process
    for cur_discourse in discourse_list:
        if not to_return[cur_discourse]:
            print("resorting to relaxed search...")
            to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_essay(essay_id, essay_text, prompt, anno_df):
    """insert newly added tokens in the essay text
    """
    tmp_df = anno_df[anno_df["essay_id"] == essay_id].copy()
    tmp_df = tmp_df.sort_values(by="discourse_start")
    buffer = 0

    for _, row in tmp_df.iterrows():
        s, e, d_type = int(row.discourse_start) + buffer, int(row.discourse_end) + buffer, row.discourse_type
        s_tok, e_tok = TOKEN_MAP[d_type]
        essay_text = " ".join([essay_text[:s], s_tok, essay_text[s:e], e_tok, essay_text[e:]])
        buffer += len(s_tok) + len(e_tok) + 4

    essay_text = "[SOE]" + " [TOPIC] " + prompt + " [TOPIC END] " +  essay_text + "[EOE]"
    return essay_text


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness", "uid"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()

    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")
    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    # anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    print("=="*40)
    print("processing essay text and inserting new tokens at span boundaries")
    notes_df["essay_text"] = notes_df[["essay_id", "essay_text", "prompt"]].apply(
        lambda x: process_essay(x[0], x[1], x[2], anno_df), axis=1
    )
    print("=="*40)

    anno_df = anno_df.drop(columns=["discourse_start", "discourse_end"])
    notes_df = notes_df.drop_duplicates(subset=["essay_id"])[["essay_id", "essay_text"]].copy()

    anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    if "discourse_effectiveness" in anno_df.columns:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_effectiveness", "discourse_type"]].agg(list).reset_index()
    else:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_type"]].agg(list).reset_index()

    grouped_df = pd.merge(grouped_df, notes_df, on="essay_id", how="left")
    grouped_df = grouped_df.rename(columns={"uid": "uids"})

    return grouped_df


#--------------- Dataset ----------------------------------------------#


class AuxFeedbackDataset:
    """Dataset class for feedback prize effectiveness task
    """

    def __init__(self, config):
        self.config = config

        self.label2id = {
            "Ineffective": 0,
            "Adequate": 1,
            "Effective": 2,
        }

        self.discourse_type2id = {
            "Lead": 1,
            "Position": 2,
            "Claim": 3,
            "Counterclaim": 4,
            "Rebuttal": 5,
            "Evidence": 6,
            "Concluding Statement": 7,
        }

        self.id2label = {v: k for k, v in self.label2id.items()}
        self.load_tokenizer()

    def load_tokenizer(self):
        """load tokenizer as per config 
        """
        self.tokenizer = get_tokenizer(self.config)
        print("=="*40)
        print("token maps...")
        print(TOKEN_MAP)
        print("=="*40)

        # print("adding new tokens...")
        # tokens_to_add = []
        # for this_tok in NEW_TOKENS:
        #     tokens_to_add.append(AddedToken(this_tok, lstrip=True, rstrip=False))
        # self.tokenizer.add_tokens(tokens_to_add)
        print(f"tokenizer len: {len(self.tokenizer)}")

        self.discourse_token_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_START_TOKENS))
        self.discourse_end_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_END_TOKENS))
        self.global_tokens = self.discourse_token_ids.union(self.discourse_end_ids)

    def tokenize_function(self, examples):
        tz = self.tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=False,  # no truncation at first
            add_special_tokens=True,
            return_offsets_mapping=True,
        )
        return tz

    def process_spans(self, examples):

        span_head_char_start_idxs, span_tail_char_end_idxs = [], []
        span_head_idxs, span_tail_idxs = [], []

        for example_input_ids, example_offset_mapping, example_uids in zip(examples["input_ids"], examples["offset_mapping"], examples["uids"]):
            example_span_head_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_token_ids]
            example_span_tail_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_end_ids]

            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in example_span_head_idxs]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in example_span_tail_idxs]

            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

            span_head_idxs.append(example_span_head_idxs)
            span_tail_idxs.append(example_span_tail_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def generate_labels(self, examples):
        labels = []
        for example_labels, example_uids in zip(examples["discourse_effectiveness"], examples["uids"]):
            labels.append([self.label2id[l] for l in example_labels])
        return {"labels": labels}

    def generate_discourse_type_ids(self, examples):
        discourse_type_ids = []
        for example_discourse_types in examples["discourse_type"]:
            discourse_type_ids.append([self.discourse_type2id[dt] for dt in example_discourse_types])
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(self, examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(self, examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1

    def sanity_check_head_labels(self, examples):
        for head_idxs, head_labels in zip(examples["span_head_idxs"], examples["labels"]):
            assert len(head_idxs) == len(head_labels)

    def get_dataset(self, df, essay_df, mode='train'):
        """main api for creating the Feedback dataset

        :param df: input annotation dataframe
        :type df: pd.DataFrame
        :param essay_df: dataframe with essay texts
        :type essay_df: pd.DataFrame
        :param mode: check if required for train or infer, defaults to 'train'
        :type mode: str, optional
        :return: the created dataset
        :rtype: Dataset
        """
        df = process_input_df(df, essay_df)

        # save a sample for sanity checks
        sample_df = df.sample(min(16, len(df)))
        sample_df.to_csv(os.path.join(self.config["model_dir"], f"{mode}_df_processed.csv"), index=False)

        task_dataset = Dataset.from_pandas(df)
        task_dataset = task_dataset.map(self.tokenize_function, batched=True)
        task_dataset = task_dataset.map(self.compute_input_length, batched=True)
        task_dataset = task_dataset.map(self.process_spans, batched=True)
        print(task_dataset)
        # todo check edge cases
        task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) == len(
            example['span_tail_idxs']))  # no need to run on empty set
        print(task_dataset)
        task_dataset = task_dataset.map(self.generate_discourse_type_ids, batched=True)
        task_dataset = task_dataset.map(self.sanity_check_head_tail, batched=True)

        if mode != "infer":
            task_dataset = task_dataset.map(self.generate_labels, batched=True)
            task_dataset = task_dataset.map(self.sanity_check_head_labels, batched=True)

        try:
            task_dataset = task_dataset.remove_columns(column_names=["__index_level_0__"])
        except Exception as e:
            pass
        return df, task_dataset

#--------------- dataset with truncation ---------------------------------------------#


def get_fast_dataset(config, df, essay_df, mode="train"):
    """Function to get fast approach dataset with truncation & sliding window
    """
    dataset_creator = AuxFeedbackDataset(config)
    _, task_dataset = dataset_creator.get_dataset(df, essay_df, mode=mode)

    original_dataset = deepcopy(task_dataset)
    tokenizer = dataset_creator.tokenizer
    START_IDS = dataset_creator.discourse_token_ids
    END_IDS = dataset_creator.discourse_end_ids
    GLOBAL_IDS = dataset_creator.global_tokens

    def tokenize_with_truncation(examples):
        tz = tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=True,
            add_special_tokens=True,
            return_offsets_mapping=True,
            max_length=config["max_length"],
            stride=config["stride"],
            return_overflowing_tokens=True,
            return_token_type_ids=True,
        )
        return tz

    def process_span(examples):
        span_head_idxs, span_tail_idxs = [], []
        span_head_char_start_idxs, span_tail_char_end_idxs = [], []

        buffer = 25  # do not include a head if it is within buffer distance away from last token

        for example_input_ids, example_offset_mapping in zip(examples["input_ids"], examples["offset_mapping"]):
            # ------------------- Span Heads -----------------------------------------#
            if len(example_input_ids) < config["max_length"]:  # no truncation
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in START_IDS]
            else:
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if (
                    (this_id in START_IDS) & (pos <= config["max_length"]-buffer))]

            n_heads = len(head_candidate)

            # ------------------- Span Tails -----------------------------------------#
            tail_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in END_IDS]

            # ------------------- Edge Cases -----------------------------------------#
            # 1. A tail occurs before the first head in the sequence due to truncation
            if (len(tail_candidate) > 0) & (len(head_candidate) > 0):
                if tail_candidate[0] < head_candidate[0]:  # truncation effect
                    # print(f"check: heads: {head_candidate}, tails {tail_candidate}")
                    tail_candidate = tail_candidate[1:]  # shift by one

            # 2. Tail got chopped off due to truncation but the corresponding head is still there
            if len(tail_candidate) < n_heads:
                assert len(tail_candidate) + 1 == n_heads
                assert len(example_input_ids) == config["max_length"]  # should only happen if input text is truncated
                tail_candidate.append(config["max_length"]-2)  # the token before [SEP] token

            # 3. Additional tails remain in the buffer region
            if len(tail_candidate) > len(head_candidate):
                tail_candidate = tail_candidate[:len(head_candidate)]

            # ------------------- Create the fields ------------------------------------#
            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in head_candidate]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in tail_candidate]

            span_head_idxs.append(head_candidate)
            span_tail_idxs.append(tail_candidate)
            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def get_global_attention_mask(examples):
        global_attention_mask = []
        for example_input_ids in examples["input_ids"]:
            global_attention_mask.append([1 if iid in GLOBAL_IDS else 0 for iid in example_input_ids])
        return {"global_attention_mask": global_attention_mask}

    def enforce_alignment(examples):
        uids = []

        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_uids = original_example["uids"]
            char2uid = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_uids)}
            current_example_uids = [char2uid[char_idx] for char_idx in example_span_head_char_start_idxs]
            uids.append(current_example_uids)
        return {"uids": uids}

    def recompute_labels(examples):
        labels = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_labels = original_example["labels"]
            char2label = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_labels)}
            current_example_labels = [char2label[char_idx] for char_idx in example_span_head_char_start_idxs]
            labels.append(current_example_labels)
        return {"labels": labels}

    def recompute_discourse_type_ids(examples):
        discourse_type_ids = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_discourse_type_ids = original_example["discourse_type_ids"]
            char2discourse_id = {k: v for k, v in zip(
                original_example_span_head_char_start_idxs, original_example_discourse_type_ids)}
            current_example_discourse_type_ids = [char2discourse_id[char_idx]
                                                  for char_idx in example_span_head_char_start_idxs]
            discourse_type_ids.append(current_example_discourse_type_ids)
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1, f"head idxs: {head_idxs}, tail idxs {tail_idxs}"

    task_dataset = task_dataset.map(
        tokenize_with_truncation,
        batched=True,
        remove_columns=task_dataset.column_names,
        batch_size=len(task_dataset)
    )

    task_dataset = task_dataset.map(process_span, batched=True)
    task_dataset = task_dataset.map(enforce_alignment, batched=True)
    task_dataset = task_dataset.map(recompute_discourse_type_ids, batched=True)
    task_dataset = task_dataset.map(get_global_attention_mask, batched=True)

    task_dataset = task_dataset.map(sanity_check_head_tail, batched=True)

    # no need to run on empty set
    task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) != 0)
    task_dataset = task_dataset.map(compute_input_length, batched=True)

    if mode != "infer":
        task_dataset = task_dataset.map(recompute_labels, batched=True)

    to_return = dict()
    to_return["dataset"] = task_dataset
    to_return["original_dataset"] = original_dataset
    to_return["tokenizer"] = tokenizer
    return to_return

In [None]:
if use_exp212: 
    os.makedirs(config["model_dir"], exist_ok=True)

    print("creating the inference datasets...")
    infer_ds_dict = get_fast_dataset(config, test_df, essay_df, mode="infer")
    tokenizer = infer_ds_dict["tokenizer"]
    infer_dataset = infer_ds_dict["dataset"]
    print(infer_dataset)

In [None]:
if use_exp212: 

    config["len_tokenizer"] = len(tokenizer)

    infer_dataset = infer_dataset.sort("input_length")

    infer_dataset.set_format(
        type=None,
        columns=['input_ids', 'attention_mask', 'token_type_ids', 'span_head_idxs', 'global_attention_mask',
                 'span_tail_idxs', 'discourse_type_ids', 'uids']
    )

## Data Loader

In [None]:
from dataclasses import dataclass

import torch
from transformers import DataCollatorWithPadding


@dataclass
class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    """
    data collector for seq classification
    """

    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = 512
    return_tensors = "pt"

    def __call__(self, features):
        uids = [feature["uids"] for feature in features]
        discourse_type_ids = [feature["discourse_type_ids"] for feature in features]
        span_head_idxs = [feature["span_head_idxs"] for feature in features]
        span_tail_idxs = [feature["span_tail_idxs"] for feature in features]
        span_attention_mask = [[1]*len(feature["span_head_idxs"]) for feature in features]
        global_attention_mask = [feature["global_attention_mask"] for feature in features]

        labels = None
        if "labels" in features[0].keys():
            labels = [feature["labels"] for feature in features]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None,
        )

        b_max = max([len(l) for l in span_head_idxs])
        max_len = len(batch["input_ids"][0])

        default_head_idx = max(max_len - 10, 1)  # for padding
        default_tail_idx = max(max_len - 4, 1)  # for padding

        batch["span_head_idxs"] = [
            ex_span_head_idxs + [default_head_idx] * (b_max - len(ex_span_head_idxs)) for ex_span_head_idxs in span_head_idxs
        ]

        batch["uids"] = [ex_uids + [-1] * (b_max - len(ex_uids)) for ex_uids in uids]
        batch["discourse_type_ids"] = [ex_discourse_type_ids + [0] *
                                       (b_max - len(ex_discourse_type_ids)) for ex_discourse_type_ids in discourse_type_ids]

        batch["span_tail_idxs"] = [
            ex_span_tail_idxs + [default_tail_idx] * (b_max - len(ex_span_tail_idxs)) for ex_span_tail_idxs in span_tail_idxs
        ]

        batch["span_attention_mask"] = [
            ex_discourse_masks + [0] * (b_max - len(ex_discourse_masks)) for ex_discourse_masks in span_attention_mask
        ]

        batch["global_attention_mask"] = [
            ex_global_attention_mask + [0] * (max_len - len(ex_global_attention_mask)) for ex_global_attention_mask in global_attention_mask
        ]

        if labels is not None:
            batch["labels"] = [ex_labels + [-1] * (b_max - len(ex_labels)) for ex_labels in labels]

        # multitask labels
        def _get_additional_labels(label_id):
            if label_id == 0:
                vec = [0, 0]
            elif label_id == 1:
                vec = [1, 0]
            elif label_id == 2:
                vec = [1, 1]
            elif label_id == -1:
                vec = [-1, -1]
            else:
                raise
            return vec

        if labels is not None:
            additional_labels = []
            for ex_labels in batch["labels"]:
                ex_additional_labels = [_get_additional_labels(el) for el in ex_labels]
                additional_labels.append(ex_additional_labels)
            batch["multitask_labels"] = additional_labels
        # pdb.set_trace()

        batch = {k: (torch.tensor(v, dtype=torch.int64) if k != "multitask_labels" else torch.tensor(
            v, dtype=torch.float32)) for k, v in batch.items()}
        return batch

In [None]:
if use_exp212: 
    data_collector = CustomDataCollatorWithPadding(tokenizer=tokenizer)

    infer_dl = DataLoader(
        infer_dataset,
        batch_size=config["infer_bs"],
        shuffle=False,
        collate_fn=data_collector
    )

## Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention


class FeedbackModel(nn.Module):
    """
    The feedback prize effectiveness model for fast approach
    """

    def __init__(self, config):
        print("=="*40)
        print("initializing the feedback model...")

        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = nn.Dropout(self.config["dropout"])

        self.num_labels = self.config["num_labels"]

        # multi-head attention over span representations
        attention_config = BertConfig()
        attention_config.update(
            {
                "num_attention_heads": self.base_model.config.num_attention_heads,
                "hidden_size": self.base_model.config.hidden_size,
                "attention_probs_dropout_prob": self.base_model.config.attention_probs_dropout_prob,
                "is_decoder": False,

            }
        )
        self.fpe_span_attention = BertAttention(attention_config, position_embedding_type="relative_key")

        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)

        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.classifier = nn.Linear(feature_size, self.num_labels)


    def forward(
        self,
        input_ids,
        token_type_ids,
        attention_mask,
        span_head_idxs,
        span_tail_idxs,
        span_attention_mask,
        global_attention_mask,
        labels=None,
        multitask_labels=None,
        **kwargs
    ):

        bs = input_ids.shape[0]  # batch size

        outputs = self.base_model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            global_attention_mask=global_attention_mask,
        )
        encoder_layer = outputs[0]

        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]  # LSTM layer outputs

        mean_feature_vector = []

        for i in range(bs):
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attention mechanism
        extended_span_attention_mask = span_attention_mask[:, None, None, :]
        # extended_span_attention_mask = extended_span_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_span_attention_mask = (1.0 - extended_span_attention_mask) * -10000.0
        feature_vector = self.fpe_span_attention(mean_feature_vector, extended_span_attention_mask)[0]

        feature_vector = self.dropout(feature_vector)
        logits = self.classifier(feature_vector)
        
        return logits

## Inference

In [None]:
checkpoints = [
    "../input/exp212-longformer-l-prompt-mlm50/fpe_model_fold_0_best.pth.tar",
    "../input/exp212-longformer-l-prompt-mlm50/fpe_model_fold_1_best.pth.tar",
    "../input/exp212-longformer-l-prompt-mlm50/fpe_model_fold_2_best.pth.tar",
    "../input/exp212-longformer-l-prompt-mlm50/fpe_model_fold_3_best.pth.tar",
    "../input/exp212-longformer-l-prompt-mlm50/fpe_model_fold_4_best.pth.tar",
    "../input/exp212-longformer-l-prompt-mlm50/fpe_model_fold_5_best.pth.tar",
    "../input/exp212-longformer-l-prompt-mlm50/fpe_model_fold_6_best.pth.tar",
    "../input/exp212-longformer-l-prompt-mlm50/fpe_model_fold_7_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp212_longformer_model_preds_{model_id}.csv", index=False)
    
if use_exp212: 
    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        ckpt = torch.load(checkpoint)
        print(f"model performance on validation set = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

    del model
    del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
if use_exp212: 
    import glob
    import pandas as pd

    csvs = glob.glob("exp212_longformer_model_preds_*.csv")

    idx = []
    preds = []

    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp212_df = pd.DataFrame()
    exp212_df["discourse_id"]  = idx
    exp212_df["Ineffective"]   = preds[:, 0]
    exp212_df["Adequate"]      = preds[:, 1]
    exp212_df["Effective"]     = preds[:, 2]

    exp212_df = exp212_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
if use_exp212: 
    display(exp212_df.head())

# EXP3 & EXP4 - delv3 Fast Model - SPAN MLM 20%

Model trained with span mlm 20% backbone

## Config

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/fpe-tapt-delv3-span-mlm-pos-1024",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 5,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

## Dataset

In [None]:
import os
import re
from copy import deepcopy
from itertools import chain

import pandas as pd
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer


#--------------- Tokenizer ----------------------------------------------#

def get_tokenizer(config):
    """load the tokenizer"""

    print("using auto tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])

    print("=="*40)
    print(f"tokenizer len: {len(tokenizer)}")
    print(f"tokenizer test: {tokenizer.tokenize('This [B-SPAN] is useful [E-SPAN]')}")
    print("=="*40)
    return tokenizer


#--------------- Processing ---------------------------------------------#

TOKEN_MAP = {
    "Lead": ["Lead [LEAD]", "[END]"],
    "Position": ["Position [POSITION]", "[END]"],
    "Claim": ["Claim [CLAIM]", "[END]"],
    "Counterclaim": ["Counterclaim [COUNTER_CLAIM]", "[END]"],
    "Rebuttal": ["Rebuttal [REBUTTAL]", "[END]"],
    "Evidence": ["Evidence [EVIDENCE]", "[END]"],
    "Concluding Statement": ["Concluding Statement [CONCLUDING_STATEMENT]", "[END]"]
}


DISCOURSE_START_TOKENS = [
    "[LEAD]",
    "[POSITION]",
    "[CLAIM]",
    "[COUNTER_CLAIM]",
    "[REBUTTAL]",
    "[EVIDENCE]",
    "[CONCLUDING_STATEMENT]"
]

DISCOURSE_END_TOKENS = [
    "[END]",
]

#--------------- Span Detection ---------------------------------------------#

def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    for cur_discourse in discourse_list:
        if cur_discourse not in to_return:
            to_return[cur_discourse] = []

        matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
        for match in matches:
            span_start, span_end = match.span()
            if span_end <= reading_head:
                continue
            to_return[cur_discourse].append(match.span())
            reading_head = span_end
            break

    # post process
    for cur_discourse in discourse_list:
        if not to_return[cur_discourse]:
            print("resorting to relaxed search...")
            to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_essay(essay_id, essay_text, anno_df):
    """insert newly added tokens in the essay text
    """
    tmp_df = anno_df[anno_df["essay_id"] == essay_id].copy()
    tmp_df = tmp_df.sort_values(by="discourse_start")
    buffer = 0

    for _, row in tmp_df.iterrows():
        s, e, d_type = int(row.discourse_start) + buffer, int(row.discourse_end) + buffer, row.discourse_type
        s_tok, e_tok = TOKEN_MAP[d_type]
        essay_text = " ".join([essay_text[:s], s_tok, essay_text[s:e], e_tok, essay_text[e:]])
        buffer += len(s_tok) + len(e_tok) + 4

    essay_text = "[SOE]" + essay_text + "[EOE]"
    return essay_text


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness", "uid"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()


    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")
    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    # anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    print("=="*40)
    print("processing essay text and inserting new tokens at span boundaries")
    notes_df["essay_text"] = notes_df[["essay_id", "essay_text"]].apply(
        lambda x: process_essay(x[0], x[1], anno_df), axis=1
    )
    print("=="*40)

    anno_df = anno_df.drop(columns=["discourse_start", "discourse_end"])
    notes_df = notes_df.drop_duplicates(subset=["essay_id"])[["essay_id", "essay_text"]].copy()

    anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    if "discourse_effectiveness" in anno_df.columns:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_effectiveness", "discourse_type"]].agg(list).reset_index()
    else:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_type"]].agg(list).reset_index()

    grouped_df = pd.merge(grouped_df, notes_df, on="essay_id", how="left")
    grouped_df = grouped_df.rename(columns={"uid": "uids"})

    return grouped_df


#--------------- Dataset ----------------------------------------------#


class AuxFeedbackDataset:
    """Dataset class for feedback prize effectiveness task
    """

    def __init__(self, config):
        self.config = config

        self.label2id = {
            "Ineffective": 0,
            "Adequate": 1,
            "Effective": 2,
        }

        self.discourse_type2id = {
            "Lead": 1,
            "Position": 2,
            "Claim": 3,
            "Counterclaim": 4,
            "Rebuttal": 5,
            "Evidence": 6,
            "Concluding Statement": 7,
        }

        self.id2label = {v: k for k, v in self.label2id.items()}
        self.load_tokenizer()

    def load_tokenizer(self):
        """load tokenizer as per config 
        """
        self.tokenizer = get_tokenizer(self.config)
        self.discourse_token_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_START_TOKENS))
        self.discourse_end_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_END_TOKENS))

    def tokenize_function(self, examples):
        tz = self.tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=False,  # no truncation at first
            add_special_tokens=True,
            return_offsets_mapping=True,
        )
        return tz

    def process_spans(self, examples):

        span_head_char_start_idxs, span_tail_char_end_idxs = [], []
        span_head_idxs, span_tail_idxs = [], []

        for example_input_ids, example_offset_mapping, example_uids in zip(examples["input_ids"], examples["offset_mapping"], examples["uids"]):
            example_span_head_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_token_ids]
            example_span_tail_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_end_ids]

            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in example_span_head_idxs]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in example_span_tail_idxs]

            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

            span_head_idxs.append(example_span_head_idxs)
            span_tail_idxs.append(example_span_tail_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def generate_labels(self, examples):
        labels = []
        for example_labels, example_uids in zip(examples["discourse_effectiveness"], examples["uids"]):
            labels.append([self.label2id[l] for l in example_labels])
        return {"labels": labels}

    def generate_discourse_type_ids(self, examples):
        discourse_type_ids = []
        for example_discourse_types in examples["discourse_type"]:
            discourse_type_ids.append([self.discourse_type2id[dt] for dt in example_discourse_types])
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(self, examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(self, examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1

    def sanity_check_head_labels(self, examples):
        for head_idxs, head_labels in zip(examples["span_head_idxs"], examples["labels"]):
            assert len(head_idxs) == len(head_labels)

    def get_dataset(self, df, essay_df, mode='train'):
        """main api for creating the Feedback dataset

        :param df: input annotation dataframe
        :type df: pd.DataFrame
        :param essay_df: dataframe with essay texts
        :type essay_df: pd.DataFrame
        :param mode: check if required for train or infer, defaults to 'train'
        :type mode: str, optional
        :return: the created dataset
        :rtype: Dataset
        """
        df = process_input_df(df, essay_df)

        # save a sample for sanity checks
        sample_df = df.sample(min(16, len(df)))
        sample_df.to_csv(os.path.join(self.config["model_dir"], f"{mode}_df_processed.csv"), index=False)

        task_dataset = Dataset.from_pandas(df)
        task_dataset = task_dataset.map(self.tokenize_function, batched=True)
        task_dataset = task_dataset.map(self.compute_input_length, batched=True)
        task_dataset = task_dataset.map(self.process_spans, batched=True)
#         print(task_dataset)
#         # todo check edge cases
#         task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) == len(
#             example['span_tail_idxs']))  # no need to run on empty set
#         print(task_dataset)
        task_dataset = task_dataset.map(self.generate_discourse_type_ids, batched=True)
#         task_dataset = task_dataset.map(self.sanity_check_head_tail, batched=True)

        if mode != "infer":
            task_dataset = task_dataset.map(self.generate_labels, batched=True)
            task_dataset = task_dataset.map(self.sanity_check_head_labels, batched=True)

        try:
            task_dataset = task_dataset.remove_columns(column_names=["__index_level_0__"])
        except Exception as e:
            pass
        return df, task_dataset

#--------------- dataset with truncation ---------------------------------------------#


def get_fast_dataset(config, df, essay_df, mode="train"):
    """Function to get fast approach dataset with truncation & sliding window
    """
    dataset_creator = AuxFeedbackDataset(config)
    _, task_dataset = dataset_creator.get_dataset(df, essay_df, mode=mode)

    original_dataset = deepcopy(task_dataset)
    tokenizer = dataset_creator.tokenizer
    START_IDS = dataset_creator.discourse_token_ids
    END_IDS = dataset_creator.discourse_end_ids

    def tokenize_with_truncation(examples):
        tz = tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=True,
            add_special_tokens=True,
            return_offsets_mapping=True,
            max_length=config["max_length"],
            stride=config["stride"],
            return_overflowing_tokens=True,
        )
        return tz

    def process_span(examples):
        span_head_idxs, span_tail_idxs = [], []
        span_head_char_start_idxs, span_tail_char_end_idxs = [], []

        buffer = 25  # do not include a head if it is within buffer distance away from last token

        for example_input_ids, example_offset_mapping in zip(examples["input_ids"], examples["offset_mapping"]):
            # ------------------- Span Heads -----------------------------------------#
            if len(example_input_ids) < config["max_length"]:  # no truncation
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in START_IDS]
            else:
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if (
                    (this_id in START_IDS) & (pos <= config["max_length"]-buffer))]

            n_heads = len(head_candidate)

            # ------------------- Span Tails -----------------------------------------#
            tail_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in END_IDS]

            # ------------------- Edge Cases -----------------------------------------#
            # 1. A tail occurs before the first head in the sequence due to truncation
            if (len(tail_candidate) > 0) & (len(head_candidate) > 0):
                if tail_candidate[0] < head_candidate[0]:  # truncation effect
                    # print(f"check: heads: {head_candidate}, tails {tail_candidate}")
                    tail_candidate = tail_candidate[1:]  # shift by one

            # 2. Tail got chopped off due to truncation but the corresponding head is still there
            if len(tail_candidate) < n_heads:
                assert len(tail_candidate) + 1 == n_heads
                assert len(example_input_ids) == config["max_length"]  # should only happen if input text is truncated
                tail_candidate.append(config["max_length"]-2)  # the token before [SEP] token

            # 3. Additional tails remain in the buffer region
            if len(tail_candidate) > len(head_candidate):
                tail_candidate = tail_candidate[:len(head_candidate)]

            # ------------------- Create the fields ------------------------------------#
            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in head_candidate]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in tail_candidate]

            span_head_idxs.append(head_candidate)
            span_tail_idxs.append(tail_candidate)
            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def enforce_alignment(examples):
        uids = []

        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_uids = original_example["uids"]
            char2uid = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_uids)}
            current_example_uids = [char2uid[char_idx] for char_idx in example_span_head_char_start_idxs]
            uids.append(current_example_uids)
        return {"uids": uids}

    def recompute_labels(examples):
        labels = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_labels = original_example["labels"]
            char2label = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_labels)}
            current_example_labels = [char2label[char_idx] for char_idx in example_span_head_char_start_idxs]
            labels.append(current_example_labels)
        return {"labels": labels}

    def recompute_discourse_type_ids(examples):
        discourse_type_ids = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_discourse_type_ids = original_example["discourse_type_ids"]
            char2discourse_id = {k: v for k, v in zip(
                original_example_span_head_char_start_idxs, original_example_discourse_type_ids)}
            current_example_discourse_type_ids = [char2discourse_id[char_idx]
                                                  for char_idx in example_span_head_char_start_idxs]
            discourse_type_ids.append(current_example_discourse_type_ids)
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1, f"head idxs: {head_idxs}, tail idxs {tail_idxs}"

    task_dataset = task_dataset.map(
        tokenize_with_truncation,
        batched=True,
        remove_columns=task_dataset.column_names,
        batch_size=len(task_dataset)
    )

    task_dataset = task_dataset.map(process_span, batched=True)
    task_dataset = task_dataset.map(enforce_alignment, batched=True)
    task_dataset = task_dataset.map(recompute_discourse_type_ids, batched=True)
    # task_dataset = task_dataset.map(sanity_check_head_tail, batched=True)

    # no need to run on empty set
    task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) != 0)
    task_dataset = task_dataset.map(compute_input_length, batched=True)

    to_return = dict()
    to_return["dataset"] = task_dataset
    to_return["original_dataset"] = original_dataset
    to_return["tokenizer"] = tokenizer
    return to_return

In [None]:
if use_exp3 or use_exp4:
    os.makedirs(config["model_dir"], exist_ok=True)

    print("creating the inference datasets...")
    infer_ds_dict = get_fast_dataset(config, test_df, essay_df, mode="infer")
    tokenizer = infer_ds_dict["tokenizer"]
    infer_dataset = infer_ds_dict["dataset"]
    print(infer_dataset)

In [None]:
if use_exp3 or use_exp4:
    config["len_tokenizer"] = len(tokenizer)

    infer_dataset = infer_dataset.sort("input_length")

    infer_dataset.set_format(
        type=None,
        columns=['input_ids', 'attention_mask', 'token_type_ids', 'span_head_idxs',
                 'span_tail_idxs', 'discourse_type_ids', 'uids']
    )

## DataLoader

In [None]:
from dataclasses import dataclass

import torch
from transformers import DataCollatorWithPadding


@dataclass
class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    """
    data collector for seq classification
    """

    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = None
    return_tensors = "pt"

    def __call__(self, features):
        uids = [feature["uids"] for feature in features]
        discourse_type_ids = [feature["discourse_type_ids"] for feature in features]
        span_head_idxs = [feature["span_head_idxs"] for feature in features]
        span_tail_idxs = [feature["span_tail_idxs"] for feature in features]
        span_attention_mask = [[1]*len(feature["span_head_idxs"]) for feature in features]

        labels = None
        if "labels" in features[0].keys():
            labels = [feature["labels"] for feature in features]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None,
        )

        b_max = max([len(l) for l in span_head_idxs])
        max_len = len(batch["input_ids"][0])

        default_head_idx = max(max_len - 10, 1)  # for padding
        default_tail_idx = max(max_len - 4, 1)  # for padding

        batch["span_head_idxs"] = [
            ex_span_head_idxs + [default_head_idx] * (b_max - len(ex_span_head_idxs)) for ex_span_head_idxs in span_head_idxs
        ]

        batch["uids"] = [ex_uids + [-1] * (b_max - len(ex_uids)) for ex_uids in uids]
        batch["discourse_type_ids"] = [ex_discourse_type_ids + [0] *
                                       (b_max - len(ex_discourse_type_ids)) for ex_discourse_type_ids in discourse_type_ids]

        batch["span_tail_idxs"] = [
            ex_span_tail_idxs + [default_tail_idx] * (b_max - len(ex_span_tail_idxs)) for ex_span_tail_idxs in span_tail_idxs
        ]

        batch["span_attention_mask"] = [
            ex_discourse_masks + [0] * (b_max - len(ex_discourse_masks)) for ex_discourse_masks in span_attention_mask
        ]

        if labels is not None:
            batch["labels"] = [ex_labels + [-1] * (b_max - len(ex_labels)) for ex_labels in labels]

        def _get_additional_labels(label_id):
            if label_id == 0:
                vec = [0, 0]
            elif label_id == 1:
                vec = [1, 0]
            elif label_id == 2:
                vec = [1, 1]
            elif label_id == -1:
                vec = [-1, -1]
            else:
                raise
            return vec

        if labels is not None:
            additional_labels = []
            for ex_labels in batch["labels"]:
                ex_additional_labels = [_get_additional_labels(el) for el in ex_labels]
                additional_labels.append(ex_additional_labels)
            batch["multitask_labels"] = additional_labels

        # batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
        batch = {k: (torch.tensor(v, dtype=torch.int64) if k != "multitask_labels" else torch.tensor(
            v, dtype=torch.float32)) for k, v in batch.items()}
        return batch

In [None]:
if use_exp3 or use_exp4:
    data_collector = CustomDataCollatorWithPadding(tokenizer=tokenizer)

    infer_dl = DataLoader(
        infer_dataset,
        batch_size=config["infer_bs"],
        shuffle=False,
        collate_fn=data_collector
    )

## Model

In [None]:
import gc
import pdb
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

import torch.utils.checkpoint
from transformers import AutoConfig, AutoModel
from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = StableDropout(self.config["dropout"])
        
        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
            bidirectional=False,
        )

        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]
        
        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        # feature_vector = mean_feature_vector
        feature_vector = self.dropout(feature_vector)

        logits = self.classifier(feature_vector)
        logits = logits[:,:, :3] # main logits
        return logits

## Inference

In [None]:

checkpoints = [
    "../input/a-delv3-prod-lstm-multihead-attention/lstm_multihead_v1_fpe_model_fold_0_best.pth.tar",
    "../input/a-delv3-prod-lstm-multihead-attention/lstm_multihead_fpe_model_fold_1_best.pth.tar",
    "../input/a-delv3-prod-lstm-multihead-attention/lstm_multihead_fpe_model_fold_2_best.pth.tar",
    "../input/a-delv3-prod-lstm-multihead-attention/lstm_multihead_fpe_model_fold_3_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp03_delv3_mlm20_model_preds_{model_id}.csv", index=False)
    
if use_exp3:
    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        ckpt = torch.load(checkpoint)
        print(f"model performance on validation set = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

    del model
    del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
checkpoints = [
    "../input/a-delv3-prod-lstm-multihead-attention/resolved_fpe_model_fold_0_best.pth.tar",
    "../input/a-delv3-prod-lstm-multihead-attention/resolved_fpe_model_fold_1_best.pth.tar",
    "../input/a-delv3-prod-lstm-multihead-attention/resolved_v1_fpe_model_fold_2_best.pth.tar",
    "../input/a-delv3-prod-lstm-multihead-attention/resolved_fpe_model_fold_3_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp04_delv3_mlm20_resolved_model_preds_{model_id}.csv", index=False)
    
if use_exp4:
    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        ckpt = torch.load(checkpoint)
        print(f"model performance on validation set = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

    del model
    del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl

    gc.collect()
    torch.cuda.empty_cache()

In [None]:
if use_exp3:
    import glob
    import pandas as pd

    csvs = glob.glob("exp03_delv3_mlm20_model_preds_*.csv")

    idx = []
    preds = []

    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp03_df = pd.DataFrame()
    exp03_df["discourse_id"]  = idx
    exp03_df["Ineffective"]   = preds[:, 0]
    exp03_df["Adequate"]      = preds[:, 1]
    exp03_df["Effective"]     = preds[:, 2]

    exp03_df = exp03_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
if use_exp3:
    display(exp03_df.head())

In [None]:
if use_exp4:
    import glob
    import pandas as pd

    csvs = glob.glob("exp04_delv3_mlm20_resolved_model_preds_*.csv")

    idx = []
    preds = []

    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp04_df = pd.DataFrame()
    exp04_df["discourse_id"]  = idx
    exp04_df["Ineffective"]   = preds[:, 0]
    exp04_df["Adequate"]      = preds[:, 1]
    exp04_df["Effective"]     = preds[:, 2]

    exp04_df = exp04_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
if use_exp4:
    display(exp04_df.head())

# EXP1 - delv3  8 fold Fast Model - SPAN MLM 40%

## Config

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/tapt-fpe-delv3-span-mlm-04",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 5,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

## Dataset

In [None]:
import os
import re
from copy import deepcopy
from itertools import chain

import pandas as pd
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer


#--------------- Tokenizer ---------------------------------------------#
def get_tokenizer(config):
    """load the tokenizer"""

    print("using auto tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])

    print("=="*40)
    print(f"tokenizer len: {len(tokenizer)}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [SOE] [LEAD] [CLAIM] [POSITION]')}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [EOE] [LEAD_END] [POSITION_END]')}")

    print("=="*40)
    return tokenizer


#--------------- Processing ---------------------------------------------#
TOKEN_MAP = {
    "Lead": ["Lead [LEAD]", "[LEAD_END]"],
    "Position": ["Position [POSITION]", "[POSITION_END]"],
    "Claim": ["Claim [CLAIM]", "[CLAIM_END]"],
    "Counterclaim": ["Counterclaim [COUNTER_CLAIM]", "[COUNTER_CLAIM_END]"],
    "Rebuttal": ["Rebuttal [REBUTTAL]", "[REBUTTAL_END]"],
    "Evidence": ["Evidence [EVIDENCE]", "[EVIDENCE_END]"],
    "Concluding Statement": ["Concluding Statement [CONCLUDING_STATEMENT]", "[CONCLUDING_STATEMENT_END]"]
}

DISCOURSE_START_TOKENS = [
    "[LEAD]",
    "[POSITION]",
    "[CLAIM]",
    "[COUNTER_CLAIM]",
    "[REBUTTAL]",
    "[EVIDENCE]",
    "[CONCLUDING_STATEMENT]"
]

DISCOURSE_END_TOKENS = [
    "[LEAD_END]",
    "[POSITION_END]",
    "[CLAIM_END]",
    "[COUNTER_CLAIM_END]",
    "[REBUTTAL_END]",
    "[EVIDENCE_END]",
    "[CONCLUDING_STATEMENT_END]"
]


def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    for cur_discourse in discourse_list:
        if cur_discourse not in to_return:
            to_return[cur_discourse] = []

        matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
        for match in matches:
            span_start, span_end = match.span()
            if span_end <= reading_head:
                continue
            to_return[cur_discourse].append(match.span())
            reading_head = span_end
            break

    # post process
    for cur_discourse in discourse_list:
        if not to_return[cur_discourse]:
            print("resorting to relaxed search...")
            to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_essay(essay_id, essay_text, anno_df):
    """insert newly added tokens in the essay text
    """
    tmp_df = anno_df[anno_df["essay_id"] == essay_id].copy()
    tmp_df = tmp_df.sort_values(by="discourse_start")
    buffer = 0

    for _, row in tmp_df.iterrows():
        s, e, d_type = int(row.discourse_start) + buffer, int(row.discourse_end) + buffer, row.discourse_type
        s_tok, e_tok = TOKEN_MAP[d_type]
        essay_text = " ".join([essay_text[:s], s_tok, essay_text[s:e], e_tok, essay_text[e:]])
        buffer += len(s_tok) + len(e_tok) + 4

    essay_text = "[SOE]" + essay_text + "[EOE]"
    return essay_text


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness", "uid"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()

    # anno_df["discourse_span"] = anno_df[["essay_id", "discourse_text"]].apply(
    #     lambda x: get_substring_span(
    #         notes_df[notes_df["essay_id"] == x[0]].iloc[0].essay_text,
    #         x[1]
    #     ), axis=1
    # )

    # anno_df["discourse_start"] = anno_df["discourse_span"].apply(lambda x: x[0])
    # anno_df["discourse_end"] = anno_df["discourse_span"].apply(lambda x: x[1])

    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")
#     set_trace()
    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    # anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    print("=="*40)
    print("processing essay text and inserting new tokens at span boundaries")
    notes_df["essay_text"] = notes_df[["essay_id", "essay_text"]].apply(
        lambda x: process_essay(x[0], x[1], anno_df), axis=1
    )
    print("=="*40)

    anno_df = anno_df.drop(columns=["discourse_start", "discourse_end"])
    notes_df = notes_df.drop_duplicates(subset=["essay_id"])[["essay_id", "essay_text"]].copy()

    anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    if "discourse_effectiveness" in anno_df.columns:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_effectiveness", "discourse_type"]].agg(list).reset_index()
    else:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_type"]].agg(list).reset_index()

    grouped_df = pd.merge(grouped_df, notes_df, on="essay_id", how="left")
    grouped_df = grouped_df.rename(columns={"uid": "uids"})

    return grouped_df


#--------------- Dataset ----------------------------------------------#


class AuxFeedbackDataset:
    """Dataset class for feedback prize effectiveness task
    """

    def __init__(self, config):
        self.config = config

        self.label2id = {
            "Ineffective": 0,
            "Adequate": 1,
            "Effective": 2,
        }

        self.discourse_type2id = {
            "Lead": 1,
            "Position": 2,
            "Claim": 3,
            "Counterclaim": 4,
            "Rebuttal": 5,
            "Evidence": 6,
            "Concluding Statement": 7,
        }

        self.id2label = {v: k for k, v in self.label2id.items()}
        self.load_tokenizer()

    def load_tokenizer(self):
        """load tokenizer as per config 
        """
        self.tokenizer = get_tokenizer(self.config)
        print("=="*40)
        print("token maps...")
        print(TOKEN_MAP)
        print("=="*40)

        self.discourse_token_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_START_TOKENS))
        self.discourse_end_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_END_TOKENS))

    def tokenize_function(self, examples):
        tz = self.tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=False,  # no truncation at first
            add_special_tokens=True,
            return_offsets_mapping=True,
        )
        return tz

    def process_spans(self, examples):

        span_head_char_start_idxs, span_tail_char_end_idxs = [], []
        span_head_idxs, span_tail_idxs = [], []

        for example_input_ids, example_offset_mapping, example_uids in zip(examples["input_ids"], examples["offset_mapping"], examples["uids"]):
            example_span_head_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_token_ids]
            example_span_tail_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_end_ids]

            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in example_span_head_idxs]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in example_span_tail_idxs]

            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

            span_head_idxs.append(example_span_head_idxs)
            span_tail_idxs.append(example_span_tail_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def generate_labels(self, examples):
        labels = []
        for example_labels, example_uids in zip(examples["discourse_effectiveness"], examples["uids"]):
            labels.append([self.label2id[l] for l in example_labels])
        return {"labels": labels}

    def generate_discourse_type_ids(self, examples):
        discourse_type_ids = []
        for example_discourse_types in examples["discourse_type"]:
            discourse_type_ids.append([self.discourse_type2id[dt] for dt in example_discourse_types])
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(self, examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(self, examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1

    def sanity_check_head_labels(self, examples):
        for head_idxs, head_labels in zip(examples["span_head_idxs"], examples["labels"]):
            assert len(head_idxs) == len(head_labels)

    def get_dataset(self, df, essay_df, mode='train'):
        """main api for creating the Feedback dataset

        :param df: input annotation dataframe
        :type df: pd.DataFrame
        :param essay_df: dataframe with essay texts
        :type essay_df: pd.DataFrame
        :param mode: check if required for train or infer, defaults to 'train'
        :type mode: str, optional
        :return: the created dataset
        :rtype: Dataset
        """
        df = process_input_df(df, essay_df)

        # save a sample for sanity checks
        sample_df = df.sample(min(16, len(df)))
        sample_df.to_csv(os.path.join(self.config["model_dir"], f"{mode}_df_processed.csv"), index=False)

        task_dataset = Dataset.from_pandas(df)
        task_dataset = task_dataset.map(self.tokenize_function, batched=True)
        task_dataset = task_dataset.map(self.compute_input_length, batched=True)
        task_dataset = task_dataset.map(self.process_spans, batched=True)
        print(task_dataset)
        # todo check edge cases
        task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) == len(
            example['span_tail_idxs']))  # no need to run on empty set
        print(task_dataset)
        task_dataset = task_dataset.map(self.generate_discourse_type_ids, batched=True)
        task_dataset = task_dataset.map(self.sanity_check_head_tail, batched=True)

        if mode != "infer":
            task_dataset = task_dataset.map(self.generate_labels, batched=True)
            task_dataset = task_dataset.map(self.sanity_check_head_labels, batched=True)

        try:
            task_dataset = task_dataset.remove_columns(column_names=["__index_level_0__"])
        except Exception as e:
            pass
        return df, task_dataset

#--------------- dataset with truncation ---------------------------------------------#


def get_fast_dataset(config, df, essay_df, mode="train"):
    """Function to get fast approach dataset with truncation & sliding window
    """
    dataset_creator = AuxFeedbackDataset(config)
    _, task_dataset = dataset_creator.get_dataset(df, essay_df, mode=mode)

    original_dataset = deepcopy(task_dataset)
    tokenizer = dataset_creator.tokenizer
    START_IDS = dataset_creator.discourse_token_ids
    END_IDS = dataset_creator.discourse_end_ids

    def tokenize_with_truncation(examples):
        tz = tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=True,
            add_special_tokens=True,
            return_offsets_mapping=True,
            max_length=config["max_length"],
            stride=config["stride"],
            return_overflowing_tokens=True,
        )
        return tz

    def process_span(examples):
        span_head_idxs, span_tail_idxs = [], []
        span_head_char_start_idxs, span_tail_char_end_idxs = [], []

        buffer = 25  # do not include a head if it is within buffer distance away from last token

        for example_input_ids, example_offset_mapping in zip(examples["input_ids"], examples["offset_mapping"]):
            # ------------------- Span Heads -----------------------------------------#
            if len(example_input_ids) < config["max_length"]:  # no truncation
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in START_IDS]
            else:
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if (
                    (this_id in START_IDS) & (pos <= config["max_length"]-buffer))]

            n_heads = len(head_candidate)

            # ------------------- Span Tails -----------------------------------------#
            tail_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in END_IDS]

            # ------------------- Edge Cases -----------------------------------------#
            # 1. A tail occurs before the first head in the sequence due to truncation
            if (len(tail_candidate) > 0) & (len(head_candidate) > 0):
                if tail_candidate[0] < head_candidate[0]:  # truncation effect
                    # print(f"check: heads: {head_candidate}, tails {tail_candidate}")
                    tail_candidate = tail_candidate[1:]  # shift by one

            # 2. Tail got chopped off due to truncation but the corresponding head is still there
            if len(tail_candidate) < n_heads:
                assert len(tail_candidate) + 1 == n_heads
                assert len(example_input_ids) == config["max_length"]  # should only happen if input text is truncated
                tail_candidate.append(config["max_length"]-2)  # the token before [SEP] token

            # 3. Additional tails remain in the buffer region
            if len(tail_candidate) > len(head_candidate):
                tail_candidate = tail_candidate[:len(head_candidate)]

            # ------------------- Create the fields ------------------------------------#
            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in head_candidate]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in tail_candidate]

            span_head_idxs.append(head_candidate)
            span_tail_idxs.append(tail_candidate)
            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def enforce_alignment(examples):
        uids = []

        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_uids = original_example["uids"]
            char2uid = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_uids)}
            current_example_uids = [char2uid[char_idx] for char_idx in example_span_head_char_start_idxs]
            uids.append(current_example_uids)
        return {"uids": uids}

    def recompute_labels(examples):
        labels = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_labels = original_example["labels"]
            char2label = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_labels)}
            current_example_labels = [char2label[char_idx] for char_idx in example_span_head_char_start_idxs]
            labels.append(current_example_labels)
        return {"labels": labels}

    def recompute_discourse_type_ids(examples):
        discourse_type_ids = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_discourse_type_ids = original_example["discourse_type_ids"]
            char2discourse_id = {k: v for k, v in zip(
                original_example_span_head_char_start_idxs, original_example_discourse_type_ids)}
            current_example_discourse_type_ids = [char2discourse_id[char_idx]
                                                  for char_idx in example_span_head_char_start_idxs]
            discourse_type_ids.append(current_example_discourse_type_ids)
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1, f"head idxs: {head_idxs}, tail idxs {tail_idxs}"

    task_dataset = task_dataset.map(
        tokenize_with_truncation,
        batched=True,
        remove_columns=task_dataset.column_names,
        batch_size=len(task_dataset)
    )

    task_dataset = task_dataset.map(process_span, batched=True)
    task_dataset = task_dataset.map(enforce_alignment, batched=True)
    task_dataset = task_dataset.map(recompute_discourse_type_ids, batched=True)
    task_dataset = task_dataset.map(sanity_check_head_tail, batched=True)

    # no need to run on empty set
    task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) != 0)
    task_dataset = task_dataset.map(compute_input_length, batched=True)

    if mode != "infer":
        task_dataset = task_dataset.map(recompute_labels, batched=True)

    to_return = dict()
    to_return["dataset"] = task_dataset
    to_return["original_dataset"] = original_dataset
    to_return["tokenizer"] = tokenizer
    return to_return

In [None]:
if use_exp1 or use_exp10 or use_exp11 or use_exp16:

    os.makedirs(config["model_dir"], exist_ok=True)

    print("creating the inference datasets...")
    infer_ds_dict = get_fast_dataset(config, test_df, essay_df, mode="infer")
    tokenizer = infer_ds_dict["tokenizer"]
    infer_dataset = infer_ds_dict["dataset"]
    print(infer_dataset)

In [None]:

if use_exp1 or use_exp10 or use_exp11:
    config["len_tokenizer"] = len(tokenizer)
    infer_dataset = infer_dataset.sort("input_length")
    infer_dataset.set_format(
        type=None,
        columns=['input_ids', 'attention_mask', 'token_type_ids', 'span_head_idxs',
                 'span_tail_idxs', 'discourse_type_ids', 'uids']
    )

## Data Loader

In [None]:
from copy import deepcopy
from dataclasses import dataclass

import torch
from transformers import DataCollatorWithPadding


@dataclass
class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    """
    data collector for seq classification
    """

    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = None
    return_tensors = "pt"

    def __call__(self, features):
        uids = [feature["uids"] for feature in features]
        discourse_type_ids = [feature["discourse_type_ids"] for feature in features]
        span_head_idxs = [feature["span_head_idxs"] for feature in features]
        span_tail_idxs = [feature["span_tail_idxs"] for feature in features]
        span_attention_mask = [[1]*len(feature["span_head_idxs"]) for feature in features]

        labels = None
        if "labels" in features[0].keys():
            labels = [feature["labels"] for feature in features]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None,
        )

        b_max = max([len(l) for l in span_head_idxs])
        max_len = len(batch["input_ids"][0])

        default_head_idx = max(max_len - 10, 1)  # for padding
        default_tail_idx = max(max_len - 4, 1)  # for padding

        batch["span_head_idxs"] = [
            ex_span_head_idxs + [default_head_idx] * (b_max - len(ex_span_head_idxs)) for ex_span_head_idxs in span_head_idxs
        ]

        batch["uids"] = [ex_uids + [-1] * (b_max - len(ex_uids)) for ex_uids in uids]
        batch["discourse_type_ids"] = [ex_discourse_type_ids + [0] *
                                       (b_max - len(ex_discourse_type_ids)) for ex_discourse_type_ids in discourse_type_ids]

        batch["span_tail_idxs"] = [
            ex_span_tail_idxs + [default_tail_idx] * (b_max - len(ex_span_tail_idxs)) for ex_span_tail_idxs in span_tail_idxs
        ]

        batch["span_attention_mask"] = [
            ex_discourse_masks + [0] * (b_max - len(ex_discourse_masks)) for ex_discourse_masks in span_attention_mask
        ]

        if labels is not None:
            batch["labels"] = [ex_labels + [-1] * (b_max - len(ex_labels)) for ex_labels in labels]

        def _get_additional_labels(label_id):
            if label_id == 0:
                vec = [0, 0]
            elif label_id == 1:
                vec = [1, 0]
            elif label_id == 2:
                vec = [1, 1]
            elif label_id == -1:
                vec = [-1, -1]
            else:
                raise
            return vec

        if labels is not None:
            additional_labels = []
            for ex_labels in batch["labels"]:
                ex_additional_labels = [_get_additional_labels(el) for el in ex_labels]
                additional_labels.append(ex_additional_labels)
            batch["multitask_labels"] = additional_labels

        # batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
        batch = {k: (torch.tensor(v, dtype=torch.int64) if k != "multitask_labels" else torch.tensor(
            v, dtype=torch.float32)) for k, v in batch.items()}
        return batch

In [None]:
if use_exp1 or use_exp10 or use_exp11:
    data_collector = CustomDataCollatorWithPadding(tokenizer=tokenizer)

    infer_dl = DataLoader(
        infer_dataset,
        batch_size=config["infer_bs"],
        shuffle=False,
        collate_fn=data_collector
    )

## Model

In [None]:
import gc
import pdb
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

import torch.utils.checkpoint
from transformers import AutoConfig, AutoModel
from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = StableDropout(self.config["dropout"])
        
        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
            bidirectional=False,
        )

        self.num_labels = self.config["num_labels"] # 5
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]
        
        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        # feature_vector = mean_feature_vector
        feature_vector = self.dropout(feature_vector)

        logits = self.classifier(feature_vector)
        logits = logits[:,:, :3] # main logits
        return logits

In [None]:
def process_swa_checkpoint(checkpoint_path):
    """
    helper function to process swa checkpoints
    """
    ckpt = torch.load(checkpoint_path)

    print("processing ckpt...")
    print("removing module from keys...")
    state_dict = ckpt['state_dict']
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        if k == "n_averaged":
            continue
        name = k[7:]  # remove 'module.'
        new_state_dict[name] = v
    processed_state = {"state_dict": new_state_dict}

    # delete old state
    del state_dict
    gc.collect()

    return processed_state

In [None]:
checkpoints = [
    "../input/a-delv3-prod-8-folds/fpe_model_fold_0_best.pth.tar",
    "../input/a-delv3-prod-8-folds/fpe_model_fold_1_best.pth.tar",
    "../input/a-delv3-prod-8-folds/fpe_model_fold_2_best.pth.tar",
    "../input/a-delv3-prod-8-folds/fpe_model_fold_3_best.pth.tar",
    "../input/a-delv3-prod-8-folds/fpe_model_fold_4_best.pth.tar",
    "../input/a-delv3-prod-8-folds/fpe_model_fold_5_best.pth.tar",
    "../input/a-delv3-prod-8-folds/fpe_model_fold_6_best.pth.tar",
    "../input/a-delv3-prod-8-folds/fpe_model_fold_7_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp01_delv3_8folds_model_preds_{model_id}.csv", index=False)
    
if use_exp1:

    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        if "swa" in checkpoint:
            ckpt = process_swa_checkpoint(checkpoint)
        else:
            ckpt = torch.load(checkpoint)
            print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

    del model
    # del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    gc.collect()

In [None]:
if use_exp1:
    import glob
    import pandas as pd

    csvs = glob.glob("exp01_delv3_8folds_model_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp01_df = pd.DataFrame()
    exp01_df["discourse_id"] = idx
    exp01_df["Ineffective"]  = preds[:, 0]
    exp01_df["Adequate"]     = preds[:, 1]
    exp01_df["Effective"]    = preds[:, 2]

    exp01_df = exp01_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
if use_exp1:
    exp01_df.head()

# Exp 10: Revisited DELV3 8 Folds

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/tapt-fpe-delv3-span-mlm-04",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 5,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

import gc
import pdb
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

import torch.utils.checkpoint
from transformers import AutoConfig, AutoModel
from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = StableDropout(self.config["dropout"])
        
        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]
        
        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        # feature_vector = mean_feature_vector
        feature_vector = self.dropout(feature_vector)

        logits = self.classifier(feature_vector)
        logits = logits[:,:, :3] # main logits
        return logits

In [None]:
checkpoints = [
    "../input/exp-10-delv3-revisited-8-folds/fpe_model_fold_0_best.pth.tar",
    "../input/exp-10-delv3-revisited-8-folds/fpe_model_fold_1_best.pth.tar",
    "../input/exp-10-delv3-revisited-8-folds/fpe_model_fold_2_best.pth.tar",
    "../input/exp-10-delv3-revisited-8-folds/fpe_model_fold_3_best.pth.tar",
    "../input/exp-10-delv3-revisited-8-folds/fpe_model_fold_4_best.pth.tar",
    "../input/exp-10-delv3-revisited-8-folds/fpe_model_fold_5_best.pth.tar",
    "../input/exp-10-delv3-revisited-8-folds/fpe_model_fold_6_best.pth.tar",
    "../input/exp-10-delv3-revisited-8-folds/fpe_model_fold_7_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp10_delv3_8folds_revisited_preds_{model_id}.csv", index=False)
    

if use_exp10:
    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        if "swa" in checkpoint:
            ckpt = process_swa_checkpoint(checkpoint)
        else:
            ckpt = torch.load(checkpoint)
            print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

    del model
    #del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
if use_exp10:
    import glob
    import pandas as pd

    csvs = glob.glob("exp10_delv3_8folds_revisited_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp10_df = pd.DataFrame()
    exp10_df["discourse_id"] = idx
    exp10_df["Ineffective"]  = preds[:, 0]
    exp10_df["Adequate"]     = preds[:, 1]
    exp10_df["Effective"]    = preds[:, 2]

    exp10_df = exp10_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
if use_exp10:
    exp10_df.head()

# Exp 11 - UDA

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/tapt-fpe-delv3-span-mlm-04",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 3,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

In [None]:
import gc
import pdb
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

import torch.utils.checkpoint
from transformers import AutoConfig, AutoModel
from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = StableDropout(self.config["dropout"])
        
        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]
        
        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        # feature_vector = mean_feature_vector
        feature_vector = self.dropout(feature_vector)

        logits = self.classifier(feature_vector)
        logits = logits[:,:, :3] # main logits
        return logits

In [None]:
def process_swa_checkpoint(checkpoint_path):
        ckpt = torch.load(checkpoint_path)

        print("processing ckpt...")
        print("removing module from keys...")
        state_dict = ckpt['state_dict']
        new_state_dict = OrderedDict()

        for k, v in state_dict.items():
            if k == "n_averaged":
                print(f"# of snapshots in {checkpoint_path} = {v}")
                continue
            name = k[7:]  # remove 'module.'
            new_state_dict[name] = v
        processed_state = {"state_dict": new_state_dict}
        
        # delete old state
        del state_dict
        gc.collect()
        
        return processed_state

In [None]:
checkpoints = [
    "../input/exp11-delv3-uda-4-folds/exp-11-delv3-uda-4-folds/fpe_model_fold_0_best.pth.tar",
    "../input/exp11-delv3-uda-4-folds/exp-11-delv3-uda-4-folds/fpe_model_fold_1_best.pth.tar",
    "../input/exp11-delv3-uda-4-folds/exp-11-delv3-uda-4-folds/fpe_model_fold_2_best.pth.tar",
    "../input/exp11-delv3-uda-4-folds/exp-11-delv3-uda-4-folds/fpe_model_fold_3_best.pth.tar",
]

def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp11_preds_{model_id}.csv", index=False)
    
if use_exp11:
    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        if "swa" in checkpoint:
            ckpt = process_swa_checkpoint(checkpoint)
        else:
            ckpt = torch.load(checkpoint)
            print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

    del model
    #del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
if use_exp11:
    import glob
    import pandas as pd

    csvs = glob.glob("exp11_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp11_df = pd.DataFrame()
    exp11_df["discourse_id"] = idx
    exp11_df["Ineffective"] = preds[:, 0]
    exp11_df["Adequate"] = preds[:, 1]
    exp11_df["Effective"] = preds[:, 2]

    exp11_df = exp11_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()
    exp11_df.to_csv("submission.csv", index=False)

In [None]:
if use_exp11:
    exp11_df.head()

In [None]:
import gc
import pdb
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

import torch.utils.checkpoint
from transformers import AutoConfig, AutoModel
from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = StableDropout(self.config["dropout"])
        
        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]
        
        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        # feature_vector = mean_feature_vector
        feature_vector = self.dropout(feature_vector)

        logits = self.classifier(feature_vector)
        logits = logits[:,:, :3] # main logits
        return logits

In [None]:
checkpoints = [
    "../input/exp-16-part-1/fpe_model_fold_0_best.pth.tar",
    "../input/exp-16-part-1/fpe_model_fold_1_best.pth.tar",
    "../input/exp-16-part-1/fpe_model_fold_2_best.pth.tar",
    "../input/exp-16-part-1/fpe_model_fold_3_best.pth.tar",
    "../input/exp-16-part-1/fpe_model_fold_4_best.pth.tar",
    "../input/exp-16-part-2/fpe_model_fold_5_best.pth.tar",
    "../input/exp-16-part-2/fpe_model_fold_6_best.pth.tar",
    "../input/exp-16-part-2/fpe_model_fold_7_best.pth.tar",
    "../input/exp-16-part-2/fpe_model_fold_8_best.pth.tar",
    "../input/exp-16-part-2/fpe_model_fold_9_best.pth.tar",
]

def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp16_preds_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackModel(config)
    if "swa" in checkpoint:
        ckpt = process_swa_checkpoint(checkpoint)
    else:
        ckpt = torch.load(checkpoint)
        print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)

del model
# del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
torch.cuda.empty_cache()
gc.collect()

In [None]:
import glob
import pandas as pd

csvs = glob.glob("exp16_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp16_df = pd.DataFrame()
exp16_df["discourse_id"] = idx
exp16_df["Ineffective"] = preds[:, 0]
exp16_df["Adequate"] = preds[:, 1]
exp16_df["Effective"] = preds[:, 2]

exp16_df = exp16_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()


In [None]:
exp16_df.head()

### ALL Data Trained Models

In [None]:
import gc
import pdb
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

import torch.utils.checkpoint
from transformers import AutoConfig, AutoModel
from transformers.models.deberta_v2.modeling_deberta_v2 import StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = StableDropout(self.config["dropout"])

        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)

        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)

        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):

        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]

        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        # feature_vector = mean_feature_vector
        feature_vector = self.dropout(feature_vector)

        logits = self.classifier(feature_vector)
        logits = logits[:,:, :3] # main logits
        return logits
    

#########################
from copy import deepcopy

checkpoints = [
    "../input/delv3-all-folds/swa_fpe_all_exp_10a_mask_aug.pth.tar",
    "../input/delv3-all-folds/fpe_all_exp_16a_mixout_high_gamma_high_mask_aug.pth.tar",
]

def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)

    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))

    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)

    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))

    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    

    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp_rb_full_data_preds_{model_id}.csv", index=False)


for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    new_config = deepcopy(config)
    new_config["num_labels"] = 5

    model = FeedbackModel(new_config)
    if "swa" in checkpoint:
        ckpt = process_swa_checkpoint(checkpoint)
    else:
        ckpt = torch.load(checkpoint)
        print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)

del model
# del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
torch.cuda.empty_cache()
gc.collect()



import glob
import pandas as pd

csvs = glob.glob("exp_rb_full_data_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):

    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)

    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp99_rb_all_df = pd.DataFrame()
exp99_rb_all_df["discourse_id"] = idx
exp99_rb_all_df["Ineffective"] = preds[:, 0]
exp99_rb_all_df["Adequate"] = preds[:, 1]
exp99_rb_all_df["Effective"] = preds[:, 2]

exp99_rb_all_df = exp99_rb_all_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()
exp99_rb_all_df.head()

#### UDA

In [None]:
# # UDA Model
# config = """{
#     "debug": false,

#     "base_model_path": "../input/tapt-fpe-delv3-span-mlm-04",
#     "model_dir": "./outputs",

#     "max_length": 1024,
#     "stride": 256,
#     "num_labels": 3,
#     "dropout": 0.1,
#     "infer_bs": 12
# }
# """
# config = json.loads(config)
# config["len_tokenizer"] = len(tokenizer)

# import gc
# import pdb
# from collections import OrderedDict

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.nn import LayerNorm

# import torch.utils.checkpoint
# from transformers import AutoConfig, AutoModel
# from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


# #-------- Model ------------------------------------------------------------------#
# class FeedbackModel(nn.Module):
#     """The feedback prize effectiveness baseline model
#     """

#     def __init__(self, config):
#         super(FeedbackModel, self).__init__()
#         self.config = config

#         # base transformer
#         base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
#         base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
#         self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

#         # dropouts
#         self.dropout = StableDropout(self.config["dropout"])
        
#         # multi-head attention
#         attention_config = deepcopy(self.base_model.config)
#         attention_config.update({"relative_attention": False})
#         self.fpe_span_attention = DebertaV2Attention(attention_config)
        
#         # classification
#         hidden_size = self.base_model.config.hidden_size
#         feature_size = hidden_size
#         self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
#         # # LSTM Head
#         self.fpe_lstm_layer = nn.LSTM(
#             input_size=feature_size,
#             hidden_size=hidden_size//2,
#             num_layers=1,
#             batch_first=True,
#             bidirectional=True,
#         )

#         self.num_labels = self.config["num_labels"]
#         self.classifier = nn.Linear(feature_size, self.config["num_labels"])

#     def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
#         bs = input_ids.shape[0]  # batch size
#         outputs = self.base_model(input_ids, attention_mask=attention_mask)
#         encoder_layer = outputs[0]
        
#         self.fpe_lstm_layer.flatten_parameters()
#         encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

#         mean_feature_vector = []
#         for i in range(bs):  # TODO: vectorize
#             span_vec_i = []

#             for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
#                 # span feature
#                 tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
#                 span_vec_i.append(tmp)
#             span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
#             mean_feature_vector.append(span_vec_i)

#         mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
#         mean_feature_vector = self.layer_norm(mean_feature_vector)

#         # attend to other features
#         extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
#         span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
#         span_attention_mask = span_attention_mask.byte()
#         feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

#         # feature_vector = mean_feature_vector
#         feature_vector = self.dropout(feature_vector)

#         logits = self.classifier(feature_vector)
#         logits = logits[:,:, :3] # main logits
#         return logits
    
# ########################
# checkpoints = [
#     "../input/exp21-uda-all-data/fpe_model_kd_seed_1.pth.tar",
#     "../input/exp21-uda-all-data/fpe_model_kd_seed_2.pth.tar",
    
# ]

# def inference_fn(model, infer_dl, model_id):
#     all_preds = []
#     all_uids = []
#     accelerator = Accelerator()
#     model, infer_dl = accelerator.prepare(model, infer_dl)
    
#     model.eval()
#     tk0 = tqdm(infer_dl, total=len(infer_dl))
    
#     for batch in tk0:
#         with torch.no_grad():
#             logits = model(**batch) # (b, nd, 3)
#             batch_preds = F.softmax(logits, dim=-1)
#             batch_uids = batch["uids"]
#         all_preds.append(batch_preds)
#         all_uids.append(batch_uids)
    
#     all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
#     all_preds = list(chain(*all_preds))
#     flat_preds = list(chain(*all_preds))
    
#     all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
#     all_uids = list(chain(*all_uids))
#     flat_uids = list(chain(*all_uids))    
    
#     preds_df = pd.DataFrame(flat_preds)
#     preds_df.columns = ["Ineffective", "Adequate", "Effective"]
#     preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
#     preds_df = preds_df[preds_df["span_uid"]>=0].copy()
#     preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
#     preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
#     preds_df.to_csv(f"exp21f_uda_preds_{model_id}.csv", index=False)
    
# from copy import deepcopy
# for model_id, checkpoint in enumerate(checkpoints):
#     print(f"infering from {checkpoint}")
#     new_config = deepcopy(config)
#     if "10a" in checkpoint:
#         new_config["num_labels"] = 5
        
#     model = FeedbackModel(new_config)
#     if "swa" in checkpoint:
#         ckpt = process_swa_checkpoint(checkpoint)
#     else:
#         ckpt = torch.load(checkpoint)
#         print(f"validation score for fold {model_id} = {ckpt['loss']}")
#     model.load_state_dict(ckpt['state_dict'])
#     inference_fn(model, infer_dl, model_id)
    
# del model
# gc.collect()
# torch.cuda.empty_cache()

# #####
# import glob
# import pandas as pd

# csvs = glob.glob("exp21f_uda_preds_*.csv")

# idx = []
# preds = []


# for csv_idx, csv in enumerate(csvs):
    
#     print("=="*40)
#     print(f"preds in {csv}")
#     df = pd.read_csv(csv)
#     df = df.sort_values(by=["discourse_id"])
#     print(df.head(10))
#     print("=="*40)
    
#     temp_preds = df.drop(["discourse_id"], axis=1).values
#     if csv_idx == 0:
#         idx = list(df["discourse_id"])
#         preds = temp_preds
#     else:
#         preds += temp_preds

# preds = preds / len(csvs)

# exp21f_df = pd.DataFrame()
# exp21f_df["discourse_id"] = idx
# exp21f_df["Ineffective"] = preds[:, 0]
# exp21f_df["Adequate"] = preds[:, 1]
# exp21f_df["Effective"] = preds[:, 2]

# exp21f_df = exp21f_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()
# exp21f_df.head()

In [None]:
# del model
# torch.cuda.empty_cache()
# gc.collect()

In [None]:
try:
    del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    gc.collect()
except Exception as e:
    print(e)

In [None]:
try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except Exception as e:
    print(e)

# Exp 102 Fast Model - SPAN MLM 40% + MSD

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/tapt-fpe-delv3-span-mlm-04",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 3,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

In [None]:
import gc
import pdb
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

import torch.utils.checkpoint
from transformers import AutoConfig, AutoModel
from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = StableDropout(self.config["dropout"])
        self.dropout1 = StableDropout(self.config["dropout"]+0.1)
        self.dropout2 = StableDropout(self.config["dropout"]+0.2)
        self.dropout3 = StableDropout(self.config["dropout"]+0.3)
        self.dropout4 = StableDropout(self.config["dropout"]+0.4)
        
        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
            bidirectional=False,
        )

        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]
        
        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        # feature_vector = mean_feature_vector
        feature_vector1 = self.dropout(feature_vector)
        feature_vector2 = self.dropout1(feature_vector)
        feature_vector3 = self.dropout2(feature_vector)
        feature_vector4 = self.dropout3(feature_vector)
        feature_vector5 = self.dropout4(feature_vector)

        logits1 = self.classifier(feature_vector1)
        logits2 = self.classifier(feature_vector2)
        logits3 = self.classifier(feature_vector3)
        logits4 = self.classifier(feature_vector4)
        logits5 = self.classifier(feature_vector5)
        
        logits = (logits1 + logits2 + logits3 + logits4 + logits5)/5
#         logits = logits[:,:, :3] # main logits
        return logits

In [None]:
checkpoints = [
    "../input/v3l-msd-fast-approach/fpe_model_fold_0_best.pth.tar",
    "../input/v3l-msd-fast-approach/fpe_model_fold_1_best.pth.tar",
    "../input/v3l-msd-fast-approach/fpe_model_fold_2_best.pth.tar",
    "../input/v3l-msd-fast-approach/fpe_model_fold_3_best.pth.tar",
    "../input/v3l-msd-fast-approach/fpe_model_fold_4_best.pth.tar"]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp102_delv3_msd_5folds_model_preds_{model_id}.csv", index=False)
    
if use_exp102:
    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        if "swa" in checkpoint:
            ckpt = process_swa_checkpoint(checkpoint)
        else:
            ckpt = torch.load(checkpoint)
            print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

    del model
    del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
if use_exp102:
    import glob
    import pandas as pd

    csvs = glob.glob("exp102_delv3_msd_5folds_model_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp102_df = pd.DataFrame()
    exp102_df["discourse_id"] = idx
    exp102_df["Ineffective"]  = preds[:, 0]
    exp102_df["Adequate"]     = preds[:, 1]
    exp102_df["Effective"]    = preds[:, 2]

    exp102_df = exp102_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

# EXP 213 - Deberta-large 10 fold LB 0.565

## Config

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/deberta-large-prompt-mlm40/",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 3,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

## Dataset

In [None]:
import os
import re
from copy import deepcopy
from itertools import chain

import pandas as pd
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer


#--------------- Tokenizer ---------------------------------------------#
def get_tokenizer(config):
    """load the tokenizer"""

    print("using auto tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])

    print("=="*40)
    print(f"tokenizer len: {len(tokenizer)}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [SOE] [LEAD] [CLAIM] [POSITION] [COUNTER_CLAIM]')}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [EOE] [LEAD_END] [POSITION_END] [CLAIM_END]')}")

    print("=="*40)
    return tokenizer


#--------------- Processing ---------------------------------------------#


DISCOURSE_START_TOKENS = [
    "[LEAD]",
    "[POSITION]",
    "[CLAIM]",
    "[COUNTER_CLAIM]",
    "[REBUTTAL]",
    "[EVIDENCE]",
    "[CONCLUDING_STATEMENT]"
]

TOKEN_MAP = {
    "topic": ["Topic [TOPIC]", "[TOPIC END]"],
    "Lead": ["Lead [LEAD]", "[LEAD END]"],
    "Position": ["Position [POSITION]", "[POSITION END]"],
    "Claim": ["Claim [CLAIM]", "[CLAIM END]"],
    "Counterclaim": ["Counterclaim [COUNTER_CLAIM]", "[COUNTER_CLAIM END]"],
    "Rebuttal": ["Rebuttal [REBUTTAL]", "[REBUTTAL END]"],
    "Evidence": ["Evidence [EVIDENCE]", "[EVIDENCE END]"],
    "Concluding Statement": ["Concluding Statement [CONCLUDING_STATEMENT]", "[CONCLUDING_STATEMENT END]"]
}


DISCOURSE_END_TOKENS = [
    "[LEAD END]",
    "[POSITION END]",
    "[CLAIM END]",
    "[COUNTER_CLAIM END]",
    "[REBUTTAL END]",
    "[EVIDENCE END]",
    "[CONCLUDING_STATEMENT END]",
]



def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    for cur_discourse in discourse_list:
        if cur_discourse not in to_return:
            to_return[cur_discourse] = []

        matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
        for match in matches:
            span_start, span_end = match.span()
            if span_end <= reading_head:
                continue
            to_return[cur_discourse].append(match.span())
            reading_head = span_end
            break

    # post process
    for cur_discourse in discourse_list:
        if not to_return[cur_discourse]:
            print("resorting to relaxed search...")
            to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_essay(essay_id, essay_text, topic, prompt, anno_df):
    """insert newly added tokens in the essay text
    """
    tmp_df = anno_df[anno_df["essay_id"] == essay_id].copy()
    tmp_df = tmp_df.sort_values(by="discourse_start")
    buffer = 0

    for _, row in tmp_df.iterrows():
        s, e, d_type = int(row.discourse_start) + buffer, int(row.discourse_end) + buffer, row.discourse_type
        s_tok, e_tok = TOKEN_MAP[d_type]
        essay_text = " ".join([essay_text[:s], s_tok, essay_text[s:e], e_tok, essay_text[e:]])
        buffer += len(s_tok) + len(e_tok) + 4

    #essay_text = "[SOE]" + " [TOPIC] " + topic + " [TOPIC END] " +  "[PROMPT] " + prompt + " [PROMPT END] " + essay_text + "[EOE]"
    essay_text = "[SOE]" + " [TOPIC] " + prompt + " [TOPIC END] " +  essay_text + "[EOE]"

    return essay_text


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness", "uid"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()

    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")
    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    # anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    print("=="*40)
    print("processing essay text and inserting new tokens at span boundaries")
    notes_df["essay_text"] = notes_df[["essay_id", "essay_text", "topic", "prompt"]].apply(
        lambda x: process_essay(x[0], x[1], x[2], x[3], anno_df), axis=1
    )
    print("=="*40)

    anno_df = anno_df.drop(columns=["discourse_start", "discourse_end"])
    notes_df = notes_df.drop_duplicates(subset=["essay_id"])[["essay_id", "essay_text"]].copy()

    anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    if "discourse_effectiveness" in anno_df.columns:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_effectiveness", "discourse_type"]].agg(list).reset_index()
    else:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_type"]].agg(list).reset_index()

    grouped_df = pd.merge(grouped_df, notes_df, on="essay_id", how="left")
    grouped_df = grouped_df.rename(columns={"uid": "uids"})

    return grouped_df


#--------------- Dataset ----------------------------------------------#


class AuxFeedbackDataset:
    """Dataset class for feedback prize effectiveness task
    """

    def __init__(self, config):
        self.config = config

        self.label2id = {
            "Ineffective": 0,
            "Adequate": 1,
            "Effective": 2,
        }

        self.discourse_type2id = {
            "Lead": 1,
            "Position": 2,
            "Claim": 3,
            "Counterclaim": 4,
            "Rebuttal": 5,
            "Evidence": 6,
            "Concluding Statement": 7,
        }

        self.id2label = {v: k for k, v in self.label2id.items()}
        self.load_tokenizer()

    def load_tokenizer(self):
        """load tokenizer as per config 
        """
        self.tokenizer = get_tokenizer(self.config)
        print("=="*40)
        print("token maps...")
        print(TOKEN_MAP)
        print("=="*40)

        # print("adding new tokens...")
        # tokens_to_add = []
        # for this_tok in NEW_TOKENS:
        #     tokens_to_add.append(AddedToken(this_tok, lstrip=True, rstrip=False))
        # self.tokenizer.add_tokens(tokens_to_add)
        print(f"tokenizer len: {len(self.tokenizer)}")

        self.discourse_token_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_START_TOKENS))
        self.discourse_end_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_END_TOKENS))
        self.global_tokens = self.discourse_token_ids.union(self.discourse_end_ids)

    def tokenize_function(self, examples):
        tz = self.tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=False,  # no truncation at first
            add_special_tokens=True,
            return_offsets_mapping=True,
        )
        return tz

    def process_spans(self, examples):

        span_head_char_start_idxs, span_tail_char_end_idxs = [], []
        span_head_idxs, span_tail_idxs = [], []

        for example_input_ids, example_offset_mapping, example_uids in zip(examples["input_ids"], examples["offset_mapping"], examples["uids"]):
            example_span_head_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_token_ids]
            example_span_tail_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_end_ids]

            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in example_span_head_idxs]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in example_span_tail_idxs]

            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

            span_head_idxs.append(example_span_head_idxs)
            span_tail_idxs.append(example_span_tail_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def generate_labels(self, examples):
        labels = []
        for example_labels, example_uids in zip(examples["discourse_effectiveness"], examples["uids"]):
            labels.append([self.label2id[l] for l in example_labels])
        return {"labels": labels}

    def generate_discourse_type_ids(self, examples):
        discourse_type_ids = []
        for example_discourse_types in examples["discourse_type"]:
            discourse_type_ids.append([self.discourse_type2id[dt] for dt in example_discourse_types])
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(self, examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(self, examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1

    def sanity_check_head_labels(self, examples):
        for head_idxs, head_labels in zip(examples["span_head_idxs"], examples["labels"]):
            assert len(head_idxs) == len(head_labels)

    def get_dataset(self, df, essay_df, mode='train'):
        """main api for creating the Feedback dataset

        :param df: input annotation dataframe
        :type df: pd.DataFrame
        :param essay_df: dataframe with essay texts
        :type essay_df: pd.DataFrame
        :param mode: check if required for train or infer, defaults to 'train'
        :type mode: str, optional
        :return: the created dataset
        :rtype: Dataset
        """
        df = process_input_df(df, essay_df)

        # save a sample for sanity checks
        sample_df = df.sample(min(16, len(df)))
        sample_df.to_csv(os.path.join(self.config["model_dir"], f"{mode}_df_processed.csv"), index=False)

        task_dataset = Dataset.from_pandas(df)
        task_dataset = task_dataset.map(self.tokenize_function, batched=True)
        task_dataset = task_dataset.map(self.compute_input_length, batched=True)
        task_dataset = task_dataset.map(self.process_spans, batched=True)
        print(task_dataset)
        # todo check edge cases
        task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) == len(
            example['span_tail_idxs']))  # no need to run on empty set
        print(task_dataset)
        task_dataset = task_dataset.map(self.generate_discourse_type_ids, batched=True)
        task_dataset = task_dataset.map(self.sanity_check_head_tail, batched=True)

        if mode != "infer":
            task_dataset = task_dataset.map(self.generate_labels, batched=True)
            task_dataset = task_dataset.map(self.sanity_check_head_labels, batched=True)

        try:
            task_dataset = task_dataset.remove_columns(column_names=["__index_level_0__"])
        except Exception as e:
            pass
        return df, task_dataset

#--------------- dataset with truncation ---------------------------------------------#


def get_fast_dataset(config, df, essay_df, mode="train"):
    """Function to get fast approach dataset with truncation & sliding window
    """
    dataset_creator = AuxFeedbackDataset(config)
    _, task_dataset = dataset_creator.get_dataset(df, essay_df, mode=mode)

    original_dataset = deepcopy(task_dataset)
    tokenizer = dataset_creator.tokenizer
    START_IDS = dataset_creator.discourse_token_ids
    END_IDS = dataset_creator.discourse_end_ids
    GLOBAL_IDS = dataset_creator.global_tokens

    def tokenize_with_truncation(examples):
        tz = tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=True,
            add_special_tokens=True,
            return_offsets_mapping=True,
            max_length=config["max_length"],
            stride=config["stride"],
            return_overflowing_tokens=True,
            return_token_type_ids=True,
        )
        return tz

    def process_span(examples):
        span_head_idxs, span_tail_idxs = [], []
        span_head_char_start_idxs, span_tail_char_end_idxs = [], []

        buffer = 25  # do not include a head if it is within buffer distance away from last token

        for example_input_ids, example_offset_mapping in zip(examples["input_ids"], examples["offset_mapping"]):
            # ------------------- Span Heads -----------------------------------------#
            if len(example_input_ids) < config["max_length"]:  # no truncation
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in START_IDS]
            else:
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if (
                    (this_id in START_IDS) & (pos <= config["max_length"]-buffer))]

            n_heads = len(head_candidate)

            # ------------------- Span Tails -----------------------------------------#
            tail_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in END_IDS]

            # ------------------- Edge Cases -----------------------------------------#
            # 1. A tail occurs before the first head in the sequence due to truncation
            if (len(tail_candidate) > 0) & (len(head_candidate) > 0):
                if tail_candidate[0] < head_candidate[0]:  # truncation effect
                    # print(f"check: heads: {head_candidate}, tails {tail_candidate}")
                    tail_candidate = tail_candidate[1:]  # shift by one

            # 2. Tail got chopped off due to truncation but the corresponding head is still there
            if len(tail_candidate) < n_heads:
                assert len(tail_candidate) + 1 == n_heads
                assert len(example_input_ids) == config["max_length"]  # should only happen if input text is truncated
                tail_candidate.append(config["max_length"]-2)  # the token before [SEP] token

            # 3. Additional tails remain in the buffer region
            if len(tail_candidate) > len(head_candidate):
                tail_candidate = tail_candidate[:len(head_candidate)]

            # ------------------- Create the fields ------------------------------------#
            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in head_candidate]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in tail_candidate]

            span_head_idxs.append(head_candidate)
            span_tail_idxs.append(tail_candidate)
            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def enforce_alignment(examples):
        uids = []

        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_uids = original_example["uids"]
            char2uid = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_uids)}
            current_example_uids = [char2uid[char_idx] for char_idx in example_span_head_char_start_idxs]
            uids.append(current_example_uids)
        return {"uids": uids}

    def recompute_labels(examples):
        labels = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_labels = original_example["labels"]
            char2label = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_labels)}
            current_example_labels = [char2label[char_idx] for char_idx in example_span_head_char_start_idxs]
            labels.append(current_example_labels)
        return {"labels": labels}

    def recompute_discourse_type_ids(examples):
        discourse_type_ids = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_discourse_type_ids = original_example["discourse_type_ids"]
            char2discourse_id = {k: v for k, v in zip(
                original_example_span_head_char_start_idxs, original_example_discourse_type_ids)}
            current_example_discourse_type_ids = [char2discourse_id[char_idx]
                                                  for char_idx in example_span_head_char_start_idxs]
            discourse_type_ids.append(current_example_discourse_type_ids)
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1, f"head idxs: {head_idxs}, tail idxs {tail_idxs}"

    task_dataset = task_dataset.map(
        tokenize_with_truncation,
        batched=True,
        remove_columns=task_dataset.column_names,
        batch_size=len(task_dataset)
    )

    task_dataset = task_dataset.map(process_span, batched=True)
    task_dataset = task_dataset.map(enforce_alignment, batched=True)
    task_dataset = task_dataset.map(recompute_discourse_type_ids, batched=True)
    task_dataset = task_dataset.map(sanity_check_head_tail, batched=True)

    # no need to run on empty set
    task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) != 0)
    task_dataset = task_dataset.map(compute_input_length, batched=True)

    if mode != "infer":
        task_dataset = task_dataset.map(recompute_labels, batched=True)

    to_return = dict()
    to_return["dataset"] = task_dataset
    to_return["original_dataset"] = original_dataset
    to_return["tokenizer"] = tokenizer
    return to_return

In [None]:
os.makedirs(config["model_dir"], exist_ok=True)

print("creating the inference datasets...")
infer_ds_dict = get_fast_dataset(config, test_df, essay_df, mode="infer")
tokenizer = infer_ds_dict["tokenizer"]
infer_dataset = infer_ds_dict["dataset"]
print(infer_dataset)

In [None]:
config["len_tokenizer"] = len(tokenizer)

infer_dataset = infer_dataset.sort("input_length")

infer_dataset.set_format(
    type=None,
    columns=['input_ids', 'attention_mask', 'token_type_ids', 'span_head_idxs',
             'span_tail_idxs', 'discourse_type_ids', 'uids']
)

## Data Loader

In [None]:
from copy import deepcopy
from dataclasses import dataclass

import torch
from transformers import DataCollatorWithPadding


@dataclass
class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    """
    data collector for seq classification
    """

    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = 512
    return_tensors = "pt"

    def __call__(self, features):
        uids = [feature["uids"] for feature in features]
        discourse_type_ids = [feature["discourse_type_ids"] for feature in features]
        span_head_idxs = [feature["span_head_idxs"] for feature in features]
        span_tail_idxs = [feature["span_tail_idxs"] for feature in features]
        span_attention_mask = [[1]*len(feature["span_head_idxs"]) for feature in features]

        labels = None
        if "labels" in features[0].keys():
            labels = [feature["labels"] for feature in features]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None,
        )

        b_max = max([len(l) for l in span_head_idxs])
        max_len = len(batch["input_ids"][0])

        default_head_idx = max(max_len - 10, 1)  # for padding
        default_tail_idx = max(max_len - 4, 1)  # for padding

        batch["span_head_idxs"] = [
            ex_span_head_idxs + [default_head_idx] * (b_max - len(ex_span_head_idxs)) for ex_span_head_idxs in span_head_idxs
        ]

        batch["uids"] = [ex_uids + [-1] * (b_max - len(ex_uids)) for ex_uids in uids]
        batch["discourse_type_ids"] = [ex_discourse_type_ids + [0] *
                                       (b_max - len(ex_discourse_type_ids)) for ex_discourse_type_ids in discourse_type_ids]

        batch["span_tail_idxs"] = [
            ex_span_tail_idxs + [default_tail_idx] * (b_max - len(ex_span_tail_idxs)) for ex_span_tail_idxs in span_tail_idxs
        ]

        batch["span_attention_mask"] = [
            ex_discourse_masks + [0] * (b_max - len(ex_discourse_masks)) for ex_discourse_masks in span_attention_mask
        ]

        if labels is not None:
            batch["labels"] = [ex_labels + [-1] * (b_max - len(ex_labels)) for ex_labels in labels]

        # multitask labels
        def _get_additional_labels(label_id):
            if label_id == 0:
                vec = [0, 0]
            elif label_id == 1:
                vec = [1, 0]
            elif label_id == 2:
                vec = [1, 1]
            elif label_id == -1:
                vec = [-1, -1]
            else:
                raise
            return vec

        if labels is not None:
            additional_labels = []
            for ex_labels in batch["labels"]:
                ex_additional_labels = [_get_additional_labels(el) for el in ex_labels]
                additional_labels.append(ex_additional_labels)
            batch["multitask_labels"] = additional_labels
        # pdb.set_trace()

        batch = {k: (torch.tensor(v, dtype=torch.int64) if k != "multitask_labels" else torch.tensor(
            v, dtype=torch.float32)) for k, v in batch.items()}
        return batch

In [None]:
data_collector = CustomDataCollatorWithPadding(tokenizer=tokenizer)

infer_dl = DataLoader(
    infer_dataset,
    batch_size=config["infer_bs"],
    shuffle=False,
    collate_fn=data_collector
)

## Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = nn.Dropout(self.config["dropout"])
        
        # multi-head attention over span representations
        attention_config = BertConfig()
        attention_config.update(
            {
                "num_attention_heads": self.base_model.config.num_attention_heads,
                "hidden_size": self.base_model.config.hidden_size,
                "attention_probs_dropout_prob": self.base_model.config.attention_probs_dropout_prob,
                "is_decoder": False,

            }
        )
        self.fpe_span_attention = BertAttention(attention_config, position_embedding_type="relative_key")
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )


        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, token_type_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        bs = input_ids.shape[0]  # batch size

        outputs = self.base_model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        encoder_layer = outputs[0]

        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]  # LSTM layer outputs

        mean_feature_vector = []

        for i in range(bs):
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attention mechanism
        extended_span_attention_mask = span_attention_mask[:, None, None, :]
        # extended_span_attention_mask = extended_span_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_span_attention_mask = (1.0 - extended_span_attention_mask) * -10000.0
        feature_vector = self.fpe_span_attention(mean_feature_vector, extended_span_attention_mask)[0]

        feature_vector = self.dropout(feature_vector)
        logits = self.classifier(feature_vector)
        
        ######
        
        #logits = logits[:,:, :3] # main logits
        return logits

## Inference

In [None]:
checkpoints = [
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_0_best.pth.tar",
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_1_best.pth.tar",
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_2_best.pth.tar",
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_3_best.pth.tar",
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_4_best.pth.tar",
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_5_best.pth.tar",
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_6_best.pth.tar",
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_7_best.pth.tar",
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_8_best.pth.tar",
    "../input/tk-fpe-models-v5/exp213-deb-l-prompt-mlm50/fpe_model_fold_9_best.pth.tar",

]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp213_dl_model_preds_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackModel(config)
    ckpt = torch.load(checkpoint)
    print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)
    
del model
#del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
gc.collect()
torch.cuda.empty_cache()

# Exp213a

In [None]:
checkpoints = [
    "../input/exp213a-deb-l-prompt/fpe_model_fold_0_best.pth.tar",
    "../input/exp213a-deb-l-prompt/fpe_model_fold_1_best.pth.tar",
    "../input/exp213a-deb-l-prompt/fpe_model_fold_2_best.pth.tar",
    "../input/exp213a-deb-l-prompt/fpe_model_fold_3_best.pth.tar",
    "../input/exp213a-deb-l-prompt/fpe_model_fold_4_best.pth.tar",
    "../input/exp213a-deb-l-prompt/fpe_model_fold_5_best.pth.tar",
    "../input/exp213a-deb-l-prompt/fpe_model_fold_6_best.pth.tar",
    "../input/exp213a-deb-l-prompt/fpe_model_fold_7_best.pth.tar",
    "../input/exp213a-deb-l-prompt/fpe_model_fold_8_best.pth.tar",
    "../input/exp213a-deb-l-prompt/fpe_model_fold_9_best.pth.tar",

]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp213a_dl_model_preds_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackModel(config)
    ckpt = torch.load(checkpoint)
    print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)
    
del model
# del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
gc.collect()
torch.cuda.empty_cache()

In [None]:
import glob
import pandas as pd

csvs = glob.glob("exp213_dl_model_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp213_df = pd.DataFrame()
exp213_df["discourse_id"] = idx
exp213_df["Ineffective"]  = preds[:, 0]
exp213_df["Adequate"]     = preds[:, 1]
exp213_df["Effective"]    = preds[:, 2]

exp213_df = exp213_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
exp213_df.head()

In [None]:
import glob
import pandas as pd

csvs = glob.glob("exp213a_dl_model_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp213a_df = pd.DataFrame()
exp213a_df["discourse_id"] = idx
exp213a_df["Ineffective"]  = preds[:, 0]
exp213a_df["Adequate"]     = preds[:, 1]
exp213a_df["Effective"]    = preds[:, 2]

exp213a_df = exp213a_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
exp213a_df.head()

#### Full data models

In [None]:
if use_full_data_models:

    checkpoints = [
        "../input/exp213f-deb-l-prompt-all/fpe_model_fold_0_best.pth.tar",
        "../input/exp213f-deb-l-prompt-all/fpe_model_fold_1_best.pth.tar",
    ]

    def inference_fn(model, infer_dl, model_id):
        all_preds = []
        all_uids = []
        accelerator = Accelerator()
        model, infer_dl = accelerator.prepare(model, infer_dl)

        model.eval()
        tk0 = tqdm(infer_dl, total=len(infer_dl))

        for batch in tk0:
            with torch.no_grad():
                logits = model(**batch) # (b, nd, 3)
                batch_preds = F.softmax(logits, dim=-1)
                batch_uids = batch["uids"]
            all_preds.append(batch_preds)
            all_uids.append(batch_uids)

        all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
        all_preds = list(chain(*all_preds))
        flat_preds = list(chain(*all_preds))

        all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
        all_uids = list(chain(*all_uids))
        flat_uids = list(chain(*all_uids))    

        preds_df = pd.DataFrame(flat_preds)
        preds_df.columns = ["Ineffective", "Adequate", "Effective"]
        preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
        preds_df = preds_df[preds_df["span_uid"]>=0].copy()
        preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
        preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
        preds_df.to_csv(f"exp213f_preds_{model_id}.csv", index=False)

    from copy import deepcopy
    for model_id, checkpoint in enumerate(checkpoints):
        #print(f"infering from {checkpoint}")
        new_config = deepcopy(config)

        model = FeedbackModel(new_config)
        #model.half()
        if "swa" in checkpoint:
            ckpt = process_swa_checkpoint(checkpoint)
        else:
            ckpt = torch.load(checkpoint)
            #print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

In [None]:
if use_full_data_models:

    import glob
    import pandas as pd

    csvs = glob.glob("exp213f_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        #print("=="*40)
        #print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        #print(df.head(10))
        #print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp213f_df = pd.DataFrame()
    exp213f_df["discourse_id"] = idx
    exp213f_df["Ineffective"] = preds[:, 0]
    exp213f_df["Adequate"] = preds[:, 1]
    exp213f_df["Effective"] = preds[:, 2]

    exp213f_df = exp213f_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()
    exp213f_df.to_csv("submission.csv", index=False)

In [None]:
if use_full_data_models:
    print(exp213f_df.head())

In [None]:
try:
    del tokenizer, infer_dataset, infer_ds_dict, data_collector, infer_dl
    gc.collect()
except Exception as e:
    print(e)

In [None]:
try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except Exception as e:
    print(e)

# EXP 205 - Deberta-v3-large - MSD + prompt

## Config

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/tk-fpe-models-v2/exp205-debv3-l-prompt/mlm_model/",
    "model_dir": "./outputs",

    "max_length": 1024,
    "stride": 256,
    "num_labels": 3,
    "dropout": 0.1,
    "infer_bs": 8
}
"""
config = json.loads(config)

## Dataset

In [None]:
import os
import re
from copy import deepcopy
from itertools import chain

import pandas as pd
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer


#--------------- Tokenizer ---------------------------------------------#
def get_tokenizer(config):
    """load the tokenizer"""

    print("using auto tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])

    print("=="*40)
    print(f"tokenizer len: {len(tokenizer)}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [SOE] [LEAD] [CLAIM] [POSITION] [COUNTER_CLAIM]')}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [EOE] [LEAD_END] [POSITION_END] [CLAIM_END]')}")

    print("=="*40)
    return tokenizer


#--------------- Processing ---------------------------------------------#


DISCOURSE_START_TOKENS = [
    "[LEAD]",
    "[POSITION]",
    "[CLAIM]",
    "[COUNTER_CLAIM]",
    "[REBUTTAL]",
    "[EVIDENCE]",
    "[CONCLUDING_STATEMENT]"
]

TOKEN_MAP = {
    "topic": ["Topic [TOPIC]", "[TOPIC END]"],
    "Lead": ["Lead [LEAD]", "[LEAD END]"],
    "Position": ["Position [POSITION]", "[POSITION END]"],
    "Claim": ["Claim [CLAIM]", "[CLAIM END]"],
    "Counterclaim": ["Counterclaim [COUNTER_CLAIM]", "[COUNTER_CLAIM END]"],
    "Rebuttal": ["Rebuttal [REBUTTAL]", "[REBUTTAL END]"],
    "Evidence": ["Evidence [EVIDENCE]", "[EVIDENCE END]"],
    "Concluding Statement": ["Concluding Statement [CONCLUDING_STATEMENT]", "[CONCLUDING_STATEMENT END]"]
}


DISCOURSE_END_TOKENS = [
    "[LEAD END]",
    "[POSITION END]",
    "[CLAIM END]",
    "[COUNTER_CLAIM END]",
    "[REBUTTAL END]",
    "[EVIDENCE END]",
    "[CONCLUDING_STATEMENT END]",
]



def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    for cur_discourse in discourse_list:
        if cur_discourse not in to_return:
            to_return[cur_discourse] = []

        matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
        for match in matches:
            span_start, span_end = match.span()
            if span_end <= reading_head:
                continue
            to_return[cur_discourse].append(match.span())
            reading_head = span_end
            break

    # post process
    for cur_discourse in discourse_list:
        if not to_return[cur_discourse]:
            print("resorting to relaxed search...")
            to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_essay(essay_id, essay_text, topic, prompt, anno_df):
    """insert newly added tokens in the essay text
    """
    tmp_df = anno_df[anno_df["essay_id"] == essay_id].copy()
    tmp_df = tmp_df.sort_values(by="discourse_start")
    buffer = 0

    for _, row in tmp_df.iterrows():
        s, e, d_type = int(row.discourse_start) + buffer, int(row.discourse_end) + buffer, row.discourse_type
        s_tok, e_tok = TOKEN_MAP[d_type]
        essay_text = " ".join([essay_text[:s], s_tok, essay_text[s:e], e_tok, essay_text[e:]])
        buffer += len(s_tok) + len(e_tok) + 4

    essay_text = "[SOE]" + " [TOPIC] " + prompt + " [TOPIC END] " +  essay_text + "[EOE]"
    return essay_text


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness", "uid"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()

    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")
    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    # anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    print("=="*40)
    print("processing essay text and inserting new tokens at span boundaries")
    notes_df["essay_text"] = notes_df[["essay_id", "essay_text", "topic", "prompt"]].apply(
        lambda x: process_essay(x[0], x[1], x[2], x[3], anno_df), axis=1
    )
    print("=="*40)

    anno_df = anno_df.drop(columns=["discourse_start", "discourse_end"])
    notes_df = notes_df.drop_duplicates(subset=["essay_id"])[["essay_id", "essay_text"]].copy()

    anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    if "discourse_effectiveness" in anno_df.columns:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_effectiveness", "discourse_type"]].agg(list).reset_index()
    else:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_type"]].agg(list).reset_index()

    grouped_df = pd.merge(grouped_df, notes_df, on="essay_id", how="left")
    grouped_df = grouped_df.rename(columns={"uid": "uids"})

    return grouped_df


#--------------- Dataset ----------------------------------------------#


class AuxFeedbackDataset:
    """Dataset class for feedback prize effectiveness task
    """

    def __init__(self, config):
        self.config = config

        self.label2id = {
            "Ineffective": 0,
            "Adequate": 1,
            "Effective": 2,
        }

        self.discourse_type2id = {
            "Lead": 1,
            "Position": 2,
            "Claim": 3,
            "Counterclaim": 4,
            "Rebuttal": 5,
            "Evidence": 6,
            "Concluding Statement": 7,
        }

        self.id2label = {v: k for k, v in self.label2id.items()}
        self.load_tokenizer()

    def load_tokenizer(self):
        """load tokenizer as per config 
        """
        self.tokenizer = get_tokenizer(self.config)
        print("=="*40)
        print("token maps...")
        print(TOKEN_MAP)
        print("=="*40)

        # print("adding new tokens...")
        # tokens_to_add = []
        # for this_tok in NEW_TOKENS:
        #     tokens_to_add.append(AddedToken(this_tok, lstrip=True, rstrip=False))
        # self.tokenizer.add_tokens(tokens_to_add)
        print(f"tokenizer len: {len(self.tokenizer)}")

        self.discourse_token_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_START_TOKENS))
        self.discourse_end_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_END_TOKENS))
        self.global_tokens = self.discourse_token_ids.union(self.discourse_end_ids)

    def tokenize_function(self, examples):
        tz = self.tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=False,  # no truncation at first
            add_special_tokens=True,
            return_offsets_mapping=True,
        )
        return tz

    def process_spans(self, examples):

        span_head_char_start_idxs, span_tail_char_end_idxs = [], []
        span_head_idxs, span_tail_idxs = [], []

        for example_input_ids, example_offset_mapping, example_uids in zip(examples["input_ids"], examples["offset_mapping"], examples["uids"]):
            example_span_head_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_token_ids]
            example_span_tail_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_end_ids]

            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in example_span_head_idxs]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in example_span_tail_idxs]

            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

            span_head_idxs.append(example_span_head_idxs)
            span_tail_idxs.append(example_span_tail_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def generate_labels(self, examples):
        labels = []
        for example_labels, example_uids in zip(examples["discourse_effectiveness"], examples["uids"]):
            labels.append([self.label2id[l] for l in example_labels])
        return {"labels": labels}

    def generate_discourse_type_ids(self, examples):
        discourse_type_ids = []
        for example_discourse_types in examples["discourse_type"]:
            discourse_type_ids.append([self.discourse_type2id[dt] for dt in example_discourse_types])
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(self, examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(self, examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1

    def sanity_check_head_labels(self, examples):
        for head_idxs, head_labels in zip(examples["span_head_idxs"], examples["labels"]):
            assert len(head_idxs) == len(head_labels)

    def get_dataset(self, df, essay_df, mode='train'):
        """main api for creating the Feedback dataset

        :param df: input annotation dataframe
        :type df: pd.DataFrame
        :param essay_df: dataframe with essay texts
        :type essay_df: pd.DataFrame
        :param mode: check if required for train or infer, defaults to 'train'
        :type mode: str, optional
        :return: the created dataset
        :rtype: Dataset
        """
        df = process_input_df(df, essay_df)

        # save a sample for sanity checks
        sample_df = df.sample(min(16, len(df)))
        sample_df.to_csv(os.path.join(self.config["model_dir"], f"{mode}_df_processed.csv"), index=False)

        task_dataset = Dataset.from_pandas(df)
        task_dataset = task_dataset.map(self.tokenize_function, batched=True)
        task_dataset = task_dataset.map(self.compute_input_length, batched=True)
        task_dataset = task_dataset.map(self.process_spans, batched=True)
        print(task_dataset)
        # todo check edge cases
        task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) == len(
            example['span_tail_idxs']))  # no need to run on empty set
        print(task_dataset)
        task_dataset = task_dataset.map(self.generate_discourse_type_ids, batched=True)
        task_dataset = task_dataset.map(self.sanity_check_head_tail, batched=True)

        if mode != "infer":
            task_dataset = task_dataset.map(self.generate_labels, batched=True)
            task_dataset = task_dataset.map(self.sanity_check_head_labels, batched=True)

        try:
            task_dataset = task_dataset.remove_columns(column_names=["__index_level_0__"])
        except Exception as e:
            pass
        return df, task_dataset

#--------------- dataset with truncation ---------------------------------------------#


def get_fast_dataset(config, df, essay_df, mode="train"):
    """Function to get fast approach dataset with truncation & sliding window
    """
    dataset_creator = AuxFeedbackDataset(config)
    _, task_dataset = dataset_creator.get_dataset(df, essay_df, mode=mode)

    original_dataset = deepcopy(task_dataset)
    tokenizer = dataset_creator.tokenizer
    START_IDS = dataset_creator.discourse_token_ids
    END_IDS = dataset_creator.discourse_end_ids
    GLOBAL_IDS = dataset_creator.global_tokens

    def tokenize_with_truncation(examples):
        tz = tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=True,
            add_special_tokens=True,
            return_offsets_mapping=True,
            max_length=config["max_length"],
            stride=config["stride"],
            return_overflowing_tokens=True,
            return_token_type_ids=True,
        )
        return tz

    def process_span(examples):
        span_head_idxs, span_tail_idxs = [], []
        span_head_char_start_idxs, span_tail_char_end_idxs = [], []

        buffer = 25  # do not include a head if it is within buffer distance away from last token

        for example_input_ids, example_offset_mapping in zip(examples["input_ids"], examples["offset_mapping"]):
            # ------------------- Span Heads -----------------------------------------#
            if len(example_input_ids) < config["max_length"]:  # no truncation
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in START_IDS]
            else:
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if (
                    (this_id in START_IDS) & (pos <= config["max_length"]-buffer))]

            n_heads = len(head_candidate)

            # ------------------- Span Tails -----------------------------------------#
            tail_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in END_IDS]

            # ------------------- Edge Cases -----------------------------------------#
            # 1. A tail occurs before the first head in the sequence due to truncation
            if (len(tail_candidate) > 0) & (len(head_candidate) > 0):
                if tail_candidate[0] < head_candidate[0]:  # truncation effect
                    # print(f"check: heads: {head_candidate}, tails {tail_candidate}")
                    tail_candidate = tail_candidate[1:]  # shift by one

            # 2. Tail got chopped off due to truncation but the corresponding head is still there
            if len(tail_candidate) < n_heads:
                assert len(tail_candidate) + 1 == n_heads
                assert len(example_input_ids) == config["max_length"]  # should only happen if input text is truncated
                tail_candidate.append(config["max_length"]-2)  # the token before [SEP] token

            # 3. Additional tails remain in the buffer region
            if len(tail_candidate) > len(head_candidate):
                tail_candidate = tail_candidate[:len(head_candidate)]

            # ------------------- Create the fields ------------------------------------#
            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in head_candidate]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in tail_candidate]

            span_head_idxs.append(head_candidate)
            span_tail_idxs.append(tail_candidate)
            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def enforce_alignment(examples):
        uids = []

        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_uids = original_example["uids"]
            char2uid = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_uids)}
            current_example_uids = [char2uid[char_idx] for char_idx in example_span_head_char_start_idxs]
            uids.append(current_example_uids)
        return {"uids": uids}

    def recompute_labels(examples):
        labels = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_labels = original_example["labels"]
            char2label = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_labels)}
            current_example_labels = [char2label[char_idx] for char_idx in example_span_head_char_start_idxs]
            labels.append(current_example_labels)
        return {"labels": labels}

    def recompute_discourse_type_ids(examples):
        discourse_type_ids = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_discourse_type_ids = original_example["discourse_type_ids"]
            char2discourse_id = {k: v for k, v in zip(
                original_example_span_head_char_start_idxs, original_example_discourse_type_ids)}
            current_example_discourse_type_ids = [char2discourse_id[char_idx]
                                                  for char_idx in example_span_head_char_start_idxs]
            discourse_type_ids.append(current_example_discourse_type_ids)
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1, f"head idxs: {head_idxs}, tail idxs {tail_idxs}"

    task_dataset = task_dataset.map(
        tokenize_with_truncation,
        batched=True,
        remove_columns=task_dataset.column_names,
        batch_size=len(task_dataset)
    )

    task_dataset = task_dataset.map(process_span, batched=True)
    task_dataset = task_dataset.map(enforce_alignment, batched=True)
    task_dataset = task_dataset.map(recompute_discourse_type_ids, batched=True)
    task_dataset = task_dataset.map(sanity_check_head_tail, batched=True)

    # no need to run on empty set
    task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) != 0)
    task_dataset = task_dataset.map(compute_input_length, batched=True)

    if mode != "infer":
        task_dataset = task_dataset.map(recompute_labels, batched=True)

    to_return = dict()
    to_return["dataset"] = task_dataset
    to_return["original_dataset"] = original_dataset
    to_return["tokenizer"] = tokenizer
    return to_return

In [None]:
if use_exp205 or use_exp209:
    os.makedirs(config["model_dir"], exist_ok=True)

    print("creating the inference datasets...")
    infer_ds_dict = get_fast_dataset(config, test_df, essay_df, mode="infer")
    tokenizer = infer_ds_dict["tokenizer"]
    infer_dataset = infer_ds_dict["dataset"]
    print(infer_dataset)

In [None]:
if use_exp205 or use_exp209:
    config["len_tokenizer"] = len(tokenizer)

    infer_dataset = infer_dataset.sort("input_length")

    infer_dataset.set_format(
        type=None,
        columns=['input_ids', 'attention_mask', 'token_type_ids', 'span_head_idxs',
                 'span_tail_idxs', 'discourse_type_ids', 'uids']
    )

## Data Loader

In [None]:
from copy import deepcopy
from dataclasses import dataclass

import torch
from transformers import DataCollatorWithPadding


@dataclass
class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    """
    data collector for seq classification
    """

    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = 512
    return_tensors = "pt"

    def __call__(self, features):
        uids = [feature["uids"] for feature in features]
        discourse_type_ids = [feature["discourse_type_ids"] for feature in features]
        span_head_idxs = [feature["span_head_idxs"] for feature in features]
        span_tail_idxs = [feature["span_tail_idxs"] for feature in features]
        span_attention_mask = [[1]*len(feature["span_head_idxs"]) for feature in features]

        labels = None
        if "labels" in features[0].keys():
            labels = [feature["labels"] for feature in features]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None,
        )

        b_max = max([len(l) for l in span_head_idxs])
        max_len = len(batch["input_ids"][0])

        default_head_idx = max(max_len - 10, 1)  # for padding
        default_tail_idx = max(max_len - 4, 1)  # for padding

        batch["span_head_idxs"] = [
            ex_span_head_idxs + [default_head_idx] * (b_max - len(ex_span_head_idxs)) for ex_span_head_idxs in span_head_idxs
        ]

        batch["uids"] = [ex_uids + [-1] * (b_max - len(ex_uids)) for ex_uids in uids]
        batch["discourse_type_ids"] = [ex_discourse_type_ids + [0] *
                                       (b_max - len(ex_discourse_type_ids)) for ex_discourse_type_ids in discourse_type_ids]

        batch["span_tail_idxs"] = [
            ex_span_tail_idxs + [default_tail_idx] * (b_max - len(ex_span_tail_idxs)) for ex_span_tail_idxs in span_tail_idxs
        ]

        batch["span_attention_mask"] = [
            ex_discourse_masks + [0] * (b_max - len(ex_discourse_masks)) for ex_discourse_masks in span_attention_mask
        ]

        if labels is not None:
            batch["labels"] = [ex_labels + [-1] * (b_max - len(ex_labels)) for ex_labels in labels]

        # multitask labels
        def _get_additional_labels(label_id):
            if label_id == 0:
                vec = [0, 0]
            elif label_id == 1:
                vec = [1, 0]
            elif label_id == 2:
                vec = [1, 1]
            elif label_id == -1:
                vec = [-1, -1]
            else:
                raise
            return vec

        if labels is not None:
            additional_labels = []
            for ex_labels in batch["labels"]:
                ex_additional_labels = [_get_additional_labels(el) for el in ex_labels]
                additional_labels.append(ex_additional_labels)
            batch["multitask_labels"] = additional_labels
        # pdb.set_trace()

        batch = {k: (torch.tensor(v, dtype=torch.int64) if k != "multitask_labels" else torch.tensor(
            v, dtype=torch.float32)) for k, v in batch.items()}
        return batch

In [None]:
if use_exp205 or use_exp209:
    data_collector = CustomDataCollatorWithPadding(tokenizer=tokenizer)

    infer_dl = DataLoader(
        infer_dataset,
        batch_size=config["infer_bs"],
        shuffle=False,
        collate_fn=data_collector
    )

## Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention

from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = StableDropout(self.config["dropout"])
        
        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
            bidirectional=False,
        )

        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]
        
        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        feature_vector = self.dropout(feature_vector)

        logits = self.classifier(feature_vector)

        return logits

## Inference

In [None]:
checkpoints = [
    "../input/exp205-debv3-l-prompt/fpe_model_fold_0_best.pth.tar",
    "../input/exp205-debv3-l-prompt/fpe_model_fold_1_best.pth.tar",
    "../input/exp205-debv3-l-prompt/fpe_model_fold_2_best.pth.tar",
    "../input/exp205-debv3-l-prompt/fpe_model_fold_3_best.pth.tar",
    "../input/exp205-debv3-l-prompt/fpe_model_fold_4_best.pth.tar",
    "../input/exp205-debv3-l-prompt/fpe_model_fold_5_best.pth.tar",
    "../input/exp205-debv3-l-prompt/fpe_model_fold_6_best.pth.tar",
    "../input/exp205-debv3-l-prompt/fpe_model_fold_7_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp205_model_preds_{model_id}.csv", index=False)
    
if use_exp205:

    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        model = FeedbackModel(config)
        ckpt = torch.load(checkpoint)
        print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

    del model
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
if use_exp205:
    import glob
    import pandas as pd

    csvs = glob.glob("exp205_model_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp205_df = pd.DataFrame()
    exp205_df["discourse_id"] = idx
    exp205_df["Ineffective"]  = preds[:, 0]
    exp205_df["Adequate"]     = preds[:, 1]
    exp205_df["Effective"]    = preds[:, 2]

    exp205_df = exp205_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
if use_exp205:
    exp205_df.head()

# Exp209 - 10 fold debv3-l

In [None]:
checkpoints = [
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_0_best.pth.tar",
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_1_best.pth.tar",
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_2_best.pth.tar",
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_3_best.pth.tar",
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_4_best.pth.tar",
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_5_best.pth.tar",
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_6_best.pth.tar",
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_7_best.pth.tar",
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_8_best.pth.tar",
    "../input/tk-fpe-models-v6/exp209-debv3-l-prompt/fpe_model_fold_9_best.pth.tar",
]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention

from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler, StableDropout, DebertaV2Attention


#-------- Model ------------------------------------------------------------------#
class FeedbackModel(nn.Module):
    """The feedback prize effectiveness baseline model
    """

    def __init__(self, config):
        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update({"add_pooling_layer": False, "max_position_embeddings": 1024})
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # dropouts
        self.dropout = StableDropout(self.config["dropout"])
        
        # multi-head attention
        attention_config = deepcopy(self.base_model.config)
        attention_config.update({"relative_attention": False})
        self.fpe_span_attention = DebertaV2Attention(attention_config)
        
        # classification
        hidden_size = self.base_model.config.hidden_size
        feature_size = hidden_size
        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)
        
        # # LSTM Head
        self.fpe_lstm_layer = nn.LSTM(
            input_size=feature_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        self.num_labels = self.config["num_labels"]
        self.classifier = nn.Linear(feature_size, self.config["num_labels"])

    def forward(self, input_ids, attention_mask, span_head_idxs, span_tail_idxs, span_attention_mask, **kwargs):
        
        bs = input_ids.shape[0]  # batch size
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        encoder_layer = outputs[0]
        
        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        mean_feature_vector = []
        for i in range(bs):  # TODO: vectorize
            span_vec_i = []

            for head, tail in zip(span_head_idxs[i], span_tail_idxs[i]):
                # span feature
                tmp = torch.mean(encoder_layer[i, head+1:tail], dim=0)  # [h]
                span_vec_i.append(tmp)
            span_vec_i = torch.stack(span_vec_i)  # (num_disourse, h)
            mean_feature_vector.append(span_vec_i)

        mean_feature_vector = torch.stack(mean_feature_vector)  # (bs, num_disourse, h)
        mean_feature_vector = self.layer_norm(mean_feature_vector)

        # attend to other features
        extended_span_attention_mask = span_attention_mask.unsqueeze(1).unsqueeze(2)
        span_attention_mask = extended_span_attention_mask * extended_span_attention_mask.squeeze(-2).unsqueeze(-1)
        span_attention_mask = span_attention_mask.byte()
        feature_vector = self.fpe_span_attention(mean_feature_vector, span_attention_mask)

        feature_vector = self.dropout(feature_vector)

        logits = self.classifier(feature_vector)

        return logits

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp209_model_preds_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackModel(config)
    ckpt = torch.load(checkpoint)
    print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)
    
del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
import glob
import pandas as pd

csvs = glob.glob("exp209_model_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp209_df = pd.DataFrame()
exp209_df["discourse_id"] = idx
exp209_df["Ineffective"]  = preds[:, 0]
exp209_df["Adequate"]     = preds[:, 1]
exp209_df["Effective"]    = preds[:, 2]

exp209_df = exp209_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
exp209_df.head()

#### All data model

In [None]:
if use_full_data_models:

    checkpoints = [
        "../input/exp209a-debv3-l-prompt-all/fpe_model_fold_0_best.pth.tar",
    ]

    def inference_fn(model, infer_dl, model_id):
        all_preds = []
        all_uids = []
        accelerator = Accelerator()
        model, infer_dl = accelerator.prepare(model, infer_dl)

        model.eval()
        tk0 = tqdm(infer_dl, total=len(infer_dl))

        for batch in tk0:
            with torch.no_grad():
                logits = model(**batch) # (b, nd, 3)
                batch_preds = F.softmax(logits, dim=-1)
                batch_uids = batch["uids"]
            all_preds.append(batch_preds)
            all_uids.append(batch_uids)

        all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
        all_preds = list(chain(*all_preds))
        flat_preds = list(chain(*all_preds))

        all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
        all_uids = list(chain(*all_uids))
        flat_uids = list(chain(*all_uids))    

        preds_df = pd.DataFrame(flat_preds)
        preds_df.columns = ["Ineffective", "Adequate", "Effective"]
        preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
        preds_df = preds_df[preds_df["span_uid"]>=0].copy()
        preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
        preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
        preds_df.to_csv(f"exp209a_preds_{model_id}.csv", index=False)

    from copy import deepcopy
    for model_id, checkpoint in enumerate(checkpoints):
        print(f"infering from {checkpoint}")
        new_config = deepcopy(config)
        if "10a" in checkpoint:
            new_config["num_labels"] = 5

        model = FeedbackModel(new_config)
        model.half()
        if "swa" in checkpoint:
            ckpt = process_swa_checkpoint(checkpoint)
        else:
            ckpt = torch.load(checkpoint)
            print(f"validation score for fold {model_id} = {ckpt['loss']}")
        model.load_state_dict(ckpt['state_dict'])
        inference_fn(model, infer_dl, model_id)

In [None]:
if use_full_data_models:

    import glob
    import pandas as pd

    csvs = glob.glob("exp209a_preds_*.csv")

    idx = []
    preds = []


    for csv_idx, csv in enumerate(csvs):

        print("=="*40)
        print(f"preds in {csv}")
        df = pd.read_csv(csv)
        df = df.sort_values(by=["discourse_id"])
        print(df.head(10))
        print("=="*40)

        temp_preds = df.drop(["discourse_id"], axis=1).values
        if csv_idx == 0:
            idx = list(df["discourse_id"])
            preds = temp_preds
        else:
            preds += temp_preds

    preds = preds / len(csvs)

    exp209a_df = pd.DataFrame()
    exp209a_df["discourse_id"] = idx
    exp209a_df["Ineffective"] = preds[:, 0]
    exp209a_df["Adequate"] = preds[:, 1]
    exp209a_df["Effective"] = preds[:, 2]

    exp209a_df = exp209a_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
if use_full_data_models:
    print(exp209a_df.head())

In [None]:
try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except Exception as e:
    print(e)

# LUKE

In [None]:
config = """{
    "debug": false,

    "base_model_path": "../input/luke-span-mlm",
    "model_dir": "./outputs",

    "max_length": 512,
    "max_position_embeddings": 512,
    "stride": 128,
    "max_mention_length": 400,
    "max_entity_length": 24,
    "num_labels": 3,
    "dropout": 0.1,
    "infer_bs": 16
}
"""
config = json.loads(config)

In [None]:
import os
import pdb
import re
from copy import deepcopy
from itertools import chain

import pandas as pd
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer, LukeTokenizer


#--------------- Tokenizer ---------------------------------------------#
def tokenizer_test(tokenizer):
    print("=="*40)
    print(f"tokenizer len: {len(tokenizer)}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [==SOE==] [==SPAN==] [==END==]')}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [==EOE==] [==SPAN==] [==END==]')}")

    print("=="*40)


def get_tokenizer(config):
    """load the tokenizer"""

    print("using auto tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])

    return tokenizer


#--------------- Additional Tokens ---------------------------------------------#

TOKEN_MAP = {
    "Lead":                     ["[==SPAN==]", "[==END==]"],
    "Position":                 ["[==SPAN==]", "[==END==]"],
    "Claim":                    ["[==SPAN==]", "[==END==]"],
    "Counterclaim":             ["[==SPAN==]", "[==END==]"],
    "Rebuttal":                 ["[==SPAN==]", "[==END==]"],
    "Evidence":                 ["[==SPAN==]", "[==END==]"],
    "Concluding Statement":     ["[==SPAN==]", "[==END==]"]
}

DISCOURSE_START_TOKENS = [
    "[==SPAN==]",
]

DISCOURSE_END_TOKENS = [
    "[==END==]",
]

NEW_TOKENS = [
    "[==SPAN==]",
    "[==END==]",
    "[==SOE==]",
    "[==EOE==]",
]

ADD_NEW_TOKENS_IN_LUKE = False
#--------------- Data Processing ---------------------------------------------#


def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    try:
        for cur_discourse in discourse_list:
            if cur_discourse not in to_return:
                to_return[cur_discourse] = []

            matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
            for match in matches:
                span_start, span_end = match.span()
                if span_end <= reading_head:
                    continue
                to_return[cur_discourse].append(match.span())
                reading_head = span_end
                break

        # post process
        for cur_discourse in discourse_list:
            if not to_return[cur_discourse]:
                print("resorting to relaxed search...")
                to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    except Exception as e:
        pdb.set_trace()
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_essay(essay_id, essay_text, anno_df):
    """insert newly added tokens in the essay text
    """
    tmp_df = anno_df[anno_df["essay_id"] == essay_id].copy()
    tmp_df = tmp_df.sort_values(by="discourse_start")
    buffer = 0

    for _, row in tmp_df.iterrows():
        s, e, d_type = int(row.discourse_start) + buffer, int(row.discourse_end) + buffer, row.discourse_type
        s_tok, e_tok = TOKEN_MAP[d_type]
        essay_text = " ".join([essay_text[:s], s_tok, essay_text[s:e], e_tok, essay_text[e:]])
        buffer += len(s_tok) + len(e_tok) + 4

    essay_text = "[==SOE==]" + essay_text + "[==EOE==]"
    return essay_text


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)
    # pdb.set_trace()
    # set_trace()

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness", "uid"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()

    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")

    print("--"*40)
    print("Warning! the following essay_ids are removed during processing...")
    remove_essay_ids = tmp_df[tmp_df["essay_text"].isna()].essay_id.unique()
    print(remove_essay_ids)
    tmp_df = tmp_df[~tmp_df["essay_id"].isin(remove_essay_ids)].copy()
    anno_df = anno_df[~anno_df["essay_id"].isin(remove_essay_ids)].copy()
    notes_df = notes_df[~notes_df["essay_id"].isin(remove_essay_ids)].copy()
    print("--"*40)

    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    # anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    print("=="*40)
    print("processing essay text and inserting new tokens at span boundaries")
    notes_df["essay_text"] = notes_df[["essay_id", "essay_text"]].apply(
        lambda x: process_essay(x[0], x[1], anno_df), axis=1
    )
    print("=="*40)

    anno_df = anno_df.drop(columns=["discourse_start", "discourse_end"])
    notes_df = notes_df.drop_duplicates(subset=["essay_id"])[["essay_id", "essay_text"]].copy()

    anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    if "discourse_effectiveness" in anno_df.columns:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_effectiveness", "discourse_type"]].agg(list).reset_index()
    else:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_type"]].agg(list).reset_index()

    grouped_df = pd.merge(grouped_df, notes_df, on="essay_id", how="left")
    grouped_df = grouped_df.rename(columns={"uid": "uids"})

    return grouped_df


#--------------- Dataset w/o Truncation ----------------------------------------------#


class AuxFeedbackDataset:
    """Dataset class for feedback prize effectiveness task
    """

    def __init__(self, config):
        self.config = config

        self.label2id = {
            "Ineffective": 0,
            "Adequate": 1,
            "Effective": 2,
            "Mask": -1,
        }

        self.discourse_type2id = {
            "Lead": 0,
            "Position": 1,
            "Claim": 2,
            "Counterclaim": 3,
            "Rebuttal": 4,
            "Evidence": 5,
            "Concluding Statement": 6,
        }

        self.id2label = {v: k for k, v in self.label2id.items()}
        self.load_tokenizer()

    def load_tokenizer(self):
        """load tokenizer as per config 
        """
        self.tokenizer = get_tokenizer(self.config)
        print("=="*40)
        print("token maps...")
        print(TOKEN_MAP)
        print("=="*40)

        print("adding new tokens...")
        tokens_to_add = []

        for this_tok in NEW_TOKENS:
            tokens_to_add.append(AddedToken(this_tok, lstrip=False, rstrip=False))
        self.tokenizer.add_tokens(tokens_to_add)
        print(f"tokenizer len: {len(self.tokenizer)}")

        self.discourse_token_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_START_TOKENS))
        self.discourse_end_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_END_TOKENS))

        tokenizer_test(self.tokenizer)

    def tokenize_function(self, examples):
        tz = self.tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=False,  # no truncation at first
            add_special_tokens=False,
            return_offsets_mapping=True,
        )
        return tz

    def process_spans(self, examples):

        span_head_char_start_idxs, span_tail_char_end_idxs = [], []
        span_head_idxs, span_tail_idxs = [], []

        for example_input_ids, example_offset_mapping, example_uids in zip(examples["input_ids"], examples["offset_mapping"], examples["uids"]):
            example_span_head_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_token_ids]
            example_span_tail_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_end_ids]

            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in example_span_head_idxs]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in example_span_tail_idxs]

            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

            span_head_idxs.append(example_span_head_idxs)
            span_tail_idxs.append(example_span_tail_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def generate_labels(self, examples):
        labels = []
        for example_labels, example_uids in zip(examples["discourse_effectiveness"], examples["uids"]):
            labels.append([self.label2id[l] for l in example_labels])
        return {"labels": labels}

    def generate_discourse_type_ids(self, examples):
        discourse_type_ids = []
        for example_discourse_types in examples["discourse_type"]:
            discourse_type_ids.append([self.discourse_type2id[dt] for dt in example_discourse_types])
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(self, examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(self, examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1

    def sanity_check_head_labels(self, examples):
        for head_idxs, head_labels in zip(examples["span_head_idxs"], examples["labels"]):
            assert len(head_idxs) == len(head_labels)

    def get_dataset(self, df, essay_df, mode='train'):
        """main api for creating the Feedback dataset

        :param df: input annotation dataframe
        :type df: pd.DataFrame
        :param essay_df: dataframe with essay texts
        :type essay_df: pd.DataFrame
        :param mode: check if required for train or infer, defaults to 'train'
        :type mode: str, optional
        :return: the created dataset
        :rtype: Dataset
        """
        df = process_input_df(df, essay_df)

        # save a sample for sanity checks
        sample_df = df.sample(min(16, len(df)))
        sample_df.to_csv(os.path.join(self.config["model_dir"], f"{mode}_df_processed.csv"), index=False)

        task_dataset = Dataset.from_pandas(df)
        task_dataset = task_dataset.map(self.tokenize_function, batched=True)
        task_dataset = task_dataset.map(self.compute_input_length, batched=True)
        task_dataset = task_dataset.map(self.process_spans, batched=True)
        print(task_dataset)
        # todo check edge cases
        task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) == len(
            example['span_tail_idxs']))  # no need to run on empty set
        print(task_dataset)
        task_dataset = task_dataset.map(self.generate_discourse_type_ids, batched=True)
        task_dataset = task_dataset.map(self.sanity_check_head_tail, batched=True)

        if mode != "infer":
            task_dataset = task_dataset.map(self.generate_labels, batched=True)
            task_dataset = task_dataset.map(self.sanity_check_head_labels, batched=True)

        try:
            task_dataset = task_dataset.remove_columns(column_names=["__index_level_0__"])
        except Exception as e:
            pass
        return df, task_dataset

#--------------- Dataset w truncation ---------------------------------------------#


def get_fast_dataset(config, df, essay_df, mode="train"):
    """Function to get fast approach dataset with truncation & sliding window
    """
    dataset_creator = AuxFeedbackDataset(config)
    _, task_dataset = dataset_creator.get_dataset(df, essay_df, mode=mode)

    original_dataset = deepcopy(task_dataset)
    tokenizer = dataset_creator.tokenizer
    START_IDS = dataset_creator.discourse_token_ids
    END_IDS = dataset_creator.discourse_end_ids

    def tokenize_with_truncation(examples):
        tz = tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=True,
            add_special_tokens=False,
            return_offsets_mapping=True,
            max_length=config["max_length"],
            stride=config["stride"],
            return_overflowing_tokens=True,
        )
        return tz

    def process_span(examples):
        span_head_idxs, span_tail_idxs = [], []
        span_head_char_start_idxs, span_tail_char_end_idxs = [], []

        buffer = 25  # do not include a head if it is within buffer distance away from last token

        for example_input_ids, example_offset_mapping in zip(examples["input_ids"], examples["offset_mapping"]):
            # ------------------- Span Heads -----------------------------------------#
            if len(example_input_ids) < config["max_length"]:  # no truncation
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in START_IDS]
            else:
                head_candidate = [pos for pos, this_id in enumerate(example_input_ids) if (
                    (this_id in START_IDS) & (pos <= config["max_length"]-buffer))]

            n_heads = len(head_candidate)

            # ------------------- Span Tails -----------------------------------------#
            tail_candidate = [pos for pos, this_id in enumerate(example_input_ids) if this_id in END_IDS]

            # ------------------- Edge Cases -----------------------------------------#
            # 1. A tail occurs before the first head in the sequence due to truncation
            if (len(tail_candidate) > 0) & (len(head_candidate) > 0):
                if tail_candidate[0] < head_candidate[0]:  # truncation effect
                    # print(f"check: heads: {head_candidate}, tails {tail_candidate}")
                    tail_candidate = tail_candidate[1:]  # shift by one

            # 2. Tail got chopped off due to truncation but the corresponding head is still there
            if len(tail_candidate) < n_heads:
                assert len(tail_candidate) + 1 == n_heads
                assert len(example_input_ids) == config["max_length"]  # should only happen if input text is truncated
                tail_candidate.append(config["max_length"]-2)  # the token before [SEP] token

            # 3. Additional tails remain in the buffer region
            if len(tail_candidate) > len(head_candidate):
                tail_candidate = tail_candidate[:len(head_candidate)]

            # ------------------- Create the fields ------------------------------------#
            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in head_candidate]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in tail_candidate]

            span_head_idxs.append(head_candidate)
            span_tail_idxs.append(tail_candidate)
            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def restore_essay_text(examples):
        essay_text = []

        for example_overflow_to_sample_mapping in examples["overflow_to_sample_mapping"]:

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_essay_text = original_example["essay_text"]
            essay_text.append(original_example_essay_text)
        return {"essay_text": essay_text}

    def enforce_alignment(examples):
        uids = []

        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_uids = original_example["uids"]
            char2uid = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_uids)}
            current_example_uids = [char2uid[char_idx] for char_idx in example_span_head_char_start_idxs]
            uids.append(current_example_uids)
        return {"uids": uids}

    def recompute_labels(examples):
        labels = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_labels = original_example["labels"]
            char2label = {k: v for k, v in zip(original_example_span_head_char_start_idxs, original_example_labels)}
            current_example_labels = [char2label[char_idx] for char_idx in example_span_head_char_start_idxs]
            labels.append(current_example_labels)
        return {"labels": labels}

    def recompute_discourse_type_ids(examples):
        discourse_type_ids = []
        for example_span_head_char_start_idxs, example_overflow_to_sample_mapping in zip(
                examples["span_head_char_start_idxs"], examples["overflow_to_sample_mapping"]):

            original_example = original_dataset[example_overflow_to_sample_mapping]
            original_example_span_head_char_start_idxs = original_example["span_head_char_start_idxs"]
            original_example_discourse_type_ids = original_example["discourse_type_ids"]
            char2discourse_id = {k: v for k, v in zip(
                original_example_span_head_char_start_idxs, original_example_discourse_type_ids)}
            current_example_discourse_type_ids = [char2discourse_id[char_idx]
                                                  for char_idx in example_span_head_char_start_idxs]
            discourse_type_ids.append(current_example_discourse_type_ids)
        return {"discourse_type_ids": discourse_type_ids}

    def update_head_tail_char_idx(examples):
        span_head_char_start_idxs, span_tail_char_end_idxs = [], []

        new_texts = []

        for example_span_head_char_start_idxs, example_span_tail_char_end_idxs, example_offset_mapping, example_essay_text in zip(
                examples["span_head_char_start_idxs"], examples["span_tail_char_end_idxs"], examples["offset_mapping"], examples["essay_text"]):

            offset_start = example_offset_mapping[0][0]
            offset_end = example_offset_mapping[-1][1]

            example_essay_text = example_essay_text[offset_start:offset_end]
            new_texts.append(example_essay_text)

            example_span_head_char_start_idxs = [pos - offset_start for pos in example_span_head_char_start_idxs]
            example_span_tail_char_end_idxs = [pos - offset_start for pos in example_span_tail_char_end_idxs]
            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)
        return {"span_head_char_start_idxs": span_head_char_start_idxs, "span_tail_char_end_idxs": span_tail_char_end_idxs, "essay_text": new_texts}

    def compute_input_length(examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1, f"head idxs: {head_idxs}, tail idxs {tail_idxs}"

    task_dataset = task_dataset.map(
        tokenize_with_truncation,
        batched=True,
        remove_columns=task_dataset.column_names,
        batch_size=len(task_dataset)
    )

    task_dataset = task_dataset.map(process_span, batched=True)
    task_dataset = task_dataset.map(enforce_alignment, batched=True)
    task_dataset = task_dataset.map(recompute_discourse_type_ids, batched=True)
    task_dataset = task_dataset.map(sanity_check_head_tail, batched=True)

    task_dataset = task_dataset.map(restore_essay_text, batched=True)

    # no need to run on empty set
    task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) != 0)
    task_dataset = task_dataset.map(compute_input_length, batched=True)

    if mode != "infer":
        task_dataset = task_dataset.map(recompute_labels, batched=True)

    task_dataset = task_dataset.map(update_head_tail_char_idx, batched=True)

    to_return = dict()
    to_return["dataset"] = task_dataset
    to_return["original_dataset"] = original_dataset
    to_return["tokenizer"] = tokenizer
    return to_return


def get_luke_dataset(config, df, essay_df, mode="train"):
    stage_one_config = deepcopy(config)
    stage_one_config["base_model_path"] = "../input/roberta-base"  # Fast Tokenizer
    buffer = 2
    stage_one_config["max_length"] = config["max_length"] - buffer  # - config["max_entity_length"]
    dataset_dict = get_fast_dataset(stage_one_config, df, essay_df, mode)

    task_dataset = dataset_dict["dataset"]

    def get_entity_spans(examples):
        entity_spans = []
        for ex_starts, ex_ends in zip(examples["span_head_char_start_idxs"], examples["span_tail_char_end_idxs"]):
            ex_entity_spans = [tuple([a, b]) for a, b in zip(ex_starts, ex_ends)]
            entity_spans.append(ex_entity_spans)
        return {"entity_spans": entity_spans}

    # prepare luke specific inputs
    task_dataset = task_dataset.map(get_entity_spans, batched=True)

    tokenizer = LukeTokenizer.from_pretrained(
        config["base_model_path"], task="entity_span_classification", max_mention_length=config["max_mention_length"])

    # add new tokens
    if ADD_NEW_TOKENS_IN_LUKE:
        print("adding new tokens...")
        tokens_to_add = []
        for this_tok in NEW_TOKENS:
            tokens_to_add.append(AddedToken(this_tok, lstrip=False, rstrip=False))
        tokenizer.add_tokens(tokens_to_add)

    tokenizer_test(tokenizer)

    def tokenize_with_entity_spans(example):
        tz = tokenizer(
            example["essay_text"],
            entity_spans=[tuple(t) for t in example["entity_spans"]],
            max_entity_length=config["max_entity_length"],
            padding=False,
            truncation=False,
            add_special_tokens=True,
        )
        return tz
    task_dataset = task_dataset.map(tokenize_with_entity_spans, batched=False)

    return_dict = {
        "dataset": task_dataset,
        "tokenizer": tokenizer
    }

    return return_dict

In [None]:
%%time
os.makedirs(config["model_dir"], exist_ok=True)

print("creating the inference datasets...")
infer_ds_dict = get_luke_dataset(config, test_df, essay_df, mode="infer")
tokenizer = infer_ds_dict["tokenizer"]
infer_dataset = infer_ds_dict["dataset"]
print(infer_dataset)

In [None]:
config["len_tokenizer"] = len(tokenizer)

infer_dataset = infer_dataset.sort("input_length")

infer_dataset.set_format(
    type=None,
    columns=["input_ids", "attention_mask", "entity_ids", "entity_position_ids", "discourse_type_ids",
             "entity_attention_mask", "entity_start_positions", "entity_end_positions", "uids"]
)
#

In [None]:
from copy import deepcopy
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn.functional as F
from transformers import DataCollatorWithPadding


@dataclass
class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    """
    data collector for seq classification
    """

    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = None
    return_tensors = "pt"

    def __call__(self, features):
        uids = [feature["uids"] for feature in features]
        discourse_type_ids = [feature["discourse_type_ids"] for feature in features]

        labels = None
        if "labels" in features[0].keys():
            labels = [feature["labels"] for feature in features]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None,
        )

        b_max = max([len(l) for l in uids])

        batch["uids"] = [ex_uids + [-1] * (b_max - len(ex_uids)) for ex_uids in uids]
        batch["discourse_type_ids"] = [ex_dts + [-1] * (b_max - len(ex_dts)) for ex_dts in discourse_type_ids]

        if labels is not None:
            batch["labels"] = [ex_labels + [-1] * (b_max - len(ex_labels)) for ex_labels in labels]

        batch = {k: (torch.tensor(v, dtype=torch.int64)) for k, v in batch.items()}
        return batch

In [None]:
data_collector = CustomDataCollatorWithPadding(tokenizer=tokenizer)

infer_dl = DataLoader(
    infer_dataset,
    batch_size=config["infer_bs"],
    shuffle=False,
    collate_fn=data_collector
)

In [None]:
import pdb
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import LayerNorm
from transformers import AutoConfig, AutoModel, BertConfig
from transformers.models.bert.modeling_bert import BertAttention, BertEncoder


#-------- Model ------------------------------------------------------------------#

class FeedbackModel(nn.Module):
    """
    The feedback prize effectiveness model for fast approach
    """

    def __init__(self, config):
        print("=="*40)
        print("initializing the feedback model...")

        super(FeedbackModel, self).__init__()
        self.config = config

        # base transformer
        base_config = AutoConfig.from_pretrained(self.config["base_model_path"])
        base_config.update(
            {"max_position_embeddings": config["max_position_embeddings"]+2}
        )
        self.base_model = AutoModel.from_pretrained(self.config["base_model_path"], config=base_config)

        # resize model embeddings
        print("resizing model embeddings...")
        print(f"tokenizer length = {config['len_tokenizer']}")
        self.base_model.resize_token_embeddings(config["len_tokenizer"])

        self.num_labels = self.config["num_labels"]

        # LSTM Head
        hidden_size = self.base_model.config.hidden_size

        self.fpe_lstm_layer = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size//2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )
        # classification
        feature_size = hidden_size * 3
        self.classifier = nn.Linear(feature_size, self.num_labels)
        self.discourse_classifier = nn.Linear(feature_size, 7)  # 7 discourse elements

        # dropout family
        self.dropout = nn.Dropout(self.config["dropout"])
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.2)
        self.dropout3 = nn.Dropout(0.3)
        self.dropout4 = nn.Dropout(0.4)
        self.dropout5 = nn.Dropout(0.5)

        self.layer_norm = LayerNorm(feature_size, self.base_model.config.layer_norm_eps)


    def forward(
        self,
        input_ids,
        attention_mask,
        entity_ids,
        entity_attention_mask,
        entity_position_ids,
        entity_start_positions,
        entity_end_positions,
        discourse_type_ids,
        **kwargs
    ):

        # get contextual token representations from base transformer
        outputs = self.base_model(
            input_ids,
            attention_mask=attention_mask,
            entity_ids=entity_ids,
            entity_attention_mask=entity_attention_mask,
            entity_position_ids=entity_position_ids,
        )

        # run contextual information through lstm
        encoder_layer = outputs.last_hidden_state
        encoder_layer_entity = outputs.entity_last_hidden_state

        self.fpe_lstm_layer.flatten_parameters()
        encoder_layer = self.fpe_lstm_layer(encoder_layer)[0]

        hidden_size = outputs.last_hidden_state.size(-1)

        entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
        start_states = torch.gather(encoder_layer, -2, entity_start_positions)
        entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
        end_states = torch.gather(encoder_layer, -2, entity_end_positions)
        feature_vector = torch.cat([start_states, end_states, encoder_layer_entity],
                                   dim=2)  # check if should use lstm

        feature_vector1 = self.dropout1(feature_vector)
        feature_vector2 = self.dropout2(feature_vector)
        feature_vector3 = self.dropout3(feature_vector)
        feature_vector4 = self.dropout4(feature_vector)
        feature_vector5 = self.dropout5(feature_vector)

        # logits = self.classifier(feature_vector)
        logits1 = self.classifier(feature_vector1)
        logits2 = self.classifier(feature_vector2)
        logits3 = self.classifier(feature_vector3)
        logits4 = self.classifier(feature_vector4)
        logits5 = self.classifier(feature_vector5)
        logits = (logits1 + logits2 + logits3 + logits4 + logits5)/5

        return logits

In [None]:
checkpoints = [
    "../input/exp17-luke-dataset-part-1/exp-17-luke-8folds-part-1/fpe_model_fold_0_best.pth.tar",
    "../input/exp17-luke-dataset-part-1/exp-17-luke-8folds-part-1/fpe_model_fold_1_best.pth.tar",
    "../input/exp17-luke-dataset-part-1/exp-17-luke-8folds-part-1/fpe_model_fold_2_best.pth.tar",
    "../input/exp17-luke-dataset-part-1/exp-17-luke-8folds-part-1/fpe_model_fold_3_best.pth.tar",
    "../input/exp17-luke-dataset-part-2/exp-17-luke-8folds-part-2/fpe_model_fold_4_best.pth.tar",
    "../input/exp17-luke-dataset-part-2/exp-17-luke-8folds-part-2/fpe_model_fold_5_best.pth.tar",
    "../input/exp17-luke-dataset-part-2/exp-17-luke-8folds-part-2/fpe_model_fold_6_best.pth.tar",
    "../input/exp17-luke-dataset-part-2/exp-17-luke-8folds-part-2/fpe_model_fold_7_best.pth.tar",
]

def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp17_model_preds_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackModel(config)
    ckpt = torch.load(checkpoint)
    print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)
    
    
del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
import glob
import pandas as pd

csvs = glob.glob("exp17_model_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp17_df = pd.DataFrame()
exp17_df["discourse_id"] = idx
exp17_df["Ineffective"] = preds[:, 0]
exp17_df["Adequate"] = preds[:, 1]
exp17_df["Effective"] = preds[:, 2]

exp17_df = exp17_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()
# exp17_df.to_csv("submission.csv", index=False)

In [None]:
exp17_df.head()

#### LUKE - All data trained

In [None]:
## 
checkpoints = [
    "../input/exp17f-luke-all-data-models/fpe_model_fold_0.pth.tar",
    "../input/exp17f-luke-all-data-models/fpe_model_fold_1.pth.tar",
]

def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"exp17f_luke_all_preds_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackModel(config)
    ckpt = torch.load(checkpoint)
    print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)
    
####
import glob
import pandas as pd

csvs = glob.glob("exp17f_luke_all_preds_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

exp17f_df = pd.DataFrame()
exp17f_df["discourse_id"] = idx
exp17f_df["Ineffective"] = preds[:, 0]
exp17f_df["Adequate"] = preds[:, 1]
exp17f_df["Effective"] = preds[:, 2]

exp17f_df = exp17f_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()
exp17f_df.head()

In [None]:
del model
gc.collect()
torch.cuda.empty_cache()

# Ensemble

In [None]:
if use_exp1:
    exp01_df = exp01_df.sort_values(by="discourse_id")

    
if use_exp3:
    exp03_df = exp03_df.sort_values(by="discourse_id")
    
if use_exp4:
    exp04_df = exp04_df.sort_values(by="discourse_id")
    
if use_exp6:
    exp06_df = exp06_df.sort_values(by="discourse_id")

exp19_df = exp19_df.sort_values(by="discourse_id")

if use_exp8:
    exp08_df = exp_08_df.sort_values(by="discourse_id") # todo: fix naming as per convention
    
if use_exp10:
    exp10_df = exp10_df.sort_values(by="discourse_id") 
    
if use_exp11:
    exp11_df = exp11_df.sort_values(by="discourse_id") 

if use_exp102:
    exp102_df = exp102_df.sort_values(by="discourse_id")

exp16_df = exp16_df.sort_values(by="discourse_id")

exp213_df = exp213_df.sort_values(by="discourse_id")
exp213a_df = exp213a_df.sort_values(by="discourse_id")

if use_exp205:
    exp205_df = exp205_df.sort_values(by="discourse_id")

exp209_df = exp209_df.sort_values(by="discourse_id")

if use_exp212:
    exp212_df = exp212_df.sort_values(by="discourse_id")

if use_exp214:
    exp214_df = exp214_df.sort_values(by="discourse_id")

exp17_df = exp17_df.sort_values(by="discourse_id")


# full data models
exp17f_df = exp17f_df.sort_values(by="discourse_id") # luke
exp19f_df = exp19f_df.sort_values(by="discourse_id")
exp20f_df = exp20f_df.sort_values(by="discourse_id")
# exp21f_df = exp21f_df.sort_values(by="discourse_id")

exp99_rb_all_df = exp99_rb_all_df.sort_values(by="discourse_id")
exp209a_df = exp209a_df.sort_values(by="discourse_id")
exp213f_df = exp213f_df.sort_values(by="discourse_id")

In [None]:
exp17f_df.head()

In [None]:
exp20f_df.head()

In [None]:
if use_full_data_models:
    print(exp99_rb_all_df.head())

In [None]:
if use_full_data_models:
    print(exp209a_df.head())

In [None]:
if use_full_data_models:
    print(exp213f_df.head())

In [None]:
# delv3-mlm20
if use_exp3:
    print(exp03_df.head())

In [None]:
# delv3-mlm20-data-issue-resolved
if use_exp4:
    print(exp04_df.head())

In [None]:
# delv3-mlm40-8folds
if use_exp1:
    print(exp01_df.head())

In [None]:
# delv3-mlm40-5folds-MSD
if use_exp102:
    print(exp102_df.head())

In [None]:
# dexl-mlm40-4folds
if use_exp6:
    print(exp06_df.head())

In [None]:
# dexl-mlm40-8folds
exp19_df.head()

In [None]:
# kd
if use_exp8:
    print(exp08_df.head())

In [None]:
# deb-l 8 fold mlm40 prompt+ spanfix + msd
exp213_df.head()

In [None]:
exp213a_df.head()

In [None]:
exp209_df.head()

In [None]:
if use_exp10:
    print(exp10_df.head())

In [None]:
if use_exp11:
    print(exp11_df.head())

In [None]:
exp16_df.head()

In [None]:
if use_exp214:
    print(exp214_df.head())

# Meta Data Prep

In [None]:
# """
# exp7_oof_dexl_0.5728.csv 0.11
# exp8_oof_del_kd_0.5720.csv 0.07
# exp10_oof_delv3_0.5729.csv 0.11
# exp16_oof_delv3_10fold.csv 0.1
# exp209_oof_debv3_l_10fold_prompt.csv 0.2
# exp212_oof_longformer_l_prompt_0.5833.csv 0.11
# exp213_oof_deb_l_prompt_10folds.csv 0.24
# exp214_oof_debv2-xl_prompt.csv 0.06
# Score: 0.55788
# Wt sum: 1.0

# """

# MODEL_WEIGHTS =  [
#      0.11, # exp7 - dexl
#      0.07, # exp8 - kd
#      0.11, # exp10 - debv3-l 8 fold
#      0.10, # exp16 - debv3-l 10 fold
#      0.20, # exp209 - debv3-l
#      0.11, # exp212 - lf
#      0.24, # exp213 - deb-l
#      0.06  # exp214 - v2-xl
#      ]


# print(f"sum of weights {np.sum(MODEL_WEIGHTS)}")

# submission_df = pd.DataFrame()

# pred_dfs  = [  
#     exp19_df,
#     exp08_df,
#     exp10_df,
#     exp16_df,
#     exp209_df,
#     exp212_df,
#     exp213_df,
#     exp214_df,
# ]


# submission_df["discourse_id"] =  pred_dfs[0]["discourse_id"].values

# for model_idx, model_preds in enumerate(pred_dfs):
#     if model_idx == 0:
#         submission_df["Ineffective"]  =  MODEL_WEIGHTS[model_idx] * model_preds["Ineffective"] 
#         submission_df["Adequate"]     =  MODEL_WEIGHTS[model_idx] * model_preds["Adequate"] 
#         submission_df["Effective"]    =  MODEL_WEIGHTS[model_idx] * model_preds["Effective"]
#     else:
#         submission_df["Ineffective"]  +=  MODEL_WEIGHTS[model_idx] * model_preds["Ineffective"] 
#         submission_df["Adequate"]     +=  MODEL_WEIGHTS[model_idx] * model_preds["Adequate"] 
#         submission_df["Effective"]    +=  MODEL_WEIGHTS[model_idx] * model_preds["Effective"]

In [None]:
# submission_df.head(10)

In [None]:
# hc_df = submission_df

 # LSTM

In [None]:
oof_dfs  = [
        exp01_df,
        exp19_df,
        exp16_df,
        exp17_df,
        exp209_df,
        exp212_df,
        exp213_df,
        exp213a_df,
]

In [None]:
pred_cols = ["Ineffective", "Adequate", "Effective"]
for model_idx in range(len(oof_dfs)):
    col_map = dict()
    for col in pred_cols:
        col_map[col] = f"model_{model_idx}_{col}"
    oof_dfs[model_idx] = oof_dfs[model_idx].rename(columns=col_map)

In [None]:
merged_df = oof_dfs[0]

for df in oof_dfs[1:]:
    keep_cols = ["discourse_id"] + [col for col in df.columns if col.startswith("model")]
    df = df[keep_cols].copy()
    merged_df = pd.merge(merged_df, df, on="discourse_id", how='inner')
assert merged_df.shape[0] == oof_dfs[0].shape[0]

In [None]:
merged_df.head(3).T

In [None]:
feature_names = [col for col in merged_df.columns if col.startswith("model")]
feature_names[:6]

In [None]:
feature_map = dict(zip(merged_df["discourse_id"], merged_df[feature_names].values))

In [None]:
import os
import re
from copy import deepcopy
from itertools import chain

import pandas as pd
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer

#--------------- Tokenizer ---------------------------------------------#
def get_tokenizer(config):
    """load the tokenizer"""

    print("using auto tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])

    print("=="*40)
    print(f"tokenizer len: {len(tokenizer)}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [SOE] [LEAD] [CLAIM] [POSITION] [COUNTER_CLAIM]')}")
    print(
        f"tokenizer test: {tokenizer.tokenize('Starts: [EOE] [LEAD_END] [POSITION_END] [CLAIM_END]')}")

    print("=="*40)
    return tokenizer


#--------------- Processing ---------------------------------------------#
USE_NEW_MAP = True

TOKEN_MAP = {
    "Lead": ["Lead [LEAD]", "[LEAD_END]"],
    "Position": ["Position [POSITION]", "[POSITION_END]"],
    "Claim": ["Claim [CLAIM]", "[CLAIM_END]"],
    "Counterclaim": ["Counterclaim [COUNTER_CLAIM]", "[COUNTER_CLAIM_END]"],
    "Rebuttal": ["Rebuttal [REBUTTAL]", "[REBUTTAL_END]"],
    "Evidence": ["Evidence [EVIDENCE]", "[EVIDENCE_END]"],
    "Concluding Statement": ["Concluding Statement [CONCLUDING_STATEMENT]", "[CONCLUDING_STATEMENT_END]"]
}

DISCOURSE_START_TOKENS = [
    "[LEAD]",
    "[POSITION]",
    "[CLAIM]",
    "[COUNTER_CLAIM]",
    "[REBUTTAL]",
    "[EVIDENCE]",
    "[CONCLUDING_STATEMENT]"
]

DISCOURSE_END_TOKENS = [
    "[LEAD_END]",
    "[POSITION_END]",
    "[CLAIM_END]",
    "[COUNTER_CLAIM_END]",
    "[REBUTTAL_END]",
    "[EVIDENCE_END]",
    "[CONCLUDING_STATEMENT_END]"
]

if USE_NEW_MAP:
    
    TOKEN_MAP = {
        "topic": ["Topic [TOPIC]", "[TOPIC END]"],
        "Lead": ["Lead [LEAD]", "[LEAD END]"],
        "Position": ["Position [POSITION]", "[POSITION END]"],
        "Claim": ["Claim [CLAIM]", "[CLAIM END]"],
        "Counterclaim": ["Counterclaim [COUNTER_CLAIM]", "[COUNTER_CLAIM END]"],
        "Rebuttal": ["Rebuttal [REBUTTAL]", "[REBUTTAL END]"],
        "Evidence": ["Evidence [EVIDENCE]", "[EVIDENCE END]"],
        "Concluding Statement": ["Concluding Statement [CONCLUDING_STATEMENT]", "[CONCLUDING_STATEMENT END]"]
    }

    DISCOURSE_START_TOKENS = [
        "[LEAD]",
        "[POSITION]",
        "[CLAIM]",
        "[COUNTER_CLAIM]",
        "[REBUTTAL]",
        "[EVIDENCE]",
        "[CONCLUDING_STATEMENT]"
    ]

    DISCOURSE_END_TOKENS = [
        "[LEAD END]",
        "[POSITION END]",
        "[CLAIM END]",
        "[COUNTER_CLAIM END]",
        "[REBUTTAL END]",
        "[EVIDENCE END]",
        "[CONCLUDING_STATEMENT END]",
    ]
    
def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    for cur_discourse in discourse_list:
        if cur_discourse not in to_return:
            to_return[cur_discourse] = []

        matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
        for match in matches:
            span_start, span_end = match.span()
            if span_end <= reading_head:
                continue
            to_return[cur_discourse].append(match.span())
            reading_head = span_end
            break

    # post process
    for cur_discourse in discourse_list:
        if not to_return[cur_discourse]:
            print("resorting to relaxed search...")
            to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_essay(essay_id, essay_text, anno_df):
    """insert newly added tokens in the essay text
    """
    tmp_df = anno_df[anno_df["essay_id"] == essay_id].copy()
    tmp_df = tmp_df.sort_values(by="discourse_start")
    buffer = 0

    for _, row in tmp_df.iterrows():
        s, e, d_type = int(row.discourse_start) + buffer, int(row.discourse_end) + buffer, row.discourse_type
        s_tok, e_tok = TOKEN_MAP[d_type]
        essay_text = " ".join([essay_text[:s], s_tok, essay_text[s:e], e_tok, essay_text[e:]])
        buffer += len(s_tok) + len(e_tok) + 4

    essay_text = "[SOE]" + essay_text + "[EOE]"
    return essay_text


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness", "uid"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()

    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")
    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    # anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    print("=="*40)
    print("processing essay text and inserting new tokens at span boundaries")
    notes_df["essay_text"] = notes_df[["essay_id", "essay_text"]].apply(
        lambda x: process_essay(x[0], x[1], anno_df), axis=1
    )
    print("=="*40)

    anno_df = anno_df.drop(columns=["discourse_start", "discourse_end"])
    notes_df = notes_df.drop_duplicates(subset=["essay_id"])[["essay_id", "essay_text"]].copy()

    anno_df = pd.merge(anno_df, notes_df, on="essay_id", how="left")

    if "discourse_effectiveness" in anno_df.columns:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_id", "discourse_effectiveness", "discourse_type"]].agg(list).reset_index()
    else:
        grouped_df = anno_df.groupby("essay_id")[["uid", "discourse_id", "discourse_type"]].agg(list).reset_index()

    grouped_df = pd.merge(grouped_df, notes_df, on="essay_id", how="left")
    grouped_df = grouped_df.rename(columns={"uid": "uids", "discourse_id": "discourse_ids"})

    return grouped_df


#--------------- Dataset ----------------------------------------------#


class FeedbackDatasetMeta:
    """Dataset class for feedback prize effectiveness task
    """

    def __init__(self, config):
        self.config = config

        self.label2id = {
            "Ineffective": 0,
            "Adequate": 1,
            "Effective": 2,
        }

        self.discourse_type2id = {
            "Lead": 1,
            "Position": 2,
            "Claim": 3,
            "Counterclaim": 4,
            "Rebuttal": 5,
            "Evidence": 6,
            "Concluding Statement": 7,
        }

        self.id2label = {v: k for k, v in self.label2id.items()}
        self.load_tokenizer()

    def load_tokenizer(self):
        """load tokenizer as per config 
        """
        self.tokenizer = get_tokenizer(self.config)
        print("=="*40)
        print("token maps...")
        print(TOKEN_MAP)
        print("=="*40)
        print(f"tokenizer len: {len(self.tokenizer)}")

        self.discourse_token_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_START_TOKENS))
        self.discourse_end_ids = set(self.tokenizer.convert_tokens_to_ids(DISCOURSE_END_TOKENS))
        self.global_tokens = self.discourse_token_ids.union(self.discourse_end_ids)

    def tokenize_function(self, examples):
        tz = self.tokenizer(
            examples["essay_text"],
            padding=False,
            truncation=False,  # no truncation at first
            add_special_tokens=True,
            return_offsets_mapping=True,
        )
        return tz

    def process_spans(self, examples):

        span_head_char_start_idxs, span_tail_char_end_idxs = [], []
        span_head_idxs, span_tail_idxs = [], []

        for example_input_ids, example_offset_mapping, example_uids in zip(examples["input_ids"], examples["offset_mapping"], examples["uids"]):
            example_span_head_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_token_ids]
            example_span_tail_idxs = [pos for pos, this_id in enumerate(
                example_input_ids) if this_id in self.discourse_end_ids]

            example_span_head_char_start_idxs = [example_offset_mapping[pos][0] for pos in example_span_head_idxs]
            example_span_tail_char_end_idxs = [example_offset_mapping[pos][1] for pos in example_span_tail_idxs]

            span_head_char_start_idxs.append(example_span_head_char_start_idxs)
            span_tail_char_end_idxs.append(example_span_tail_char_end_idxs)

            span_head_idxs.append(example_span_head_idxs)
            span_tail_idxs.append(example_span_tail_idxs)

        return {
            "span_head_idxs": span_head_idxs,
            "span_tail_idxs": span_tail_idxs,
            "span_head_char_start_idxs": span_head_char_start_idxs,
            "span_tail_char_end_idxs": span_tail_char_end_idxs,
        }

    def generate_labels(self, examples):
        labels = []
        for example_labels, example_uids in zip(examples["discourse_effectiveness"], examples["uids"]):
            labels.append([self.label2id[l] for l in example_labels])
        return {"labels": labels}

    def generate_meta_features(self, examples):
        meta_features = []
        for example_ids in examples["discourse_ids"]:
            current_features = []
            for didx in example_ids:
                current_features.append(self.feature_map[didx])
            meta_features.append(current_features)
        return {"meta_features": meta_features}

    def generate_discourse_type_ids(self, examples):
        discourse_type_ids = []
        for example_discourse_types in examples["discourse_type"]:
            discourse_type_ids.append([self.discourse_type2id[dt] for dt in example_discourse_types])
        return {"discourse_type_ids": discourse_type_ids}

    def compute_input_length(self, examples):
        return {"input_length": [len(x) for x in examples["input_ids"]]}

    def sanity_check_head_tail(self, examples):
        for head_idxs, tail_idxs in zip(examples["span_head_idxs"], examples["span_tail_idxs"]):
            assert len(head_idxs) == len(tail_idxs)
            for head, tail in zip(head_idxs, tail_idxs):
                assert tail > head + 1

    def sanity_check_head_labels(self, examples):
        for head_idxs, head_labels in zip(examples["span_head_idxs"], examples["labels"]):
            assert len(head_idxs) == len(head_labels)

    def get_dataset(self, df, essay_df, feature_map, mode='train'):
        """main api for creating the Feedback dataset

        :param df: input annotation dataframe
        :type df: pd.DataFrame
        :param essay_df: dataframe with essay texts
        :type essay_df: pd.DataFrame
        :param mode: check if required for train or infer, defaults to 'train'
        :type mode: str, optional
        :return: the created dataset
        :rtype: Dataset
        """
        self.feature_map = feature_map
        df = process_input_df(df, essay_df)

        # save a sample for sanity checks
        sample_df = df.sample(min(16, len(df)))
        sample_df.to_csv(os.path.join(self.config["model_dir"], f"{mode}_df_processed.csv"), index=False)

        task_dataset = Dataset.from_pandas(df)
        task_dataset = task_dataset.map(self.tokenize_function, batched=True)
        task_dataset = task_dataset.map(self.compute_input_length, batched=True)
        task_dataset = task_dataset.map(self.process_spans, batched=True)
        task_dataset = task_dataset.map(self.generate_meta_features, batched=True)

        print(task_dataset)
        # todo check edge cases
        task_dataset = task_dataset.filter(lambda example: len(example['span_head_idxs']) == len(
            example['span_tail_idxs']))  # no need to run on empty set
        print(task_dataset)
        task_dataset = task_dataset.map(self.generate_discourse_type_ids, batched=True)
        task_dataset = task_dataset.map(self.sanity_check_head_tail, batched=True)

        if mode != "infer":
            task_dataset = task_dataset.map(self.generate_labels, batched=True)
            task_dataset = task_dataset.map(self.sanity_check_head_labels, batched=True)

        try:
            task_dataset = task_dataset.remove_columns(column_names=["__index_level_0__"])
        except Exception as e:
            pass
        return task_dataset

@dataclass
class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    """
    data collector for seq classification
    """
    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = 512
    return_tensors = "pt"

    def __call__(self, features):
        uids = [feature["uids"] for feature in features]
        discourse_type_ids = [feature["discourse_type_ids"] for feature in features]
        span_head_idxs = [feature["span_head_idxs"] for feature in features]
        span_tail_idxs = [feature["span_tail_idxs"] for feature in features]
        meta_features = [feature["meta_features"] for feature in features]

        span_attention_mask = [[1]*len(feature["span_head_idxs"]) for feature in features]

        labels = None
        if "labels" in features[0].keys():
            labels = [feature["labels"] for feature in features]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None,
        )

        b_max = max([len(l) for l in span_head_idxs])
        max_len = len(batch["input_ids"][0])

        default_head_idx = max(max_len - 10, 1)  # for padding
        default_tail_idx = max(max_len - 4, 1)  # for padding

        batch["span_head_idxs"] = [
            ex_span_head_idxs + [default_head_idx] * (b_max - len(ex_span_head_idxs)) for ex_span_head_idxs in span_head_idxs
        ]

        batch["uids"] = [ex_uids + [-1] * (b_max - len(ex_uids)) for ex_uids in uids]
        batch["discourse_type_ids"] = [ex_discourse_type_ids + [0] *
                                       (b_max - len(ex_discourse_type_ids)) for ex_discourse_type_ids in discourse_type_ids]

        batch["span_tail_idxs"] = [
            ex_span_tail_idxs + [default_tail_idx] * (b_max - len(ex_span_tail_idxs)) for ex_span_tail_idxs in span_tail_idxs
        ]
        

        padded_meta_features = []
        for ex_features in meta_features:
            pad_len = b_max - len(ex_features)
            pad_vector = [0. for _ in range(len(ex_features[0]))]
            for _ in range(pad_len):
                ex_features.append(pad_vector)
            padded_meta_features.append(ex_features)
        # set_trace()

        batch["meta_features"] = padded_meta_features

        batch["span_attention_mask"] = [
            ex_discourse_masks + [0] * (b_max - len(ex_discourse_masks)) for ex_discourse_masks in span_attention_mask
        ]

        if labels is not None:
            batch["labels"] = [ex_labels + [-1] * (b_max - len(ex_labels)) for ex_labels in labels]

        batch = {k: (torch.tensor(v, dtype=torch.int64) if k != "meta_features" else torch.tensor(
            v, dtype=torch.float32)) for k, v in batch.items()}        
        return batch

In [None]:
config = {
    "base_model_path": "../input/tk-fpe-models-v2/exp205-debv3-l-prompt/mlm_model",
    "model_dir": "./",
    "valid_bs": 16,
}

In [None]:
dataset_creator = FeedbackDatasetMeta(config)
infer_dataset = dataset_creator.get_dataset(test_df, essay_df, feature_map, mode="infer")

In [None]:
tokenizer = dataset_creator.tokenizer
config["len_tokenizer"] = len(tokenizer)
data_collector = CustomDataCollatorWithPadding(tokenizer=tokenizer)

# sort valid dataset for faster evaluation
infer_dataset = infer_dataset.sort("input_length")

infer_dataset.set_format(
    type=None,
    columns=['input_ids', 'attention_mask', 'token_type_ids', 'span_head_idxs',
                'span_tail_idxs', 'discourse_type_ids', "meta_features", 'uids']
    )

infer_dl = DataLoader(
    infer_dataset,
    batch_size=config["valid_bs"],
    shuffle=False,
    collate_fn=data_collector,
    pin_memory=True,
)

In [None]:
class FeedbackMetaModelResidual(nn.Module):
    """
    The feedback prize effectiveness meta model for fast approach
    """

    def __init__(self, config):
        print("==" * 40)
        print("initializing the feedback model...")

        super(FeedbackMetaModelResidual, self).__init__()

        self.config = config
        self.num_labels = config["num_labels"]
        self.num_meta_features = config["num_features"]
        # self.layer_norm_raw = LayerNorm(config["num_features"], 1e-7)

        print(f'Num fts: {self.num_meta_features}')
        # dropouts
        self.dropout = nn.Dropout(self.config["dropout"])
        hidden_size = 512
        self.projection = nn.Linear(self.num_meta_features, hidden_size)
        self.layer_norm = LayerNorm(hidden_size, 1e-7)

        self.meta_rnn = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )
        self.classifier = nn.Linear(hidden_size, self.num_labels)

    def forward(
            self,
            meta_features,
            attention_mask,
            span_attention_mask,
            discourse_type_ids,
            labels=None,
            **kwargs
    ):
        # projection
        meta_features = self.projection(meta_features)

        # layer normalization
        meta_features = self.layer_norm(meta_features)

        # dropout
        meta_features = self.dropout(meta_features)

        # run through rnn
        meta_features_rnn = self.meta_rnn(meta_features)[0]

        # dropout
        meta_features = meta_features + meta_features_rnn
        meta_features = self.dropout(meta_features)
        logits = self.classifier(meta_features)

        return logits

In [None]:
model_config = {
    "num_labels": 3,
    "num_features": len(feature_names),
    "dropout": 0.15,
}
model_config

In [None]:
meta_checkpoints = [
    "../input/ens57-lstm-meta-8m/fpe_model_fold_0_best.pth.tar",
    "../input/ens57-lstm-meta-8m/fpe_model_fold_1_best.pth.tar",
    "../input/ens57-lstm-meta-8m/fpe_model_fold_2_best.pth.tar",
    "../input/ens57-lstm-meta-8m/fpe_model_fold_3_best.pth.tar",
    "../input/ens57-lstm-meta-8m/fpe_model_fold_4_best.pth.tar",
    "../input/ens57-lstm-meta-8m/fpe_model_fold_5_best.pth.tar",
    "../input/ens57-lstm-meta-8m/fpe_model_fold_6_best.pth.tar",
    "../input/ens57-lstm-meta-8m/fpe_model_fold_7_best.pth.tar",
]

In [None]:
def inference_fn(model, infer_dl, model_id):
    all_preds = []
    all_uids = []
    accelerator = Accelerator()
    model, infer_dl = accelerator.prepare(model, infer_dl)
    
    model.eval()
    tk0 = tqdm(infer_dl, total=len(infer_dl))
    
    for batch in tk0:
        with torch.no_grad():
            logits = model(**batch) # (b, nd, 3)
            batch_preds = F.softmax(logits, dim=-1)
            batch_uids = batch["uids"]
        all_preds.append(batch_preds)
        all_uids.append(batch_uids)
    
    all_preds = [p.to('cpu').detach().numpy().tolist() for p in all_preds]
    all_preds = list(chain(*all_preds))
    flat_preds = list(chain(*all_preds))
    
    all_uids = [p.to('cpu').detach().numpy().tolist() for p in all_uids]
    all_uids = list(chain(*all_uids))
    flat_uids = list(chain(*all_uids))    
    
    preds_df = pd.DataFrame(flat_preds)
    preds_df.columns = ["Ineffective", "Adequate", "Effective"]
    preds_df["span_uid"] = flat_uids # SORTED_DISCOURSE_IDS
    preds_df = preds_df[preds_df["span_uid"]>=0].copy()
    preds_df["discourse_id"] = preds_df["span_uid"].map(idx2discourse)
    preds_df = preds_df[["discourse_id", "Ineffective", "Adequate", "Effective"]].copy()
    preds_df.to_csv(f"meta_model_{model_id}.csv", index=False)
    

for model_id, checkpoint in enumerate(meta_checkpoints):
    print(f"infering from {checkpoint}")
    model = FeedbackMetaModelResidual(model_config)
    ckpt = torch.load(checkpoint)
    print(f"validation score for fold {model_id} = {ckpt['loss']}")
    model.load_state_dict(ckpt['state_dict'])
    inference_fn(model, infer_dl, model_id)
    
del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
import glob
import pandas as pd

csvs = glob.glob("meta_model_*.csv")

idx = []
preds = []


for csv_idx, csv in enumerate(csvs):
    
    print("=="*40)
    print(f"preds in {csv}")
    df = pd.read_csv(csv)
    df = df.sort_values(by=["discourse_id"])
    print(df.head(10))
    print("=="*40)
    
    temp_preds = df.drop(["discourse_id"], axis=1).values
    if csv_idx == 0:
        idx = list(df["discourse_id"])
        preds = temp_preds
    else:
        preds += temp_preds

preds = preds / len(csvs)

meta_pred_df = pd.DataFrame()
meta_pred_df["discourse_id"] = idx
meta_pred_df["Ineffective"]  = preds[:, 0]
meta_pred_df["Adequate"]     = preds[:, 1]
meta_pred_df["Effective"]    = preds[:, 2]

meta_pred_df = meta_pred_df.groupby("discourse_id")[["Ineffective", "Adequate", "Effective"]].agg(np.mean).reset_index()

In [None]:
meta_pred_df.head()

In [None]:
submission_df = meta_pred_df.copy() # pd.DataFrame()
submission_df.head(10)

## LGB

In [None]:
merged_df.head()

In [None]:
from textblob import TextBlob

meta_df = pd.merge(merged_df, test_df, on="discourse_id", how="left")
meta_df = pd.merge(meta_df, essay_df, on="essay_id", how="left")

def get_substring_span(text, substring, min_length=10, fraction=0.999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return get_substring_span(text=text,
                                      substring=half_substring,
                                      min_length=min_length,
                                      fraction=fraction)

    span = [position, position+substring_length]
    return span

def tags(text):
    blob = TextBlob(text)
    return blob

def count_typ_tags(text, typ):
    return len([word for (word,tag) in text.tags if tag.startswith(typ)])

def get_features(meta_df):
    config = dict()
    feature_names = [col for col in merged_df.columns if col.startswith("model")]

    config["features"] = feature_names 
    config["cat_features"] = []
    
    print('Processing spans')
    meta_df["discourse_span"] = meta_df[["essay_text", "discourse_text"]].apply(lambda x: get_substring_span(x[0], x[1]), axis=1)
    meta_df["discourse_start"] = meta_df["discourse_span"].apply(lambda x: x[0])
    meta_df["discourse_end"] = meta_df["discourse_span"].apply(lambda x: x[1])

    meta_df['discourse_len'] = meta_df['discourse_end'] - meta_df['discourse_start']
    meta_df['freq_of_essay_id'] = meta_df['essay_id'].map(dict(meta_df['essay_id'].value_counts()))
    meta_df['blob_discourse'] = meta_df['discourse_text'].apply(tags)
    meta_df['discourse_Adjectives'] = meta_df['blob_discourse'].apply(lambda x: count_typ_tags(x, 'JJ'))
    meta_df['discourse_Verbs'] = meta_df['blob_discourse'].apply(lambda x: count_typ_tags(x, 'VB'))
    meta_df['discourse_Adverbs'] = meta_df['blob_discourse'].apply(lambda x: count_typ_tags(x, 'RB'))
    meta_df['discourse_Nouns'] = meta_df['blob_discourse'].apply(lambda x: count_typ_tags(x, 'NN'))
    meta_df['discourse_VBP'] = meta_df['blob_discourse'].apply(lambda x: count_typ_tags(x, 'VBP'))
    meta_df['discourse_PRP'] = meta_df['blob_discourse'].apply(lambda x: count_typ_tags(x, 'PRP'))
    meta_df['count_next_line_essay'] = meta_df['essay_text'].apply(lambda x: x.count("\n\n"))
    
    discourse_type2id = {
    "Lead": 1,
    "Position": 2,
    "Claim": 3,
    "Counterclaim": 4,
    "Rebuttal": 5,
    "Evidence": 6,
    "Concluding Statement": 7,
}

    new_col = []
    for unique in ['Claim']:
        meta_df['is_' + unique] = meta_df['discourse_type'].apply(lambda x: 1 if x == unique else 0)
        new_col.append('is_'+unique)
        
    meta_df = meta_df.sort_values(by = ['essay_id','discourse_id']).reset_index(drop = True)
    
    essay_discourse_list = meta_df.groupby(['essay_id']).apply(\
                                lambda x: x['discourse_type'].tolist()).reset_index()
    essay_discourse_list.rename(columns = {0:'discourse_type_list'}, inplace = True)
    essay_discourse_list['discourse_type_list'] = essay_discourse_list['discourse_type_list'].apply(lambda x: " ".join(x))
    meta_df = meta_df.merge(essay_discourse_list[['essay_id','discourse_type_list']], \
                  on = 'essay_id',how = 'left')
    meta_df["discourse_type"] = meta_df["discourse_type"].map(discourse_type2id)
    meta_df['discourse_type_fe'] = meta_df['discourse_type'].map(dict(meta_df['discourse_type'].value_counts()))

    essay_discourse = meta_df.groupby(['essay_id']).apply(lambda x: \
                                x['discourse_type'].nunique()).reset_index()
    essay_discourse.rename(columns = {0:'unique_discourse_type'}, inplace = True)
    essay_discourse.head()

    meta_df = meta_df.merge(essay_discourse[['essay_id','unique_discourse_type']], on = 'essay_id',\
                  how = 'left')

    config["features"].extend(["discourse_type",
                "discourse_type_fe","discourse_len","freq_of_essay_id",\
                          "unique_discourse_type"]+new_col)
    config['features'].extend(['discourse_Adjectives','discourse_Verbs',\
                              'discourse_Adverbs','discourse_Nouns',\
                              'count_next_line_essay','discourse_VBP','discourse_PRP'])
    config["cat_features"].append("discourse_type")
    
    return meta_df, config

In [None]:
meta_df, config = get_features(meta_df)
meta_df.shape, len(config['features'])

In [None]:
import lightgbm as lgbm
import joblib

model_paths = [
    "../input/meta-lgbm-model/lgbm_model_fold_0.txt",
    "../input/meta-lgbm-model/lgbm_model_fold_1.txt",
    "../input/meta-lgbm-model/lgbm_model_fold_2.txt",
    "../input/meta-lgbm-model/lgbm_model_fold_3.txt",
    "../input/meta-lgbm-model/lgbm_model_fold_4.txt",
    "../input/meta-lgbm-model/lgbm_model_fold_5.txt",
    "../input/meta-lgbm-model/lgbm_model_fold_6.txt",
    "../input/meta-lgbm-model/lgbm_model_fold_7.txt",
]

In [None]:
for midx, mp in enumerate(model_paths):
    model = lgbm.Booster(model_file=mp)
    if midx == 0:
        preds = model.predict(meta_df[config["features"]], num_iteration=model.best_iteration)
    else:
        preds += model.predict(meta_df[config["features"]], num_iteration=model.best_iteration)
preds = preds/len(model_paths)
preds
submission_df1 = pd.DataFrame()

submission_df1["discourse_id"] =  meta_df["discourse_id"].values
submission_df1["Ineffective"]  =  preds[:, 0]
submission_df1["Adequate"]     =  preds[:, 1]
submission_df1["Effective"]    =  preds[:, 2]

# All Data

In [None]:
df_list = [
    exp17f_df,
    exp19f_df,
    exp20f_df,
    exp99_rb_all_df,
    exp209a_df,
    exp213f_df,

]

MODEL_WEIGHTS = [
    0.08, # [2] luke -rb
    0.18, # [2] dexl -rb
    0.04, # [1] del kd -rb
    0.15, # [2] delv3 -rb
    0.20, # [1] delv3 - tk
    0.35, # [2] del - tk
] 

print(f"sum of weights {np.sum(MODEL_WEIGHTS)}")

all_data_df = pd.DataFrame()
all_data_df["discourse_id"] =  df_list[0]["discourse_id"].values

for model_idx, model_preds in enumerate(df_list):
    if model_idx == 0:
        all_data_df["Ineffective"]  =  MODEL_WEIGHTS[model_idx] * model_preds["Ineffective"] 
        all_data_df["Adequate"]     =  MODEL_WEIGHTS[model_idx] * model_preds["Adequate"] 
        all_data_df["Effective"]    =  MODEL_WEIGHTS[model_idx] * model_preds["Effective"]
    else:
        all_data_df["Ineffective"]  +=  MODEL_WEIGHTS[model_idx] * model_preds["Ineffective"] 
        all_data_df["Adequate"]     +=  MODEL_WEIGHTS[model_idx] * model_preds["Adequate"] 
        all_data_df["Effective"]    +=  MODEL_WEIGHTS[model_idx] * model_preds["Effective"]

In [None]:
all_data_df.head()

# Final Ensemble

In [None]:
lgb_df = submission_df1.sort_values(by="discourse_id")
lstm_df = submission_df.sort_values(by="discourse_id")
all_data_df = all_data_df.sort_values(by="discourse_id") # TODO: check for flag

In [None]:
lgb_df.head()

In [None]:
lstm_df.head()

In [None]:
all_data_df.head()

In [None]:
# 0.536321*0.25 + 0.616211*0.35 + 0.499419*0.4
# exp99_rb_all_df.head(1)
# exp209a_df.head(1)
# exp213f_df.head(1)

In [None]:
sub_df = pd.DataFrame()
sub_df["discourse_id"] =  lgb_df["discourse_id"].values

lgb_vals = lgb_df[["Ineffective", "Adequate", "Effective"]].values
lstm_vals = lstm_df[["Ineffective", "Adequate", "Effective"]].values
all_vals = all_data_df[["Ineffective", "Adequate", "Effective"]].values


sub_df["Ineffective"]  =  0.60 * (0.7* lstm_vals[:, 0] + 0.3 * lgb_vals[:, 0]) + 0.40 * all_vals[:, 0]
sub_df["Adequate"]     =  0.60 * (0.7* lstm_vals[:, 1] + 0.3 * lgb_vals[:, 1]) + 0.40 * all_vals[:, 1]
sub_df["Effective"]    =  0.60 * (0.7* lstm_vals[:, 2] + 0.3 * lgb_vals[:, 2]) + 0.40 * all_vals[:, 2]

In [None]:
sub_df.to_csv("submission.csv", index=False)
sub_df.head()

In [None]:
sub_df.to_csv("submission.csv", index=False)
sub_df.head()