In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import sys
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, pad_sequence
import torch
from transformers import LongformerModel,LongformerConfig,LongformerTokenizer, AdamW, get_linear_schedule_with_warmup
from torch.nn.modules import CrossEntropyLoss, BCEWithLogitsLoss
from typing import Tuple, List
import datetime
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
import argparse
import logging
import random
import matplotlib.pyplot as plt
import time
import os
import sys
import torch.optim as optim
import json

# Experiement configurations

In [2]:
def get_config():
    config = argparse.ArgumentParser()
    config.add_argument("--save-path", default="./saved_models", type=str)

    # training
    config.add_argument("--device", default='1', type=str)
    config.add_argument("--seed", default=42, type=int)
    config.add_argument("--batch-size", default=4, type=int)
    config.add_argument("--epochs", default=2, type=int)
    config.add_argument("--showtime", default=2000, type=int)
    config.add_argument("--base-encoder-lr", default=1e-5, type=float)
    config.add_argument("--longformer-base-encoder-lr", default=1e-5, type=float)
    config.add_argument("--finetune-lr", default=1e-3, type=float)
    config.add_argument("--warm-up", default=5e-2, type=float)
    config.add_argument("--weight-decay", default=1e-5, type=float)
    config.add_argument("--early-num", default=5, type=int)
    config.add_argument("--num-tags", default=5, type=int)
    config.add_argument("--threshold", default=0.3, type=float)

    config.add_argument("--hidden-size", default=128, type=int)
    config.add_argument("--layers", default=2, type=int)
    config.add_argument("--is-bi", default=False, type=bool)
    config.add_argument("--bert-output-size", default=768, type=int)
    config.add_argument("--mlp-size", default=512, type=int)
    config.add_argument("--scale-factor", default=2, type=int)
    config.add_argument("--dropout", default=0.7, type=float)
    config.add_argument("--max-grad-norm", default=1.0, type=float)

    config.add_argument("--num-heads", default=4, type=int)
    config.add_argument("--att_dropout", default=0.1, type=float)

    config.add_argument("--mrc_dropout", type=float, default=0.7,
                        help="mrc dropout rate")
    config.add_argument("--lstm_dropout", type=float, default=0.4,
                        help="lstm dropout rate")
    config.add_argument("--classifier_act_func", type=str, default="gelu")
    config.add_argument("--classifier_intermediate_hidden_size", type=int, default=128)
    config.add_argument("--weight_start", type=float, default=1.0)
    config.add_argument("--weight_end", type=float, default=1.0)
    config.add_argument("--weight_span", type=float, default=0.1)

    # Use parse_known_args to ignore unrecognized arguments
    config, _ = config.parse_known_args()

    return config

In [3]:
class SingleLinearClassifier(nn.Module):
    def __init__(self, hidden_size, num_label):
        super(SingleLinearClassifier, self).__init__()
        self.num_label = num_label
        self.classifier = nn.Linear(hidden_size, num_label)

    def forward(self, input_features):
        features_output = self.classifier(input_features)
        return features_output


class MultiNonLinearClassifier(nn.Module):
    def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None):
        super(MultiNonLinearClassifier, self).__init__()
        self.num_label = num_label
        self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
        self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)
        self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)
        self.dropout = nn.Dropout(dropout_rate)
        self.act_func = act_func

    def forward(self, input_features):
        features_output1 = self.classifier1(input_features)
        if self.act_func == "gelu":
            features_output1 = F.gelu(features_output1)
        elif self.act_func == "relu":
            features_output1 = F.relu(features_output1)
        elif self.act_func == "tanh":
            features_output1 = F.tanh(features_output1)
        elif self.act_func == "leakyrelu":
            features_output1 = F.leaky_relu(features_output1,0.2,inplace=True)
        else:
            raise ValueError
        features_output1 = self.dropout(features_output1)
        features_output2 = self.classifier2(features_output1)
        return features_output2



class ThreeNonLinearClassifier(nn.Module):
    def __init__(self, hidden_size, num_label, dropout_rate,intermediate_hidden_size_0,intermediate_hidden_size_1, act_func="gelu"):
        super(ThreeNonLinearClassifier, self).__init__()
        self.num_label = num_label
        self.intermediate_hidden_size_0 = intermediate_hidden_size_0
        self.intermediate_hidden_size_1 = intermediate_hidden_size_1

        self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size_0)
        self.classifier2 = nn.Linear(self.intermediate_hidden_size_0, self.intermediate_hidden_size_1)
        self.classifier3 = nn.Linear(self.intermediate_hidden_size_1, self.num_label)
        self.dropout = nn.Dropout(dropout_rate)
        self.act_func = act_func

    def forward(self, input_features):
        features_output1 = self.classifier1(input_features)
        if self.act_func == "gelu":
            features_output1 = F.gelu(features_output1)
        elif self.act_func == "relu":
            features_output1 = F.relu(features_output1)
        elif self.act_func == "tanh":
            features_output1 = F.tanh(features_output1)
        else:
            raise ValueError
        features_output1 = self.dropout(features_output1)
        features_output2 = self.classifier2(features_output1)
        if self.act_func == "gelu":
            features_output2 = F.gelu(features_output2)
        elif self.act_func == "relu":
            features_output2 = F.relu(features_output2)
        elif self.act_func == "tanh":
            features_output2 = F.tanh(features_output2)
        else:
            raise ValueError

        features_output2 = self.dropout(features_output2)
        features_output3 = self.classifier3(features_output2)
        return features_output3

class BERTTaggerClassifier(nn.Module):
    def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None):
        super(BERTTaggerClassifier, self).__init__()
        self.num_label = num_label
        self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
        self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)
        self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)
        self.dropout = nn.Dropout(dropout_rate)
        self.act_func = act_func

    def forward(self, input_features):
        features_output1 = self.classifier1(input_features)
        if self.act_func == "gelu":
            features_output1 = F.gelu(features_output1)
        elif self.act_func == "relu":
            features_output1 = F.relu(features_output1)
        elif self.act_func == "tanh":
            features_output1 = F.tanh(features_output1)
        else:
            raise ValueError
        features_output1 = self.dropout(features_output1)
        features_output2 = self.classifier2(features_output1)
        return features_output2


# Decoder

In [4]:
tags2id = {'O': 0, 'B-Review': 1, 'I-Review': 2, 'E-Review': 3, 'S-Review': 4,
           'B-Reply': 1, 'I-Reply': 2, 'E-Reply': 3, 'S-Reply': 4,
           'B': 1, 'I': 2, 'E': 3, 'S': 4}
def spans_to_tags(spans, seq_len):
    tags = [tags2id['O']] * seq_len
    for span in spans:
        tags[span[0]] = tags2id['B']
        tags[span[0]:span[1]+1] = [tags2id['I']] * (span[1]-span[0]+1)
        if span[0] == span[1]:
            tags[span[0]] = tags2id['S']
        else:
            tags[span[0]] = tags2id['B']
            tags[span[1]] = tags2id['E']
    return tags


def get_arg_span(bioes_tags):
    start, end = None, None
    arguments = []
    in_entity_flag = False
    for idx, tag in enumerate(bioes_tags):
        if in_entity_flag == False:
            if tag == 1: # B
                in_entity_flag = True
                start = idx
            elif tag == 4: # S
                start = idx
                end = idx
                arguments.append((start, end))
                start = None
                end = None
        else:
            if tag == 0: # O
                in_entity_flag = False
                start = None
                end = None
            elif tag == 1: # B
                in_entity_flag = True
                start = idx
            elif tag == 3: # E
                in_entity_flag = False
                end = idx
                arguments.append((start, end))
                start = None
                end = None
            elif tag == 4: # S
                in_entity_flag = False
                start = idx
                end = idx
                arguments.append((start, end))
                start = None
                end = None
    return arguments




def extract_arguments(bioes_list):
    arguments_list = []
    for pred_tags in bioes_list:
        arguments = get_arg_span(pred_tags)
        arguments_list.append(arguments)
    return arguments_list

def extract_span_arguments_yi(match_labels,start_labels,end_labels):
    arguments_list = []
    for match_l, start_l, end_l in zip(match_labels,start_labels,end_labels):
        arguments = extract_flat_spans_yi( start_l, end_l,match_l)
        arguments_list.append(arguments)
    return arguments_list
def extract_span_arguments(match_labels,start_labels,end_labels):
    arguments_list = []
    for match_l, start_l, end_l in zip(match_labels,start_labels,end_labels):
        arguments = extract_flat_spans( start_l, end_l,match_l)
        arguments_list.append(arguments)
    return arguments_list

def extract_span_arguments_nested(match_labels,start_labels,end_labels):
    arguments_list = []
    for match_l, start_l, end_l in zip(match_labels,start_labels,end_labels):
        arguments = extract_flat_spans_nested( start_l, end_l,match_l)
        arguments_list.append(arguments)
    return arguments_list

