<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [2]:
!pip install datasets



In [3]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [4]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [5]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [6]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }

    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }"""

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-5] != name:
                continue

              story_id = row[0]
              sent = row[2].strip()
              target = row[3].strip()

              label = row[-6]

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 1:
                l = [1,0]
              elif label == 0:
                l = [0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue

              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              label = row[3].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue

              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              label = row[3].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'neither':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'no_relation':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1,0]
        elif label == 'neither':
          l = [0,0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [7]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      labels_std = labels[:batch_size // 2, :]

      emb_attack = samples[labels_std == [0,1]]
      emb_support = samples[labels_std == [1,0]]

      samples_adv = H_sent[batch_size // 2:, ]
      labels_adv = labels[batch_size // 2:, :]

      emb_caus = samples_adv[labels_adv == [1,0,0]]
      emb_other = samples_adv[labels_adv == [0,1,0] or labels_adv == [0,0,1]]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl)
      support_prediction = self.support_linear(mean_grl)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    #H_sent = torch.mean(att_output[0], dim=1)
    H_sent = att_output[0][:,0,:]

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [8]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [9]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [10]:
import datetime
import json
from pathlib import Path

def save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed=None, discovery_weight=None, adv_weight=None):
  json_dict = config | args | {
      "discovery_weight": discovery_weight,
      "adv_weight": adv_weight,
      "accuracy": best_test_acc,
      "precision": best_test_pre,
      "recall": best_test_rec,
      "f1": best_test_f1,
      "time": datetime.datetime.now()
  }

  file_path = f"results/{config['dataset']}"
  if config["grid_search"]:
    file_path += "/grid_search"
  else:
    file_path += "/standard"

  if config["double_adversarial"]:
    file_path += "/double_adversarial"
  elif config["adversarial"]:
    file_path += "/adversarial"
  else:
    file_path += "/standard"

  if config["injection"]:
    file_path += "/injection"
  else:
    file_path += "/standard"

  if discovery_weight is not None and adv_weight is not None:
    file_path += f"/discovery_{discovery_weight}/adv_{adv_weight}"

  if seed is not None:
    file_path += f"/{seed}"

  Path(file_path).mkdir(parents=True, exist_ok=True)

  file_path += "/result.json"

  with open(file_path, "w") as outfile:
    json.dump(json_dict, outfile, default=str)

In [11]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 2, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": True,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": 10, #[9.375, 30, 1], #10, [2.071, 1.933, 1]
    "lr": 1e-5
}

config = {
    "dataset": "student_essay", #"student_essay", #debate, m-arg, nk
    "adversarial": False,
    "double_adversarial": False,
    "dataset_from_saved": True,
    "injection": False,
    "grid_search": False,
    "visualize": True,
    "train": True,
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
            attack_target = targets[targets == [0,1]]
            support_target = targets[targets == [1,0]]

            cause_target = targets_adv[targets_adv == [1,0,0]]
            other_target = targets_adv[targets_adv == [0,1,0] or targets_adv == [0,0,1]]
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + .2*loss4 + .2*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
            preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
            loss = loss_fn(out, labels.float())
            val_loss += loss.item()

            labels = labels.cpu().numpy().tolist()

            val_labels.extend(labels)
            if len(labels[0]) != 2:
              for pred in preds:
                if pred == 0:
                  val_preds.append([1,0,0])
                elif pred == 1:
                  val_preds.append([0,1,0])
                else:
                  val_preds.append([0,0,1])
            else:
              val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1, val_loss

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

def run():
  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/student_essay.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/debate.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")
  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  df = datasets.load_dataset("discovery","discovery")
  adv_processor = DiscourseMarkerProcessor()
  if not config["dataset_from_saved"]:
    print("processing discourse marker dataset...")
    train_adv = adv_processor.process_dataset(df["train"])
    with open("./adv_dataset.pkl", "wb") as writer:
      pickle.dump(train_adv, writer)
  else:
    with open("./adv_dataset.pkl", "rb") as reader:
      train_adv = pickle.load(reader)
  train_set_adv = dataset(train_adv)

  if config["adversarial"]:
    data_train_tot = data_train + train_adv
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if not config["adversarial"]:
    train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv) #loading adv for visualization
    model = BaselineModelWithSentenceComparison()
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

  if config["dataset"] == "m-arg" or config["dataset"] == "nk":
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)

    for discovery_weight in range_disc:
      for adv_weight in range_adv:
        avg_acc, avg_pre, avg_rec, avg_f1 = [], [], [], []
        for seed in args["seed"]:
          set_random_seeds(seed)
          for epoch in range(args["epochs"]):
            print('===== Start training: epoch {} ====='.format(epoch + 1))
            print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}, seed = {seed}")
            train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight)
            dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
            test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
            if dev_f1 > best_dev_f1:
              best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
              best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
              #save model

          print('best result:')
          print(best_test_acc)
          print(best_test_pre)
          print(best_test_rec)
          print(best_test_f1)
          avg_acc.append(best_test_acc)
          avg_pre.append(best_test_pre)
          avg_rec.append(best_test_rec)
          avg_f1.append(best_test_f1)

          save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed, discovery_weight, adv_weight)
          result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
          del model
          del optimizer

          model = AdversarialNet()
          model = model.to(device)

          optimizer_grouped_parameters = [
            {
              "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
              "weight_decay": 0.01,
            },
            {
              "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
              "weight_decay": 0.0
            },
          ]
          optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

          best_acc = -1
          best_pre = -1
          best_rec = -1
          best_f1 = -1
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

        avg_acc_score = sum(avg_acc) / len(avg_acc)
        avg_pre_score = sum(avg_pre) / len(avg_pre)
        avg_rec_score = sum(avg_rec) / len(avg_rec)
        avg_f1_score = sum(avg_f1) / len(avg_f1)
        print()
        print('avg result:')
        print(avg_acc_score)
        print(avg_pre_score)
        print(avg_rec_score)
        print(avg_f1_score)
        save_result(config, args, avg_acc_score, avg_pre_score, avg_rec_score, avg_f1_score, discovery_weight=discovery_weight, adv_weight=adv_weight)

  else:
    avg_acc, avg_pre, avg_rec, avg_f1 = [], [], [], []

    for seed in args["seed"]:
      set_random_seeds(seed)

      scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, min_lr=1e-7)
      for epoch in range(args["epochs"]):
        if config["train"]:
          print('===== Start training: epoch {} with lr = {}, seed = {} ====='.format(epoch + 1, scheduler.get_last_lr(), seed))
          train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
          dev_a, dev_p, dev_r, dev_f1, dev_loss = val(model, dev_dataloader)
          test_a, test_p, test_r, test_f1, _ = val(model, test_dataloader)
          scheduler.step(dev_loss)
          if dev_f1 > best_dev_f1:
            best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
            best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
            torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

      if config["visualize"] and config["adversarial"]:
        model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
        visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

          #save model

      print('best result:')
      print(best_test_acc)
      print(best_test_pre)
      print(best_test_rec)
      print(best_test_f1)

      avg_acc.append(best_test_acc)
      avg_pre.append(best_test_pre)
      avg_rec.append(best_test_rec)
      avg_f1.append(best_test_f1)

      if not config["adversarial"] and not config["double_adversarial"]:
        save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed)
      else:
        save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed, discovery_weight, adv_weight)

      result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

      del model
      del optimizer

      model = BaselineModelWithSentenceComparison()
      model = model.to(device)

      optimizer_grouped_parameters = [
        {
          "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
          "weight_decay": 0.01,
        },
        {
          "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
          "weight_decay": 0.0
        },
      ]
      optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

      best_acc = -1
      best_pre = -1
      best_rec = -1
      best_f1 = -1
      best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  avg_acc_score = sum(avg_acc) / len(avg_acc)
  avg_pre_score = sum(avg_pre) / len(avg_pre)
  avg_rec_score = sum(avg_rec) / len(avg_rec)
  avg_f1_score = sum(avg_f1) / len(avg_f1)

  print()
  print('avg result:')
  print(avg_acc_score)
  print(avg_pre_score)
  print(avg_rec_score)
  print(avg_f1_score)
  save_result(config, args, avg_acc_score, avg_pre_score, avg_rec_score, avg_f1_score, discovery_weight=discovery_weight, adv_weight=adv_weight)

if __name__ == "__main__":
  run()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
tokenizing...: 100%|██████████| 3070/3070 [00:01<00:00, 2683.93it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 1142/1142 [00:00<00:00, 2417.44it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 1100/1100 [00:00<00:00, 2253.09it/s]


finished preprocessing examples in test


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 with lr = [1e-05, 1e-05], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:08<00:00,  5.69it/s]


Timing: 8.445309400558472, Epoch: 1, training loss: 60.00935459136963, current learning rate 1e-05
val loss: 12.305994689464569
accuracy:      0.668
precision:     0.553
recall:        0.618
f1:            0.530
val loss: 12.291643798351288
accuracy:      0.681
precision:     0.554
recall:        0.661
f1:            0.523
===== Start training: epoch 2 with lr = [1e-05, 1e-05], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.751257658004761, Epoch: 2, training loss: 57.21275305747986, current learning rate 1e-05
val loss: 11.269932925701141
accuracy:      0.649
precision:     0.562
recall:        0.644
f1:            0.530
val loss: 10.801858305931091
accuracy:      0.701
precision:     0.568
recall:        0.697
f1:            0.544
===== Start training: epoch 3 with lr = [1e-05, 1e-05], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.765641927719116, Epoch: 3, training loss: 47.12387478351593, current learning rate 1e-05
val loss: 8.236640363931656
accuracy:      0.799
precision:     0.602
recall:        0.659
f1:            0.616
val loss: 7.763295501470566
accuracy:      0.813
precision:     0.593
recall:        0.693
f1:            0.609
===== Start training: epoch 4 with lr = [1e-05, 1e-05], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.772023439407349, Epoch: 4, training loss: 39.536176323890686, current learning rate 1e-05
val loss: 13.458544135093689
accuracy:      0.616
precision:     0.580
recall:        0.696
f1:            0.527
val loss: 12.526123702526093
accuracy:      0.646
precision:     0.574
recall:        0.737
f1:            0.524
===== Start training: epoch 5 with lr = [1e-05, 1e-05], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.780911207199097, Epoch: 5, training loss: 33.1135870218277, current learning rate 1e-05
val loss: 10.680159986019135
accuracy:      0.716
precision:     0.585
recall:        0.682
f1:            0.579
val loss: 10.164892852306366
accuracy:      0.746
precision:     0.589
recall:        0.742
f1:            0.584
===== Start training: epoch 6 with lr = [1e-05, 1e-05], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.749873399734497, Epoch: 6, training loss: 25.930625274777412, current learning rate 1e-05
val loss: 8.248073637485504
accuracy:      0.805
precision:     0.602
recall:        0.652
f1:            0.616
val loss: 8.237584114074707
accuracy:      0.820
precision:     0.605
recall:        0.717
f1:            0.624
===== Start training: epoch 7 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.755176067352295, Epoch: 7, training loss: 17.203529238700867, current learning rate 1.0000000000000002e-06
val loss: 7.643723636865616
accuracy:      0.827
precision:     0.613
recall:        0.644
f1:            0.624
val loss: 7.330311179161072
accuracy:      0.838
precision:     0.611
recall:        0.707
f1:            0.634
===== Start training: epoch 8 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.16it/s]


Timing: 7.796815395355225, Epoch: 8, training loss: 15.995863884687424, current learning rate 1.0000000000000002e-06
val loss: 7.4814020693302155
accuracy:      0.836
precision:     0.620
recall:        0.643
f1:            0.630
val loss: 6.81953939050436
accuracy:      0.844
precision:     0.609
recall:        0.690
f1:            0.630
===== Start training: epoch 9 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.753231525421143, Epoch: 9, training loss: 14.887515291571617, current learning rate 1.0000000000000002e-06
val loss: 7.187684416770935
accuracy:      0.856
precision:     0.640
recall:        0.637
f1:            0.638
val loss: 6.479240417480469
accuracy:      0.863
precision:     0.622
recall:        0.685
f1:            0.643
===== Start training: epoch 10 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.753191232681274, Epoch: 10, training loss: 14.259269714355469, current learning rate 1.0000000000000002e-06
val loss: 7.221035569906235
accuracy:      0.856
precision:     0.640
recall:        0.634
f1:            0.637
val loss: 6.094950124621391
accuracy:      0.875
precision:     0.637
recall:        0.692
f1:            0.657
===== Start training: epoch 11 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.7895519733428955, Epoch: 11, training loss: 14.568077892065048, current learning rate 1.0000000000000002e-06
val loss: 7.945687830448151
accuracy:      0.842
precision:     0.634
recall:        0.660
f1:            0.645
val loss: 7.066792577505112
accuracy:      0.855
precision:     0.621
recall:        0.701
f1:            0.644
===== Start training: epoch 12 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.755746126174927, Epoch: 12, training loss: 12.795184552669525, current learning rate 1.0000000000000002e-06
val loss: 7.45488429069519
accuracy:      0.859
precision:     0.644
recall:        0.636
f1:            0.640
val loss: 6.035135984420776
accuracy:      0.882
precision:     0.647
recall:        0.696
f1:            0.666
===== Start training: epoch 13 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.77384877204895, Epoch: 13, training loss: 11.524100244045258, current learning rate 1.0000000000000002e-07
val loss: 7.546055465936661
accuracy:      0.858
precision:     0.644
recall:        0.638
f1:            0.641
val loss: 6.2354744374752045
accuracy:      0.881
precision:     0.648
recall:        0.700
f1:            0.668
===== Start training: epoch 14 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.758304834365845, Epoch: 14, training loss: 11.092774957418442, current learning rate 1.0000000000000002e-07
val loss: 7.555141597986221
accuracy:      0.858
precision:     0.644
recall:        0.638
f1:            0.641
val loss: 6.355795085430145
accuracy:      0.880
precision:     0.646
recall:        0.700
f1:            0.667
===== Start training: epoch 15 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.759742975234985, Epoch: 15, training loss: 10.47250160574913, current learning rate 1.0000000000000002e-07
val loss: 7.541291296482086
accuracy:      0.859
precision:     0.644
recall:        0.636
f1:            0.640
val loss: 6.481644332408905
accuracy:      0.880
precision:     0.645
recall:        0.695
f1:            0.664
===== Start training: epoch 16 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.788220643997192, Epoch: 16, training loss: 10.412051163613796, current learning rate 1.0000000000000002e-07
val loss: 7.545952469110489
accuracy:      0.861
precision:     0.646
recall:        0.633
f1:            0.639
val loss: 6.072058200836182
accuracy:      0.884
precision:     0.649
recall:        0.692
f1:            0.666
===== Start training: epoch 17 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.787440776824951, Epoch: 17, training loss: 12.294969961047173, current learning rate 1.0000000000000002e-07
val loss: 7.520657002925873
accuracy:      0.862
precision:     0.648
recall:        0.634
f1:            0.640
val loss: 6.034125089645386
accuracy:      0.885
precision:     0.650
recall:        0.692
f1:            0.667
===== Start training: epoch 18 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.784400224685669, Epoch: 18, training loss: 11.736525118350983, current learning rate 1.0000000000000002e-07
val loss: 7.656487613916397
accuracy:      0.857
precision:     0.641
recall:        0.635
f1:            0.638
val loss: 6.732750058174133
accuracy:      0.881
precision:     0.646
recall:        0.695
f1:            0.665
===== Start training: epoch 19 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.763303995132446, Epoch: 19, training loss: 11.20133888721466, current learning rate 1.0000000000000002e-07
val loss: 7.592505574226379
accuracy:      0.859
precision:     0.643
recall:        0.632
f1:            0.637
val loss: 6.294331729412079
accuracy:      0.884
precision:     0.649
recall:        0.692
f1:            0.666
===== Start training: epoch 20 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.16it/s]


Timing: 7.801877498626709, Epoch: 20, training loss: 11.845110446214676, current learning rate 1.0000000000000002e-07
val loss: 7.653249442577362
accuracy:      0.859
precision:     0.643
recall:        0.632
f1:            0.637
val loss: 6.608153522014618
accuracy:      0.882
precision:     0.647
recall:        0.696
f1:            0.666
===== Start training: epoch 21 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.16it/s]


Timing: 7.79768443107605, Epoch: 21, training loss: 11.2114377617836, current learning rate 1.0000000000000002e-07
val loss: 7.53182016313076
accuracy:      0.863
precision:     0.651
recall:        0.635
f1:            0.642
val loss: 6.637214124202728
accuracy:      0.886
precision:     0.653
recall:        0.693
f1:            0.670
===== Start training: epoch 22 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.774156093597412, Epoch: 22, training loss: 11.469924360513687, current learning rate 1.0000000000000002e-07
val loss: 7.69637930393219
accuracy:      0.862
precision:     0.651
recall:        0.640
f1:            0.645
val loss: 6.736976325511932
accuracy:      0.883
precision:     0.649
recall:        0.696
f1:            0.668
===== Start training: epoch 23 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.773006916046143, Epoch: 23, training loss: 9.965581357479095, current learning rate 1.0000000000000002e-07
val loss: 7.7563841342926025
accuracy:      0.861
precision:     0.649
recall:        0.640
f1:            0.644
val loss: 6.461744159460068
accuracy:      0.882
precision:     0.647
recall:        0.696
f1:            0.666
===== Start training: epoch 24 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.16it/s]


Timing: 7.800416469573975, Epoch: 24, training loss: 9.792370930314064, current learning rate 1.0000000000000002e-07
val loss: 7.702613681554794
accuracy:      0.862
precision:     0.649
recall:        0.637
f1:            0.643
val loss: 6.682861566543579
accuracy:      0.884
precision:     0.649
recall:        0.692
f1:            0.666
===== Start training: epoch 25 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.768461227416992, Epoch: 25, training loss: 11.403057143092155, current learning rate 1.0000000000000002e-07
val loss: 7.644466131925583
accuracy:      0.863
precision:     0.650
recall:        0.631
f1:            0.639
val loss: 6.321884095668793
accuracy:      0.886
precision:     0.653
recall:        0.693
f1:            0.670
===== Start training: epoch 26 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.7599570751190186, Epoch: 26, training loss: 10.368236646056175, current learning rate 1.0000000000000002e-07
val loss: 7.771613121032715
accuracy:      0.861
precision:     0.648
recall:        0.637
f1:            0.642
val loss: 6.480249255895615
accuracy:      0.884
precision:     0.649
recall:        0.692
f1:            0.666
===== Start training: epoch 27 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.780143976211548, Epoch: 27, training loss: 10.589723393321037, current learning rate 1.0000000000000002e-07
val loss: 7.873673677444458
accuracy:      0.861
precision:     0.649
recall:        0.640
f1:            0.644
val loss: 6.686615943908691
accuracy:      0.883
precision:     0.649
recall:        0.696
f1:            0.668
===== Start training: epoch 28 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.753904819488525, Epoch: 28, training loss: 10.056262090802193, current learning rate 1.0000000000000002e-07
val loss: 7.992461621761322
accuracy:      0.860
precision:     0.650
recall:        0.646
f1:            0.648
val loss: 6.931377649307251
accuracy:      0.881
precision:     0.648
recall:        0.700
f1:            0.668
===== Start training: epoch 29 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.769378900527954, Epoch: 29, training loss: 9.390282481908798, current learning rate 1.0000000000000002e-07
val loss: 7.730876386165619
accuracy:      0.863
precision:     0.648
recall:        0.631
f1:            0.638
val loss: 6.307923957705498
accuracy:      0.885
precision:     0.651
recall:        0.693
f1:            0.668
===== Start training: epoch 30 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.15it/s]


Timing: 7.804372310638428, Epoch: 30, training loss: 10.567374929785728, current learning rate 1.0000000000000002e-07
val loss: 7.848634839057922
accuracy:      0.860
precision:     0.643
recall:        0.629
f1:            0.636
val loss: 6.46486359834671
accuracy:      0.885
precision:     0.651
recall:        0.693
f1:            0.668
best result:
0.8809090909090909
0.6476980452674896
0.7001328701031377
0.667849132973292


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.765881538391113, Epoch: 1, training loss: 59.99085474014282, current learning rate 1e-05
val loss: 12.91397887468338
accuracy:      0.153
precision:     0.529
recall:        0.512
f1:            0.149
val loss: 12.9487025141716
accuracy:      0.136
precision:     0.544
recall:        0.529
f1:            0.136
===== Start training: epoch 2 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.767238140106201, Epoch: 2, training loss: 58.620667457580566, current learning rate 1e-05
val loss: 11.750133395195007
accuracy:      0.693
precision:     0.597
recall:        0.723
f1:            0.579
val loss: 11.680445373058319
accuracy:      0.699
precision:     0.578
recall:        0.731
f1:            0.553
===== Start training: epoch 3 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.776763439178467, Epoch: 3, training loss: 49.629268527030945, current learning rate 1e-05
val loss: 11.190035939216614
accuracy:      0.716
precision:     0.607
recall:        0.739
f1:            0.599
val loss: 11.648904800415039
accuracy:      0.697
precision:     0.581
recall:        0.745
f1:            0.556
===== Start training: epoch 4 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.781184196472168, Epoch: 4, training loss: 39.967440605163574, current learning rate 1e-05
val loss: 6.975732237100601
accuracy:      0.852
precision:     0.652
recall:        0.679
f1:            0.663
val loss: 7.005119323730469
accuracy:      0.853
precision:     0.626
recall:        0.720
f1:            0.651
===== Start training: epoch 5 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.782507419586182, Epoch: 5, training loss: 31.6046741604805, current learning rate 1e-05
val loss: 6.6360883712768555
accuracy:      0.862
precision:     0.670
recall:        0.694
f1:            0.680
val loss: 6.367251396179199
accuracy:      0.854
precision:     0.627
recall:        0.720
f1:            0.652
===== Start training: epoch 6 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.7829015254974365, Epoch: 6, training loss: 24.409355401992798, current learning rate 1e-05
val loss: 6.83955055475235
accuracy:      0.866
precision:     0.672
recall:        0.680
f1:            0.676
val loss: 5.99180730432272
accuracy:      0.868
precision:     0.627
recall:        0.683
f1:            0.647
===== Start training: epoch 7 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.16it/s]


Timing: 7.796367168426514, Epoch: 7, training loss: 20.45287823677063, current learning rate 1e-05
val loss: 6.470172613859177
accuracy:      0.881
precision:     0.695
recall:        0.655
f1:            0.671
val loss: 5.781941339373589
accuracy:      0.890
precision:     0.647
recall:        0.660
f1:            0.653
===== Start training: epoch 8 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.7655088901519775, Epoch: 8, training loss: 15.490246564149857, current learning rate 1e-05
val loss: 7.149909943342209
accuracy:      0.878
precision:     0.702
recall:        0.717
f1:            0.709
val loss: 7.331471126526594
accuracy:      0.873
precision:     0.648
recall:        0.731
f1:            0.675
===== Start training: epoch 9 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.776979923248291, Epoch: 9, training loss: 13.978017970919609, current learning rate 1e-05
val loss: 8.048306286334991
accuracy:      0.869
precision:     0.680
recall:        0.695
f1:            0.687
val loss: 8.32839322090149
accuracy:      0.863
precision:     0.633
recall:        0.715
f1:            0.658
===== Start training: epoch 10 with lr = [1e-05, 1e-05], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.766800165176392, Epoch: 10, training loss: 9.614433027803898, current learning rate 1e-05
val loss: 9.622047156095505
accuracy:      0.864
precision:     0.675
recall:        0.699
f1:            0.686
val loss: 10.287860780954361
accuracy:      0.869
precision:     0.642
recall:        0.724
f1:            0.668
===== Start training: epoch 11 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.736496686935425, Epoch: 11, training loss: 6.263189926743507, current learning rate 1.0000000000000002e-06
val loss: 8.511482894420624
accuracy:      0.890
precision:     0.722
recall:        0.663
f1:            0.686
val loss: 7.859861731529236
accuracy:      0.895
precision:     0.665
recall:        0.683
f1:            0.673
===== Start training: epoch 12 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.731556177139282, Epoch: 12, training loss: 4.0215340768918395, current learning rate 1.0000000000000002e-06
val loss: 8.808696329593658
accuracy:      0.892
precision:     0.731
recall:        0.654
f1:            0.682
val loss: 7.099517822265625
accuracy:      0.906
precision:     0.690
recall:        0.684
f1:            0.687
===== Start training: epoch 13 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.735413551330566, Epoch: 13, training loss: 5.243657734245062, current learning rate 1.0000000000000002e-06
val loss: 9.153775557875633
accuracy:      0.894
precision:     0.737
recall:        0.662
f1:            0.689
val loss: 7.467275500297546
accuracy:      0.901
precision:     0.677
recall:        0.686
f1:            0.681
===== Start training: epoch 14 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.749819040298462, Epoch: 14, training loss: 3.0275606801733375, current learning rate 1.0000000000000002e-07
val loss: 9.216741561889648
accuracy:      0.895
precision:     0.741
recall:        0.662
f1:            0.691
val loss: 7.460218290798366
accuracy:      0.900
precision:     0.674
recall:        0.681
f1:            0.677
===== Start training: epoch 15 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.739697217941284, Epoch: 15, training loss: 3.677785724401474, current learning rate 1.0000000000000002e-07
val loss: 9.226972669363022
accuracy:      0.895
precision:     0.741
recall:        0.662
f1:            0.691
val loss: 8.26195964217186
accuracy:      0.901
precision:     0.677
recall:        0.686
f1:            0.681
===== Start training: epoch 16 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.7325215339660645, Epoch: 16, training loss: 3.6304146349430084, current learning rate 1.0000000000000002e-07
val loss: 9.26283347606659
accuracy:      0.891
precision:     0.728
recall:        0.664
f1:            0.688
val loss: 7.668174813501537
accuracy:      0.900
precision:     0.675
recall:        0.686
f1:            0.680
===== Start training: epoch 17 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.7486302852630615, Epoch: 17, training loss: 3.2307108528912067, current learning rate 1.0000000000000002e-07
val loss: 9.326364248991013
accuracy:      0.896
precision:     0.744
recall:        0.666
f1:            0.694
val loss: 8.12369392812252
accuracy:      0.901
precision:     0.677
recall:        0.686
f1:            0.681
===== Start training: epoch 18 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.762743234634399, Epoch: 18, training loss: 4.026962094940245, current learning rate 1.0000000000000002e-07
val loss: 9.281865328550339
accuracy:      0.895
precision:     0.741
recall:        0.662
f1:            0.691
val loss: 8.753656879067421
accuracy:      0.901
precision:     0.676
recall:        0.681
f1:            0.678
===== Start training: epoch 19 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.77655816078186, Epoch: 19, training loss: 2.8931759744882584, current learning rate 1.0000000000000002e-07
val loss: 9.320363461971283
accuracy:      0.896
precision:     0.744
recall:        0.663
f1:            0.692
val loss: 7.726810112595558
accuracy:      0.903
precision:     0.680
recall:        0.682
f1:            0.681
===== Start training: epoch 20 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.755463123321533, Epoch: 20, training loss: 2.7194696441292763, current learning rate 1.0000000000000002e-07
val loss: 9.371411889791489
accuracy:      0.895
precision:     0.740
recall:        0.666
f1:            0.693
val loss: 8.326080113649368
accuracy:      0.901
precision:     0.676
recall:        0.681
f1:            0.678
===== Start training: epoch 21 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.744754791259766, Epoch: 21, training loss: 3.417749337852001, current learning rate 1.0000000000000002e-07
val loss: 9.440273404121399
accuracy:      0.896
precision:     0.744
recall:        0.666
f1:            0.694
val loss: 8.603343173861504
accuracy:      0.904
precision:     0.683
recall:        0.683
f1:            0.683
===== Start training: epoch 22 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.758184194564819, Epoch: 22, training loss: 3.325586214661598, current learning rate 1.0000000000000002e-07
val loss: 9.387285679578781
accuracy:      0.895
precision:     0.741
recall:        0.659
f1:            0.688
val loss: 7.616618399973959
accuracy:      0.903
precision:     0.676
recall:        0.667
f1:            0.671
===== Start training: epoch 23 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.764749526977539, Epoch: 23, training loss: 4.667421214282513, current learning rate 1.0000000000000002e-07
val loss: 9.510212242603302
accuracy:      0.896
precision:     0.744
recall:        0.666
f1:            0.694
val loss: 8.163572192192078
accuracy:      0.904
precision:     0.683
recall:        0.683
f1:            0.683
===== Start training: epoch 24 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.7553582191467285, Epoch: 24, training loss: 3.7270210618153214, current learning rate 1.0000000000000002e-07
val loss: 9.469738185405731
accuracy:      0.896
precision:     0.744
recall:        0.666
f1:            0.694
val loss: 8.51460137963295
accuracy:      0.903
precision:     0.680
recall:        0.682
f1:            0.681
===== Start training: epoch 25 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.776451587677002, Epoch: 25, training loss: 2.8174761328846216, current learning rate 1.0000000000000002e-07
val loss: 9.52645330131054
accuracy:      0.896
precision:     0.744
recall:        0.666
f1:            0.694
val loss: 8.104516312479973
accuracy:      0.904
precision:     0.681
recall:        0.678
f1:            0.679
===== Start training: epoch 26 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.766160726547241, Epoch: 26, training loss: 4.075372219085693, current learning rate 1.0000000000000002e-07
val loss: 9.543234556913376
accuracy:      0.896
precision:     0.744
recall:        0.666
f1:            0.694
val loss: 7.857948422431946
accuracy:      0.904
precision:     0.683
recall:        0.683
f1:            0.683
===== Start training: epoch 27 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.759715795516968, Epoch: 27, training loss: 3.064399051014334, current learning rate 1.0000000000000002e-07
val loss: 9.537841200828552
accuracy:      0.895
precision:     0.740
recall:        0.666
f1:            0.693
val loss: 8.265329241752625
accuracy:      0.903
precision:     0.680
recall:        0.682
f1:            0.681
===== Start training: epoch 28 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.769002199172974, Epoch: 28, training loss: 2.8988342564553022, current learning rate 1.0000000000000002e-07
val loss: 9.590536653995514
accuracy:      0.895
precision:     0.741
recall:        0.662
f1:            0.691
val loss: 7.823490749113262
accuracy:      0.904
precision:     0.681
recall:        0.678
f1:            0.679
===== Start training: epoch 29 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.778522491455078, Epoch: 29, training loss: 2.5764738265424967, current learning rate 1.0000000000000002e-07
val loss: 9.587975636124611
accuracy:      0.896
precision:     0.744
recall:        0.666
f1:            0.694
val loss: 8.235764622688293
accuracy:      0.905
precision:     0.685
recall:        0.683
f1:            0.684
===== Start training: epoch 30 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.737173795700073, Epoch: 30, training loss: 3.2803948847576976, current learning rate 1.0000000000000002e-07
val loss: 9.63218954205513
accuracy:      0.894
precision:     0.737
recall:        0.665
f1:            0.692
val loss: 7.989162181969732
accuracy:      0.903
precision:     0.680
recall:        0.682
f1:            0.681
best result:
0.8727272727272727
0.6477993565900669
0.7306657663446563
0.6749930355650479


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 with lr = [1e-05, 1e-05], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.788198232650757, Epoch: 1, training loss: 59.712438225746155, current learning rate 1e-05
val loss: 13.363341152668
accuracy:      0.130
precision:     0.534
recall:        0.506
f1:            0.122
val loss: 13.420677185058594
accuracy:      0.099
precision:     0.517
recall:        0.504
f1:            0.095
===== Start training: epoch 2 with lr = [1e-05, 1e-05], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.774900674819946, Epoch: 2, training loss: 57.585838317871094, current learning rate 1e-05
val loss: 10.924778401851654
accuracy:      0.698
precision:     0.555
recall:        0.615
f1:            0.543
val loss: 10.543919920921326
accuracy:      0.735
precision:     0.556
recall:        0.645
f1:            0.545
===== Start training: epoch 3 with lr = [1e-05, 1e-05], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.783419132232666, Epoch: 3, training loss: 49.721628308296204, current learning rate 1e-05
val loss: 12.291810512542725
accuracy:      0.588
precision:     0.560
recall:        0.647
f1:            0.498
val loss: 11.349880874156952
accuracy:      0.648
precision:     0.569
recall:        0.718
f1:            0.521
===== Start training: epoch 4 with lr = [1e-05, 1e-05], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.16it/s]


Timing: 7.804599285125732, Epoch: 4, training loss: 40.791062116622925, current learning rate 1e-05
val loss: 7.942134082317352
accuracy:      0.799
precision:     0.597
recall:        0.648
f1:            0.610
val loss: 7.509461402893066
accuracy:      0.814
precision:     0.595
recall:        0.698
f1:            0.612
===== Start training: epoch 5 with lr = [1e-05, 1e-05], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.15it/s]


Timing: 7.818056583404541, Epoch: 5, training loss: 32.1292662024498, current learning rate 1e-05
val loss: 12.45361602306366
accuracy:      0.688
precision:     0.568
recall:        0.650
f1:            0.552
val loss: 12.517367541790009
accuracy:      0.686
precision:     0.560
recall:        0.679
f1:            0.531
===== Start training: epoch 6 with lr = [1e-05, 1e-05], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.15it/s]


Timing: 7.817283391952515, Epoch: 6, training loss: 31.21237075328827, current learning rate 1e-05
val loss: 7.3738886415958405
accuracy:      0.835
precision:     0.610
recall:        0.625
f1:            0.617
val loss: 6.43710121512413
accuracy:      0.861
precision:     0.620
recall:        0.684
f1:            0.641
===== Start training: epoch 7 with lr = [1e-05, 1e-05], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.767982721328735, Epoch: 7, training loss: 25.929246872663498, current learning rate 1e-05
val loss: 9.519695281982422
accuracy:      0.788
precision:     0.595
recall:        0.656
f1:            0.608
val loss: 8.142654031515121
accuracy:      0.822
precision:     0.606
recall:        0.718
f1:            0.626
===== Start training: epoch 8 with lr = [1e-05, 1e-05], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.785013675689697, Epoch: 8, training loss: 21.552350163459778, current learning rate 1e-05
val loss: 12.075923264026642
accuracy:      0.753
precision:     0.601
recall:        0.700
f1:            0.607
val loss: 10.967813432216644
accuracy:      0.769
precision:     0.595
recall:        0.744
f1:            0.599
===== Start training: epoch 9 with lr = [1e-05, 1e-05], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.764511823654175, Epoch: 9, training loss: 18.97311297059059, current learning rate 1e-05
val loss: 8.581005275249481
accuracy:      0.853
precision:     0.634
recall:        0.632
f1:            0.633
val loss: 7.70580780506134
accuracy:      0.863
precision:     0.614
recall:        0.665
f1:            0.632
===== Start training: epoch 10 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.16it/s]


Timing: 7.793588399887085, Epoch: 10, training loss: 12.27625885605812, current learning rate 1.0000000000000002e-06
val loss: 7.940825864672661
accuracy:      0.864
precision:     0.647
recall:        0.622
f1:            0.632
val loss: 6.909284353256226
accuracy:      0.879
precision:     0.624
recall:        0.644
f1:            0.633
===== Start training: epoch 11 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.774486780166626, Epoch: 11, training loss: 10.007438763976097, current learning rate 1.0000000000000002e-06
val loss: 7.7060339748859406
accuracy:      0.884
precision:     0.695
recall:        0.612
f1:            0.637
val loss: 6.757557153701782
accuracy:      0.892
precision:     0.643
recall:        0.641
f1:            0.642
===== Start training: epoch 12 with lr = [1.0000000000000002e-06, 1.0000000000000002e-06], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.763272047042847, Epoch: 12, training loss: 10.635390937328339, current learning rate 1.0000000000000002e-06
val loss: 7.67159041762352
accuracy:      0.886
precision:     0.706
recall:        0.617
f1:            0.644
val loss: 6.524364709854126
accuracy:      0.894
precision:     0.644
recall:        0.637
f1:            0.641
===== Start training: epoch 13 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.766782999038696, Epoch: 13, training loss: 11.24746897816658, current learning rate 1.0000000000000002e-07
val loss: 7.738297522068024
accuracy:      0.884
precision:     0.699
recall:        0.616
f1:            0.642
val loss: 6.655638059601188
accuracy:      0.893
precision:     0.647
recall:        0.647
f1:            0.647
===== Start training: epoch 14 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.774160623550415, Epoch: 14, training loss: 9.743325173854828, current learning rate 1.0000000000000002e-07
val loss: 7.836011290550232
accuracy:      0.879
precision:     0.681
recall:        0.617
f1:            0.638
val loss: 6.7748085260391235
accuracy:      0.888
precision:     0.639
recall:        0.649
f1:            0.644
===== Start training: epoch 15 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.7599036693573, Epoch: 15, training loss: 9.237407632172108, current learning rate 1.0000000000000002e-07
val loss: 7.840778261423111
accuracy:      0.879
precision:     0.680
recall:        0.613
f1:            0.635
val loss: 7.463591277599335
accuracy:      0.888
precision:     0.637
recall:        0.644
f1:            0.641
===== Start training: epoch 16 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.772364377975464, Epoch: 16, training loss: 8.723481517285109, current learning rate 1.0000000000000002e-07
val loss: 7.830426931381226
accuracy:      0.881
precision:     0.686
recall:        0.614
f1:            0.637
val loss: 7.263789653778076
accuracy:      0.890
precision:     0.641
recall:        0.645
f1:            0.643
===== Start training: epoch 17 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.778551340103149, Epoch: 17, training loss: 10.021027192473412, current learning rate 1.0000000000000002e-07
val loss: 7.8172705471515656
accuracy:      0.881
precision:     0.686
recall:        0.614
f1:            0.637
val loss: 7.418342471122742
accuracy:      0.891
precision:     0.643
recall:        0.646
f1:            0.644
===== Start training: epoch 18 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.76303243637085, Epoch: 18, training loss: 8.565625697374344, current learning rate 1.0000000000000002e-07
val loss: 7.856535792350769
accuracy:      0.883
precision:     0.692
recall:        0.615
f1:            0.639
val loss: 7.261568963527679
accuracy:      0.892
precision:     0.645
recall:        0.646
f1:            0.645
===== Start training: epoch 19 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.7648797035217285, Epoch: 19, training loss: 9.659831535071135, current learning rate 1.0000000000000002e-07
val loss: 7.874061107635498
accuracy:      0.881
precision:     0.687
recall:        0.618
f1:            0.640
val loss: 7.098130524158478
accuracy:      0.891
precision:     0.643
recall:        0.646
f1:            0.644
===== Start training: epoch 20 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.789385795593262, Epoch: 20, training loss: 8.076348379254341, current learning rate 1.0000000000000002e-07
val loss: 7.853250235319138
accuracy:      0.885
precision:     0.702
recall:        0.617
f1:            0.643
val loss: 7.327703177928925
accuracy:      0.892
precision:     0.645
recall:        0.646
f1:            0.645
===== Start training: epoch 21 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.16it/s]


Timing: 7.794801950454712, Epoch: 21, training loss: 10.371148824691772, current learning rate 1.0000000000000002e-07
val loss: 7.910085022449493
accuracy:      0.884
precision:     0.699
recall:        0.616
f1:            0.642
val loss: 7.20275604724884
accuracy:      0.892
precision:     0.645
recall:        0.646
f1:            0.645
===== Start training: epoch 22 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.765275955200195, Epoch: 22, training loss: 8.989018023014069, current learning rate 1.0000000000000002e-07
val loss: 7.955316483974457
accuracy:      0.882
precision:     0.690
recall:        0.618
f1:            0.641
val loss: 6.870147109031677
accuracy:      0.891
precision:     0.645
recall:        0.651
f1:            0.648
===== Start training: epoch 23 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.754055976867676, Epoch: 23, training loss: 7.595847073942423, current learning rate 1.0000000000000002e-07
val loss: 8.00130420923233
accuracy:      0.881
precision:     0.687
recall:        0.618
f1:            0.640
val loss: 7.437455177307129
accuracy:      0.890
precision:     0.643
recall:        0.650
f1:            0.646
===== Start training: epoch 24 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.760149240493774, Epoch: 24, training loss: 9.083807274699211, current learning rate 1.0000000000000002e-07
val loss: 7.969609022140503
accuracy:      0.881
precision:     0.686
recall:        0.614
f1:            0.637
val loss: 7.257603049278259
accuracy:      0.891
precision:     0.643
recall:        0.646
f1:            0.644
===== Start training: epoch 25 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.767521142959595, Epoch: 25, training loss: 8.203888058662415, current learning rate 1.0000000000000002e-07
val loss: 7.938491851091385
accuracy:      0.884
precision:     0.695
recall:        0.616
f1:            0.640
val loss: 7.172608017921448
accuracy:      0.892
precision:     0.645
recall:        0.646
f1:            0.645
===== Start training: epoch 26 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.7684409618377686, Epoch: 26, training loss: 7.052554115653038, current learning rate 1.0000000000000002e-07
val loss: 7.973825454711914
accuracy:      0.886
precision:     0.706
recall:        0.617
f1:            0.644
val loss: 7.051616162061691
accuracy:      0.894
precision:     0.644
recall:        0.637
f1:            0.641
===== Start training: epoch 27 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.7668137550354, Epoch: 27, training loss: 9.073769956827164, current learning rate 1.0000000000000002e-07
val loss: 7.997150659561157
accuracy:      0.885
precision:     0.702
recall:        0.617
f1:            0.643
val loss: 6.8543269243091345
accuracy:      0.894
precision:     0.646
recall:        0.642
f1:            0.644
===== Start training: epoch 28 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.753472566604614, Epoch: 28, training loss: 8.367707520723343, current learning rate 1.0000000000000002e-07
val loss: 8.025265276432037
accuracy:      0.883
precision:     0.692
recall:        0.615
f1:            0.639
val loss: 6.894427518360317
accuracy:      0.895
precision:     0.651
recall:        0.648
f1:            0.649
===== Start training: epoch 29 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.750378370285034, Epoch: 29, training loss: 7.899267204105854, current learning rate 1.0000000000000002e-07
val loss: 8.091922923922539
accuracy:      0.883
precision:     0.692
recall:        0.615
f1:            0.639
val loss: 6.9935496263206005
accuracy:      0.890
precision:     0.641
recall:        0.645
f1:            0.643
===== Start training: epoch 30 with lr = [1.0000000000000002e-07, 1.0000000000000002e-07], seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.7520246505737305, Epoch: 30, training loss: 9.136232890188694, current learning rate 1.0000000000000002e-07
val loss: 8.137240260839462
accuracy:      0.883
precision:     0.693
recall:        0.619
f1:            0.643
val loss: 7.302424907684326
accuracy:      0.889
precision:     0.639
recall:        0.645
f1:            0.642
best result:
0.8936363636363637
0.6443397091876519
0.6370849170650954
0.6405740760910545


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



avg result:
0.8824242424242423
0.6466123703484029
0.6892945178376298
0.6611387482097981




UnboundLocalError: local variable 'discovery_weight' referenced before assignment

In [None]:
from google.colab import runtime
runtime.unassign()