<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [None]:
!pip install datasets



In [None]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [None]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [None]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [None]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }"""

    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'neither':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'no_relation':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], [], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [None]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      embed_sent1_std = embed_sent1[:batch_size // 2, :, :]
      labels_std = torch.tensor(labels[:batch_size // 2]).to(device)

      if self.num_classes == 2:
        emb_attack = embed_sent1_std[(labels_std == torch.tensor([0,1]).to(device)).all(dim=1)]
        emb_support = embed_sent1_std[(labels_std == torch.tensor([1,0]).to(device)).all(dim=1)]
      else:
        emb_attack = embed_sent1_std[(labels_std == torch.tensor([0,1,0]).to(device)).all(dim=1)]
        emb_support = embed_sent1_std[(labels_std == torch.tensor([1,0,0]).to(device)).all(dim=1) | (labels_std == torch.tensor([0,0,1]).to(device)).all(dim=1)]

      samples_adv = H_sent[batch_size // 2:, :]
      embed_sent1_adv = embed_sent1[batch_size // 2:, :, :]
      labels_adv = torch.tensor(labels[batch_size // 2:]).to(device)

      emb_contr = embed_sent1_adv[(labels_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)]
      emb_other = embed_sent1_adv[(labels_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (labels_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      mean_grl_attack = GRLayer.apply(torch.mean(torch.cat([emb_attack, emb_contr], dim=0), dim=1), .01)
      mean_grl_support = GRLayer.apply(torch.mean(torch.cat([emb_support, emb_other], dim=0), dim=1), .01)

      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl_attack)
      support_prediction = self.support_linear(mean_grl_support)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class BaselineModelWithSentenceComparisonAndCue(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparisonAndCue, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.attention = attention
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*3, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    """self.linear_initial_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.linear_end_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)"""
    self.sigmoid = torch.nn.Sigmoid()

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    """self._init_weights(self.linear_initial_sent)
    self._init_weights(self.linear_end_sent)"""

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)[0]
    else:
      att_output = embed_sent1

    initial_sent1 = att_output[:,0,:]
    initial_sent2 = att_output[torch.arange(att_output.shape[0]), torch.argmax(segs_sent1 * torch.arange(att_output.shape[1], 0, -1).to(device), dim=-1)]
    end_sent2 = att_output[torch.arange(att_output.shape[0]), torch.sum(att_mask_sent1, dim=-1)-1]

    initial_sent1 = self.linear_initial_sent(initial_sent1)
    end_sent2 = self.linear_end_sent(end_sent2)
    gate = self.sigmoid(initial_sent1 + end_sent2)

    final_emb = initial_sent2 * gate

    predictions = self.linear_layer(final_emb)

    return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]
    self.attention = attention

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

      H_sent = torch.mean(att_output[0], dim=1)
    else:
      H_sent = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [None]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [None]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import LinearLR

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 2, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": False,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": 10, #[9.375, 1, 30], #10 [2.071, 1, 1.933]
    "lr": 1e-5
}

config = {
    "dataset": "student_essay", #"student_essay", #debate, m-arg
    "adversarial": False,
    "double_adversarial": False,
    "dataset_from_saved": False,
    "injection": True,
    "finetuning_discovery": True,
    "grid_search": False,
    "visualize": False,
    "train": True,
    "scheduler": False,
    "attention": False,
    "cue_gating": False
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3, rl=0):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)

            attack_len = torch.sum((targets == torch.tensor([0,1]).to(device)).all(dim=1)).item()
            support_len = torch.sum((targets == torch.tensor([1,0]).to(device)).all(dim=1)).item()
            contr_len = torch.sum((targets_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)).item()
            other_len = torch.sum((targets_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (targets_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)).item()

            attack_target = [[0,1]] * attack_len + [[1,0]] * contr_len
            support_target = [[0,1]] * support_len + [[1,0]] * other_len
            attack_target, support_target = torch.tensor(np.array(attack_target)).to(device), torch.tensor(np.array(support_target)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + rl*loss4 + rl*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
          if config["double_adversarial"]:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          else:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
          loss = loss_fn(out, labels.float())
          val_loss += loss.item()

          labels = labels.cpu().numpy().tolist()

          val_labels.extend(labels)
          if len(labels[0]) != 2:
            for pred in preds:
              if pred == 0:
                val_preds.append([1,0,0])
              elif pred == 1:
                val_preds.append([0,1,0])
              else:
                val_preds.append([0,0,1])
          else:
            val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()





def run(seed):
  set_random_seeds(seed)

  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/train_essay.txt"
    path_dev = "./data/student_essay/dev_essay.txt"
    path_test = "./data/student_essay/test_essay.txt"
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/train_debate_concept.txt"
    path_dev = "./data/debate/dev_debate_concept.txt"
    path_test = "./data/debate/test_debate_concept.txt"
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKWithDiscourseInjectionProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")

  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  if config["adversarial"] or config["double_adversarial"] or config["finetuning_discovery"]:
    df = datasets.load_dataset("discovery","discovery", trust_remote_code=True)
    adv_processor = DiscourseMarkerProcessor()
    if not config["dataset_from_saved"]:
      print("processing discourse marker dataset...")
      train_adv = adv_processor.process_dataset(df["train"])
      with open("./adv_dataset.pkl", "wb") as writer:
        pickle.dump(train_adv, writer)
    else:
      with open("./adv_dataset.pkl", "rb") as reader:
        train_adv = pickle.load(reader)

    data_train_tot = data_train + train_adv
    train_set_adv = dataset(train_adv)
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if config["double_adversarial"]:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = DoubleAdversarialNet()

  elif not config["adversarial"]:
    if config["finetuning_discovery"]:
      sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
      train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    else:
      train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

    if not config["cue_gating"]:
      model = BaselineModelWithSentenceComparison(attention=config["attention"])
    else:
      model = BaselineModelWithSentenceComparisonAndCue(attention=config["attention"])
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

  if config["dataset"] in ["m-arg","nk"]:
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0.8,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)
    range_local = np.arange(0.2, 1, 0.2)

    for rl in reversed(range_local):
      print(f"rl {rl}")
      for discovery_weight in reversed(range_disc):
        for adv_weight in range_adv:
          for epoch in range(args["epochs"]):
            print('===== Start training: epoch {} ====='.format(epoch + 1))
            print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}")
            train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight, rl=rl)
            dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
            test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
            if dev_f1 > best_dev_f1:
              best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
              best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
              #save model

          print('best result:')
          print(best_test_acc)
          print(best_test_pre)
          print(best_test_rec)
          print(best_test_f1)
          result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
          del model
          del optimizer

          set_random_seeds(seed)
          model = DoubleAdversarialNet()
          model = model.to(device)

          optimizer_grouped_parameters = [
            {
              "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
              "weight_decay": 0.01,
            },
            {
              "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
              "weight_decay": 0.0
            },
          ]
          optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

          best_acc = -1
          best_pre = -1
          best_rec = -1
          best_f1 = -1
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1
  else:
    if config["scheduler"]:
      scheduler = LinearLR(optimizer, start_factor=1, end_factor=1e-2, total_iters = 30)
    for epoch in range(args["epochs"]):
      if config["train"]:
        print('===== Start training: epoch {} ====='.format(epoch + 1))
        train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
        dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
        test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
        if config["scheduler"]:
          scheduler.step()
        if dev_f1 > best_dev_f1:
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
          best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
          torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

    if config["visualize"] and config["adversarial"]:
      model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
      visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

        #save model

    print('best result:')
    print(best_test_acc)
    print(best_test_pre)
    print(best_test_rec)
    print(best_test_f1)
    result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

  print(result_metrics)
  return result_metrics[0]

if __name__ == "__main__":
  results = []
  for seed in args["seed"]:
    print(f"**** trying with seed {seed} ****")
    result_metrics = run(seed)
    results.append(result_metrics)
  avg = torch.mean(torch.tensor(results), dim=0)
  print(avg)

**** trying with seed 0 ****


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
tokenizing...: 100%|██████████| 3070/3070 [00:01<00:00, 2611.86it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 1142/1142 [00:00<00:00, 2589.91it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 1100/1100 [00:00<00:00, 2748.89it/s]


finished preprocessing examples in test
processing discourse marker dataset...


processing labels...: 341840it [00:00, 356700.24it/s]


one hot encoding...


creating results...: 341840it [00:02, 168034.00it/s]
tokenizing...: 100%|██████████| 341840/341840 [02:43<00:00, 2088.83it/s]


finished preprocessing examples in train


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration:   2%|▏         | 95/5390 [00:15<14:44,  5.99it/s]


Timing: 15.876251935958862, Epoch: 1, training loss: 150.9464988708496, current learning rate 1e-05
val loss: 12.76092380285263
accuracy:      0.508
precision:     0.539
recall:        0.595
f1:            0.441
val loss: 12.771366119384766
accuracy:      0.488
precision:     0.526
recall:        0.586
f1:            0.406
===== Start training: epoch 2 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:49,  6.39it/s]


Timing: 14.883723735809326, Epoch: 2, training loss: 127.77315616607666, current learning rate 1e-05
val loss: 7.655857503414154
accuracy:      0.817
precision:     0.632
recall:        0.709
f1:            0.653
val loss: 7.677141904830933
accuracy:      0.840
precision:     0.614
recall:        0.713
f1:            0.638
===== Start training: epoch 3 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:50,  6.38it/s]


Timing: 14.90134859085083, Epoch: 3, training loss: 112.87254571914673, current learning rate 1e-05
val loss: 12.284650981426239
accuracy:      0.685
precision:     0.593
recall:        0.715
f1:            0.572
val loss: 10.780708402395248
accuracy:      0.708
precision:     0.585
recall:        0.751
f1:            0.564
===== Start training: epoch 4 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:46,  6.40it/s]


Timing: 14.841840267181396, Epoch: 4, training loss: 102.66222274303436, current learning rate 1e-05
val loss: 7.422014504671097
accuracy:      0.841
precision:     0.653
recall:        0.716
f1:            0.674
val loss: 6.059508219361305
accuracy:      0.864
precision:     0.635
recall:        0.721
f1:            0.662
===== Start training: epoch 5 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:50,  6.37it/s]


Timing: 14.910173177719116, Epoch: 5, training loss: 93.86647111177444, current learning rate 1e-05
val loss: 9.604003638029099
accuracy:      0.787
precision:     0.614
recall:        0.702
f1:            0.630
val loss: 8.090518683195114
accuracy:      0.832
precision:     0.621
recall:        0.748
f1:            0.646
===== Start training: epoch 6 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:47,  6.40it/s]


Timing: 14.857859134674072, Epoch: 6, training loss: 86.95515930652618, current learning rate 1e-05
val loss: 7.307601824402809
accuracy:      0.860
precision:     0.659
recall:        0.670
f1:            0.664
val loss: 6.119189061224461
accuracy:      0.879
precision:     0.642
recall:        0.689
f1:            0.660
===== Start training: epoch 7 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:47,  6.40it/s]


Timing: 14.858104228973389, Epoch: 7, training loss: 85.58084058761597, current learning rate 1e-05
val loss: 8.318109080195427
accuracy:      0.856
precision:     0.653
recall:        0.670
f1:            0.661
val loss: 7.095708519220352
accuracy:      0.880
precision:     0.657
recall:        0.735
f1:            0.684
===== Start training: epoch 8 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:46,  6.41it/s]


Timing: 14.82907247543335, Epoch: 8, training loss: 80.69181430339813, current learning rate 1e-05
val loss: 8.812437653541565
accuracy:      0.902
precision:     0.788
recall:        0.633
f1:            0.672
val loss: 6.010969966650009
accuracy:      0.925
precision:     0.758
recall:        0.659
f1:            0.694
===== Start training: epoch 9 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:46,  6.40it/s]


Timing: 14.838223934173584, Epoch: 9, training loss: 71.80607751011848, current learning rate 1e-05
val loss: 9.4593755453825
accuracy:      0.898
precision:     0.759
recall:        0.634
f1:            0.669
val loss: 6.404832694679499
accuracy:      0.919
precision:     0.730
recall:        0.671
f1:            0.695
===== Start training: epoch 10 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:45,  6.42it/s]


Timing: 14.810510396957397, Epoch: 10, training loss: 71.51240155100822, current learning rate 1e-05
val loss: 9.713210582733154
accuracy:      0.893
precision:     0.735
recall:        0.651
f1:            0.680
val loss: 6.29538282006979
accuracy:      0.916
precision:     0.720
recall:        0.684
f1:            0.700
===== Start training: epoch 11 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:48,  6.39it/s]


Timing: 14.86477255821228, Epoch: 11, training loss: 68.66000989079475, current learning rate 1e-05
val loss: 10.401883393526077
accuracy:      0.898
precision:     0.767
recall:        0.628
f1:            0.664
val loss: 7.82975409924984
accuracy:      0.922
precision:     0.744
recall:        0.642
f1:            0.676
===== Start training: epoch 12 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:44,  6.42it/s]


Timing: 14.804808378219604, Epoch: 12, training loss: 72.72030627727509, current learning rate 1e-05
val loss: 9.177199512720108
accuracy:      0.886
precision:     0.715
recall:        0.698
f1:            0.706
val loss: 8.048134721815586
accuracy:      0.895
precision:     0.679
recall:        0.733
f1:            0.701
===== Start training: epoch 13 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:46,  6.40it/s]


Timing: 14.841634273529053, Epoch: 13, training loss: 67.62393149733543, current learning rate 1e-05
val loss: 11.75826609134674
accuracy:      0.867
precision:     0.678
recall:        0.697
f1:            0.687
val loss: 10.154047951102257
accuracy:      0.884
precision:     0.665
recall:        0.747
f1:            0.694
===== Start training: epoch 14 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:44,  6.42it/s]


Timing: 14.807340383529663, Epoch: 14, training loss: 67.7711111009121, current learning rate 1e-05
val loss: 9.419311076402664
accuracy:      0.904
precision:     0.794
recall:        0.641
f1:            0.681
val loss: 6.9188534170389175
accuracy:      0.919
precision:     0.729
recall:        0.656
f1:            0.684
===== Start training: epoch 15 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:45,  6.42it/s]


Timing: 14.810283422470093, Epoch: 15, training loss: 67.15634950995445, current learning rate 1e-05
val loss: 9.879152536392212
accuracy:      0.892
precision:     0.731
recall:        0.664
f1:            0.689
val loss: 7.579069584608078
accuracy:      0.905
precision:     0.690
recall:        0.694
f1:            0.692
===== Start training: epoch 16 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:45,  6.42it/s]


Timing: 14.809823036193848, Epoch: 16, training loss: 66.65035888552666, current learning rate 1e-05
val loss: 10.90168908238411
accuracy:      0.900
precision:     0.799
recall:        0.605
f1:            0.642
val loss: 7.579465351998806
accuracy:      0.928
precision:     0.788
recall:        0.646
f1:            0.688
===== Start training: epoch 17 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:45,  6.42it/s]


Timing: 14.816314458847046, Epoch: 17, training loss: 65.70685294270515, current learning rate 1e-05
val loss: 11.044442057609558
accuracy:      0.898
precision:     0.761
recall:        0.641
f1:            0.676
val loss: 9.160956298932433
accuracy:      0.915
precision:     0.716
recall:        0.674
f1:            0.692
===== Start training: epoch 18 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:43,  6.43it/s]


Timing: 14.780370712280273, Epoch: 18, training loss: 65.50196185708046, current learning rate 1e-05
val loss: 12.125016629695892
accuracy:      0.898
precision:     0.775
recall:        0.614
f1:            0.650
val loss: 9.222573705017567
accuracy:      0.926
precision:     0.769
recall:        0.660
f1:            0.697
===== Start training: epoch 19 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:44,  6.42it/s]


Timing: 14.796000003814697, Epoch: 19, training loss: 65.48205897212029, current learning rate 1e-05
val loss: 12.410035818815231
accuracy:      0.888
precision:     0.716
recall:        0.662
f1:            0.683
val loss: 9.791502445936203
accuracy:      0.906
precision:     0.698
recall:        0.719
f1:            0.708
===== Start training: epoch 20 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:44,  6.42it/s]


Timing: 14.803081750869751, Epoch: 20, training loss: 65.56891238689423, current learning rate 1e-05
val loss: 12.443947970867157
accuracy:      0.901
precision:     0.774
recall:        0.646
f1:            0.683
val loss: 9.389081627130508
accuracy:      0.918
precision:     0.725
recall:        0.650
f1:            0.678
===== Start training: epoch 21 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:44,  6.42it/s]


Timing: 14.801534652709961, Epoch: 21, training loss: 65.54334595799446, current learning rate 1e-05
val loss: 12.174232602119446
accuracy:      0.889
precision:     0.717
recall:        0.636
f1:            0.663
val loss: 8.919702850282192
accuracy:      0.910
precision:     0.695
recall:        0.661
f1:            0.676
===== Start training: epoch 22 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:43,  6.43it/s]


Timing: 14.782433032989502, Epoch: 22, training loss: 66.08615601062775, current learning rate 1e-05
val loss: 12.569361060857773
accuracy:      0.897
precision:     0.769
recall:        0.603
f1:            0.637
val loss: 8.249371901154518
accuracy:      0.928
precision:     0.794
recall:        0.636
f1:            0.679
===== Start training: epoch 23 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:45,  6.42it/s]


Timing: 14.814512491226196, Epoch: 23, training loss: 67.26995718479156, current learning rate 1e-05
val loss: 11.851400926709175
accuracy:      0.897
precision:     0.747
recall:        0.667
f1:            0.696
val loss: 8.526175877079368
accuracy:      0.922
precision:     0.742
recall:        0.697
f1:            0.717
===== Start training: epoch 24 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:44,  6.42it/s]


Timing: 14.80164623260498, Epoch: 24, training loss: 62.78454276919365, current learning rate 1e-05
val loss: 11.860674768686295
accuracy:      0.899
precision:     0.785
recall:        0.611
f1:            0.648
val loss: 8.158947845338844
accuracy:      0.919
precision:     0.729
recall:        0.626
f1:            0.658
===== Start training: epoch 25 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:45,  6.42it/s]


Timing: 14.809626817703247, Epoch: 25, training loss: 63.44243574142456, current learning rate 1e-05
val loss: 12.734661819413304
accuracy:      0.898
precision:     0.761
recall:        0.641
f1:            0.676
val loss: 9.478027865290642
accuracy:      0.919
precision:     0.730
recall:        0.661
f1:            0.687
===== Start training: epoch 26 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:47,  6.40it/s]


Timing: 14.84651255607605, Epoch: 26, training loss: 65.15125197172165, current learning rate 1e-05
val loss: 15.213513225317001
accuracy:      0.862
precision:     0.663
recall:        0.674
f1:            0.668
val loss: 11.524922624230385
accuracy:      0.884
precision:     0.660
recall:        0.727
f1:            0.684
===== Start training: epoch 27 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:43,  6.43it/s]


Timing: 14.782958269119263, Epoch: 27, training loss: 59.92863065004349, current learning rate 1e-05
val loss: 12.544994294643402
accuracy:      0.896
precision:     0.765
recall:        0.599
f1:            0.632
val loss: 9.49543958902359
accuracy:      0.924
precision:     0.758
recall:        0.633
f1:            0.671
===== Start training: epoch 28 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:43,  6.43it/s]


Timing: 14.77909255027771, Epoch: 28, training loss: 62.898434937000275, current learning rate 1e-05
val loss: 12.990433409810066
accuracy:      0.898
precision:     0.753
recall:        0.654
f1:            0.686
val loss: 9.923605099320412
accuracy:      0.921
precision:     0.738
recall:        0.682
f1:            0.705
===== Start training: epoch 29 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:44,  6.43it/s]


Timing: 14.79203462600708, Epoch: 29, training loss: 65.37443107366562, current learning rate 1e-05
val loss: 13.364997312426567
accuracy:      0.897
precision:     0.751
recall:        0.643
f1:            0.677
val loss: 9.29461932182312
accuracy:      0.919
precision:     0.730
recall:        0.676
f1:            0.698
===== Start training: epoch 30 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:44,  6.43it/s]


Timing: 14.790428161621094, Epoch: 30, training loss: 61.417923986911774, current learning rate 1e-05
val loss: 12.973185658454895
accuracy:      0.903
precision:     0.815
recall:        0.613
f1:            0.653
val loss: 8.860819804191124
accuracy:      0.923
precision:     0.751
recall:        0.638
f1:            0.674
best result:
0.8954545454545455
0.6793472794931108
0.7330563391019288
0.7011088507221824
[[0.8954545454545455, 0.6793472794931108, 0.7330563391019288, 0.7011088507221824]]
**** trying with seed 1 ****


tokenizing...: 100%|██████████| 3070/3070 [00:01<00:00, 2473.30it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 1142/1142 [00:00<00:00, 2612.42it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 1100/1100 [00:00<00:00, 2725.39it/s]


finished preprocessing examples in test
processing discourse marker dataset...


processing labels...: 341840it [00:00, 596707.57it/s] 


one hot encoding...


creating results...: 341840it [00:01, 171547.12it/s]
tokenizing...: 100%|██████████| 341840/341840 [02:44<00:00, 2077.32it/s]


finished preprocessing examples in train


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:55,  6.34it/s]


Timing: 14.987873077392578, Epoch: 1, training loss: 159.0615634918213, current learning rate 1e-05
val loss: 15.17016077041626
accuracy:      0.114
precision:     0.057
recall:        0.500
f1:            0.102


  _warn_prf(average, modifier, msg_start, len(result))


val loss: 15.196088194847107
accuracy:      0.083
precision:     0.041
recall:        0.500
f1:            0.076


  _warn_prf(average, modifier, msg_start, len(result))


===== Start training: epoch 2 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:55,  6.34it/s]


Timing: 15.001513957977295, Epoch: 2, training loss: 137.32079219818115, current learning rate 1e-05
val loss: 15.917571246623993
accuracy:      0.431
precision:     0.542
recall:        0.595
f1:            0.394
val loss: 16.046040892601013
accuracy:      0.429
precision:     0.535
recall:        0.609
f1:            0.376
===== Start training: epoch 3 =====


Iteration:   2%|▏         | 95/5390 [00:15<13:56,  6.33it/s]


Timing: 15.017078399658203, Epoch: 3, training loss: 129.56168270111084, current learning rate 1e-05
val loss: 8.253778457641602
accuracy:      0.831
precision:     0.626
recall:        0.667
f1:            0.641
val loss: 7.8842573165893555
accuracy:      0.836
precision:     0.589
recall:        0.651
f1:            0.605
===== Start training: epoch 4 =====


Iteration:   2%|▏         | 95/5390 [00:15<13:56,  6.33it/s]


Timing: 15.008034706115723, Epoch: 4, training loss: 117.84793996810913, current learning rate 1e-05
val loss: 6.267931312322617
accuracy:      0.874
precision:     0.661
recall:        0.604
f1:            0.622
val loss: 5.1739542707800865
accuracy:      0.889
precision:     0.632
recall:        0.630
f1:            0.631
===== Start training: epoch 5 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.965025663375854, Epoch: 5, training loss: 104.40108436346054, current learning rate 1e-05
val loss: 12.84192556142807
accuracy:      0.687
precision:     0.589
recall:        0.706
f1:            0.570
val loss: 13.102476119995117
accuracy:      0.695
precision:     0.575
recall:        0.724
f1:            0.549
===== Start training: epoch 6 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:55,  6.34it/s]


Timing: 15.000738620758057, Epoch: 6, training loss: 96.69857442378998, current learning rate 1e-05
val loss: 8.33682256937027
accuracy:      0.858
precision:     0.635
recall:        0.618
f1:            0.626
val loss: 6.4792593121528625
accuracy:      0.876
precision:     0.619
recall:        0.643
f1:            0.629
===== Start training: epoch 7 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.34it/s]


Timing: 14.980412006378174, Epoch: 7, training loss: 92.96796107292175, current learning rate 1e-05
val loss: 9.789389610290527
accuracy:      0.799
precision:     0.601
recall:        0.656
f1:            0.615
val loss: 8.556258767843246
accuracy:      0.820
precision:     0.594
recall:        0.687
f1:            0.611
===== Start training: epoch 8 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.947750568389893, Epoch: 8, training loss: 90.44561365246773, current learning rate 1e-05
val loss: 9.116103053092957
accuracy:      0.871
precision:     0.669
recall:        0.642
f1:            0.654
val loss: 7.474031453486532
accuracy:      0.880
precision:     0.627
recall:        0.650
f1:            0.637
===== Start training: epoch 9 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.34it/s]


Timing: 14.986447811126709, Epoch: 9, training loss: 81.07779884338379, current learning rate 1e-05
val loss: 9.731638759374619
accuracy:      0.824
precision:     0.635
recall:        0.703
f1:            0.655
val loss: 9.119096115231514
accuracy:      0.824
precision:     0.605
recall:        0.714
f1:            0.626
===== Start training: epoch 10 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:55,  6.34it/s]


Timing: 14.996875047683716, Epoch: 10, training loss: 81.45955249667168, current learning rate 1e-05
val loss: 8.34014019370079
accuracy:      0.867
precision:     0.665
recall:        0.653
f1:            0.659
val loss: 7.5256321132183075
accuracy:      0.876
precision:     0.632
recall:        0.673
f1:            0.648
===== Start training: epoch 11 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.35it/s]


Timing: 14.974445819854736, Epoch: 11, training loss: 74.69286504387856, current learning rate 1e-05
val loss: 10.635081857442856
accuracy:      0.855
precision:     0.656
recall:        0.680
f1:            0.666
val loss: 10.37672808766365
accuracy:      0.868
precision:     0.636
recall:        0.708
f1:            0.660
===== Start training: epoch 12 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.970444440841675, Epoch: 12, training loss: 73.4967727959156, current learning rate 1e-05
val loss: 10.11563329398632
accuracy:      0.889
precision:     0.717
recall:        0.605
f1:            0.634
val loss: 8.133613228797913
accuracy:      0.906
precision:     0.677
recall:        0.639
f1:            0.655
===== Start training: epoch 13 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.942755699157715, Epoch: 13, training loss: 71.62140420079231, current learning rate 1e-05
val loss: 9.8019700050354
accuracy:      0.874
precision:     0.683
recall:        0.671
f1:            0.677
val loss: 7.862205605022609
accuracy:      0.885
precision:     0.645
recall:        0.677
f1:            0.658
===== Start training: epoch 14 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.964881420135498, Epoch: 14, training loss: 68.30782276391983, current learning rate 1e-05
val loss: 10.887401044368744
accuracy:      0.886
precision:     0.705
recall:        0.611
f1:            0.637
val loss: 7.746800924418494
accuracy:      0.915
precision:     0.709
recall:        0.648
f1:            0.672
===== Start training: epoch 15 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:51,  6.37it/s]


Timing: 14.930608749389648, Epoch: 15, training loss: 66.98246994614601, current learning rate 1e-05
val loss: 11.758288562297821
accuracy:      0.878
precision:     0.685
recall:        0.640
f1:            0.657
val loss: 9.080216273665428
accuracy:      0.902
precision:     0.679
recall:        0.687
f1:            0.683
===== Start training: epoch 16 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:51,  6.37it/s]


Timing: 14.930643320083618, Epoch: 16, training loss: 67.78478211164474, current learning rate 1e-05
val loss: 10.827847689390182
accuracy:      0.884
precision:     0.690
recall:        0.592
f1:            0.616
val loss: 9.289036966860294
accuracy:      0.908
precision:     0.678
recall:        0.625
f1:            0.645
===== Start training: epoch 17 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:50,  6.37it/s]


Timing: 14.910489320755005, Epoch: 17, training loss: 68.70098182559013, current learning rate 1e-05
val loss: 11.639507621526718
accuracy:      0.890
precision:     0.728
recall:        0.569
f1:            0.591
val loss: 8.578164177946746
accuracy:      0.914
precision:     0.694
recall:        0.598
f1:            0.625
===== Start training: epoch 18 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:51,  6.37it/s]


Timing: 14.923145532608032, Epoch: 18, training loss: 65.13284707069397, current learning rate 1e-05
val loss: 12.069988071918488
accuracy:      0.894
precision:     0.783
recall:        0.565
f1:            0.586
val loss: 9.850180357694626
accuracy:      0.922
precision:     0.752
recall:        0.602
f1:            0.638
===== Start training: epoch 19 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.36it/s]


Timing: 14.95387601852417, Epoch: 19, training loss: 65.8318375647068, current learning rate 1e-05
val loss: 12.355957269668579
accuracy:      0.876
precision:     0.683
recall:        0.658
f1:            0.669
val loss: 10.249228924512863
accuracy:      0.901
precision:     0.677
recall:        0.686
f1:            0.681
===== Start training: epoch 20 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:51,  6.36it/s]


Timing: 14.932335615158081, Epoch: 20, training loss: 65.82565608620644, current learning rate 1e-05
val loss: 13.009421467781067
accuracy:      0.842
precision:     0.649
recall:        0.703
f1:            0.669
val loss: 10.201337575912476
accuracy:      0.878
precision:     0.653
recall:        0.729
f1:            0.680
===== Start training: epoch 21 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.936057329177856, Epoch: 21, training loss: 66.67470797896385, current learning rate 1e-05
val loss: 12.380649447441101
accuracy:      0.878
precision:     0.691
recall:        0.663
f1:            0.675
val loss: 8.981074020266533
accuracy:      0.905
precision:     0.691
recall:        0.699
f1:            0.695
===== Start training: epoch 22 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.943199396133423, Epoch: 22, training loss: 64.4259817302227, current learning rate 1e-05
val loss: 11.81115285307169
accuracy:      0.884
precision:     0.696
recall:        0.603
f1:            0.628
val loss: 8.78561718761921
accuracy:      0.920
precision:     0.734
recall:        0.641
f1:            0.673
===== Start training: epoch 23 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.35it/s]


Timing: 14.977908372879028, Epoch: 23, training loss: 64.40333938598633, current learning rate 1e-05
val loss: 11.586619436740875
accuracy:      0.891
precision:     0.735
recall:        0.590
f1:            0.618
val loss: 8.584087178111076
accuracy:      0.918
precision:     0.723
recall:        0.600
f1:            0.632
===== Start training: epoch 24 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.941396236419678, Epoch: 24, training loss: 64.98578980565071, current learning rate 1e-05
val loss: 11.436928123235703
accuracy:      0.881
precision:     0.687
recall:        0.621
f1:            0.643
val loss: 8.04466242599301
accuracy:      0.917
precision:     0.723
recall:        0.675
f1:            0.695
===== Start training: epoch 25 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.942144632339478, Epoch: 25, training loss: 62.07431757450104, current learning rate 1e-05
val loss: 12.30281913280487
accuracy:      0.881
precision:     0.690
recall:        0.631
f1:            0.652
val loss: 9.040312990546227
accuracy:      0.912
precision:     0.701
recall:        0.662
f1:            0.679
===== Start training: epoch 26 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.935915470123291, Epoch: 26, training loss: 63.07877251505852, current learning rate 1e-05
val loss: 11.234161548316479
accuracy:      0.889
precision:     0.717
recall:        0.599
f1:            0.627
val loss: 8.765177071094513
accuracy:      0.920
precision:     0.734
recall:        0.631
f1:            0.664
===== Start training: epoch 27 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.942296504974365, Epoch: 27, training loss: 65.39891719818115, current learning rate 1e-05
val loss: 11.721087455749512
accuracy:      0.885
precision:     0.704
recall:        0.627
f1:            0.652
val loss: 8.338267615064979
accuracy:      0.914
precision:     0.706
recall:        0.653
f1:            0.674
===== Start training: epoch 28 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.36it/s]


Timing: 14.953381299972534, Epoch: 28, training loss: 65.79686054587364, current learning rate 1e-05
val loss: 11.867193758487701
accuracy:      0.884
precision:     0.705
recall:        0.653
f1:            0.673
val loss: 9.62272784113884
accuracy:      0.903
precision:     0.680
recall:        0.682
f1:            0.681
===== Start training: epoch 29 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.956310033798218, Epoch: 29, training loss: 64.22579285502434, current learning rate 1e-05
val loss: 11.950492590665817
accuracy:      0.887
precision:     0.708
recall:        0.604
f1:            0.631
val loss: 8.455747872591019
accuracy:      0.914
precision:     0.700
recall:        0.623
f1:            0.649
===== Start training: epoch 30 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.961649417877197, Epoch: 30, training loss: 64.20027175545692, current learning rate 1e-05
val loss: 9.373592793941498
accuracy:      0.885
precision:     0.704
recall:        0.627
f1:            0.652
val loss: 7.72110316157341
accuracy:      0.907
precision:     0.689
recall:        0.670
f1:            0.678
best result:
0.8845454545454545
0.6446834632219494
0.6771202038793713
0.6584143285041872
[[0.8845454545454545, 0.6446834632219494, 0.6771202038793713, 0.6584143285041872]]
**** trying with seed 2 ****


tokenizing...: 100%|██████████| 3070/3070 [00:01<00:00, 2545.29it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 1142/1142 [00:00<00:00, 2476.92it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 1100/1100 [00:00<00:00, 2663.49it/s]


finished preprocessing examples in test
processing discourse marker dataset...


processing labels...: 341840it [00:00, 557199.99it/s]


one hot encoding...


creating results...: 341840it [00:01, 174423.80it/s]
tokenizing...: 100%|██████████| 341840/341840 [02:46<00:00, 2049.35it/s]


finished preprocessing examples in train


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:55,  6.34it/s]


Timing: 14.99788784980774, Epoch: 1, training loss: 154.35326957702637, current learning rate 1e-05
val loss: 10.074627816677094
accuracy:      0.831
precision:     0.578
recall:        0.576
f1:            0.577
val loss: 9.945144057273865
accuracy:      0.848
precision:     0.562
recall:        0.582
f1:            0.570
===== Start training: epoch 2 =====


Iteration:   2%|▏         | 95/5390 [00:15<13:56,  6.33it/s]


Timing: 15.018234729766846, Epoch: 2, training loss: 129.9287829399109, current learning rate 1e-05
val loss: 8.09757849574089
accuracy:      0.813
precision:     0.626
recall:        0.697
f1:            0.645
val loss: 7.2295843958854675
accuracy:      0.843
precision:     0.613
recall:        0.704
f1:            0.636
===== Start training: epoch 3 =====


Iteration:   2%|▏         | 95/5390 [00:15<13:57,  6.32it/s]


Timing: 15.031432628631592, Epoch: 3, training loss: 116.37761998176575, current learning rate 1e-05
val loss: 9.484829127788544
accuracy:      0.764
precision:     0.609
recall:        0.712
f1:            0.619
val loss: 8.93992018699646
accuracy:      0.771
precision:     0.589
recall:        0.725
f1:            0.594
===== Start training: epoch 4 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:55,  6.33it/s]


Timing: 15.004732608795166, Epoch: 4, training loss: 103.98394644260406, current learning rate 1e-05
val loss: 8.517603665590286
accuracy:      0.802
precision:     0.612
recall:        0.681
f1:            0.629
val loss: 7.40632164478302
accuracy:      0.837
precision:     0.619
recall:        0.731
f1:            0.643
===== Start training: epoch 5 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.34it/s]


Timing: 14.982292175292969, Epoch: 5, training loss: 92.52786588668823, current learning rate 1e-05
val loss: 6.49713322520256
accuracy:      0.891
precision:     0.730
recall:        0.614
f1:            0.644
val loss: 4.811421405524015
accuracy:      0.914
precision:     0.698
recall:        0.613
f1:            0.640
===== Start training: epoch 6 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.35it/s]


Timing: 14.97658371925354, Epoch: 6, training loss: 86.39452385902405, current learning rate 1e-05
val loss: 7.647403448820114
accuracy:      0.883
precision:     0.701
recall:        0.662
f1:            0.679
val loss: 6.1765428483486176
accuracy:      0.899
precision:     0.665
recall:        0.660
f1:            0.662
===== Start training: epoch 7 =====


Iteration:   2%|▏         | 95/5390 [00:15<13:56,  6.33it/s]


Timing: 15.012154817581177, Epoch: 7, training loss: 81.27321547269821, current learning rate 1e-05
val loss: 7.909253492951393
accuracy:      0.863
precision:     0.672
recall:        0.698
f1:            0.683
val loss: 7.2119584530591965
accuracy:      0.870
precision:     0.636
recall:        0.704
f1:            0.659
===== Start training: epoch 8 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.35it/s]


Timing: 14.979225397109985, Epoch: 8, training loss: 77.42655086517334, current learning rate 1e-05
val loss: 9.19505162537098
accuracy:      0.866
precision:     0.669
recall:        0.670
f1:            0.669
val loss: 7.811859756708145
accuracy:      0.885
precision:     0.650
recall:        0.692
f1:            0.667
===== Start training: epoch 9 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.955780744552612, Epoch: 9, training loss: 75.33553159236908, current learning rate 1e-05
val loss: 9.449019618332386
accuracy:      0.885
precision:     0.705
recall:        0.634
f1:            0.658
val loss: 7.150707423686981
accuracy:      0.911
precision:     0.699
recall:        0.666
f1:            0.681
===== Start training: epoch 10 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.96714735031128, Epoch: 10, training loss: 73.40307071805, current learning rate 1e-05
val loss: 8.958689734339714
accuracy:      0.883
precision:     0.699
recall:        0.649
f1:            0.668
val loss: 7.532649835105985
accuracy:      0.894
precision:     0.649
recall:        0.647
f1:            0.648
===== Start training: epoch 11 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.35it/s]


Timing: 14.969286680221558, Epoch: 11, training loss: 71.91310900449753, current learning rate 1e-05
val loss: 11.207918509840965
accuracy:      0.888
precision:     0.712
recall:        0.585
f1:            0.610
val loss: 8.752194195985794
accuracy:      0.912
precision:     0.692
recall:        0.622
f1:            0.646
===== Start training: epoch 12 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.948330402374268, Epoch: 12, training loss: 71.92647475004196, current learning rate 1e-05
val loss: 9.736115738749504
accuracy:      0.879
precision:     0.681
recall:        0.617
f1:            0.638
val loss: 8.13346728682518
accuracy:      0.892
precision:     0.647
recall:        0.651
f1:            0.649
===== Start training: epoch 13 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.933910369873047, Epoch: 13, training loss: 62.521889090538025, current learning rate 1e-05
val loss: 12.259908884763718
accuracy:      0.863
precision:     0.661
recall:        0.661
f1:            0.661
val loss: 11.262651026248932
accuracy:      0.886
precision:     0.653
recall:        0.693
f1:            0.670
===== Start training: epoch 14 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.941821813583374, Epoch: 14, training loss: 69.4032379090786, current learning rate 1e-05
val loss: 10.385044127702713
accuracy:      0.893
precision:     0.743
recall:        0.605
f1:            0.636
val loss: 8.213078051805496
accuracy:      0.915
precision:     0.706
recall:        0.633
f1:            0.660
===== Start training: epoch 15 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.96064567565918, Epoch: 15, training loss: 65.94618913531303, current learning rate 1e-05
val loss: 10.539205119013786
accuracy:      0.896
precision:     0.776
recall:        0.586
f1:            0.616
val loss: 8.527230516076088
accuracy:      0.920
precision:     0.735
recall:        0.616
f1:            0.650
===== Start training: epoch 16 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.951199531555176, Epoch: 16, training loss: 69.79974290728569, current learning rate 1e-05
val loss: 11.1146881878376
accuracy:      0.870
precision:     0.667
recall:        0.642
f1:            0.653
val loss: 8.685752754448913
accuracy:      0.894
precision:     0.660
recall:        0.677
f1:            0.668
===== Start training: epoch 17 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.35it/s]


Timing: 14.96899127960205, Epoch: 17, training loss: 67.66500088572502, current learning rate 1e-05
val loss: 9.705354124307632
accuracy:      0.897
precision:     0.765
recall:        0.610
f1:            0.644
val loss: 8.460539504885674
accuracy:      0.913
precision:     0.697
recall:        0.627
f1:            0.652
===== Start training: epoch 18 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.36it/s]


Timing: 14.952572584152222, Epoch: 18, training loss: 66.56602251529694, current learning rate 1e-05
val loss: 12.152133375406265
accuracy:      0.891
precision:     0.747
recall:        0.567
f1:            0.588
val loss: 8.365506321191788
accuracy:      0.927
precision:     0.806
recall:        0.610
f1:            0.653
===== Start training: epoch 19 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.949109554290771, Epoch: 19, training loss: 67.52537137269974, current learning rate 1e-05
val loss: 10.186021864414215
accuracy:      0.888
precision:     0.714
recall:        0.628
f1:            0.656
val loss: 7.243452280759811
accuracy:      0.914
precision:     0.707
recall:        0.658
f1:            0.678
===== Start training: epoch 20 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:55,  6.34it/s]


Timing: 14.99424695968628, Epoch: 20, training loss: 63.165514558553696, current learning rate 1e-05
val loss: 11.359315946698189
accuracy:      0.882
precision:     0.698
recall:        0.658
f1:            0.675
val loss: 8.627693235874176
accuracy:      0.896
precision:     0.664
recall:        0.674
f1:            0.668
===== Start training: epoch 21 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.35it/s]


Timing: 14.973899126052856, Epoch: 21, training loss: 66.71120396256447, current learning rate 1e-05
val loss: 9.928361237049103
accuracy:      0.890
precision:     0.721
recall:        0.616
f1:            0.645
val loss: 7.615862473845482
accuracy:      0.911
precision:     0.696
recall:        0.651
f1:            0.670
===== Start training: epoch 22 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.964300394058228, Epoch: 22, training loss: 63.43852159380913, current learning rate 1e-05
val loss: 10.297458909451962
accuracy:      0.890
precision:     0.721
recall:        0.626
f1:            0.655
val loss: 7.961976483464241
accuracy:      0.911
precision:     0.698
recall:        0.661
f1:            0.677
===== Start training: epoch 23 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.968461990356445, Epoch: 23, training loss: 63.65582889318466, current learning rate 1e-05
val loss: 11.515232414007187
accuracy:      0.900
precision:     0.783
recall:        0.622
f1:            0.660
val loss: 8.267764836549759
accuracy:      0.921
precision:     0.740
recall:        0.637
f1:            0.670
===== Start training: epoch 24 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:52,  6.36it/s]


Timing: 14.94305968284607, Epoch: 24, training loss: 68.35607093572617, current learning rate 1e-05
val loss: 11.009116977453232
accuracy:      0.891
precision:     0.729
recall:        0.630
f1:            0.660
val loss: 8.70960883796215
accuracy:      0.914
precision:     0.703
recall:        0.638
f1:            0.662
===== Start training: epoch 25 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.958280086517334, Epoch: 25, training loss: 61.83197122812271, current learning rate 1e-05
val loss: 11.441120684146881
accuracy:      0.879
precision:     0.694
recall:        0.670
f1:            0.681
val loss: 8.786233162507415
accuracy:      0.899
precision:     0.680
recall:        0.710
f1:            0.693
===== Start training: epoch 26 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.36it/s]


Timing: 14.954594373703003, Epoch: 26, training loss: 60.20278522372246, current learning rate 1e-05
val loss: 10.260380938649178
accuracy:      0.900
precision:     0.774
recall:        0.635
f1:            0.673
val loss: 7.463360438821837
accuracy:      0.920
precision:     0.735
recall:        0.626
f1:            0.660
===== Start training: epoch 27 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.964948415756226, Epoch: 27, training loss: 61.815477669239044, current learning rate 1e-05
val loss: 10.829384915530682
accuracy:      0.898
precision:     0.761
recall:        0.641
f1:            0.676
val loss: 7.945683341473341
accuracy:      0.918
precision:     0.724
recall:        0.640
f1:            0.670
===== Start training: epoch 28 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:54,  6.35it/s]


Timing: 14.971467733383179, Epoch: 28, training loss: 64.46495661139488, current learning rate 1e-05
val loss: 9.93121612071991
accuracy:      0.891
precision:     0.728
recall:        0.647
f1:            0.675
val loss: 7.202801614999771
accuracy:      0.916
precision:     0.720
recall:        0.684
f1:            0.700
===== Start training: epoch 29 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.966679573059082, Epoch: 29, training loss: 62.07310065627098, current learning rate 1e-05
val loss: 10.15686671435833
accuracy:      0.891
precision:     0.724
recall:        0.647
f1:            0.674
val loss: 7.5486591558437794
accuracy:      0.912
precision:     0.703
recall:        0.672
f1:            0.686
===== Start training: epoch 30 =====


Iteration:   2%|▏         | 95/5390 [00:14<13:53,  6.35it/s]


Timing: 14.968964099884033, Epoch: 30, training loss: 59.29863676428795, current learning rate 1e-05
val loss: 10.02195779979229
accuracy:      0.899
precision:     0.766
recall:        0.641
f1:            0.678
val loss: 7.7739994172006845
accuracy:      0.922
precision:     0.743
recall:        0.657
f1:            0.689
best result:
0.87
0.6361866573686656
0.7041843191496313
0.6593579124032266
[[0.87, 0.6361866573686656, 0.7041843191496313, 0.6593579124032266]]


In [None]:
from google.colab import runtime
runtime.unassign()