class Tag(object):
    def __init__(self, term, tag, begin, end):
        self.term = term
        self.tag = tag
        self.begin = begin
        self.end = end

    def to_tuple(self):
        return tuple([self.term, self.begin, self.end])

    def __str__(self):
        return str({key: value for key, value in self.__dict__.items()})

    def __repr__(self):
        return str({key: value for key, value in self.__dict__.items()})


def bmes_decode(char_label_list: List[Tuple[str, str]]) -> List[Tag]:
    """
    decode inputs to tags
    Args:
        char_label_list: list of tuple (word, bmes-tag)
    Returns:
        tags
    Examples:
        >>> x = [("Hi", "O"), ("Beijing", "S-LOC")]
        >>> bmes_decode(x)
        [{'term': 'Beijing', 'tag': 'LOC', 'begin': 1, 'end': 2}]
    """
    idx = 0
    length = len(char_label_list)
    tags = []
    while idx < length:
        term, label = char_label_list[idx]
        current_label = label[0]

        # correct labels
        if idx + 1 == length and current_label == "B":
            current_label = "S"

        # merge chars
        if current_label == "O":
            idx += 1
            continue
        if current_label == "S":
            tags.append(Tag(term, label[2:], idx, idx + 1))
            idx += 1
            continue
        if current_label == "B":
            end = idx + 1
            while end + 1 < length and char_label_list[end][1][0] == "M":
                end += 1
            if char_label_list[end][1][0] == "E":  # end with E
                entity = "".join(char_label_list[i][0] for i in range(idx, end + 1))
                tags.append(Tag(entity, label[2:], idx, end + 1))
                idx = end + 1
            else:  # end with M/B
                entity = "".join(char_label_list[i][0] for i in range(idx, end))
                tags.append(Tag(entity, label[2:], idx, end))
                idx = end
            continue
        else:
            idx += 1
            continue
            # print("?")
            # raise Exception("Invalid Inputs")
    return tags

def extract_flat_spans_nested(start_pred, end_pred, match_pred,  pseudo_tag = "TAG"):
    seq_len = start_pred.size()[0]

    start_l_mask = [[1 for i in range(seq_len)]]
    end_l_mask = [[1 for i in range(seq_len)]]

    start_label_mask = torch.LongTensor(start_l_mask).cuda()
    end_label_mask = torch.LongTensor(end_l_mask).cuda()

    start_label_mask = start_label_mask.bool()
    end_label_mask = end_label_mask.bool()
    bsz, seq_len = start_label_mask.size()


    start_preds = start_pred.bool().unsqueeze(0).cuda()
    end_preds = end_pred.bool().unsqueeze(0).cuda()
    match_pred_s=match_pred.bool().unsqueeze(0).cuda()


    match_preds = (match_pred_s & start_preds.unsqueeze(-1).expand(-1, -1, seq_len) & end_preds.unsqueeze(1).expand(-1, seq_len, -1))
    match_label_mask = (start_label_mask.unsqueeze(-1).expand(-1, -1, seq_len) & end_label_mask.unsqueeze(1).expand(-1, seq_len, -1))
    match_label_mask = torch.triu(match_label_mask, 0)  # start should be less or equal to end
    match_preds = match_label_mask & match_preds
    match_pos_pairs = np.transpose(np.nonzero(match_preds.cpu().numpy())).tolist()
    return [(pos[1], pos[2]) for pos in match_pos_pairs]

def extract_flat_spans(start_pred, end_pred, match_pred,  pseudo_tag = "TAG"):
    """
    Extract flat-ner spans from start/end/match logits
    Args:
        start_pred: [seq_len], 1/True for start, 0/False for non-start
        end_pred: [seq_len, 2], 1/True for end, 0/False for non-end
        match_pred: [seq_len, seq_len], 1/True for match, 0/False for non-match
        label_mask: [seq_len], 1 for valid boundary.
    Returns:
        tags: list of tuple (start, end)
    Examples:
        >>> start_pred = [0, 1]
        >>> end_pred = [0, 1]
        >>> match_pred = [[0, 0], [0, 1]]
        >>> label_mask = [1, 1]
        >>> extract_flat_spans(start_pred, end_pred, match_pred, label_mask)
        [(1, 2)]
    """
    pseudo_input = "a"

    label_mask=[1]*len(start_pred) #TODO

    bmes_labels = ["O"] * len(start_pred)
    start_positions = [idx for idx, tmp in enumerate(start_pred) if tmp and label_mask[idx]]
    end_positions = [idx for idx, tmp in enumerate(end_pred) if tmp and label_mask[idx]]

    for start_item in start_positions:
        bmes_labels[start_item] = f"B-{pseudo_tag}"
    for end_item in end_positions:
        bmes_labels[end_item] = f"E-{pseudo_tag}"

    for tmp_start in start_positions:
        tmp_end = [tmp for tmp in end_positions if tmp >= tmp_start]
        if len(tmp_end) == 0:
            continue
        else:
            tmp_end = min(tmp_end)
        if match_pred[tmp_start][tmp_end]:
            if tmp_start != tmp_end:
                for i in range(tmp_start+1, tmp_end):
                    bmes_labels[i] = f"M-{pseudo_tag}"
            else:
                bmes_labels[tmp_end] = f"S-{pseudo_tag}"

    tags = bmes_decode([(pseudo_input, label) for label in bmes_labels])

    return [(entity.begin, entity.end-1) for entity in tags]


def extract_flat_spans_yi(start_pred, end_pred, match_pred):
    """
    Extract flat-ner spans from start/end/match logits
    Args:
        start_pred: [seq_len], 1/True for start, 0/False for non-start
        end_pred: [seq_len, 2], 1/True for end, 0/False for non-end
        match_pred: [seq_len, seq_len], 1/True for match, 0/False for non-match
        label_mask: [seq_len], 1 for valid boundary.
    Returns:
        tags: list of tuple (start, end)
    Examples:
        >>> start_pred = [0, 1]
        >>> end_pred = [0, 1]
        >>> match_pred = [[0, 0], [0, 1]]
        >>> label_mask = [1, 1]
        >>> extract_flat_spans(start_pred, end_pred, match_pred, label_mask)
        [(1, 2)]
    """
    pseudo_input = "a"

    label_mask=[1]*len(start_pred) #TODO

    bmes_labels = ["O"] * len(start_pred)
    start_positions = [idx for idx, tmp in enumerate(start_pred) if tmp and label_mask[idx]]
    end_positions = [idx for idx, tmp in enumerate(end_pred) if tmp and label_mask[idx]]

    for start_item in start_positions:
        bmes_labels[start_item] = f"B"
    for end_item in end_positions:
        bmes_labels[end_item] = f"E"

    for tmp_start in start_positions:
        tmp_end = [tmp for tmp in end_positions if tmp >= tmp_start]
        if len(tmp_end) == 0:
            continue
        else:
            tmp_end = min(tmp_end)
        if match_pred[tmp_start][tmp_end]:
            if tmp_start != tmp_end:
                for i in range(tmp_start+1, tmp_end):
                    bmes_labels[i] = f"I"
            else:
                bmes_labels[tmp_end] = f"S"

    tags = get_arg_span([tags2id[label] for label in bmes_labels])

    return tags

# Model Architecture

