<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [2]:
!pip install datasets



In [3]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [4]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [5]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [6]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }

    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }"""

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-5] != name:
                continue

              story_id = row[0]
              sent = row[2].strip()
              target = row[3].strip()

              label = row[-6]

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 1:
                l = [1,0]
              elif label == 0:
                l = [0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue

              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              label = row[3].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue

              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              label = row[3].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'neither':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'no_relation':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1,0]
        elif label == 'neither':
          l = [0,0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [7]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      labels_std = labels[:batch_size // 2, :]

      emb_attack = samples[labels_std == [0,1]]
      emb_support = samples[labels_std == [1,0]]

      samples_adv = H_sent[batch_size // 2:, ]
      labels_adv = labels[batch_size // 2:, :]

      emb_caus = samples_adv[labels_adv == [1,0,0]]
      emb_other = samples_adv[labels_adv == [0,1,0] or labels_adv == [0,0,1]]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl)
      support_prediction = self.support_linear(mean_grl)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)
    #H_sent = att_output[0][:,0,:]

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [8]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [9]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [10]:
import datetime
import json
from pathlib import Path

def save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed=None, discovery_weight=None, adv_weight=None):
  json_dict = config | args | {
      "discovery_weight": discovery_weight,
      "adv_weight": adv_weight,
      "accuracy": best_test_acc,
      "precision": best_test_pre,
      "recall": best_test_rec,
      "f1": best_test_f1,
      "time": datetime.datetime.now()
  }

  file_path = f"results/{config['dataset']}"
  if config["grid_search"]:
    file_path += "/grid_search"
  else:
    file_path += "/standard"

  if config["double_adversarial"]:
    file_path += "/double_adversarial"
  elif config["adversarial"]:
    file_path += "/adversarial"
  else:
    file_path += "/standard"

  if config["injection"]:
    file_path += "/injection"
  else:
    file_path += "/standard"

  if discovery_weight is not None and adv_weight is not None:
    file_path += f"/discovery_{discovery_weight}/adv_{adv_weight}"

  if seed is not None:
    file_path += f"/{seed}"

  Path(file_path).mkdir(parents=True, exist_ok=True)

  file_path += "/result.json"

  with open(file_path, "w") as outfile:
    json.dump(json_dict, outfile, default=str)

In [None]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 2, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": True,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": 10, #[9.375, 30, 1], #10, [2.071, 1.933, 1]
    "lr": 1e-5
}

config = {
    "dataset": "student_essay", #"student_essay", #debate, m-arg, nk
    "adversarial": True,
    "double_adversarial": False,
    "dataset_from_saved": True,
    "injection": False,
    "grid_search": False,
    "visualize": True,
    "train": True,
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
            attack_target = targets[targets == [0,1]]
            support_target = targets[targets == [1,0]]

            cause_target = targets_adv[targets_adv == [1,0,0]]
            other_target = targets_adv[targets_adv == [0,1,0] or targets_adv == [0,0,1]]
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + .2*loss4 + .2*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
            preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
            loss = loss_fn(out, labels.float())
            val_loss += loss.item()

            labels = labels.cpu().numpy().tolist()

            val_labels.extend(labels)
            if len(labels[0]) != 2:
              for pred in preds:
                if pred == 0:
                  val_preds.append([1,0,0])
                elif pred == 1:
                  val_preds.append([0,1,0])
                else:
                  val_preds.append([0,0,1])
            else:
              val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1, val_loss

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

def run():
  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/student_essay.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/debate.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")
  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  df = datasets.load_dataset("discovery","discovery")
  adv_processor = DiscourseMarkerProcessor()
  if not config["dataset_from_saved"]:
    print("processing discourse marker dataset...")
    train_adv = adv_processor.process_dataset(df["train"])
    with open("./adv_dataset.pkl", "wb") as writer:
      pickle.dump(train_adv, writer)
  else:
    with open("./adv_dataset.pkl", "rb") as reader:
      train_adv = pickle.load(reader)
  train_set_adv = dataset(train_adv)

  if config["adversarial"]:
    data_train_tot = data_train + train_adv
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if not config["adversarial"]:
    train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv) #loading adv for visualization
    model = BaselineModel()
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

  if config["dataset"] == "m-arg" or config["dataset"] == "nk":
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)

    for discovery_weight in range_disc:
      for adv_weight in range_adv:
        avg_acc, avg_pre, avg_rec, avg_f1 = [], [], [], []
        for seed in args["seed"]:
          set_random_seeds(seed)
          for epoch in range(args["epochs"]):
            print('===== Start training: epoch {} ====='.format(epoch + 1))
            print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}, seed = {seed}")
            train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight)
            dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
            test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
            if dev_f1 > best_dev_f1:
              best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
              best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
              #save model

          print('best result:')
          print(best_test_acc)
          print(best_test_pre)
          print(best_test_rec)
          print(best_test_f1)
          avg_acc.append(best_test_acc)
          avg_pre.append(best_test_pre)
          avg_rec.append(best_test_rec)
          avg_f1.append(best_test_f1)

          save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed, discovery_weight, adv_weight)
          result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
          del model
          del optimizer

          model = AdversarialNet()
          model = model.to(device)

          optimizer_grouped_parameters = [
            {
              "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
              "weight_decay": 0.01,
            },
            {
              "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
              "weight_decay": 0.0
            },
          ]
          optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

          best_acc = -1
          best_pre = -1
          best_rec = -1
          best_f1 = -1
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

        avg_acc_score = sum(avg_acc) / len(avg_acc)
        avg_pre_score = sum(avg_pre) / len(avg_pre)
        avg_rec_score = sum(avg_rec) / len(avg_rec)
        avg_f1_score = sum(avg_f1) / len(avg_f1)
        print()
        print('avg result:')
        print(avg_acc_score)
        print(avg_pre_score)
        print(avg_rec_score)
        print(avg_f1_score)
        save_result(config, args, avg_acc_score, avg_pre_score, avg_rec_score, avg_f1_score, discovery_weight=discovery_weight, adv_weight=adv_weight)

  else:
    avg_acc, avg_pre, avg_rec, avg_f1 = [], [], [], []

    for seed in args["seed"]:
      set_random_seeds(seed)

      #scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, min_lr=1e-7)
      for epoch in range(args["epochs"]):
        if config["train"]:
          print('===== Start training: epoch {}, seed = {} ====='.format(epoch + 1, seed))
          train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
          dev_a, dev_p, dev_r, dev_f1, dev_loss = val(model, dev_dataloader)
          test_a, test_p, test_r, test_f1, _ = val(model, test_dataloader)
          #scheduler.step(dev_loss)
          if dev_f1 > best_dev_f1:
            best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
            best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
            torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

      if config["visualize"] and config["adversarial"]:
        model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
        visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

          #save model

      print('best result:')
      print(best_test_acc)
      print(best_test_pre)
      print(best_test_rec)
      print(best_test_f1)

      avg_acc.append(best_test_acc)
      avg_pre.append(best_test_pre)
      avg_rec.append(best_test_rec)
      avg_f1.append(best_test_f1)

      if not config["adversarial"] and not config["double_adversarial"]:
        save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed)
      else:
        save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed, discovery_weight, adv_weight)

      result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

      del model
      del optimizer

      model = BaselineModel()
      model = model.to(device)

      optimizer_grouped_parameters = [
        {
          "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
          "weight_decay": 0.01,
        },
        {
          "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
          "weight_decay": 0.0
        },
      ]
      optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

      best_acc = -1
      best_pre = -1
      best_rec = -1
      best_f1 = -1
      best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  avg_acc_score = sum(avg_acc) / len(avg_acc)
  avg_pre_score = sum(avg_pre) / len(avg_pre)
  avg_rec_score = sum(avg_rec) / len(avg_rec)
  avg_f1_score = sum(avg_f1) / len(avg_f1)

  print()
  print('avg result:')
  print(avg_acc_score)
  print(avg_pre_score)
  print(avg_rec_score)
  print(avg_f1_score)
  save_result(config, args, avg_acc_score, avg_pre_score, avg_rec_score, avg_f1_score)

if __name__ == "__main__":
  run()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
tokenizing...: 100%|██████████| 3070/3070 [00:01<00:00, 2552.63it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 1142/1142 [00:00<00:00, 2788.32it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 1100/1100 [00:00<00:00, 2687.59it/s]


finished preprocessing examples in test


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:08<00:00,  5.93it/s]


Timing: 8.095559358596802, Epoch: 1, training loss: 59.573150873184204, current learning rate 1e-05
val loss: 11.291436076164246
accuracy:      0.836
precision:     0.579
recall:        0.572
f1:            0.576
val loss: 11.119580626487732
accuracy:      0.860
precision:     0.565
recall:        0.574
f1:            0.569
===== Start training: epoch 2, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.49it/s]


Timing: 7.398643732070923, Epoch: 2, training loss: 57.78188347816467, current learning rate 1e-05
val loss: 15.393638610839844
accuracy:      0.350
precision:     0.516
recall:        0.533
f1:            0.330
val loss: 15.537602126598358
accuracy:      0.366
precision:     0.526
recall:        0.575
f1:            0.331
===== Start training: epoch 3, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.46it/s]


Timing: 7.438832759857178, Epoch: 3, training loss: 46.841943204402924, current learning rate 1e-05
val loss: 7.381483674049377
accuracy:      0.825
precision:     0.625
recall:        0.677
f1:            0.642
val loss: 7.252090781927109
accuracy:      0.815
precision:     0.589
recall:        0.679
f1:            0.605
===== Start training: epoch 4, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.50it/s]


Timing: 7.396424770355225, Epoch: 4, training loss: 37.79619562625885, current learning rate 1e-05
val loss: 11.780457139015198
accuracy:      0.680
precision:     0.591
recall:        0.712
f1:            0.568
val loss: 12.18034228682518
accuracy:      0.680
precision:     0.571
recall:        0.716
f1:            0.538
===== Start training: epoch 5, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.47it/s]


Timing: 7.4338219165802, Epoch: 5, training loss: 27.242274403572083, current learning rate 1e-05
val loss: 7.593372732400894
accuracy:      0.830
precision:     0.638
recall:        0.700
f1:            0.658
val loss: 7.980362832546234
accuracy:      0.819
precision:     0.609
recall:        0.731
f1:            0.629
===== Start training: epoch 6, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.49it/s]


Timing: 7.400205612182617, Epoch: 6, training loss: 19.997706294059753, current learning rate 1e-05
val loss: 7.055647403001785
accuracy:      0.867
precision:     0.668
recall:        0.663
f1:            0.666
val loss: 7.485230386257172
accuracy:      0.860
precision:     0.615
recall:        0.674
f1:            0.635
===== Start training: epoch 7, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.51it/s]


Timing: 7.3752288818359375, Epoch: 7, training loss: 14.979232132434845, current learning rate 1e-05
val loss: 8.254609495401382
accuracy:      0.859
precision:     0.667
recall:        0.696
f1:            0.679
val loss: 8.946076452732086
accuracy:      0.837
precision:     0.600
recall:        0.676
f1:            0.618
===== Start training: epoch 8, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.52it/s]


Timing: 7.372246503829956, Epoch: 8, training loss: 12.520982027053833, current learning rate 1e-05
val loss: 7.8066079914569855
accuracy:      0.880
precision:     0.682
recall:        0.611
f1:            0.633
val loss: 6.022060722112656
accuracy:      0.901
precision:     0.662
recall:        0.641
f1:            0.651
===== Start training: epoch 9, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.51it/s]


Timing: 7.3834569454193115, Epoch: 9, training loss: 8.308980584144592, current learning rate 1e-05
val loss: 8.480720043182373
accuracy:      0.878
precision:     0.685
recall:        0.640
f1:            0.657
val loss: 7.4231381714344025
accuracy:      0.895
precision:     0.658
recall:        0.663
f1:            0.661
===== Start training: epoch 10, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.52it/s]


Timing: 7.372807502746582, Epoch: 10, training loss: 8.819473668932915, current learning rate 1e-05
val loss: 9.670847088098526
accuracy:      0.863
precision:     0.668
recall:        0.682
f1:            0.674
val loss: 8.358307592570782
accuracy:      0.875
precision:     0.647
recall:        0.717
f1:            0.671
===== Start training: epoch 11, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.358363389968872, Epoch: 11, training loss: 6.010360203683376, current learning rate 1e-05
val loss: 9.666304931044579
accuracy:      0.883
precision:     0.699
recall:        0.649
f1:            0.668
val loss: 8.5699354223907
accuracy:      0.887
precision:     0.651
recall:        0.684
f1:            0.665
===== Start training: epoch 12, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.352493524551392, Epoch: 12, training loss: 6.677553576417267, current learning rate 1e-05
val loss: 9.6936115026474
accuracy:      0.890
precision:     0.724
recall:        0.586
f1:            0.612
val loss: 6.358695328235626
accuracy:      0.922
precision:     0.744
recall:        0.642
f1:            0.676
===== Start training: epoch 13, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.341654539108276, Epoch: 13, training loss: 4.642550200223923, current learning rate 1e-05
val loss: 9.987887382507324
accuracy:      0.894
precision:     0.757
recall:        0.588
f1:            0.618
val loss: 6.77973152208142
accuracy:      0.925
precision:     0.768
recall:        0.624
f1:            0.663
===== Start training: epoch 14, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.348804712295532, Epoch: 14, training loss: 7.834432631731033, current learning rate 1e-05
val loss: 11.005691677331924
accuracy:      0.880
precision:     0.688
recall:        0.634
f1:            0.654
val loss: 8.648662947118282
accuracy:      0.895
precision:     0.653
recall:        0.653
f1:            0.653
===== Start training: epoch 15, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.359900951385498, Epoch: 15, training loss: 5.470532298088074, current learning rate 1e-05
val loss: 9.101867586374283
accuracy:      0.889
precision:     0.717
recall:        0.619
f1:            0.647
val loss: 7.169467180967331
accuracy:      0.910
precision:     0.691
recall:        0.646
f1:            0.664
===== Start training: epoch 16, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.351094722747803, Epoch: 16, training loss: 5.742247104644775, current learning rate 1e-05
val loss: 9.399920776486397
accuracy:      0.889
precision:     0.718
recall:        0.585
f1:            0.611
val loss: 6.616904973983765
accuracy:      0.919
precision:     0.729
recall:        0.616
f1:            0.649
===== Start training: epoch 17, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.339711666107178, Epoch: 17, training loss: 3.6078698756173253, current learning rate 1e-05
val loss: 9.007636904716492
accuracy:      0.890
precision:     0.722
recall:        0.596
f1:            0.624
val loss: 6.703364407643676
accuracy:      0.915
precision:     0.710
recall:        0.629
f1:            0.657
===== Start training: epoch 18, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.35329532623291, Epoch: 18, training loss: 5.618996784090996, current learning rate 1e-05
val loss: 9.860153079032898
accuracy:      0.891
precision:     0.738
recall:        0.566
f1:            0.587
val loss: 7.101140275597572
accuracy:      0.919
precision:     0.730
recall:        0.601
f1:            0.634
===== Start training: epoch 19, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.3390586376190186, Epoch: 19, training loss: 2.6571412216871977, current learning rate 1e-05
val loss: 10.39650909602642
accuracy:      0.891
precision:     0.727
recall:        0.603
f1:            0.632
val loss: 7.4747214913368225
accuracy:      0.915
precision:     0.707
recall:        0.614
f1:            0.643
===== Start training: epoch 20, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.3309972286224365, Epoch: 20, training loss: 2.2898114239796996, current learning rate 1e-05
val loss: 9.860980659723282
accuracy:      0.894
precision:     0.748
recall:        0.605
f1:            0.637
val loss: 7.920550798997283
accuracy:      0.914
precision:     0.700
recall:        0.623
f1:            0.649
===== Start training: epoch 21, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.341458559036255, Epoch: 21, training loss: 3.10576768592, current learning rate 1e-05
val loss: 9.666728526353836
accuracy:      0.892
precision:     0.737
recall:        0.604
f1:            0.635
val loss: 7.892533451318741
accuracy:      0.911
precision:     0.688
recall:        0.622
f1:            0.645
===== Start training: epoch 22, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.354079961776733, Epoch: 22, training loss: 3.0030037560500205, current learning rate 1e-05
val loss: 10.173995673656464
accuracy:      0.890
precision:     0.725
recall:        0.579
f1:            0.604
val loss: 7.447586640715599
accuracy:      0.918
precision:     0.723
recall:        0.610
f1:            0.642
===== Start training: epoch 23, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.343997240066528, Epoch: 23, training loss: 2.0895181111991405, current learning rate 1e-05
val loss: 10.609608799219131
accuracy:      0.892
precision:     0.756
recall:        0.567
f1:            0.589
val loss: 7.299684584140778
accuracy:      0.924
precision:     0.773
recall:        0.598
f1:            0.636
===== Start training: epoch 24, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.33092737197876, Epoch: 24, training loss: 2.609223233535886, current learning rate 1e-05
val loss: 10.671315878629684
accuracy:      0.893
precision:     0.736
recall:        0.635
f1:            0.666
val loss: 8.238296389579773
accuracy:      0.912
precision:     0.698
recall:        0.647
f1:            0.667
===== Start training: epoch 25, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.32959246635437, Epoch: 25, training loss: 1.8098332968074828, current learning rate 1e-05
val loss: 10.858701318502426
accuracy:      0.895
precision:     0.748
recall:        0.622
f1:            0.656
val loss: 8.792569503188133
accuracy:      0.915
precision:     0.712
recall:        0.644
f1:            0.669
===== Start training: epoch 26, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.3439037799835205, Epoch: 26, training loss: 0.4270404330454767, current learning rate 1e-05
val loss: 11.58550637960434
accuracy:      0.886
precision:     0.707
recall:        0.631
f1:            0.656
val loss: 10.032312393188477
accuracy:      0.902
precision:     0.665
recall:        0.642
f1:            0.652
===== Start training: epoch 27, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.344937562942505, Epoch: 27, training loss: 1.0634809951297939, current learning rate 1e-05
val loss: 10.645322322845459
accuracy:      0.892
precision:     0.734
recall:        0.621
f1:            0.652
val loss: 9.000612825155258
accuracy:      0.913
precision:     0.699
recall:        0.637
f1:            0.661
===== Start training: epoch 28, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.350749731063843, Epoch: 28, training loss: 1.658750890288502, current learning rate 1e-05
val loss: 12.165869683027267
accuracy:      0.886
precision:     0.699
recall:        0.560
f1:            0.578
val loss: 7.92567753046751
accuracy:      0.924
precision:     0.759
recall:        0.628
f1:            0.666
===== Start training: epoch 29, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.3459248542785645, Epoch: 29, training loss: 1.5872902169357985, current learning rate 1e-05
val loss: 11.61447536945343
accuracy:      0.893
precision:     0.737
recall:        0.628
f1:            0.660
val loss: 8.956350952386856
accuracy:      0.915
precision:     0.709
recall:        0.653
f1:            0.676
===== Start training: epoch 30, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.341107368469238, Epoch: 30, training loss: 0.6943790970835835, current learning rate 1e-05
val loss: 11.968859136104584
accuracy:      0.889
precision:     0.717
recall:        0.636
f1:            0.663
val loss: 10.021594166755676
accuracy:      0.901
precision:     0.666
recall:        0.651
f1:            0.658
best result:
0.8372727272727273
0.5995670995670996
0.6763469434430782
0.6184929210964308


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.49it/s]


Timing: 7.39823579788208, Epoch: 1, training loss: 59.99755418300629, current learning rate 1e-05
val loss: 12.903700351715088
accuracy:      0.263
precision:     0.537
recall:        0.550
f1:            0.261
val loss: 12.955235838890076
accuracy:      0.231
precision:     0.529
recall:        0.551
f1:            0.226
===== Start training: epoch 2, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.50it/s]


Timing: 7.395608186721802, Epoch: 2, training loss: 58.5348744392395, current learning rate 1e-05
val loss: 10.907523691654205
accuracy:      0.716
precision:     0.565
recall:        0.632
f1:            0.559
val loss: 10.497526824474335
accuracy:      0.764
precision:     0.573
recall:        0.681
f1:            0.574
===== Start training: epoch 3, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.50it/s]


Timing: 7.388141632080078, Epoch: 3, training loss: 50.82194983959198, current learning rate 1e-05
val loss: 11.164696335792542
accuracy:      0.635
precision:     0.564
recall:        0.653
f1:            0.525
val loss: 10.610678672790527
accuracy:      0.661
precision:     0.561
recall:        0.690
f1:            0.521
===== Start training: epoch 4, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.49it/s]


Timing: 7.4073872566223145, Epoch: 4, training loss: 41.07785642147064, current learning rate 1e-05
val loss: 6.2179189920425415
accuracy:      0.868
precision:     0.647
recall:        0.607
f1:            0.622
val loss: 5.524990081787109
accuracy:      0.885
precision:     0.623
recall:        0.627
f1:            0.625
===== Start training: epoch 5, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.50it/s]


Timing: 7.387558937072754, Epoch: 5, training loss: 36.40164113044739, current learning rate 1e-05
val loss: 7.712877511978149
accuracy:      0.805
precision:     0.609
recall:        0.669
f1:            0.625
val loss: 7.3440887331962585
accuracy:      0.810
precision:     0.598
recall:        0.711
f1:            0.615
===== Start training: epoch 6, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.50it/s]


Timing: 7.389358997344971, Epoch: 6, training loss: 26.085926413536072, current learning rate 1e-05
val loss: 8.208808064460754
accuracy:      0.806
precision:     0.605
recall:        0.656
f1:            0.619
val loss: 7.665512405335903
accuracy:      0.813
precision:     0.603
recall:        0.723
f1:            0.621
===== Start training: epoch 7, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.48it/s]


Timing: 7.4086594581604, Epoch: 7, training loss: 18.66192239522934, current learning rate 1e-05
val loss: 7.549612045288086
accuracy:      0.848
precision:     0.632
recall:        0.643
f1:            0.637
val loss: 6.632387101650238
accuracy:      0.864
precision:     0.629
recall:        0.701
f1:            0.652
===== Start training: epoch 8, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.51it/s]


Timing: 7.381200551986694, Epoch: 8, training loss: 13.584522902965546, current learning rate 1e-05
val loss: 7.59341499209404
accuracy:      0.878
precision:     0.685
recall:        0.640
f1:            0.657
val loss: 6.731537739746273
accuracy:      0.869
precision:     0.604
recall:        0.629
f1:            0.614
===== Start training: epoch 9, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.52it/s]


Timing: 7.367283344268799, Epoch: 9, training loss: 15.359511405229568, current learning rate 1e-05
val loss: 8.378478646278381
accuracy:      0.851
precision:     0.644
recall:        0.661
f1:            0.652
val loss: 7.843461871147156
accuracy:      0.853
precision:     0.619
recall:        0.700
f1:            0.642
===== Start training: epoch 10, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.50it/s]


Timing: 7.386398077011108, Epoch: 10, training loss: 14.752210408449173, current learning rate 1e-05
val loss: 8.796237826347351
accuracy:      0.866
precision:     0.660
recall:        0.646
f1:            0.653
val loss: 8.199280798435211
accuracy:      0.881
precision:     0.644
recall:        0.690
f1:            0.662
===== Start training: epoch 11, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.359185457229614, Epoch: 11, training loss: 10.343205690383911, current learning rate 1e-05
val loss: 8.771783024072647
accuracy:      0.886
precision:     0.701
recall:        0.574
f1:            0.596
val loss: 6.612375319004059
accuracy:      0.915
precision:     0.707
recall:        0.614
f1:            0.643
===== Start training: epoch 12, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.352245807647705, Epoch: 12, training loss: 11.18217259645462, current learning rate 1e-05
val loss: 8.865550726652145
accuracy:      0.883
precision:     0.695
recall:        0.629
f1:            0.652
val loss: 7.057498127222061
accuracy:      0.903
precision:     0.680
recall:        0.682
f1:            0.681
===== Start training: epoch 13, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.342028379440308, Epoch: 13, training loss: 8.733450911939144, current learning rate 1e-05
val loss: 8.792468786239624
accuracy:      0.891
precision:     0.730
recall:        0.614
f1:            0.644
val loss: 6.23713019490242
accuracy:      0.915
precision:     0.705
recall:        0.628
f1:            0.655
===== Start training: epoch 14, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.338519096374512, Epoch: 14, training loss: 8.400933710858226, current learning rate 1e-05
val loss: 9.281030982732773
accuracy:      0.885
precision:     0.698
recall:        0.590
f1:            0.615
val loss: 6.754274666309357
accuracy:      0.907
precision:     0.673
recall:        0.620
f1:            0.640
===== Start training: epoch 15, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.360597610473633, Epoch: 15, training loss: 5.9810367822647095, current learning rate 1e-05
val loss: 10.140758380293846
accuracy:      0.879
precision:     0.679
recall:        0.610
f1:            0.632
val loss: 8.757422983646393
accuracy:      0.894
precision:     0.651
recall:        0.652
f1:            0.651
===== Start training: epoch 16, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.352608680725098, Epoch: 16, training loss: 4.883612632751465, current learning rate 1e-05
val loss: 9.358159214258194
accuracy:      0.879
precision:     0.682
recall:        0.620
f1:            0.641
val loss: 7.28677324950695
accuracy:      0.906
precision:     0.688
recall:        0.679
f1:            0.684
===== Start training: epoch 17, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.352936744689941, Epoch: 17, training loss: 4.895330756902695, current learning rate 1e-05
val loss: 9.80639111995697
accuracy:      0.890
precision:     0.740
recall:        0.549
f1:            0.561
val loss: 7.103065133094788
accuracy:      0.921
precision:     0.753
recall:        0.577
f1:            0.607
===== Start training: epoch 18, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.339359998703003, Epoch: 18, training loss: 2.716000471729785, current learning rate 1e-05
val loss: 10.871251940727234
accuracy:      0.883
precision:     0.684
recall:        0.585
f1:            0.608
val loss: 8.824747532606125
accuracy:      0.912
precision:     0.699
recall:        0.652
f1:            0.671
===== Start training: epoch 19, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.351074457168579, Epoch: 19, training loss: 1.7366961534135044, current learning rate 1e-05
val loss: 12.573444902896881
accuracy:      0.873
precision:     0.666
recall:        0.623
f1:            0.640
val loss: 9.784969449043274
accuracy:      0.904
precision:     0.689
recall:        0.708
f1:            0.698
===== Start training: epoch 20, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.3531224727630615, Epoch: 20, training loss: 3.064236082136631, current learning rate 1e-05
val loss: 10.331205666065216
accuracy:      0.882
precision:     0.686
recall:        0.605
f1:            0.628
val loss: 7.865400269627571
accuracy:      0.909
precision:     0.686
recall:        0.641
f1:            0.659
===== Start training: epoch 21, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.335529804229736, Epoch: 21, training loss: 1.6490229666233063, current learning rate 1e-05
val loss: 10.891981184482574
accuracy:      0.887
precision:     0.707
recall:        0.571
f1:            0.592
val loss: 8.182148396968842
accuracy:      0.919
precision:     0.730
recall:        0.601
f1:            0.634
===== Start training: epoch 22, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.344709634780884, Epoch: 22, training loss: 1.273352786898613, current learning rate 1e-05
val loss: 12.291044622659683
accuracy:      0.883
precision:     0.689
recall:        0.602
f1:            0.626
val loss: 9.165459806798026
accuracy:      0.905
precision:     0.683
recall:        0.669
f1:            0.675
===== Start training: epoch 23, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.342726945877075, Epoch: 23, training loss: 2.9045537021011114, current learning rate 1e-05
val loss: 11.409742951393127
accuracy:      0.886
precision:     0.702
recall:        0.584
f1:            0.608
val loss: 8.837278291583061
accuracy:      0.909
precision:     0.680
recall:        0.621
f1:            0.642
===== Start training: epoch 24, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.51it/s]


Timing: 7.376620531082153, Epoch: 24, training loss: 4.025071274489164, current learning rate 1e-05
val loss: 10.803626626729965
accuracy:      0.883
precision:     0.685
recall:        0.589
f1:            0.612
val loss: 8.881586581468582
accuracy:      0.907
precision:     0.678
recall:        0.635
f1:            0.652
===== Start training: epoch 25, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.33046817779541, Epoch: 25, training loss: 1.6890405882149935, current learning rate 1e-05
val loss: 11.429611921310425
accuracy:      0.884
precision:     0.693
recall:        0.586
f1:            0.610
val loss: 9.42281660437584
accuracy:      0.908
precision:     0.675
recall:        0.615
f1:            0.636
===== Start training: epoch 26, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.345279693603516, Epoch: 26, training loss: 1.6919486750848591, current learning rate 1e-05
val loss: 11.26708248257637
accuracy:      0.891
precision:     0.736
recall:        0.587
f1:            0.614
val loss: 8.415158540010452
accuracy:      0.914
precision:     0.691
recall:        0.588
f1:            0.614
===== Start training: epoch 27, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.355716943740845, Epoch: 27, training loss: 1.0005896836519241, current learning rate 1e-05
val loss: 12.652095556259155
accuracy:      0.876
precision:     0.669
recall:        0.611
f1:            0.631
val loss: 9.758234485983849
accuracy:      0.906
precision:     0.683
recall:        0.659
f1:            0.670
===== Start training: epoch 28, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.343024730682373, Epoch: 28, training loss: 0.679008633363992, current learning rate 1e-05
val loss: 11.758273124694824
accuracy:      0.893
precision:     0.754
recall:        0.581
f1:            0.608
val loss: 8.416924764751457
accuracy:      0.920
precision:     0.736
recall:        0.611
f1:            0.645
===== Start training: epoch 29, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.342275381088257, Epoch: 29, training loss: 2.9885802632197738, current learning rate 1e-05
val loss: 12.97371569275856
accuracy:      0.880
precision:     0.685
recall:        0.621
f1:            0.642
val loss: 10.976171046495438
accuracy:      0.898
precision:     0.666
recall:        0.670
f1:            0.668
===== Start training: epoch 30, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.343865871429443, Epoch: 30, training loss: 1.473057295428589, current learning rate 1e-05
val loss: 12.997841536998749
accuracy:      0.881
precision:     0.687
recall:        0.618
f1:            0.640
val loss: 10.33516064286232
accuracy:      0.901
precision:     0.671
recall:        0.666
f1:            0.669
best result:
0.8690909090909091
0.604325755903774
0.6287042986745663
0.6143771119182792


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.50it/s]


Timing: 7.395845174789429, Epoch: 1, training loss: 59.687726736068726, current learning rate 1e-05
val loss: 12.973401129245758
accuracy:      0.412
precision:     0.531
recall:        0.571
f1:            0.377
val loss: 12.997074127197266
accuracy:      0.433
precision:     0.537
recall:        0.616
f1:            0.379
===== Start training: epoch 2, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.48it/s]


Timing: 7.411743640899658, Epoch: 2, training loss: 53.080514550209045, current learning rate 1e-05
val loss: 7.910702139139175
accuracy:      0.814
precision:     0.616
recall:        0.671
f1:            0.632
val loss: 7.685356497764587
accuracy:      0.819
precision:     0.574
recall:        0.636
f1:            0.586
===== Start training: epoch 3, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.49it/s]


Timing: 7.397751331329346, Epoch: 3, training loss: 40.57607841491699, current learning rate 1e-05
val loss: 9.045182883739471
accuracy:      0.754
precision:     0.597
recall:        0.690
f1:            0.604
val loss: 8.854415595531464
accuracy:      0.755
precision:     0.581
recall:        0.711
f1:            0.579
===== Start training: epoch 4, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.48it/s]


Timing: 7.418637037277222, Epoch: 4, training loss: 29.0231232047081, current learning rate 1e-05
val loss: 7.356568694114685
accuracy:      0.835
precision:     0.628
recall:        0.665
f1:            0.642
val loss: 7.7635791301727295
accuracy:      0.824
precision:     0.595
recall:        0.684
f1:            0.612
===== Start training: epoch 5, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.52it/s]


Timing: 7.375730276107788, Epoch: 5, training loss: 21.46738436818123, current learning rate 1e-05
val loss: 6.359898209571838
accuracy:      0.879
precision:     0.691
recall:        0.657
f1:            0.671
val loss: 6.012270629405975
accuracy:      0.875
precision:     0.625
recall:        0.657
f1:            0.638
===== Start training: epoch 6, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.50it/s]


Timing: 7.390666961669922, Epoch: 6, training loss: 19.104024529457092, current learning rate 1e-05
val loss: 7.034723401069641
accuracy:      0.870
precision:     0.662
recall:        0.629
f1:            0.642
val loss: 6.871118545532227
accuracy:      0.879
precision:     0.636
recall:        0.674
f1:            0.651
===== Start training: epoch 7, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.50it/s]


Timing: 7.395133972167969, Epoch: 7, training loss: 16.29621520638466, current learning rate 1e-05
val loss: 7.769227981567383
accuracy:      0.856
precision:     0.652
recall:        0.664
f1:            0.658
val loss: 7.908349275588989
accuracy:      0.856
precision:     0.623
recall:        0.702
f1:            0.646
===== Start training: epoch 8, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.51it/s]


Timing: 7.378331899642944, Epoch: 8, training loss: 13.24424883723259, current learning rate 1e-05
val loss: 8.95829963684082
accuracy:      0.852
precision:     0.649
recall:        0.672
f1:            0.659
val loss: 10.405466616153717
accuracy:      0.839
precision:     0.603
recall:        0.682
f1:            0.623
===== Start training: epoch 9, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.52it/s]


Timing: 7.371144533157349, Epoch: 9, training loss: 8.903991460800171, current learning rate 1e-05
val loss: 8.89284685254097
accuracy:      0.884
precision:     0.696
recall:        0.619
f1:            0.644
val loss: 7.288261383771896
accuracy:      0.897
precision:     0.662
recall:        0.664
f1:            0.663
===== Start training: epoch 10, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.361708402633667, Epoch: 10, training loss: 7.995360970497131, current learning rate 1e-05
val loss: 9.18278081715107
accuracy:      0.888
precision:     0.712
recall:        0.591
f1:            0.618
val loss: 6.748745759483427
accuracy:      0.905
precision:     0.671
recall:        0.629
f1:            0.645
===== Start training: epoch 11, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.51it/s]


Timing: 7.376672029495239, Epoch: 11, training loss: 7.411045926623046, current learning rate 1e-05
val loss: 10.379878342151642
accuracy:      0.891
precision:     0.732
recall:        0.580
f1:            0.605
val loss: 7.9500690549612045
accuracy:      0.914
precision:     0.699
recall:        0.618
f1:            0.645
===== Start training: epoch 12, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.350804328918457, Epoch: 12, training loss: 4.960543190129101, current learning rate 1e-05
val loss: 10.188557505607605
accuracy:      0.892
precision:     0.739
recall:        0.597
f1:            0.627
val loss: 8.686916053295135
accuracy:      0.903
precision:     0.662
recall:        0.627
f1:            0.641
===== Start training: epoch 13, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.354816675186157, Epoch: 13, training loss: 3.5286057367920876, current learning rate 1e-05
val loss: 11.034546136856079
accuracy:      0.887
precision:     0.708
recall:        0.598
f1:            0.624
val loss: 9.092156887054443
accuracy:      0.901
precision:     0.641
recall:        0.596
f1:            0.612
===== Start training: epoch 14, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.357637166976929, Epoch: 14, training loss: 5.562833547592163, current learning rate 1e-05
val loss: 10.808798924088478
accuracy:      0.891
precision:     0.729
recall:        0.627
f1:            0.657
val loss: 9.772710755467415
accuracy:      0.898
precision:     0.657
recall:        0.645
f1:            0.651
===== Start training: epoch 15, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.52it/s]


Timing: 7.365381479263306, Epoch: 15, training loss: 6.056693106889725, current learning rate 1e-05
val loss: 9.81779845803976
accuracy:      0.894
precision:     0.750
recall:        0.602
f1:            0.633
val loss: 10.29490402340889
accuracy:      0.896
precision:     0.628
recall:        0.594
f1:            0.607
===== Start training: epoch 16, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.344747066497803, Epoch: 16, training loss: 3.024297535419464, current learning rate 1e-05
val loss: 10.504514381289482
accuracy:      0.894
precision:     0.751
recall:        0.598
f1:            0.629
val loss: 9.325373947620392
accuracy:      0.906
precision:     0.663
recall:        0.604
f1:            0.625
===== Start training: epoch 17, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.3552405834198, Epoch: 17, training loss: 6.462713871151209, current learning rate 1e-05
val loss: 11.08017784357071
accuracy:      0.877
precision:     0.682
recall:        0.639
f1:            0.656
val loss: 10.17990592122078
accuracy:      0.889
precision:     0.645
recall:        0.660
f1:            0.652
===== Start training: epoch 18, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.339185476303101, Epoch: 18, training loss: 2.0859606275334954, current learning rate 1e-05
val loss: 10.355099737644196
accuracy:      0.891
precision:     0.730
recall:        0.617
f1:            0.647
val loss: 10.145301595330238
accuracy:      0.905
precision:     0.669
recall:        0.633
f1:            0.648
===== Start training: epoch 19, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.342573881149292, Epoch: 19, training loss: 1.544534404296428, current learning rate 1e-05
val loss: 10.40592160820961
accuracy:      0.898
precision:     0.805
recall:        0.587
f1:            0.619
val loss: 8.910798236727715
accuracy:      0.912
precision:     0.685
recall:        0.602
f1:            0.628
===== Start training: epoch 20, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.358256578445435, Epoch: 20, training loss: 4.5707424827851355, current learning rate 1e-05
val loss: 10.98496425151825
accuracy:      0.890
precision:     0.721
recall:        0.633
f1:            0.661
val loss: 11.201931685209274
accuracy:      0.889
precision:     0.643
recall:        0.655
f1:            0.649
===== Start training: epoch 21, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.328134536743164, Epoch: 21, training loss: 1.8776260272134095, current learning rate 1e-05
val loss: 11.820014983415604
accuracy:      0.889
precision:     0.717
recall:        0.595
f1:            0.623
val loss: 10.019860982894897
accuracy:      0.907
precision:     0.677
recall:        0.630
f1:            0.648
===== Start training: epoch 22, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.344280481338501, Epoch: 22, training loss: 2.5190546163357794, current learning rate 1e-05
val loss: 11.716872051358223
accuracy:      0.886
precision:     0.707
recall:        0.631
f1:            0.656
val loss: 11.996769785881042
accuracy:      0.879
precision:     0.626
recall:        0.649
f1:            0.636
===== Start training: epoch 23, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.342290163040161, Epoch: 23, training loss: 4.116282450966537, current learning rate 1e-05
val loss: 10.98841904103756
accuracy:      0.887
precision:     0.709
recall:        0.614
f1:            0.642
val loss: 10.965478360652924
accuracy:      0.892
precision:     0.631
recall:        0.616
f1:            0.623
===== Start training: epoch 24, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.348024606704712, Epoch: 24, training loss: 0.6078533071558923, current learning rate 1e-05
val loss: 11.372129440307617
accuracy:      0.888
precision:     0.713
recall:        0.612
f1:            0.639
val loss: 10.386409163475037
accuracy:      0.895
precision:     0.639
recall:        0.618
f1:            0.627
===== Start training: epoch 25, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.338938474655151, Epoch: 25, training loss: 2.731167733669281, current learning rate 1e-05
val loss: 11.24309153854847
accuracy:      0.891
precision:     0.750
recall:        0.563
f1:            0.583
val loss: 8.552215496776626
accuracy:      0.913
precision:     0.683
recall:        0.583
f1:            0.608
===== Start training: epoch 26, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.35784125328064, Epoch: 26, training loss: 0.7344608628191054, current learning rate 1e-05
val loss: 11.33010944724083
accuracy:      0.893
precision:     0.748
recall:        0.591
f1:            0.621
val loss: 9.251364648342133
accuracy:      0.910
precision:     0.675
recall:        0.596
f1:            0.620
===== Start training: epoch 27, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.53it/s]


Timing: 7.3571617603302, Epoch: 27, training loss: 2.729122579097748, current learning rate 1e-05
val loss: 11.355056032538414
accuracy:      0.893
precision:     0.737
recall:        0.631
f1:            0.663
val loss: 11.099792510271072
accuracy:      0.897
precision:     0.659
recall:        0.654
f1:            0.656
===== Start training: epoch 28, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.54it/s]


Timing: 7.347395181655884, Epoch: 28, training loss: 2.330113711534068, current learning rate 1e-05
val loss: 10.775877639651299
accuracy:      0.898
precision:     0.769
recall:        0.614
f1:            0.649
val loss: 9.053832992911339
accuracy:      0.911
precision:     0.686
recall:        0.617
f1:            0.641
===== Start training: epoch 29, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.3326005935668945, Epoch: 29, training loss: 1.222382701933384, current learning rate 1e-05
val loss: 11.346237510442734
accuracy:      0.891
precision:     0.733
recall:        0.600
f1:            0.630
val loss: 9.242155288811773
accuracy:      0.906
precision:     0.670
recall:        0.619
f1:            0.638
===== Start training: epoch 30, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.55it/s]


Timing: 7.33480978012085, Epoch: 30, training loss: 1.3882255041971803, current learning rate 1e-05
val loss: 12.31810611486435
accuracy:      0.887
precision:     0.712
recall:        0.645
f1:            0.669
val loss: 12.835040658712387
accuracy:      0.886
precision:     0.636
recall:        0.648
f1:            0.642
best result:
0.8754545454545455
0.6245382995616003
0.6571679064245962
0.6378439821301118


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from google.colab import runtime
runtime.unassign()