<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [2]:
!pip install datasets



In [3]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [4]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [5]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [6]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }

    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }"""

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'neither':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'no_relation':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], [], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [7]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      labels_std = labels[:batch_size // 2, :]

      emb_attack = samples[labels_std == [0,1]]
      emb_support = samples[labels_std == [1,0]]

      samples_adv = H_sent[batch_size // 2:, ]
      labels_adv = labels[batch_size // 2:, :]

      emb_caus = samples_adv[labels_adv == [1,0,0]]
      emb_other = samples_adv[labels_adv == [0,1,0] or labels_adv == [0,0,1]]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl)
      support_prediction = self.support_linear(mean_grl)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class BaselineModelWithSentenceComparisonAndCue(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparisonAndCue, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.attention = attention
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.linear_initial_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.linear_end_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.sigmoid = torch.nn.Sigmoid()

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.linear_initial_sent)
    self._init_weights(self.linear_end_sent)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)[0]
    else:
      att_output = embed_sent1

    initial_sent1 = att_output[:,0,:]
    initial_sent2 = att_output[torch.arange(att_output.shape[0]), torch.argmax(segs_sent1 * torch.arange(att_output.shape[1], 0, -1).to(device), dim=-1)]
    end_sent2 = att_output[torch.arange(att_output.shape[0]), torch.sum(att_mask_sent1, dim=-1)-1]

    initial_sent1 = self.linear_initial_sent(initial_sent1)
    end_sent2 = self.linear_end_sent(end_sent2)
    gate = self.sigmoid(initial_sent1 + end_sent2)

    final_emb = initial_sent2 * gate

    predictions = self.linear_layer(final_emb)

    return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]
    self.attention = attention

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

      H_sent = torch.mean(att_output[0], dim=1)
    else:
      H_sent = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [8]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [9]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [None]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import LinearLR

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 2, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": True,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": 1, #[9.375, 30, 1], #10 [2.071, 1.933, 1]
    "lr": 1e-5
}

config = {
    "dataset": "debate", #"student_essay", #debate, m-arg
    "adversarial": False,
    "double_adversarial": False,
    "dataset_from_saved": False,
    "injection": False,
    "grid_search": False,
    "visualize": True,
    "train": True,
    "scheduler": True,
    "attention": True,
    "cue_gating": False
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
            attack_target = targets[targets == [0,1]]
            support_target = targets[targets == [1,0]]

            cause_target = targets_adv[targets_adv == [1,0,0]]
            other_target = targets_adv[targets_adv == [0,1,0] or targets_adv == [0,0,1]]
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + .2*loss4 + .2*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
            preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
            loss = loss_fn(out, labels.float())
            val_loss += loss.item()

            labels = labels.cpu().numpy().tolist()

            val_labels.extend(labels)
            if len(labels[0]) != 2:
              for pred in preds:
                if pred == 0:
                  val_preds.append([1,0,0])
                elif pred == 1:
                  val_preds.append([0,1,0])
                else:
                  val_preds.append([0,0,1])
            else:
              val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

def run(seed):
  set_random_seeds(seed)

  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/train_essay.txt"
    path_dev = "./data/student_essay/dev_essay.txt"
    path_test = "./data/student_essay/test_essay.txt"
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/train_debate_concept.txt"
    path_dev = "./data/debate/dev_debate_concept.txt"
    path_test = "./data/debate/test_debate_concept.txt"
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKWithDiscourseInjectionProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")

  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  if config["adversarial"]:
    df = datasets.load_dataset("discovery","discovery")
    adv_processor = DiscourseMarkerProcessor()
    if not config["dataset_from_saved"]:
      print("processing discourse marker dataset...")
      train_adv = adv_processor.process_dataset(df["train"])
      with open("./adv_dataset.pkl", "wb") as writer:
        pickle.dump(train_adv, writer)
    else:
      with open("./adv_dataset.pkl", "rb") as reader:
        train_adv = pickle.load(reader)

    data_train_tot = data_train + train_adv
    train_set_adv = dataset(train_adv)
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if not config["adversarial"]:
    #sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    #train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
    if not config["cue_gating"]:
      model = BaselineModelWithSentenceComparison(attention=config["attention"])
    else:
      model = BaselineModelWithSentenceComparisonAndCue(attention=config["attention"])
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

  if config["dataset"] in ["m-arg","nk"]:
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)

    for discovery_weight in range_disc:
      for adv_weight in range_adv:
        for epoch in range(args["epochs"]):
          print('===== Start training: epoch {} ====='.format(epoch + 1))
          print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}")
          train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight)
          dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
          test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
          if dev_f1 > best_dev_f1:
            best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
            best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
            #save model

        print('best result:')
        print(best_test_acc)
        print(best_test_pre)
        print(best_test_rec)
        print(best_test_f1)
        result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
        del model
        del optimizer

        set_random_seeds(args["seed"])
        model = AdversarialNet()
        model = model.to(device)

        optimizer_grouped_parameters = [
          {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
          },
          {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0
          },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

        best_acc = -1
        best_pre = -1
        best_rec = -1
        best_f1 = -1
        best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1
  else:
    if config["scheduler"]:
      scheduler = LinearLR(optimizer, start_factor=1, end_factor=1e-1, total_iters = 30)
    for epoch in range(args["epochs"]):
      if config["train"]:
        print('===== Start training: epoch {} ====='.format(epoch + 1))
        train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
        dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
        test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
        if config["scheduler"]:
          scheduler.step()
        if dev_f1 > best_dev_f1:
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
          best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
          torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

    if config["visualize"] and config["adversarial"]:
      model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
      visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

        #save model

    print('best result:')
    print(best_test_acc)
    print(best_test_pre)
    print(best_test_rec)
    print(best_test_f1)
    result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

  print(result_metrics)
  return result_metrics[0]

if __name__ == "__main__":
  results = []
  for seed in args["seed"]:
    print(f"**** trying with seed {seed} ****")
    result_metrics = run(seed)
    results.append(result_metrics)
  avg = torch.mean(torch.tensor(results), dim=0)
  print(avg)

**** trying with seed 0 ****


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
tokenizing...: 100%|██████████| 6486/6486 [00:01<00:00, 3590.79it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 2164/2164 [00:00<00:00, 2377.12it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 2162/2162 [00:00<00:00, 3299.74it/s]


finished preprocessing examples in test


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration: 100%|██████████| 102/102 [00:17<00:00,  5.78it/s]


Timing: 17.639450550079346, Epoch: 1, training loss: 70.62908869981766, current learning rate 1e-05
val loss: 23.03021949529648
accuracy:      0.624
precision:     0.624
recall:        0.624
f1:            0.624
val loss: 23.04170197248459
accuracy:      0.636
precision:     0.636
recall:        0.636
f1:            0.635
===== Start training: epoch 2 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.471906423568726, Epoch: 2, training loss: 58.65974044799805, current learning rate 9.7e-06
val loss: 19.13801297545433
accuracy:      0.709
precision:     0.727
recall:        0.709
f1:            0.703
val loss: 19.751294374465942
accuracy:      0.689
precision:     0.708
recall:        0.695
f1:            0.686
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.17it/s]


Timing: 16.54024076461792, Epoch: 3, training loss: 47.54796892404556, current learning rate 9.4e-06
val loss: 18.109326124191284
accuracy:      0.748
precision:     0.748
recall:        0.749
f1:            0.748
val loss: 18.518449366092682
accuracy:      0.731
precision:     0.730
recall:        0.730
f1:            0.730
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.18it/s]


Timing: 16.498409748077393, Epoch: 4, training loss: 39.4118529856205, current learning rate 9.1e-06
val loss: 18.582179725170135
accuracy:      0.748
precision:     0.751
recall:        0.749
f1:            0.748
val loss: 18.261507391929626
accuracy:      0.741
precision:     0.742
recall:        0.738
f1:            0.739
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.484004735946655, Epoch: 5, training loss: 34.01193705201149, current learning rate 8.799999999999999e-06
val loss: 19.826998233795166
accuracy:      0.741
precision:     0.744
recall:        0.742
f1:            0.741
val loss: 19.02538162469864
accuracy:      0.744
precision:     0.744
recall:        0.742
f1:            0.742
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.489577531814575, Epoch: 6, training loss: 28.776869401335716, current learning rate 8.499999999999998e-06
val loss: 21.09551775455475
accuracy:      0.743
precision:     0.743
recall:        0.743
f1:            0.743
val loss: 20.37622308731079
accuracy:      0.739
precision:     0.741
recall:        0.741
f1:            0.739
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.47192668914795, Epoch: 7, training loss: 24.10760423541069, current learning rate 8.199999999999998e-06
val loss: 24.685041159391403
accuracy:      0.739
precision:     0.742
recall:        0.740
f1:            0.738
val loss: 22.849198549985886
accuracy:      0.751
precision:     0.752
recall:        0.749
f1:            0.749
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.453828811645508, Epoch: 8, training loss: 20.609987139701843, current learning rate 7.899999999999997e-06
val loss: 25.148744612932205
accuracy:      0.742
precision:     0.743
recall:        0.742
f1:            0.742
val loss: 23.260567247867584
accuracy:      0.753
precision:     0.753
recall:        0.752
f1:            0.753
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.462871313095093, Epoch: 9, training loss: 15.950610980391502, current learning rate 7.5999999999999975e-06
val loss: 27.808140218257904
accuracy:      0.739
precision:     0.739
recall:        0.740
f1:            0.740
val loss: 26.55336058139801
accuracy:      0.743
precision:     0.743
recall:        0.743
f1:            0.743
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.469001531600952, Epoch: 10, training loss: 13.73215551301837, current learning rate 7.299999999999998e-06
val loss: 30.824627935886383
accuracy:      0.738
precision:     0.738
recall:        0.739
f1:            0.739
val loss: 28.965221166610718
accuracy:      0.740
precision:     0.741
recall:        0.741
f1:            0.740
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.406241416931152, Epoch: 11, training loss: 10.900207474827766, current learning rate 6.999999999999997e-06
val loss: 36.668137431144714
accuracy:      0.731
precision:     0.741
recall:        0.732
f1:            0.728
val loss: 33.792907536029816
accuracy:      0.747
precision:     0.752
recall:        0.743
f1:            0.743
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.427271127700806, Epoch: 12, training loss: 9.667510196566582, current learning rate 6.699999999999998e-06
val loss: 37.23850899934769
accuracy:      0.737
precision:     0.742
recall:        0.737
f1:            0.735
val loss: 33.88367426395416
accuracy:      0.749
precision:     0.751
recall:        0.746
f1:            0.747
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.41690683364868, Epoch: 13, training loss: 8.187384650111198, current learning rate 6.399999999999998e-06
val loss: 38.77055215835571
accuracy:      0.741
precision:     0.741
recall:        0.742
f1:            0.741
val loss: 36.95136648416519
accuracy:      0.757
precision:     0.757
recall:        0.757
f1:            0.757
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.407721042633057, Epoch: 14, training loss: 6.7133116740733385, current learning rate 6.099999999999998e-06
val loss: 43.8273988366127
accuracy:      0.738
precision:     0.740
recall:        0.738
f1:            0.737
val loss: 41.12939536571503
accuracy:      0.745
precision:     0.744
recall:        0.744
f1:            0.744
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.433841943740845, Epoch: 15, training loss: 5.728734787553549, current learning rate 5.799999999999998e-06
val loss: 46.58969980478287
accuracy:      0.734
precision:     0.734
recall:        0.734
f1:            0.734
val loss: 44.05281841754913
accuracy:      0.737
precision:     0.737
recall:        0.738
f1:            0.737
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.408931732177734, Epoch: 16, training loss: 5.728654159232974, current learning rate 5.499999999999998e-06
val loss: 46.583144783973694
accuracy:      0.733
precision:     0.736
recall:        0.734
f1:            0.733
val loss: 43.97046595811844
accuracy:      0.735
precision:     0.734
recall:        0.733
f1:            0.734
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.23it/s]


Timing: 16.380239486694336, Epoch: 17, training loss: 5.086402496905066, current learning rate 5.1999999999999985e-06
val loss: 46.86128705739975
accuracy:      0.743
precision:     0.745
recall:        0.743
f1:            0.742
val loss: 44.623723804950714
accuracy:      0.741
precision:     0.741
recall:        0.740
f1:            0.740
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.40059542655945, Epoch: 18, training loss: 4.253938652575016, current learning rate 4.899999999999999e-06
val loss: 48.016992032527924
accuracy:      0.741
precision:     0.741
recall:        0.741
f1:            0.741
val loss: 46.879356265068054
accuracy:      0.741
precision:     0.742
recall:        0.742
f1:            0.741
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.422990322113037, Epoch: 19, training loss: 4.782216541469097, current learning rate 4.599999999999999e-06
val loss: 48.8008987903595
accuracy:      0.740
precision:     0.741
recall:        0.741
f1:            0.740
val loss: 48.03106528520584
accuracy:      0.738
precision:     0.740
recall:        0.740
f1:            0.738
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.40911626815796, Epoch: 20, training loss: 4.050571014609886, current learning rate 4.2999999999999995e-06
val loss: 48.003440380096436
accuracy:      0.747
precision:     0.748
recall:        0.747
f1:            0.747
val loss: 46.22348088026047
accuracy:      0.747
precision:     0.747
recall:        0.746
f1:            0.746
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.423461437225342, Epoch: 21, training loss: 3.2090353465173393, current learning rate 4e-06
val loss: 53.05736839771271
accuracy:      0.747
precision:     0.747
recall:        0.747
f1:            0.747
val loss: 51.031589329242706
accuracy:      0.755
precision:     0.755
recall:        0.756
f1:            0.755
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.395845413208008, Epoch: 22, training loss: 2.756487561855465, current learning rate 3.7e-06
val loss: 58.185449957847595
accuracy:      0.745
precision:     0.746
recall:        0.745
f1:            0.745
val loss: 55.89589011669159
accuracy:      0.747
precision:     0.746
recall:        0.745
f1:            0.746
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.23it/s]


Timing: 16.384093046188354, Epoch: 23, training loss: 3.221696700900793, current learning rate 3.4e-06
val loss: 56.48061007261276
accuracy:      0.753
precision:     0.753
recall:        0.753
f1:            0.753
val loss: 56.90895414352417
accuracy:      0.744
precision:     0.745
recall:        0.745
f1:            0.744
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.39818239212036, Epoch: 24, training loss: 2.0063617425912526, current learning rate 3.1e-06
val loss: 61.468313574790955
accuracy:      0.746
precision:     0.746
recall:        0.747
f1:            0.746
val loss: 60.55621147155762
accuracy:      0.749
precision:     0.750
recall:        0.750
f1:            0.749
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.23it/s]


Timing: 16.36699676513672, Epoch: 25, training loss: 2.785779530007858, current learning rate 2.8e-06
val loss: 60.6560263633728
accuracy:      0.744
precision:     0.744
recall:        0.745
f1:            0.745
val loss: 59.72455143928528
accuracy:      0.750
precision:     0.751
recall:        0.751
f1:            0.750
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.23it/s]


Timing: 16.375118017196655, Epoch: 26, training loss: 2.113597659394145, current learning rate 2.4999999999999998e-06
val loss: 62.1952720284462
accuracy:      0.744
precision:     0.745
recall:        0.745
f1:            0.745
val loss: 61.0376113653183
accuracy:      0.748
precision:     0.748
recall:        0.748
f1:            0.748
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.23it/s]


Timing: 16.37356948852539, Epoch: 27, training loss: 2.099520522286184, current learning rate 2.1999999999999997e-06
val loss: 63.84905689954758
accuracy:      0.745
precision:     0.749
recall:        0.746
f1:            0.745
val loss: 61.83120000362396
accuracy:      0.745
precision:     0.745
recall:        0.743
f1:            0.743
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.23it/s]


Timing: 16.377621173858643, Epoch: 28, training loss: 1.5822409149259329, current learning rate 1.8999999999999996e-06
val loss: 65.17192786931992
accuracy:      0.749
precision:     0.749
recall:        0.749
f1:            0.749
val loss: 64.31250250339508
accuracy:      0.750
precision:     0.750
recall:        0.751
f1:            0.750
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.23it/s]


Timing: 16.372559070587158, Epoch: 29, training loss: 1.9489922188222408, current learning rate 1.5999999999999995e-06
val loss: 65.27201879024506
accuracy:      0.754
precision:     0.754
recall:        0.755
f1:            0.754
val loss: 63.917137145996094
accuracy:      0.751
precision:     0.750
recall:        0.751
f1:            0.751
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.23it/s]


Timing: 16.382356643676758, Epoch: 30, training loss: 1.582927742274478, current learning rate 1.2999999999999996e-06
val loss: 66.17334008216858
accuracy:      0.749
precision:     0.749
recall:        0.750
f1:            0.749
val loss: 65.66395199298859
accuracy:      0.754
precision:     0.754
recall:        0.754
f1:            0.754
best result:
0.7506938020351527
0.7504741559339674
0.7508922409543528
0.7505080005129783
[[0.7506938020351527, 0.7504741559339674, 0.7508922409543528, 0.7505080005129783]]
**** trying with seed 1 ****


tokenizing...: 100%|██████████| 6486/6486 [00:01<00:00, 3642.51it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 2164/2164 [00:00<00:00, 3578.40it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 2162/2162 [00:00<00:00, 3541.33it/s]


finished preprocessing examples in test


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.47200632095337, Epoch: 1, training loss: 70.71951019763947, current learning rate 1e-05
val loss: 23.529753506183624
accuracy:      0.503
precision:     0.752
recall:        0.500
f1:            0.336
val loss: 23.582725524902344
accuracy:      0.479
precision:     0.489
recall:        0.500
f1:            0.325
===== Start training: epoch 2 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.18it/s]


Timing: 16.512094020843506, Epoch: 2, training loss: 64.64519435167313, current learning rate 9.7e-06
val loss: 20.9668008685112
accuracy:      0.653
precision:     0.704
recall:        0.652
f1:            0.629
val loss: 20.81735336780548
accuracy:      0.649
precision:     0.714
recall:        0.660
f1:            0.629
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.17it/s]


Timing: 16.525790214538574, Epoch: 3, training loss: 54.44509673118591, current learning rate 9.4e-06
val loss: 18.98479449748993
accuracy:      0.704
precision:     0.716
recall:        0.703
f1:            0.699
val loss: 18.731791973114014
accuracy:      0.702
precision:     0.720
recall:        0.707
f1:            0.699
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.18it/s]


Timing: 16.50016188621521, Epoch: 4, training loss: 45.542027443647385, current learning rate 9.1e-06
val loss: 20.037336438894272
accuracy:      0.717
precision:     0.722
recall:        0.718
f1:            0.715
val loss: 19.087523341178894
accuracy:      0.747
precision:     0.750
recall:        0.744
f1:            0.744
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.18it/s]


Timing: 16.50235891342163, Epoch: 5, training loss: 39.98207226395607, current learning rate 8.799999999999999e-06
val loss: 19.159237921237946
accuracy:      0.736
precision:     0.737
recall:        0.736
f1:            0.735
val loss: 18.694948375225067
accuracy:      0.740
precision:     0.744
recall:        0.742
f1:            0.740
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.49232530593872, Epoch: 6, training loss: 33.83233422040939, current learning rate 8.499999999999998e-06
val loss: 20.789582908153534
accuracy:      0.736
precision:     0.736
recall:        0.737
f1:            0.736
val loss: 20.15214914083481
accuracy:      0.752
precision:     0.751
recall:        0.751
f1:            0.751
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.48307204246521, Epoch: 7, training loss: 29.983032912015915, current learning rate 8.199999999999998e-06
val loss: 21.270335495471954
accuracy:      0.744
precision:     0.745
recall:        0.744
f1:            0.743
val loss: 20.493967682123184
accuracy:      0.747
precision:     0.749
recall:        0.749
f1:            0.747
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.481999397277832, Epoch: 8, training loss: 25.305508732795715, current learning rate 7.899999999999997e-06
val loss: 23.409809857606888
accuracy:      0.730
precision:     0.733
recall:        0.730
f1:            0.729
val loss: 22.52302783727646
accuracy:      0.750
precision:     0.753
recall:        0.752
f1:            0.750
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.465489864349365, Epoch: 9, training loss: 22.06141358613968, current learning rate 7.5999999999999975e-06
val loss: 23.039801836013794
accuracy:      0.734
precision:     0.734
recall:        0.734
f1:            0.734
val loss: 22.131237417459488
accuracy:      0.759
precision:     0.760
recall:        0.760
f1:            0.759
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.475454807281494, Epoch: 10, training loss: 18.328323006629944, current learning rate 7.299999999999998e-06
val loss: 24.830538421869278
accuracy:      0.742
precision:     0.742
recall:        0.742
f1:            0.742
val loss: 23.81497186422348
accuracy:      0.757
precision:     0.758
recall:        0.758
f1:            0.757
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.459164142608643, Epoch: 11, training loss: 16.683250468224287, current learning rate 6.999999999999997e-06
val loss: 26.001722872257233
accuracy:      0.735
precision:     0.738
recall:        0.735
f1:            0.734
val loss: 24.718627512454987
accuracy:      0.750
precision:     0.756
recall:        0.753
f1:            0.750
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.46903705596924, Epoch: 12, training loss: 14.521544612944126, current learning rate 6.699999999999998e-06
val loss: 27.669943869113922
accuracy:      0.744
precision:     0.745
recall:        0.744
f1:            0.743
val loss: 27.498360961675644
accuracy:      0.747
precision:     0.749
recall:        0.748
f1:            0.746
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.45698094367981, Epoch: 13, training loss: 11.773364158347249, current learning rate 6.399999999999998e-06
val loss: 30.51791813969612
accuracy:      0.738
precision:     0.740
recall:        0.738
f1:            0.737
val loss: 29.614575922489166
accuracy:      0.747
precision:     0.750
recall:        0.749
f1:            0.747
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.448662281036377, Epoch: 14, training loss: 11.230841719545424, current learning rate 6.099999999999998e-06
val loss: 31.161913752555847
accuracy:      0.736
precision:     0.736
recall:        0.737
f1:            0.736
val loss: 30.523780465126038
accuracy:      0.750
precision:     0.749
recall:        0.749
f1:            0.749
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.43412971496582, Epoch: 15, training loss: 9.747706845402718, current learning rate 5.799999999999998e-06
val loss: 31.733217298984528
accuracy:      0.736
precision:     0.738
recall:        0.736
f1:            0.735
val loss: 31.28793090581894
accuracy:      0.746
precision:     0.750
recall:        0.749
f1:            0.746
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.451165437698364, Epoch: 16, training loss: 8.553178812377155, current learning rate 5.499999999999998e-06
val loss: 32.884805738925934
accuracy:      0.735
precision:     0.735
recall:        0.735
f1:            0.735
val loss: 31.708949744701385
accuracy:      0.749
precision:     0.749
recall:        0.748
f1:            0.748
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.45194959640503, Epoch: 17, training loss: 7.5757391303777695, current learning rate 5.1999999999999985e-06
val loss: 34.49624252319336
accuracy:      0.739
precision:     0.743
recall:        0.739
f1:            0.738
val loss: 33.71159926056862
accuracy:      0.753
precision:     0.759
recall:        0.756
f1:            0.753
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.427183389663696, Epoch: 18, training loss: 7.372830497100949, current learning rate 4.899999999999999e-06
val loss: 39.56421494483948
accuracy:      0.727
precision:     0.727
recall:        0.728
f1:            0.728
val loss: 37.579189479351044
accuracy:      0.746
precision:     0.746
recall:        0.746
f1:            0.745
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.443415880203247, Epoch: 19, training loss: 5.53084447234869, current learning rate 4.599999999999999e-06
val loss: 36.722048193216324
accuracy:      0.735
precision:     0.735
recall:        0.735
f1:            0.735
val loss: 35.555234491825104
accuracy:      0.752
precision:     0.752
recall:        0.752
f1:            0.752
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.43336820602417, Epoch: 20, training loss: 6.47243851236999, current learning rate 4.2999999999999995e-06
val loss: 35.98881083726883
accuracy:      0.749
precision:     0.750
recall:        0.749
f1:            0.748
val loss: 35.74434894323349
accuracy:      0.752
precision:     0.755
recall:        0.754
f1:            0.752
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.446971893310547, Epoch: 21, training loss: 5.091070473194122, current learning rate 4e-06
val loss: 38.3541299700737
accuracy:      0.745
precision:     0.746
recall:        0.745
f1:            0.745
val loss: 36.47594851255417
accuracy:      0.759
precision:     0.761
recall:        0.760
f1:            0.759
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.428997039794922, Epoch: 22, training loss: 5.296508897095919, current learning rate 3.7e-06
val loss: 40.12022662162781
accuracy:      0.737
precision:     0.737
recall:        0.737
f1:            0.737
val loss: 37.139659106731415
accuracy:      0.760
precision:     0.760
recall:        0.759
f1:            0.759
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.4079167842865, Epoch: 23, training loss: 5.644998620264232, current learning rate 3.4e-06
val loss: 34.25833302736282
accuracy:      0.744
precision:     0.745
recall:        0.744
f1:            0.743
val loss: 32.703251123428345
accuracy:      0.753
precision:     0.756
recall:        0.755
f1:            0.752
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.410200119018555, Epoch: 24, training loss: 5.01741973310709, current learning rate 3.1e-06
val loss: 37.08552926778793
accuracy:      0.741
precision:     0.742
recall:        0.742
f1:            0.741
val loss: 35.29397642612457
accuracy:      0.752
precision:     0.751
recall:        0.751
f1:            0.751
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.40739369392395, Epoch: 25, training loss: 3.7490741012152284, current learning rate 2.8e-06
val loss: 44.339581072330475
accuracy:      0.741
precision:     0.741
recall:        0.741
f1:            0.741
val loss: 41.891881585121155
accuracy:      0.758
precision:     0.759
recall:        0.759
f1:            0.758
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.23it/s]


Timing: 16.388074159622192, Epoch: 26, training loss: 3.9780236072838306, current learning rate 2.4999999999999998e-06
val loss: 42.90142285823822
accuracy:      0.732
precision:     0.732
recall:        0.732
f1:            0.732
val loss: 39.95550149679184
accuracy:      0.752
precision:     0.752
recall:        0.752
f1:            0.752
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.401264905929565, Epoch: 27, training loss: 3.399266179651022, current learning rate 2.1999999999999997e-06
val loss: 43.74360662698746
accuracy:      0.730
precision:     0.730
recall:        0.730
f1:            0.730
val loss: 40.974649131298065
accuracy:      0.750
precision:     0.752
recall:        0.751
f1:            0.750
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.400355577468872, Epoch: 28, training loss: 4.168310035020113, current learning rate 1.8999999999999996e-06
val loss: 38.11446672677994
accuracy:      0.744
precision:     0.744
recall:        0.744
f1:            0.744
val loss: 36.457288324832916
accuracy:      0.755
precision:     0.756
recall:        0.756
f1:            0.755
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.40245246887207, Epoch: 29, training loss: 3.120010018348694, current learning rate 1.5999999999999995e-06
val loss: 41.73283350467682
accuracy:      0.741
precision:     0.742
recall:        0.741
f1:            0.740
val loss: 40.29136371612549
accuracy:      0.753
precision:     0.755
recall:        0.754
f1:            0.752
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.40755581855774, Epoch: 30, training loss: 2.695098993368447, current learning rate 1.2999999999999996e-06
val loss: 42.27817094326019
accuracy:      0.741
precision:     0.741
recall:        0.741
f1:            0.741
val loss: 40.47778445482254
accuracy:      0.754
precision:     0.755
recall:        0.755
f1:            0.754
best result:
0.7520814061054579
0.7552761288156314
0.7541950113378685
0.752004793699709
[[0.7520814061054579, 0.7552761288156314, 0.7541950113378685, 0.752004793699709]]
**** trying with seed 2 ****


tokenizing...: 100%|██████████| 6486/6486 [00:01<00:00, 3695.74it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 2164/2164 [00:00<00:00, 3598.97it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 2162/2162 [00:00<00:00, 3314.69it/s]


finished preprocessing examples in test


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.479707717895508, Epoch: 1, training loss: 70.73345673084259, current learning rate 1e-05
val loss: 23.558409512043
accuracy:      0.503
precision:     0.251
recall:        0.500
f1:            0.335


  _warn_prf(average, modifier, msg_start, len(result))


val loss: 23.616265535354614
accuracy:      0.479
precision:     0.239
recall:        0.500
f1:            0.324


  _warn_prf(average, modifier, msg_start, len(result))


===== Start training: epoch 2 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.48936414718628, Epoch: 2, training loss: 69.60399144887924, current learning rate 9.7e-06
val loss: 21.555897057056427
accuracy:      0.657
precision:     0.657
recall:        0.657
f1:            0.657
val loss: 21.450157463550568
accuracy:      0.658
precision:     0.658
recall:        0.659
f1:            0.658
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.17it/s]


Timing: 16.528630018234253, Epoch: 3, training loss: 60.24849373102188, current learning rate 9.4e-06
val loss: 20.500821590423584
accuracy:      0.684
precision:     0.702
recall:        0.684
f1:            0.677
val loss: 20.524596214294434
accuracy:      0.684
precision:     0.707
recall:        0.690
f1:            0.679
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.16it/s]


Timing: 16.551037549972534, Epoch: 4, training loss: 52.07112738490105, current learning rate 9.1e-06
val loss: 19.105003148317337
accuracy:      0.698
precision:     0.703
recall:        0.698
f1:            0.696
val loss: 18.678772509098053
accuracy:      0.703
precision:     0.712
recall:        0.707
f1:            0.702
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.18it/s]


Timing: 16.49776577949524, Epoch: 5, training loss: 44.58415314555168, current learning rate 8.799999999999999e-06
val loss: 19.82600826025009
accuracy:      0.722
precision:     0.725
recall:        0.723
f1:            0.722
val loss: 18.898369073867798
accuracy:      0.735
precision:     0.735
recall:        0.733
f1:            0.734
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.18it/s]


Timing: 16.498814344406128, Epoch: 6, training loss: 38.412718057632446, current learning rate 8.499999999999998e-06
val loss: 19.58954030275345
accuracy:      0.725
precision:     0.726
recall:        0.725
f1:            0.724
val loss: 18.988386631011963
accuracy:      0.732
precision:     0.731
recall:        0.731
f1:            0.731
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.468522548675537, Epoch: 7, training loss: 33.30972281098366, current learning rate 8.199999999999998e-06
val loss: 20.044382095336914
accuracy:      0.715
precision:     0.726
recall:        0.715
f1:            0.712
val loss: 20.08046320080757
accuracy:      0.723
precision:     0.741
recall:        0.728
f1:            0.720
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.410759449005127, Epoch: 8, training loss: 30.157961666584015, current learning rate 7.899999999999997e-06
val loss: 21.297184258699417
accuracy:      0.739
precision:     0.742
recall:        0.739
f1:            0.739
val loss: 21.01954659819603
accuracy:      0.740
precision:     0.748
recall:        0.743
f1:            0.739
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.467503547668457, Epoch: 9, training loss: 26.215811610221863, current learning rate 7.5999999999999975e-06
val loss: 23.11384755373001
accuracy:      0.728
precision:     0.729
recall:        0.728
f1:            0.728
val loss: 22.964122653007507
accuracy:      0.727
precision:     0.730
recall:        0.729
f1:            0.727
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.44392466545105, Epoch: 10, training loss: 22.141953468322754, current learning rate 7.299999999999998e-06
val loss: 24.075058698654175
accuracy:      0.732
precision:     0.736
recall:        0.731
f1:            0.730
val loss: 23.76976376771927
accuracy:      0.728
precision:     0.738
recall:        0.732
f1:            0.727
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.47763156890869, Epoch: 11, training loss: 19.73345673084259, current learning rate 6.999999999999997e-06
val loss: 25.944510340690613
accuracy:      0.738
precision:     0.738
recall:        0.738
f1:            0.738
val loss: 25.690589904785156
accuracy:      0.725
precision:     0.728
recall:        0.727
f1:            0.725
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.18it/s]


Timing: 16.52026128768921, Epoch: 12, training loss: 17.423914194107056, current learning rate 6.699999999999998e-06
val loss: 25.610576272010803
accuracy:      0.731
precision:     0.732
recall:        0.732
f1:            0.731
val loss: 24.243872195482254
accuracy:      0.747
precision:     0.746
recall:        0.745
f1:            0.746
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.16it/s]


Timing: 16.557037353515625, Epoch: 13, training loss: 14.671544440090656, current learning rate 6.399999999999998e-06
val loss: 28.875990360975266
accuracy:      0.736
precision:     0.737
recall:        0.736
f1:            0.736
val loss: 27.129255324602127
accuracy:      0.741
precision:     0.741
recall:        0.740
f1:            0.740
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.18it/s]


Timing: 16.5133695602417, Epoch: 14, training loss: 14.220894483849406, current learning rate 6.099999999999998e-06
val loss: 29.7089501619339
accuracy:      0.739
precision:     0.740
recall:        0.740
f1:            0.739
val loss: 29.722465604543686
accuracy:      0.730
precision:     0.732
recall:        0.732
f1:            0.730
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.474483013153076, Epoch: 15, training loss: 11.75046268850565, current learning rate 5.799999999999998e-06
val loss: 30.44405233860016
accuracy:      0.727
precision:     0.728
recall:        0.728
f1:            0.727
val loss: 29.08755248785019
accuracy:      0.734
precision:     0.736
recall:        0.735
f1:            0.734
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.462703943252563, Epoch: 16, training loss: 11.457544907927513, current learning rate 5.499999999999998e-06
val loss: 30.164714515209198
accuracy:      0.726
precision:     0.726
recall:        0.726
f1:            0.726
val loss: 28.17103797197342
accuracy:      0.742
precision:     0.743
recall:        0.743
f1:            0.742
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.452203273773193, Epoch: 17, training loss: 10.056904092431068, current learning rate 5.1999999999999985e-06
val loss: 31.818958908319473
accuracy:      0.738
precision:     0.738
recall:        0.738
f1:            0.738
val loss: 30.114750266075134
accuracy:      0.746
precision:     0.748
recall:        0.747
f1:            0.746
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.456891536712646, Epoch: 18, training loss: 8.438844062387943, current learning rate 4.899999999999999e-06
val loss: 33.18272399902344
accuracy:      0.724
precision:     0.725
recall:        0.724
f1:            0.724
val loss: 29.876138508319855
accuracy:      0.745
precision:     0.745
recall:        0.744
f1:            0.744
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.43239140510559, Epoch: 19, training loss: 8.024940818548203, current learning rate 4.599999999999999e-06
val loss: 34.92576730251312
accuracy:      0.736
precision:     0.738
recall:        0.736
f1:            0.735
val loss: 32.84264898300171
accuracy:      0.744
precision:     0.744
recall:        0.742
f1:            0.743
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.448812246322632, Epoch: 20, training loss: 8.40455524995923, current learning rate 4.2999999999999995e-06
val loss: 32.62453347444534
accuracy:      0.737
precision:     0.737
recall:        0.737
f1:            0.737
val loss: 31.633906304836273
accuracy:      0.741
precision:     0.741
recall:        0.742
f1:            0.741
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.450795650482178, Epoch: 21, training loss: 6.790800577029586, current learning rate 4e-06
val loss: 38.03034871816635
accuracy:      0.733
precision:     0.738
recall:        0.733
f1:            0.732
val loss: 37.822798788547516
accuracy:      0.732
precision:     0.741
recall:        0.736
f1:            0.731
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.464406728744507, Epoch: 22, training loss: 6.94788646325469, current learning rate 3.7e-06
val loss: 37.01299500465393
accuracy:      0.730
precision:     0.730
recall:        0.730
f1:            0.730
val loss: 34.802238047122955
accuracy:      0.741
precision:     0.742
recall:        0.742
f1:            0.741
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.42499566078186, Epoch: 23, training loss: 5.422643948346376, current learning rate 3.4e-06
val loss: 41.156297981739044
accuracy:      0.724
precision:     0.724
recall:        0.724
f1:            0.724
val loss: 38.469092071056366
accuracy:      0.743
precision:     0.743
recall:        0.744
f1:            0.743
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.22it/s]


Timing: 16.40678572654724, Epoch: 24, training loss: 6.354889206588268, current learning rate 3.1e-06
val loss: 39.84070026874542
accuracy:      0.735
precision:     0.735
recall:        0.735
f1:            0.735
val loss: 37.723710000514984
accuracy:      0.742
precision:     0.742
recall:        0.743
f1:            0.742
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.42975640296936, Epoch: 25, training loss: 6.250089098699391, current learning rate 2.8e-06
val loss: 37.9993662238121
accuracy:      0.736
precision:     0.736
recall:        0.737
f1:            0.736
val loss: 36.52404075860977
accuracy:      0.739
precision:     0.739
recall:        0.740
f1:            0.739
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.463255643844604, Epoch: 26, training loss: 6.289888516999781, current learning rate 2.4999999999999998e-06
val loss: 34.71498638391495
accuracy:      0.737
precision:     0.738
recall:        0.737
f1:            0.737
val loss: 33.9159294962883
accuracy:      0.737
precision:     0.741
recall:        0.740
f1:            0.737
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.21it/s]


Timing: 16.441302061080933, Epoch: 27, training loss: 4.268028177320957, current learning rate 2.1999999999999997e-06
val loss: 41.36070820689201
accuracy:      0.734
precision:     0.735
recall:        0.734
f1:            0.734
val loss: 40.74305671453476
accuracy:      0.738
precision:     0.741
recall:        0.740
f1:            0.738
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.20it/s]


Timing: 16.45502209663391, Epoch: 28, training loss: 3.572747439146042, current learning rate 1.8999999999999996e-06
val loss: 41.57653087377548
accuracy:      0.737
precision:     0.737
recall:        0.737
f1:            0.737
val loss: 40.384471356868744
accuracy:      0.741
precision:     0.741
recall:        0.741
f1:            0.740
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.18it/s]


Timing: 16.504822969436646, Epoch: 29, training loss: 4.377536810934544, current learning rate 1.5999999999999995e-06
val loss: 36.990735590457916
accuracy:      0.733
precision:     0.734
recall:        0.734
f1:            0.733
val loss: 36.493616461753845
accuracy:      0.736
precision:     0.739
recall:        0.738
f1:            0.736
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 102/102 [00:16<00:00,  6.19it/s]


Timing: 16.48729419708252, Epoch: 30, training loss: 4.164557188749313, current learning rate 1.2999999999999996e-06
val loss: 39.420367419719696
accuracy:      0.738
precision:     0.739
recall:        0.739
f1:            0.739
val loss: 38.78415536880493
accuracy:      0.739
precision:     0.739
recall:        0.739
f1:            0.739
best result:
0.7303422756706753
0.7323715086485224
0.7320023661638568
0.7303256021934315
[[0.7303422756706753, 0.7323715086485224, 0.7320023661638568, 0.7303256021934315]]
tensor([0.7444, 0.7460, 0.7457, 0.7443], dtype=torch.float64)


In [None]:
from google.colab import runtime
runtime.unassign()