In [5]:
class BERT_BiLSTM_CRF(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.layers = config.layers
        self.hidden_size = config.hidden_size
        self.mlp_size = config.mlp_size
        self.dropout = nn.Dropout(p=config.dropout)
        self.scale_factor = config.scale_factor

        self.special_tokens_yi = ['[TAB]', '[LINE]',
                                  '[EQU]', '[URL]', '[NUM]',
                                  '[SPE]', '<sep>', '[q]']
        self.special_tokens_dict_yi = {'additional_special_tokens': self.special_tokens_yi}

        self.longtokenizer = LongformerTokenizer.from_pretrained(
            'allenai/longformer-base-4096')
        self.longtokenizer.add_special_tokens(self.special_tokens_dict_yi) 

        self.longformerconfig = LongformerConfig.from_pretrained(
            'allenai/longformer-base-4096')
        self.longformerconfig.attention_mode = 'sliding_chunks'
        self.longformerconfig.attention_window = [8,8,8,8,8,8,8,8,8,8,8,8]
        self.attentionwindow = self.longformerconfig.attention_window[0]

        self.longformer = LongformerModel.from_pretrained(
            'allenai/longformer-base-4096', config=self.longformerconfig)

        self.longformer.resize_token_embeddings(len(self.longtokenizer))


        self.am_bilstm = nn.LSTM(config.bert_output_size, config.hidden_size, \
                                 num_layers=1, bidirectional=config.is_bi, batch_first=True)

        self.start_outputs = nn.Linear(config.hidden_size, 1)
        self.end_outputs = nn.Linear(config.hidden_size, 1)
        self.span_embedding = MultiNonLinearClassifier(config.hidden_size*2, 1, config.mrc_dropout,
                                                       intermediate_hidden_size=128)#256

        self.span_loss_candidates = ["all", "pred_and_gold", "pred_gold_random", "gold"][0] 

        self.bce_loss = BCEWithLogitsLoss(reduction="none")

        self.weight_start = config.weight_start
        self.weight_end = config.weight_end
        self.weight_span = config.weight_span


    def compute_loss(self, start_logits, end_logits, span_logits,
                     start_labels, end_labels, match_labels):

        batch_size, seq_len = start_logits.size()

        start_l_mask = torch.ones([batch_size, seq_len],dtype=torch.long)
        end_l_mask =  torch.ones([batch_size, seq_len],dtype=torch.long)

        start_label_mask = torch.LongTensor(start_l_mask).cuda()
        end_label_mask = torch.LongTensor(end_l_mask).cuda()

        start_float_label_mask = start_label_mask.view(-1).float()
        end_float_label_mask = end_label_mask.view(-1).float()
        match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(-1, -1, seq_len)
        match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(-1, seq_len, -1)
        match_label_mask = match_label_row_mask & match_label_col_mask
        match_label_mask = torch.triu(match_label_mask, 0)  

        if self.span_loss_candidates == "all":
            # naive mask
            float_match_label_mask = match_label_mask.view(batch_size, -1).float()
        else:
            start_preds = start_logits > 0
            end_preds = end_logits > 0
            if self.span_loss_candidates == "gold":
                match_candidates = ((start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0)
                                    & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0))
            elif self.span_loss_candidates == "pred_gold_random":
                gold_and_pred = torch.logical_or(
                    (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_labels.unsqueeze(-2).expand(-1, seq_len, -1))
                )
                data_generator = torch.Generator()
                data_generator.manual_seed(0)
                random_matrix = torch.empty(batch_size, seq_len, seq_len).uniform_(0, 1)
                random_matrix = torch.bernoulli(random_matrix, generator=data_generator).long()
                random_matrix = random_matrix.cuda()
                match_candidates = torch.logical_or(
                    gold_and_pred, random_matrix
                )
            else:
                match_candidates = torch.logical_or(
                    (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_labels.unsqueeze(-2).expand(-1, seq_len, -1))
                )
            match_label_mask = match_label_mask & match_candidates
            float_match_label_mask = match_label_mask.view(batch_size, -1).float()

        start_loss = self.bce_loss(start_logits.view(-1), start_labels.view(-1).float())
        start_loss = (start_loss * start_float_label_mask).sum() / start_float_label_mask.sum()
        end_loss = self.bce_loss(end_logits.view(-1), end_labels.view(-1).float())
        end_loss = (end_loss * end_float_label_mask).sum() / end_float_label_mask.sum()
        match_loss = self.bce_loss(span_logits.view(batch_size, -1), match_labels.view(batch_size, -1).float())
        match_loss = match_loss * float_match_label_mask
        match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10)

        return start_loss, end_loss, match_loss

    def bert_emb_for_task1(self, para_tokens_list):
        para_len_list = [len(para) for para in para_tokens_list]
        max_para_len = max(para_len_list)

        question_tokens = '[q]'
        question_tokens_cls_sep = self.longtokenizer.cls_token + ' ' + question_tokens + ' ' + self.longtokenizer.sep_token 
        question_ids = self.longtokenizer.convert_tokens_to_ids(question_tokens_cls_sep.split(' '))
        question_length = len(question_ids)

        sent_tokens_list = [sent for para in para_tokens_list for sent in para]
        sent_length_list = [len(sent.split(' ')) for para in para_tokens_list for sent in para]
        passage_tokens = ' '.join(sent_tokens_list)
        passage_tokens_cls_sep = self.longtokenizer.cls_token + ' ' + passage_tokens + ' ' + self.longtokenizer.sep_token

        sent_ids = self.longtokenizer.convert_tokens_to_ids(passage_tokens_cls_sep.split(' '))  # !

        question_sents_ids = [question_ids + sent_ids]

        qs_ids_padding_list, qs_mask_list, max_len = self.padding_and_mask_with_return(question_sents_ids) 
        qs_ids_padding_tensor = torch.tensor(qs_ids_padding_list).cuda()
        qs_mask_tensor = torch.tensor(qs_mask_list).cuda()

        _, global_att_mask = self.padding_and_mask_to_max_lenth([question_ids], max_len) 
        global_att_mask_tensor = torch.tensor(global_att_mask).cuda()

        try:

            longformer_outputs = self.longformer(qs_ids_padding_tensor,
                                                 attention_mask=qs_mask_tensor,
                                                 global_attention_mask=global_att_mask_tensor)
        except:
            import traceback
            traceback.print_exc()

        last_hidden_state = longformer_outputs[0] 
        last_hidden_state_for_s = last_hidden_state[:, (question_length + 1):-1, :]

        sen_emb_list = []
        start_index = 0
        end_index = 0
        for i in range(0, len(sent_length_list)):
            end_index = start_index + sent_length_list[i]
            last_h_temp = last_hidden_state_for_s[:, start_index:end_index, :]
            sen_emb_list.append(last_h_temp.mean(dim=-2))
            start_index = end_index

        sent_emb = torch.cat(sen_emb_list, dim=0)

        sent_emb = self.dropout(sent_emb)
        return sent_emb

    def bert_emb_for_task2(self, para_tokens_list,argument_para_tokens_list):
        para_len_list = [len(para) for para in para_tokens_list]
        max_para_len = max(para_len_list)

        argument_tokens=' '.join(argument_para_tokens_list)
        argument_tokens_cls_sep=self.longtokenizer.cls_token +' '+argument_tokens+' '+self.longtokenizer.sep_token
        argument_ids=self.longtokenizer.convert_tokens_to_ids(argument_tokens_cls_sep.split(' '))
        argument_length=len(argument_ids)

        sent_tokens_list = [sent for para in para_tokens_list for sent in para]
        sent_length_list = [len(sent.split(' ')) for para in para_tokens_list for sent in para]
        passage_tokens=' '.join(sent_tokens_list)
        passage_tokens_cls_sep=self.longtokenizer.sep_token +' '+passage_tokens+' '+self.longtokenizer.sep_token 
        sent_ids = self.longtokenizer.convert_tokens_to_ids(passage_tokens_cls_sep.split(' '))

        pair_ids=[argument_ids+sent_ids]


        pair_ids_padding_list, pair_mask_list,max_len = self.padding_and_mask_with_return(pair_ids)
        pair_ids_padding_tensor = torch.tensor(pair_ids_padding_list).cuda()
        pair_mask_tensor = torch.tensor(pair_mask_list).cuda()

        _,global_att_mask=self.padding_and_mask_to_max_lenth([argument_ids],max_len)
        global_att_mask_tensor=torch.tensor(global_att_mask).cuda()

        try:

            longformer_outputs = self.longformer(pair_ids_padding_tensor,
                                                 attention_mask=pair_mask_tensor,global_attention_mask=global_att_mask_tensor) 
        except:
            import traceback
            traceback.print_exc()

        last_hidden_state = longformer_outputs[0] 

        last_hidden_state_for_s = last_hidden_state[:, (argument_length+1):-1, :]

        sen_emb_list = []
        start_index = 0
        end_index = 0
        for i in range(0, len(sent_length_list)):
            end_index = start_index + sent_length_list[i]
            last_h_temp = last_hidden_state_for_s[:, start_index:end_index, :]
            sen_emb_list.append(last_h_temp.mean(dim=-2))
            start_index = end_index

        sent_emb = torch.cat(sen_emb_list, dim=0)

        sent_emb = self.dropout(sent_emb) 
        return sent_emb




    def am_tagging_span_for_task1(self, para_tokens_list, mode):





        sent_num_list = [len(para) for para in para_tokens_list]
        sent_emb = self.bert_emb_for_task1(para_tokens_list) 

        para_emb = torch.split(sent_emb, sent_num_list, 0)

        para_emb_packed = pack_sequence(para_emb, enforce_sorted=False)
        para_lstm_out_packed, (h, c) = self.am_bilstm(para_emb_packed)
        para_lstm_out_padded, _ = pad_packed_sequence(para_lstm_out_packed, batch_first=True)  


        para_lstm_out = para_lstm_out_padded
        batch_size, seq_len, hid_size = para_lstm_out.size()

        start_logits = self.start_outputs(para_lstm_out).squeeze(
            -1)  
        end_logits = self.end_outputs(para_lstm_out).squeeze(-1)  

        start_extend = para_lstm_out.unsqueeze(2).expand(-1, -1, seq_len, -1)

        end_extend = para_lstm_out.unsqueeze(1).expand(-1, seq_len, -1, -1)

        span_matrix = torch.cat([start_extend, end_extend], 3)


        span_logits = self.span_embedding(span_matrix).squeeze(-1) 



        return start_logits,end_logits,span_logits

    def am_tagging_span_for_task2(self, rev_para_tokens_list,rep_para_tokens_list, arg_pair_sems_list,mode='train'):


        sent_num_list = [len(para) for para in rep_para_tokens_list]

        arg_num_list = []


        temp_arg_list = []
        for batch_i, pred_arguments_labeldict in enumerate(arg_pair_sems_list):  
            for rev_arg_span, label_dict in pred_arguments_labeldict.items():



                temp_argu_o = rev_para_tokens_list[batch_i][rev_arg_span[0]:rev_arg_span[1] + 1] 
                temp_arg_list.append(temp_argu_o)
                arg_num_list.append(sent_num_list[0])

        para_lstm_out_list = []

        for arg in temp_arg_list:
            sent_emb = self.bert_emb_for_task2(rep_para_tokens_list, arg)
            para_emb = torch.split(sent_emb, sent_num_list, 0)
            para_emb_packed = pack_sequence(para_emb, enforce_sorted=False)

            para_lstm_out_packed, (h, c) = self.am_bilstm(para_emb_packed)
            para_lstm_out_padded, _ = pad_packed_sequence(para_lstm_out_packed, batch_first=True)
            para_lstm_out_list.append(para_lstm_out_padded)

        try:
            lstm_out_cat = torch.cat(para_lstm_out_list, dim=0)
        except:
            import traceback
            traceback.print_exc()

        para_lstm_out = lstm_out_cat 
        batch_size, seq_len, hid_size = para_lstm_out.size()


        start_logits = self.start_outputs(para_lstm_out).squeeze(
            -1) 
        end_logits = self.end_outputs(para_lstm_out).squeeze(-1) 

        start_extend = para_lstm_out.unsqueeze(2).expand(-1, -1, seq_len, -1)

        end_extend = para_lstm_out.unsqueeze(1).expand(-1, seq_len, -1, -1)

        span_matrix = torch.cat([start_extend, end_extend], 3)

        span_logits = self.span_embedding(span_matrix).squeeze(-1) 
        return start_logits, end_logits, span_logits


    def am_tagging_with_task1_for_task2(self, rev_para_tokens_list, rep_para_tokens_list, arg_list_from_task1,
                              mode='train'):

        sent_num_list = [len(para) for para in rep_para_tokens_list]

        arg_num_list = []

        temp_arg_list = []
        for batch_i, pred_arguments in enumerate(arg_list_from_task1):
            for rev_arg_span in pred_arguments:
                temp_argu_o = rev_para_tokens_list[0][rev_arg_span[0]:rev_arg_span[1] + 1] 
                temp_arg_list.append(temp_argu_o) 

                arg_num_list.append(sent_num_list[0]) 

        para_lstm_out_list = []

        for arg in temp_arg_list:
            sent_emb = self.bert_emb_for_task2(rep_para_tokens_list, arg)

            para_emb = torch.split(sent_emb, sent_num_list, 0)
            para_emb_packed = pack_sequence(para_emb, enforce_sorted=False)

            para_lstm_out_packed, (h, c) = self.am_bilstm(para_emb_packed)
            para_lstm_out_padded, _ = pad_packed_sequence(para_lstm_out_packed, batch_first=True)
            para_lstm_out_list.append(para_lstm_out_padded)

        try:
            lstm_out_cat = torch.cat(para_lstm_out_list, dim=0)
        except:
            import traceback
            traceback.print_exc()


        para_lstm_out = lstm_out_cat 
        batch_size, seq_len, hid_size = para_lstm_out.size()

        start_logits = self.start_outputs(para_lstm_out).squeeze(
            -1) 
        end_logits = self.end_outputs(para_lstm_out).squeeze(-1)

        start_extend = para_lstm_out.unsqueeze(2).expand(-1, -1, seq_len, -1)

        end_extend = para_lstm_out.unsqueeze(1).expand(-1, seq_len, -1, -1)

        span_matrix = torch.cat([start_extend, end_extend], 3)

        span_logits = self.span_embedding(span_matrix).squeeze(-1) 

        return start_logits, end_logits, span_logits 

    def forward(self, para_tokens_list_o,para_tokens_list_for_2_o,rr_arg_pair_list_o,
                     match_labels_o, start_labels_o, end_labels_o,tag_list_o):

        total_loss_all=0

        for para_tokens,para_tokens_2,rr_arg_pair,tag ,start_labels,end_labels,match_labels in zip(para_tokens_list_o,para_tokens_list_for_2_o,rr_arg_pair_list_o,tag_list_o,start_labels_o,end_labels_o,match_labels_o):

            para_tokens_list=[para_tokens]
            para_tokens_list_for_2=[para_tokens_2]
            rr_arg_pair_list=[rr_arg_pair]



            if tag=="task1_review" or tag=="task1_reply":

                start_logits, end_logits, span_logits= self.am_tagging_span_for_task1(para_tokens_list,mode="train")

            elif tag=="task2_review" :
                # review
                start_logits, end_logits, span_logits\
                    = self.am_tagging_span_for_task2(para_tokens_list,para_tokens_list_for_2,rr_arg_pair_list,mode="train")

            elif tag == "task2_reply":
                start_logits, end_logits, span_logits \
                    = self.am_tagging_span_for_task2(para_tokens_list_for_2, para_tokens_list, rr_arg_pair_list,mode="train")


            start_loss, end_loss, match_loss = self.compute_loss(start_logits,
                                                                 end_logits, span_logits,
                                                                 start_labels.expand(1,
                                                                                        -1).cuda(),
                                                                 end_labels.expand(1,
                                                                                      -1).cuda(),
                                                                 match_labels.expand(1, -1,
                                                                                        -1).cuda())
            total_loss = self.weight_start * start_loss + self.weight_end * end_loss + self.weight_span * match_loss

            total_loss_all=total_loss_all+total_loss



        total_loss_all_pingjun=total_loss_all/len(para_tokens_list_o)

        return total_loss_all_pingjun

    def predict_span(self, review_para_tokens_list, review_tags_list,

                     reply_para_tokens_list, reply_tags_list):


        # review
        review_start_logits, review_end_logits, review_span_logits = self.am_tagging_span_for_task1(
            review_para_tokens_list,
             mode="test")
        review_start_preds, review_end_preds, review_span_preds = F.sigmoid(review_start_logits) > 0.5, F.sigmoid(
            review_end_logits) > 0.5, F.sigmoid(review_span_logits) > 0.5

        pred_rev_args_dict = {}
        pred_rev_args_dict['review_start_preds'] = review_start_preds
        pred_rev_args_dict['review_end_preds'] = review_end_preds
        pred_rev_args_dict['review_span_preds'] = review_span_preds

        # reply
        reply_start_logits, reply_end_logits, reply_span_logits = self.am_tagging_span_for_task1(reply_para_tokens_list,

                                                                                                 mode="test")

        reply_start_preds, reply_end_preds, reply_span_preds = F.sigmoid(reply_start_logits) > 0.5, F.sigmoid(
            reply_end_logits) > 0.5, F.sigmoid(reply_span_logits) > 0.5

        pred_rep_args_dict = {}
        pred_rep_args_dict['reply_start_preds'] = reply_start_preds
        pred_rep_args_dict['reply_end_preds'] = reply_end_preds
        pred_rep_args_dict['reply_span_preds'] = reply_span_preds

        pred_rev_args_list = extract_span_arguments_yi(pred_rev_args_dict['review_span_preds'],
                                                       pred_rev_args_dict['review_start_preds'],
                                                       pred_rev_args_dict['review_end_preds'])
        pred_rep_args_list = extract_span_arguments_yi(pred_rep_args_dict['reply_span_preds'],
                                                       pred_rep_args_dict['reply_start_preds'],
                                                       pred_rep_args_dict['reply_end_preds'])

        test_rev_args_list = []

        for iitem in pred_rev_args_list:
            for arg in iitem:
                test_rev_args_list.append(arg)
        if test_rev_args_list == []:
            pred_args_pair_dict_list = [{} for t in pred_rev_args_list]
        else:
            try:
                # review
                review_start_logits, review_end_logits, review_span_logits \
                    = self.am_tagging_with_task1_for_task2(review_para_tokens_list, reply_para_tokens_list,
                                                           pred_rev_args_list, mode="test")

                review_start_preds, review_end_preds, review_span_preds = F.sigmoid(
                    review_start_logits) > 0.5, F.sigmoid(
                    review_end_logits) > 0.5, F.sigmoid(review_span_logits) > 0.5

            except:
                import traceback
                traceback.print_exc()

            review_span_preds_list = [i for i in review_span_preds]
            review_start_preds_list = [i for i in review_start_preds]
            review_end_preds_list = [i for i in review_end_preds]

            pred_pair_rep_args_list = extract_span_arguments_yi(review_span_preds_list, review_start_preds_list,
                                                                review_end_preds_list)

            # true_rev_args_list = extract_arguments(review_tags_list)
            pred_args_pair_dict_list = []
            i = 0
            for true_arguments in pred_rev_args_list:
                pred_args_pair_dict = {}
                for args in true_arguments:
                    pred_args_pair_dict[args] = (pred_pair_rep_args_list[i], \
                                                 [1] * len(pred_pair_rep_args_list[i]))
                    i += 1
                pred_args_pair_dict_list.append(pred_args_pair_dict)

        # reply
        test_rep_args_list = []
        # true_rep_args_list = extract_arguments(reply_tags_list)
        for iitem in pred_rep_args_list:
            for arg in iitem:
                test_rep_args_list.append(arg)
        if test_rep_args_list == []:
            pred_args_pair_dict_2_list = [{} for t in pred_rep_args_list]
        else:
            try:
                # reply
                reply_start_logits, reply_end_logits, reply_span_logits = self.am_tagging_with_task1_for_task2(
                    reply_para_tokens_list, review_para_tokens_list, pred_rep_args_list,
                    mode="test")

                # reply_start_preds, reply_end_preds, reply_span_preds = reply_start_logits > 0, reply_end_logits > 0, reply_span_logits > 0
                reply_start_preds, reply_end_preds, reply_span_preds = F.sigmoid(reply_start_logits) > 0.5, F.sigmoid(
                    reply_end_logits) > 0.5, F.sigmoid(reply_span_logits) > 0.5


            except:
                import traceback
                traceback.print_exc()

            rebuttal_span_preds_list = [i for i in reply_span_preds]
            rebuttal_start_preds_list = [i for i in reply_start_preds]
            rebuttal_end_preds_list = [i for i in reply_end_preds]

            pred_pair_rev_args_list = extract_span_arguments_yi(rebuttal_span_preds_list, rebuttal_start_preds_list,
                                                                rebuttal_end_preds_list)

            pred_args_pair_dict_2_list = []
            i = 0
            for true_arguments in pred_rep_args_list:
                pred_args_pair_dict = {}
                for args in true_arguments:
                    pred_args_pair_dict[args] = (pred_pair_rev_args_list[i], \
                                                 [1] * len(pred_pair_rev_args_list[i]))
                    i += 1
                pred_args_pair_dict_2_list.append(pred_args_pair_dict)

        return pred_rev_args_dict, pred_rep_args_dict, pred_args_pair_dict_list, pred_args_pair_dict_2_list

    def predict_span_for_task1(self, review_para_tokens_list, review_tags_list,

                     reply_para_tokens_list, reply_tags_list):

        # evaluate for task1:

        # review
        review_start_logits, review_end_logits, review_span_logits = self.am_tagging_span_for_task1(
            review_para_tokens_list,
             mode="test")

        review_start_preds, review_end_preds, review_span_preds = F.sigmoid(review_start_logits) > 0.5, F.sigmoid(
            review_end_logits) > 0.5, F.sigmoid(review_span_logits) > 0.5

        pred_rev_args_dict = {}
        pred_rev_args_dict['review_start_preds'] = review_start_preds
        pred_rev_args_dict['review_end_preds'] = review_end_preds
        pred_rev_args_dict['review_span_preds'] = review_span_preds

        # reply
        reply_start_logits, reply_end_logits, reply_span_logits = self.am_tagging_span_for_task1(reply_para_tokens_list,

                                                                                                 mode="test")
        reply_start_preds, reply_end_preds, reply_span_preds = F.sigmoid(reply_start_logits) > 0.5, F.sigmoid(
            reply_end_logits) > 0.5, F.sigmoid(reply_span_logits) > 0.5

        pred_rep_args_dict = {}
        pred_rep_args_dict['reply_start_preds'] = reply_start_preds
        pred_rep_args_dict['reply_end_preds'] = reply_end_preds
        pred_rep_args_dict['reply_span_preds'] = reply_span_preds


        return pred_rev_args_dict, pred_rep_args_dict




    def padding_and_mask(self, ids_list):
        max_len = max([len(x) for x in ids_list])
        mask_list = []
        ids_padding_list = []
        for ids in ids_list:
            mask = [1.] * len(ids) + [0.] * (max_len - len(ids))
            ids = ids + [0] * (max_len - len(ids))
            mask_list.append(mask)
            ids_padding_list.append(ids)
        return ids_padding_list, mask_list
    def padding_and_mask_with_return(self, ids_list):
        max_len = max([len(x) for x in ids_list])
        mask_list = []
        ids_padding_list = []
        for ids in ids_list:
            mask = [1.] * len(ids) + [0.] * (max_len - len(ids))
            ids = ids + [0] * (max_len - len(ids))
            mask_list.append(mask)
            ids_padding_list.append(ids)
        return ids_padding_list, mask_list,max_len

    def padding_and_mask_to_max_lenth(self, ids_list,max_len):
        mask_list = []
        ids_padding_list = []
        for ids in ids_list:

            mask = [1.] * len(ids) +[0.] * (max_len - len(ids)) 
            ids = ids + [0] * (max_len - len(ids))
            mask_list.append(mask)
            ids_padding_list.append(ids)
        return ids_padding_list, mask_list


    def padding_matrix(self, matrix_tensor_list):
        seq_list=[]
        for matrix in matrix_tensor_list:
            seq=matrix.size()[0]
            seq_list.append(seq)

        max_seq_len = max(seq_list)

        new_matrix_list=[]
        for matrix in matrix_tensor_list:
            seq=matrix.size()[0]

            if seq<max_seq_len:

                o_t_list=torch.split(matrix,1,dim=0)
                p_o_t_list=[t[0] for t in o_t_list]
                left_num=max_seq_len-seq

                for i in range(left_num):
                    p_o_t_list.append(torch.zeros(max_seq_len,dtype=torch.long))

                new_matrix=pad_sequence(p_o_t_list,batch_first=True)
                new_matrix_list.append(new_matrix)

            else:
                new_matrix_list.append(matrix)

        padded_matrix=torch.stack(new_matrix_list)
        return padded_matrix


# Data Pre-processing

In [6]:
tags2id = {'O': 0, 'B-Review': 1, 'I-Review': 2, 'E-Review': 3, 'S-Review': 4,
           'B-Reply': 1, 'I-Reply': 2, 'E-Reply': 3, 'S-Reply': 4,
           'B': 1, 'I': 2, 'E': 3, 'S': 4}

def load_data_new_sample(file_path):

    sample_list_task1_for_review = []
    sample_list_task1_for_reply = []
    sample_list_task2_for_review_dir = []
    sample_list_task2_for_reply_dir = []

    with open(file_path, 'r') as fp:
        rr_pair_list = fp.read().split('\n\n\n')
        for rr_pair in rr_pair_list:
            if rr_pair == '':
                continue
            review, reply = rr_pair.split('\n\n')

            sample_review = {'sentences': [], 'bio_tags': [],
                             'pair_tags': [], 'text_type': None, 'sub_ids': [], 'arg_spans': []}
            for line in review.strip().split('\n'):
                sent, bio_tag, pair_tag, text_type, sub_id = line.strip().split('\t')
                sample_review['sentences'].append(sent)
                sample_review['bio_tags'].append(bio_tag)
                sample_review['pair_tags'].append(pair_tag)
                sample_review['text_type'] = text_type
                sample_review['sub_ids'] = sub_id
            tags_ids = [tags2id[t] for t in sample_review['bio_tags']]

            review_spans=get_arg_span(tags_ids)

            sample_review['arg_spans'] = review_spans

            seq_len = len(tags_ids)

            review_start_positions = []
            review_end_positions = []
            match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
            for start, end in review_spans:
                review_start_positions.append(start)
                review_end_positions.append(end)
                if start >= seq_len or end >= seq_len:
                    continue
                match_labels[start, end] = 1

            start_labels = torch.LongTensor([(1 if idx in review_start_positions else 0) for idx in range(
                seq_len)]) 
            end_labels = torch.LongTensor([(1 if idx in review_end_positions else 0) for idx in range(
                seq_len)])  
            sample_review['match_labels'] = match_labels
            sample_review['start_labels'] = start_labels
            sample_review['end_labels'] = end_labels

            sample_review['tag']="task1_review"

            sample_list_task1_for_review.append(sample_review)

            sample_reply = {'sentences': [], 'bio_tags': [],
                            'pair_tags': [], 'text_type': None, 'sub_ids': [], 'arg_spans': []}
            for line in reply.strip().split('\n'):
                sent, bio_tag, pair_tag, text_type, sub_id = line.strip().split('\t')
                sample_reply['sentences'].append(sent)
                sample_reply['bio_tags'].append(bio_tag)
                sample_reply['pair_tags'].append(pair_tag)
                sample_reply['text_type'] = text_type
                sample_reply['sub_ids'] = sub_id
            tags_ids = [tags2id[t] for t in sample_reply['bio_tags']]


            reply_spans = get_arg_span(tags_ids)

            sample_reply['arg_spans'] = reply_spans

            seq_len = len(tags_ids)

            reply_start_positions = []
            reply_end_positions = []
            match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
            for start, end in reply_spans:
                reply_start_positions.append(start)
                reply_end_positions.append(end)
                if start >= seq_len or end >= seq_len:
                    continue
                match_labels[start, end] = 1

            start_labels = torch.LongTensor([(1 if idx in reply_start_positions else 0) for idx in
                                             range(
                                                 seq_len)])
            end_labels = torch.LongTensor([(1 if idx in reply_end_positions else 0) for idx in
                                           range(
                                               seq_len)]) 

            sample_reply['match_labels'] = match_labels
            sample_reply['start_labels'] = start_labels
            sample_reply['end_labels'] = end_labels


            sample_reply['tag'] = "task1_reply"
            sample_list_task1_for_reply.append(sample_reply)

            rev_arg_2_rep_arg_dict = {}
            for rev_arg_span in sample_review['arg_spans']:
                rev_arg_pair_id = int(sample_review['pair_tags'][rev_arg_span[0]].split('-')[-1])
                rev_arg_2_rep_arg_dict[rev_arg_span] = []
                for rep_arg_span in sample_reply['arg_spans']:
                    rep_arg_pair_id = int(sample_reply['pair_tags'][rep_arg_span[0]].split('-')[-1])
                    if rev_arg_pair_id == rep_arg_pair_id:
                        rev_arg_2_rep_arg_dict[rev_arg_span].append(rep_arg_span)
            sample_review['rev_arg_2_rep_arg_dict'] = rev_arg_2_rep_arg_dict


            rep_seq_len = len(sample_reply['bio_tags'])



            for rev_arg_span, rep_arg_spans in rev_arg_2_rep_arg_dict.items():

                pair_reply_start_positions = []
                pair_reply_end_positions = []
                pair_match_labels = torch.zeros([rep_seq_len, rep_seq_len], dtype=torch.long)
                for start, end in rep_arg_spans:
                    pair_reply_start_positions.append(start)
                    pair_reply_end_positions.append(end)
                    if start >= rep_seq_len or end >= rep_seq_len:
                        continue
                    pair_match_labels[start, end] = 1

                pair_start_labels = torch.LongTensor([(1 if idx in pair_reply_start_positions else 0) for idx in range(rep_seq_len)])
                pair_end_labels = torch.LongTensor([(1 if idx in pair_reply_end_positions else 0) for idx in range(rep_seq_len)])


                sample_review_dir_temp={}
                sample_review_dir_temp['review_sentences']=sample_review['sentences']
                sample_review_dir_temp['reply_sentences'] = sample_reply['sentences']
                sample_review_dir_temp['match_labels']=pair_match_labels
                sample_review_dir_temp['start_labels'] = pair_start_labels
                sample_review_dir_temp['end_labels'] = pair_end_labels

                sample_review_dir_temp['tag'] ="task2_review"

                temp_rr_dict={}
                tags = spans_to_tags(rep_arg_spans, rep_seq_len)
                temp_rr_dict[rev_arg_span] = tags
                sample_review_dir_temp['rr_arg_dict']=temp_rr_dict

                sample_list_task2_for_review_dir.append(sample_review_dir_temp)

            rep_arg_2_rev_arg_dict = {}


            for rep_arg_span in sample_reply['arg_spans']:
                rep_arg_pair_id = int(sample_reply['pair_tags'][rep_arg_span[0]].split('-')[-1])
                rep_arg_2_rev_arg_dict[rep_arg_span] = []
                for rev_arg_span in sample_review['arg_spans']:
                    rev_arg_pair_id = int(sample_review['pair_tags'][rev_arg_span[0]].split('-')[-1])
                    if rep_arg_pair_id == rev_arg_pair_id:
                        rep_arg_2_rev_arg_dict[rep_arg_span].append(rev_arg_span)
            sample_reply['rep_arg_2_rev_arg_dict'] = rep_arg_2_rev_arg_dict




            rev_seq_len = len(sample_review['bio_tags'])


            for rep_arg_span, rev_arg_spans in rep_arg_2_rev_arg_dict.items():

                pair_review_start_positions = []
                pair_review_end_positions = []
                pair_match_labels = torch.zeros([rev_seq_len, rev_seq_len], dtype=torch.long)
                for start, end in rev_arg_spans:
                    pair_review_start_positions.append(start)
                    pair_review_end_positions.append(end)
                    if start >= rev_seq_len or end >= rev_seq_len:
                        continue
                    pair_match_labels[start, end] = 1

                pair_start_labels = torch.LongTensor([(1 if idx in pair_review_start_positions else 0) for idx in range( rev_seq_len)])
                pair_end_labels = torch.LongTensor([(1 if idx in pair_review_end_positions else 0) for idx in range(rev_seq_len)])


                sample_reply_dir_temp = {}
                sample_reply_dir_temp['review_sentences'] = sample_review['sentences']
                sample_reply_dir_temp['reply_sentences'] = sample_reply['sentences']
                sample_reply_dir_temp['match_labels'] = pair_match_labels
                sample_reply_dir_temp['start_labels'] = pair_start_labels
                sample_reply_dir_temp['end_labels'] = pair_end_labels

                sample_reply_dir_temp['tag'] = "task2_reply"

                temp_rr_dict = {}
                tags = spans_to_tags(rev_arg_spans, rev_seq_len)
                temp_rr_dict[rep_arg_span] = tags
                sample_reply_dir_temp['rr_arg_dict']= temp_rr_dict


                sample_list_task2_for_reply_dir.append(sample_reply_dir_temp)

    return sample_list_task1_for_review,sample_list_task1_for_reply,sample_list_task2_for_review_dir,sample_list_task2_for_reply_dir


def load_data(file_path):
    sample_list = []
    with open(file_path, 'r') as fp:
        rr_pair_list = fp.read().split('\n\n\n')
        for rr_pair in rr_pair_list:
            if rr_pair == '':
                continue
            review, reply = rr_pair.split('\n\n')

            sample_review = {'sentences': [], 'bio_tags': [],
                             'pair_tags': [], 'text_type': None, 'sub_ids': [], 'arg_spans': []}
            for line in review.strip().split('\n'):
                sent, bio_tag, pair_tag, text_type, sub_id = line.strip().split('\t')
                sample_review['sentences'].append(sent)
                sample_review['bio_tags'].append(bio_tag)
                sample_review['pair_tags'].append(pair_tag)
                sample_review['text_type'] = text_type
                sample_review['sub_ids'] = sub_id
            tags_ids = [tags2id[t] for t in sample_review['bio_tags']]

            review_spans=get_arg_span(tags_ids)

            sample_review['arg_spans'] = review_spans

            seq_len = len(tags_ids)

            review_start_positions = []
            review_end_positions = []
            match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
            for start, end in review_spans:
                review_start_positions.append(start)
                review_end_positions.append(end)
                if start >= seq_len or end >= seq_len:
                    continue
                match_labels[start, end] = 1

            start_labels = torch.LongTensor([(1 if idx in review_start_positions else 0) for idx in range(
                seq_len)])  
            end_labels = torch.LongTensor([(1 if idx in review_end_positions else 0) for idx in range(
                seq_len)])  
            sample_review['match_labels'] = match_labels
            sample_review['start_labels'] = start_labels
            sample_review['end_labels'] = end_labels

            sample_reply = {'sentences': [], 'bio_tags': [],
                            'pair_tags': [], 'text_type': None, 'sub_ids': [], 'arg_spans': []}
            for line in reply.strip().split('\n'):
                sent, bio_tag, pair_tag, text_type, sub_id = line.strip().split('\t')
                sample_reply['sentences'].append(sent)
                sample_reply['bio_tags'].append(bio_tag)
                sample_reply['pair_tags'].append(pair_tag)
                sample_reply['text_type'] = text_type
                sample_reply['sub_ids'] = sub_id
            tags_ids = [tags2id[t] for t in sample_reply['bio_tags']]


            reply_spans = get_arg_span(tags_ids)

            sample_reply['arg_spans'] = reply_spans

            seq_len = len(tags_ids)

            reply_start_positions = []
            reply_end_positions = []
            match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
            for start, end in reply_spans:
                reply_start_positions.append(start)
                reply_end_positions.append(end)
                if start >= seq_len or end >= seq_len:
                    continue
                match_labels[start, end] = 1

            start_labels = torch.LongTensor([(1 if idx in reply_start_positions else 0) for idx in
                                             range(
                                                 seq_len)]) 
            end_labels = torch.LongTensor([(1 if idx in reply_end_positions else 0) for idx in
                                           range(
                                               seq_len)]) 

            sample_reply['match_labels'] = match_labels
            sample_reply['start_labels'] = start_labels
            sample_reply['end_labels'] = end_labels




            rev_arg_2_rep_arg_dict = {}
            for rev_arg_span in sample_review['arg_spans']:
                rev_arg_pair_id = int(sample_review['pair_tags'][rev_arg_span[0]].split('-')[-1])
                rev_arg_2_rep_arg_dict[rev_arg_span] = []
                for rep_arg_span in sample_reply['arg_spans']:
                    rep_arg_pair_id = int(sample_reply['pair_tags'][rep_arg_span[0]].split('-')[-1])
                    if rev_arg_pair_id == rep_arg_pair_id:
                        rev_arg_2_rep_arg_dict[rev_arg_span].append(rep_arg_span)
            sample_review['rev_arg_2_rep_arg_dict'] = rev_arg_2_rep_arg_dict


            rep_seq_len = len(sample_reply['bio_tags'])

            rev_arg_2_rep_arg_dict_sem = {}
            for rev_arg_span, rep_arg_spans in rev_arg_2_rep_arg_dict.items():

                pair_reply_start_positions = []
                pair_reply_end_positions = []
                pair_match_labels = torch.zeros([rep_seq_len, rep_seq_len], dtype=torch.long)
                for start, end in rep_arg_spans:
                    pair_reply_start_positions.append(start)
                    pair_reply_end_positions.append(end)
                    if start >= rep_seq_len or end >= rep_seq_len:
                        continue
                    pair_match_labels[start, end] = 1

                pair_start_labels = torch.LongTensor([(1 if idx in pair_reply_start_positions else 0) for idx in range(rep_seq_len)])
                pair_end_labels = torch.LongTensor([(1 if idx in pair_reply_end_positions else 0) for idx in range(rep_seq_len)])

                temp_dict={}
                temp_dict['match_labels'] = pair_match_labels
                temp_dict['start_labels'] = pair_start_labels
                temp_dict['end_labels'] = pair_end_labels

                rev_arg_2_rep_arg_dict_sem[rev_arg_span] =temp_dict

            sample_review['rev_arg_2_rep_arg_dict_sem'] = rev_arg_2_rep_arg_dict_sem


            rev_arg_2_rep_arg_tags_dict = {}
            for rev_arg_span, rep_arg_spans in rev_arg_2_rep_arg_dict.items():
                tags = spans_to_tags(rep_arg_spans, rep_seq_len)
                rev_arg_2_rep_arg_tags_dict[rev_arg_span] = tags
            sample_review['rev_arg_2_rep_arg_tags_dict'] = rev_arg_2_rep_arg_tags_dict

            rep_arg_2_rev_arg_dict = {}
            for rep_arg_span in sample_reply['arg_spans']:
                rep_arg_pair_id = int(sample_reply['pair_tags'][rep_arg_span[0]].split('-')[-1])
                rep_arg_2_rev_arg_dict[rep_arg_span] = []
                for rev_arg_span in sample_review['arg_spans']:
                    rev_arg_pair_id = int(sample_review['pair_tags'][rev_arg_span[0]].split('-')[-1])
                    if rep_arg_pair_id == rev_arg_pair_id:
                        rep_arg_2_rev_arg_dict[rep_arg_span].append(rev_arg_span)
            sample_reply['rep_arg_2_rev_arg_dict'] = rep_arg_2_rev_arg_dict



            rev_seq_len = len(sample_review['bio_tags'])

            rep_arg_2_rev_arg_dict_sem={}
            for rep_arg_span, rev_arg_spans in rep_arg_2_rev_arg_dict.items():

                pair_review_start_positions = []
                pair_review_end_positions = []
                pair_match_labels = torch.zeros([rev_seq_len, rev_seq_len], dtype=torch.long)
                for start, end in rev_arg_spans:
                    pair_review_start_positions.append(start)
                    pair_review_end_positions.append(end)
                    if start >= rev_seq_len or end >= rev_seq_len:
                        continue
                    pair_match_labels[start, end] = 1

                pair_start_labels = torch.LongTensor([(1 if idx in pair_review_start_positions else 0) for idx in range( rev_seq_len)])
                pair_end_labels = torch.LongTensor([(1 if idx in pair_review_end_positions else 0) for idx in range(rev_seq_len)])

                temp_dict={}
                temp_dict['match_labels'] = pair_match_labels
                temp_dict['start_labels'] = pair_start_labels
                temp_dict['end_labels'] = pair_end_labels

                rep_arg_2_rev_arg_dict_sem[rep_arg_span] = temp_dict
            sample_reply['rep_arg_2_rev_arg_dict_sem'] = rep_arg_2_rev_arg_dict_sem


            rep_arg_2_rev_arg_tags_dict = {}
            for rep_arg_span, rev_arg_spans in rep_arg_2_rev_arg_dict.items():
                tags = spans_to_tags(rev_arg_spans, rev_seq_len)
                rep_arg_2_rev_arg_tags_dict[rep_arg_span] = tags
            sample_reply['rep_arg_2_rev_arg_tags_dict'] = rep_arg_2_rev_arg_tags_dict

            sample_list.append({'review': sample_review,
                                'reply': sample_reply})
    return sample_list

# Evaluation

In [7]:
def args_metric(true_args_list, pred_args_list):
    tp, tn, fp, fn = 0, 0, 0, 0
    for true_args, pred_args in zip(true_args_list, pred_args_list):
        true_args_set = set(true_args)
        pred_args_set = set(pred_args)
        assert len(true_args_set) == len(true_args)
        assert len(pred_args_set) == len(pred_args)
        tp += len(true_args_set & pred_args_set)
        fp += len(pred_args_set - true_args_set)
        fn += len(true_args_set - pred_args_set)
    if tp + fp == 0:
        pre = tp / (tp + fp + 1e-10)
    else:
        pre = tp / (tp + fp)
    if tp + fn == 0:
        rec = tp / (tp + fn + 1e-10)
    else:
        rec = tp / (tp + fn)
    if pre == 0. and rec == 0.:
        f1 = (2 * pre * rec) / (pre + rec + 1e-10)
    else:
        f1 = (2 * pre * rec) / (pre + rec)
    acc = (tp + tn) / (tp + tn + fp + fn)
    return {'pre': pre, 'rec': rec, 'f1': f1, 'acc': acc}


def evaluate(model, data_list):
    data_len = len(data_list)
    model.eval()

    all_true_rev_args_list = []
    all_pred_rev_args_list = []
    all_true_rep_args_list = []
    all_pred_rep_args_list = []

    all_true_arg_pairs_list = []

    all_pred_arg_pairs_list = []
    all_pred_arg_pairs_list_from_rev = []
    all_pred_arg_pairs_list_from_rep = []


    for batch_i in tqdm(range(data_len)):
        data_batch = data_list[batch_i :(batch_i + 1) ]


        review_para_tokens_list, review_tags_list = [], []
        reply_para_tokens_list, reply_tags_list = [], []

        review_match_labels, review_start_labels, review_end_labels = [], [], []
        reply_match_labels, reply_start_labels, reply_end_labels = [], [], []




        rev_arg_2_rep_arg_sems_list = []
        rep_arg_2_rev_arg_sems_list = []

        true_arg_pairs_list = []


        for sample in data_batch:

            review_para_tokens_list.append(sample['review']['sentences'])
            tags_ids = [tags2id[tag] for tag in sample['review']['bio_tags']]
            review_tags_list.append(tags_ids)
            # review for task1
            review_match_labels.append(sample['review']['match_labels'])
            review_start_labels.append(sample['review']['start_labels'])
            review_end_labels.append(sample['review']['end_labels'])
            # review for task2

            rep_arg_2_rev_arg_sems_list.append(sample['reply']['rep_arg_2_rev_arg_dict_sem'])

            reply_para_tokens_list.append(sample['reply']['sentences'])
            tags_ids = [tags2id[tag] for tag in sample['reply']['bio_tags']]
            reply_tags_list.append(tags_ids)
            #reply for task1
            reply_match_labels.append(sample['reply']['match_labels'])
            reply_start_labels.append(sample['reply']['start_labels'])
            reply_end_labels.append(sample['reply']['end_labels'])
            # reply for task2

            rev_arg_2_rep_arg_sems_list.append(sample['review']['rev_arg_2_rep_arg_dict_sem'])

            #task2 total

            arg_pairs = []
            for rev_arg, rep_args in sample['review']['rev_arg_2_rep_arg_dict'].items():
                for rep_arg in rep_args:
                    arg_pairs.append((rev_arg, rep_arg))
            true_arg_pairs_list.append(arg_pairs)



        with torch.no_grad():



            pred_rev_args_dict, pred_rep_args_dict,pred_pair_args_list, pred_pair_args_2_list = \
                model.predict_span(review_para_tokens_list, review_tags_list,
                              reply_para_tokens_list, reply_tags_list)

        true_rev_args_list_span = extract_span_arguments_yi(review_match_labels, review_start_labels, review_end_labels)
        all_true_rev_args_list.extend(true_rev_args_list_span)
        pred_rev_args_list_span = extract_span_arguments_yi(pred_rev_args_dict['review_span_preds'], pred_rev_args_dict['review_start_preds'], pred_rev_args_dict['review_end_preds'])
        all_pred_rev_args_list.extend(pred_rev_args_list_span)

        true_rep_args_list_span = extract_span_arguments_yi(reply_match_labels, reply_start_labels, reply_end_labels)
        all_true_rep_args_list.extend(true_rep_args_list_span)
        pred_rep_args_list_span = extract_span_arguments_yi(pred_rep_args_dict['reply_span_preds'],
                                                         pred_rep_args_dict['reply_start_preds'],
                                                         pred_rep_args_dict['reply_end_preds'])
        all_pred_rep_args_list.extend(pred_rep_args_list_span)

        all_true_arg_pairs_list.extend(true_arg_pairs_list) 

        pred_arg_pairs_list = []
        for pred_rep_args in pred_pair_args_list: 
            pred_arg_pairs = []

            for rev_arg, rep_args in pred_rep_args.items():
                for rep_arg, rep_arg_prob in zip(rep_args[0], rep_args[1]):
                    pred_arg_pairs.append((rev_arg, rep_arg))

            pred_arg_pairs_list.append(pred_arg_pairs)

        pred_arg_pairs_2_list = []
        for pred_rep_args_2 in pred_pair_args_2_list: 
            pred_arg_pairs = []

            for rep_arg, rev_args in pred_rep_args_2.items():
                for rev_arg, rev_arg_prob in zip(rev_args[0], rev_args[1]):
                    pred_arg_pairs.append((rev_arg, rep_arg))

            pred_arg_pairs_2_list.append(pred_arg_pairs)

        all_pred_arg_pairs_list.extend(
            [list(set(a + b)) for a, b in zip(pred_arg_pairs_list, pred_arg_pairs_2_list)]) 
        all_pred_arg_pairs_list_from_rev.extend([a for a in pred_arg_pairs_list])
        all_pred_arg_pairs_list_from_rep.extend([b for b in pred_arg_pairs_2_list])  

    args_pair_dict = args_metric(all_true_arg_pairs_list, all_pred_arg_pairs_list) 

    return  args_pair_dict

# Training

In [8]:
config = get_config()
now = datetime.datetime.now()
now_time_string = "{:0>4d}{:0>2d}{:0>2d}_{:0>2d}{:0>2d}{:0>2d}_{:0>5d}".format(
    now.year, now.month, now.day, now.hour, now.minute, now.second, config.seed)
save_path = './saved_models'
save_path = os.path.join(save_path, now_time_string)
if not os.path.exists(save_path):
    os.makedirs(save_path)
else:
    print("save_path exists!!")
    exit(1)
with open(os.path.join(save_path, "config.json"), "w") as fp:
    json.dump(config.__dict__, fp)
logger = logging.getLogger()

logger.warning('> training arguments:')
for arg in vars(config):
    logger.warning('>>> {0}: {1}'.format(arg, getattr(config, arg)))


train_list_task1_review,train_list_task1_reply, train_list_task2_review,train_list_task2_reply= \
    load_data_new_sample('./data/processed/train.txt.bioes')



dev_list = load_data('./data/processed/dev.txt.bioes')
test_list = load_data('./data/processed/test.txt.bioes')


train_list=train_list_task1_review+train_list_task1_reply+train_list_task2_review+train_list_task2_reply


train_len = len(train_list)
train_iter_len = (train_len // config.batch_size) + 1
if train_len % config.batch_size == 0:
    train_iter_len -= 1
num_training_steps = train_iter_len * config.epochs
num_warmup_steps = int(num_training_steps * config.warm_up)
logger.warning('Data loaded.')

logger.warning('Initializing model...')
model = BERT_BiLSTM_CRF(config)
model.cuda()
logger.warning('Model initialized.')


longformer_model_para = list(model.longformer.parameters())
lstm_para=list(model.am_bilstm.parameters())
other_model_para = list(set(model.parameters()) - set(longformer_model_para)-set(lstm_para))



longformer_base_encoder_lr=1e-5
lstm_para_lr=1e-3
finetune_lr=1e-3

optimizer_grouped_parameters = [
    {'params': [p for p in other_model_para if len(p.data.size()) > 1], 'weight_decay': config.weight_decay},
    {'params': [p for p in other_model_para if len(p.data.size()) == 1], 'weight_decay': 0.0},
    {'params': longformer_model_para, 'lr': longformer_base_encoder_lr},
    {'params': lstm_para, 'lr': lstm_para_lr}
]

optimizer = AdamW(optimizer_grouped_parameters, finetune_lr)

total_batch, early_stop = 0, 0
best_batch, best_f1 = 0, 0.0


random.shuffle(train_list)


for epoch_i in range(config.epochs):
    logger.warning("Running epoch: {}".format(epoch_i))
    loss_0, loss_1 = None, None
    last_loss_0, last_loss_1 = 0, 0
    bw_flag = False

    batch_id = 0
    for batch_i in tqdm(range(train_iter_len)):
        try:
            model.train()
            train_batch = train_list[batch_i * config.batch_size:(batch_i + 1) * config.batch_size]
            if len(train_batch) <= 1:
                    continue

            para_tokens_list= []
            match_labels, start_labels,end_labels = [], [],[]

            para_tokens_list_for_2 = []

            rr_arg_pair_list=[]

            sample_tags_list=[]




            tt=[]

            for sample in train_batch:
                sample_tags_list.append(sample['tag'])

                if "task1" in sample['tag']:
                    para_tokens_list.append(sample['sentences'])

                    para_tokens_list_for_2.append([])
                    rr_arg_pair_list.append({})

                    match_labels.append(sample['match_labels'])
                    start_labels.append(sample['start_labels'])
                    end_labels.append(sample['end_labels'])



                elif "task2" in sample['tag']:
                    para_tokens_list.append(sample['review_sentences'])
                    para_tokens_list_for_2.append(sample['reply_sentences'])

                    rr_arg_pair_list.append(sample['rr_arg_dict'])

                    match_labels.append(sample['match_labels'])
                    start_labels.append(sample['start_labels'])
                    end_labels.append(sample['end_labels'])




            loss = model(para_tokens_list,para_tokens_list_for_2,rr_arg_pair_list,match_labels,start_labels,end_labels,tag_list_o=sample_tags_list)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            optimizer.step()


            total_batch += 1
            batch_id += 1
        except UnboundLocalError:
            print('Bad')
            continue

    # evaluate
    t_start = time.time()

    dev_args_pair_dict=evaluate(model, dev_list)

    t_end = time.time()
    total_f1 = dev_args_pair_dict['f1']
    if total_f1 > best_f1:
        early_stop = 0
        best_f1 = total_f1
        torch.save(model.state_dict(), os.path.join(save_path, 'best_model.mdl'))
        logger.warning('*' * 20 + 'best' + '*' * 20)
        best_batch = total_batch
        logger.warning('*' * 20 + 'the performance in valid set...' + '*' * 20)
        logger.warning('running time: {}'.format(t_end - t_start))
        logger.warning('total batch: {}'.format(total_batch))
        logger.warning('total pair f1:\t{:.4f}'.format(
            dev_args_pair_dict['f1']))

        test_args_pair_dict = evaluate(
            model, test_list)
        logger.warning('*' * 20 + 'the performance in test set...' + '*' * 20)
        logger.warning('total pair f1:\t{:.4f}'.format(
            test_args_pair_dict['f1']))


> training arguments:
>>> save_path: ./saved_models
>>> device: 1
>>> seed: 42
>>> batch_size: 4
>>> epochs: 2
>>> showtime: 2000
>>> base_encoder_lr: 1e-05
>>> longformer_base_encoder_lr: 1e-05
>>> finetune_lr: 0.001
>>> warm_up: 0.05
>>> weight_decay: 1e-05
>>> early_num: 5
>>> num_tags: 5
>>> threshold: 0.3
>>> hidden_size: 128
>>> layers: 2
>>> is_bi: False
>>> bert_output_size: 768
>>> mlp_size: 512
>>> scale_factor: 2
>>> dropout: 0.7
>>> max_grad_norm: 1.0
>>> num_heads: 4
>>> att_dropout: 0.1
>>> mrc_dropout: 0.7
>>> lstm_dropout: 0.4
>>> classifier_act_func: gelu
>>> classifier_intermediate_hidden_size: 128
>>> weight_start: 1.0
>>> weight_end: 1.0
>>> weight_span: 0.1
Data loaded.
Initializing model...
Model initialized.
Running epoch: 0
100%|█████████████████████████████████████| 9971/9971 [1:33:13<00:00,  1.78it/s]
100%|█████████████████████████████████████████| 473/473 [01:45<00:00,  4.46it/s]
********************best********************
********************the performance

Please note that the best performing model is saved in 'saved_models' folder, it also has a json file with the expirement results.