<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [2]:
!pip install datasets



In [3]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [4]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [5]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [6]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }

    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }"""

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-5] != name:
                continue

              story_id = row[0]
              sent = row[2].strip()
              target = row[3].strip()

              label = row[-6]

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 1:
                l = [1,0]
              elif label == 0:
                l = [0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue

              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              label = row[3].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue

              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              label = row[3].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'neither':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'no_relation':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1,0]
        elif label == 'neither':
          l = [0,0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [7]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      labels_std = labels[:batch_size // 2, :]

      emb_attack = samples[labels_std == [0,1]]
      emb_support = samples[labels_std == [1,0]]

      samples_adv = H_sent[batch_size // 2:, ]
      labels_adv = labels[batch_size // 2:, :]

      emb_caus = samples_adv[labels_adv == [1,0,0]]
      emb_other = samples_adv[labels_adv == [0,1,0] or labels_adv == [0,0,1]]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl)
      support_prediction = self.support_linear(mean_grl)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)
    #H_sent = att_output[0][:,0,:]

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [8]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [9]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [10]:
import datetime
import json
from pathlib import Path

def save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed=None, discovery_weight=None, adv_weight=None):
  json_dict = config | args | {
      "discovery_weight": discovery_weight,
      "adv_weight": adv_weight,
      "accuracy": best_test_acc,
      "precision": best_test_pre,
      "recall": best_test_rec,
      "f1": best_test_f1,
      "time": datetime.datetime.now()
  }

  file_path = f"results/{config['dataset']}"
  if config["grid_search"]:
    file_path += "/grid_search"
  else:
    file_path += "/standard"

  if config["double_adversarial"]:
    file_path += "/double_adversarial"
  elif config["adversarial"]:
    file_path += "/adversarial"
  else:
    file_path += "/standard"

  if config["injection"]:
    file_path += "/injection"
  else:
    file_path += "/standard"

  if discovery_weight is not None and adv_weight is not None:
    file_path += f"/discovery_{discovery_weight}/adv_{adv_weight}"

  if seed is not None:
    file_path += f"/{seed}"

  Path(file_path).mkdir(parents=True, exist_ok=True)

  file_path += "/result.json"

  with open(file_path, "w") as outfile:
    json.dump(json_dict, outfile, default=str)

In [11]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 2, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": True,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": 10, #[9.375, 30, 1], #10, [2.071, 1.933, 1]
    "lr": 1e-5
}

config = {
    "dataset": "student_essay", #"student_essay", #debate, m-arg, nk
    "adversarial": False,
    "double_adversarial": False,
    "dataset_from_saved": True,
    "injection": False,
    "grid_search": False,
    "visualize": True,
    "train": True,
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
            attack_target = targets[targets == [0,1]]
            support_target = targets[targets == [1,0]]

            cause_target = targets_adv[targets_adv == [1,0,0]]
            other_target = targets_adv[targets_adv == [0,1,0] or targets_adv == [0,0,1]]
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + .2*loss4 + .2*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
            preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
            loss = loss_fn(out, labels.float())
            val_loss += loss.item()

            labels = labels.cpu().numpy().tolist()

            val_labels.extend(labels)
            if len(labels[0]) != 2:
              for pred in preds:
                if pred == 0:
                  val_preds.append([1,0,0])
                elif pred == 1:
                  val_preds.append([0,1,0])
                else:
                  val_preds.append([0,0,1])
            else:
              val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1, val_loss

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

def run():
  set_random_seeds(0)
  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/student_essay.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/debate.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")
  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  df = datasets.load_dataset("discovery","discovery")
  adv_processor = DiscourseMarkerProcessor()
  if not config["dataset_from_saved"]:
    print("processing discourse marker dataset...")
    train_adv = adv_processor.process_dataset(df["train"])
    with open("./adv_dataset.pkl", "wb") as writer:
      pickle.dump(train_adv, writer)
  else:
    with open("./adv_dataset.pkl", "rb") as reader:
      train_adv = pickle.load(reader)
  train_set_adv = dataset(train_adv)

  if config["adversarial"]:
    data_train_tot = data_train + train_adv
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if not config["adversarial"]:
    train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv) #loading adv for visualization
    model = BaselineModelWithSentenceComparison()
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=False, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=False, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

  if config["dataset"] == "m-arg" or config["dataset"] == "nk":
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)

    for discovery_weight in range_disc:
      for adv_weight in range_adv:
        avg_acc, avg_pre, avg_rec, avg_f1 = [], [], [], []
        for seed in args["seed"]:
          set_random_seeds(seed)
          for epoch in range(args["epochs"]):
            print('===== Start training: epoch {} ====='.format(epoch + 1))
            print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}, seed = {seed}")
            train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight)
            dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
            test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
            if dev_f1 > best_dev_f1:
              best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
              best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
              #save model

          print('best result:')
          print(best_test_acc)
          print(best_test_pre)
          print(best_test_rec)
          print(best_test_f1)
          avg_acc.append(best_test_acc)
          avg_pre.append(best_test_pre)
          avg_rec.append(best_test_rec)
          avg_f1.append(best_test_f1)

          save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed, discovery_weight, adv_weight)
          result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
          del model
          del optimizer

          model = BaselineModelWithSentenceComparison()
          model = model.to(device)

          optimizer_grouped_parameters = [
            {
              "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
              "weight_decay": 0.01,
            },
            {
              "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
              "weight_decay": 0.0
            },
          ]
          optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

          best_acc = -1
          best_pre = -1
          best_rec = -1
          best_f1 = -1
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

        avg_acc_score = sum(avg_acc) / len(avg_acc)
        avg_pre_score = sum(avg_pre) / len(avg_pre)
        avg_rec_score = sum(avg_rec) / len(avg_rec)
        avg_f1_score = sum(avg_f1) / len(avg_f1)
        print()
        print('avg result:')
        print(avg_acc_score)
        print(avg_pre_score)
        print(avg_rec_score)
        print(avg_f1_score)
        save_result(config, args, avg_acc_score, avg_pre_score, avg_rec_score, avg_f1_score, discovery_weight=discovery_weight, adv_weight=adv_weight)

  else:
    avg_acc, avg_pre, avg_rec, avg_f1 = [], [], [], []

    for seed in args["seed"]:
      set_random_seeds(seed)
      #scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, min_lr=1e-7)
      for epoch in range(args["epochs"]):
        if config["train"]:
          print('===== Start training: epoch {}, seed = {} ====='.format(epoch + 1, seed))
          train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.4, adv_weight=1)
          dev_a, dev_p, dev_r, dev_f1, dev_loss = val(model, dev_dataloader)
          test_a, test_p, test_r, test_f1, _ = val(model, test_dataloader)
          #scheduler.step(dev_loss)
          if dev_f1 > best_dev_f1:
            best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
            best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
            torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

      if config["visualize"] and config["adversarial"]:
        model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
        visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.4, 1)

          #save model

      print('best result:')
      print(best_test_acc)
      print(best_test_pre)
      print(best_test_rec)
      print(best_test_f1)

      avg_acc.append(best_test_acc)
      avg_pre.append(best_test_pre)
      avg_rec.append(best_test_rec)
      avg_f1.append(best_test_f1)

      if not config["adversarial"] and not config["double_adversarial"]:
        save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed)
      else:
        save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed, 0.4, 1)

      result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

      del model
      del optimizer

      model = BaselineModelWithSentenceComparison()
      model = model.to(device)

      optimizer_grouped_parameters = [
        {
          "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
          "weight_decay": 0.01,
        },
        {
          "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
          "weight_decay": 0.0
        },
      ]
      optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

      best_acc = -1
      best_pre = -1
      best_rec = -1
      best_f1 = -1
      best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  avg_acc_score = sum(avg_acc) / len(avg_acc)
  avg_pre_score = sum(avg_pre) / len(avg_pre)
  avg_rec_score = sum(avg_rec) / len(avg_rec)
  avg_f1_score = sum(avg_f1) / len(avg_f1)

  print()
  print('avg result:')
  print(avg_acc_score)
  print(avg_pre_score)
  print(avg_rec_score)
  print(avg_f1_score)
  save_result(config, args, avg_acc_score, avg_pre_score, avg_rec_score, avg_f1_score)

if __name__ == "__main__":
  run()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
tokenizing...: 100%|██████████| 3070/3070 [00:01<00:00, 2470.87it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 1142/1142 [00:00<00:00, 2717.89it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 1100/1100 [00:00<00:00, 2756.04it/s]


finished preprocessing examples in test


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Generating train split:   0%|          | 0/1566000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/87000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/87000 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:09<00:00,  5.23it/s]


Timing: 9.17670464515686, Epoch: 1, training loss: 60.01724708080292, current learning rate 1e-05
val loss: 11.90853601694107
accuracy:      0.887
precision:     0.744
recall:        0.511
f1:            0.492
val loss: 11.864756286144257
accuracy:      0.915
precision:     0.602
recall:        0.509
f1:            0.498
===== Start training: epoch 2, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.762176036834717, Epoch: 2, training loss: 58.754374265670776, current learning rate 1e-05
val loss: 13.899328172206879
accuracy:      0.212
precision:     0.549
recall:        0.545
f1:            0.212
val loss: 14.064173400402069
accuracy:      0.225
precision:     0.548
recall:        0.578
f1:            0.223
===== Start training: epoch 3, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.777334451675415, Epoch: 3, training loss: 52.498668015003204, current learning rate 1e-05
val loss: 9.223101526498795
accuracy:      0.785
precision:     0.619
recall:        0.718
f1:            0.634
val loss: 9.088250517845154
accuracy:      0.796
precision:     0.610
recall:        0.764
f1:            0.625
===== Start training: epoch 4, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.7769012451171875, Epoch: 4, training loss: 40.22534966468811, current learning rate 1e-05
val loss: 6.72183221578598
accuracy:      0.847
precision:     0.643
recall:        0.672
f1:            0.655
val loss: 6.475513160228729
accuracy:      0.855
precision:     0.631
recall:        0.731
f1:            0.658
===== Start training: epoch 5, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.778073787689209, Epoch: 5, training loss: 30.788128197193146, current learning rate 1e-05
val loss: 5.775025829672813
accuracy:      0.887
precision:     0.717
recall:        0.692
f1:            0.703
val loss: 5.598540157079697
accuracy:      0.877
precision:     0.641
recall:        0.693
f1:            0.660
===== Start training: epoch 6, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.7817652225494385, Epoch: 6, training loss: 22.456720747053623, current learning rate 1e-05
val loss: 5.817975670099258
accuracy:      0.890
precision:     0.721
recall:        0.649
f1:            0.675
val loss: 4.9467597007751465
accuracy:      0.903
precision:     0.682
recall:        0.687
f1:            0.684
===== Start training: epoch 7, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.779169797897339, Epoch: 7, training loss: 17.76686802506447, current learning rate 1e-05
val loss: 6.37243165075779
accuracy:      0.888
precision:     0.717
recall:        0.669
f1:            0.688
val loss: 5.941023916006088
accuracy:      0.891
precision:     0.666
recall:        0.711
f1:            0.684
===== Start training: epoch 8, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.758659362792969, Epoch: 8, training loss: 16.690007582306862, current learning rate 1e-05
val loss: 6.839805327355862
accuracy:      0.894
precision:     0.746
recall:        0.612
f1:            0.644
val loss: 5.701629340648651
accuracy:      0.905
precision:     0.680
recall:        0.659
f1:            0.669
===== Start training: epoch 9, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.74373197555542, Epoch: 9, training loss: 12.165999036282301, current learning rate 1e-05
val loss: 8.130916431546211
accuracy:      0.882
precision:     0.705
recall:        0.695
f1:            0.700
val loss: 8.432803094387054
accuracy:      0.876
precision:     0.651
recall:        0.728
f1:            0.677
===== Start training: epoch 10, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.743674993515015, Epoch: 10, training loss: 9.816847160458565, current learning rate 1e-05
val loss: 9.640310227870941
accuracy:      0.881
precision:     0.702
recall:        0.691
f1:            0.697
val loss: 10.012924954295158
accuracy:      0.874
precision:     0.644
recall:        0.716
f1:            0.669
===== Start training: epoch 11, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.73263955116272, Epoch: 11, training loss: 9.220235079526901, current learning rate 1e-05
val loss: 9.34712141752243
accuracy:      0.883
precision:     0.705
recall:        0.686
f1:            0.695
val loss: 9.57349842786789
accuracy:      0.881
precision:     0.648
recall:        0.700
f1:            0.668
===== Start training: epoch 12, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.753880023956299, Epoch: 12, training loss: 7.761382818222046, current learning rate 1e-05
val loss: 9.43740302324295
accuracy:      0.892
precision:     0.731
recall:        0.648
f1:            0.676
val loss: 8.3152349665761
accuracy:      0.908
precision:     0.692
recall:        0.675
f1:            0.683
===== Start training: epoch 13, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.725041627883911, Epoch: 13, training loss: 5.027443270199001, current learning rate 1e-05
val loss: 9.308263301849365
accuracy:      0.897
precision:     0.755
recall:        0.630
f1:            0.665
val loss: 8.051190588623285
accuracy:      0.914
precision:     0.706
recall:        0.653
f1:            0.674
===== Start training: epoch 14, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.729045152664185, Epoch: 14, training loss: 6.633202098542824, current learning rate 1e-05
val loss: 9.513127014040947
accuracy:      0.895
precision:     0.745
recall:        0.632
f1:            0.665
val loss: 7.968125209212303
accuracy:      0.912
precision:     0.702
recall:        0.667
f1:            0.682
===== Start training: epoch 15, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.726718902587891, Epoch: 15, training loss: 5.152314165607095, current learning rate 1e-05
val loss: 10.266323924064636
accuracy:      0.895
precision:     0.739
recall:        0.686
f1:            0.708
val loss: 9.529293078929186
accuracy:      0.902
precision:     0.681
recall:        0.692
f1:            0.686
===== Start training: epoch 16, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.730913162231445, Epoch: 16, training loss: 3.13763412181288, current learning rate 1e-05
val loss: 10.8578782081604
accuracy:      0.892
precision:     0.731
recall:        0.668
f1:            0.692
val loss: 10.191220354288816
accuracy:      0.900
precision:     0.675
recall:        0.686
f1:            0.680
===== Start training: epoch 17, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.738606929779053, Epoch: 17, training loss: 2.2505381698720157, current learning rate 1e-05
val loss: 11.424101710319519
accuracy:      0.896
precision:     0.753
recall:        0.623
f1:            0.657
val loss: 9.007446881383657
accuracy:      0.918
precision:     0.725
recall:        0.645
f1:            0.674
===== Start training: epoch 18, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.727647304534912, Epoch: 18, training loss: 4.972065823618323, current learning rate 1e-05
val loss: 12.254212886095047
accuracy:      0.883
precision:     0.706
recall:        0.689
f1:            0.697
val loss: 12.716914966702461
accuracy:      0.878
precision:     0.647
recall:        0.709
f1:            0.670
===== Start training: epoch 19, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.7027294635772705, Epoch: 19, training loss: 3.212195174768567, current learning rate 1e-05
val loss: 12.48027616739273
accuracy:      0.890
precision:     0.722
recall:        0.660
f1:            0.683
val loss: 11.299418315291405
accuracy:      0.905
precision:     0.691
recall:        0.699
f1:            0.695
===== Start training: epoch 20, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.724979639053345, Epoch: 20, training loss: 5.2303051352500916, current learning rate 1e-05
val loss: 11.467631816864014
accuracy:      0.896
precision:     0.751
recall:        0.629
f1:            0.663
val loss: 9.743906116113067
accuracy:      0.912
precision:     0.699
recall:        0.652
f1:            0.671
===== Start training: epoch 21, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.718445301055908, Epoch: 21, training loss: 3.763099327683449, current learning rate 1e-05
val loss: 10.803494282066822
accuracy:      0.900
precision:     0.780
recall:        0.625
f1:            0.663
val loss: 8.010510389693081
accuracy:      0.921
precision:     0.739
recall:        0.647
f1:            0.679
===== Start training: epoch 22, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.729457139968872, Epoch: 22, training loss: 2.095219122245908, current learning rate 1e-05
val loss: 12.628704324364662
accuracy:      0.886
precision:     0.711
recall:        0.664
f1:            0.683
val loss: 11.495533972978592
accuracy:      0.903
precision:     0.684
recall:        0.697
f1:            0.690
===== Start training: epoch 23, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.703278541564941, Epoch: 23, training loss: 4.017430038074963, current learning rate 1e-05
val loss: 14.19418577849865
accuracy:      0.887
precision:     0.710
recall:        0.621
f1:            0.648
val loss: 10.767554878606461
accuracy:      0.917
precision:     0.721
recall:        0.655
f1:            0.680
===== Start training: epoch 24, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.734595775604248, Epoch: 24, training loss: 2.3103056233376265, current learning rate 1e-05
val loss: 13.373404279351234
accuracy:      0.889
precision:     0.719
recall:        0.666
f1:            0.687
val loss: 12.931408017873764
accuracy:      0.895
precision:     0.662
recall:        0.673
f1:            0.667
===== Start training: epoch 25, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.720234155654907, Epoch: 25, training loss: 1.9665655562421307, current learning rate 1e-05
val loss: 12.820876106619835
accuracy:      0.889
precision:     0.719
recall:        0.662
f1:            0.684
val loss: 11.2440116815269
accuracy:      0.905
precision:     0.687
recall:        0.684
f1:            0.685
===== Start training: epoch 26, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.7071754932403564, Epoch: 26, training loss: 1.8055163575336337, current learning rate 1e-05
val loss: 14.053208693861961
accuracy:      0.891
precision:     0.728
recall:        0.647
f1:            0.675
val loss: 11.757286624982953
accuracy:      0.913
precision:     0.704
recall:        0.662
f1:            0.680
===== Start training: epoch 27, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.711486577987671, Epoch: 27, training loss: 0.7126195803284645, current learning rate 1e-05
val loss: 13.661476716399193
accuracy:      0.894
precision:     0.739
recall:        0.645
f1:            0.676
val loss: 11.513310767710209
accuracy:      0.906
precision:     0.679
recall:        0.644
f1:            0.659
===== Start training: epoch 28, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.709423303604126, Epoch: 28, training loss: 1.0881318121682853, current learning rate 1e-05
val loss: 13.579483479261398
accuracy:      0.894
precision:     0.738
recall:        0.652
f1:            0.681
val loss: 11.496953392401338
accuracy:      0.911
precision:     0.702
recall:        0.681
f1:            0.691
===== Start training: epoch 29, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.708827972412109, Epoch: 29, training loss: 0.5157859933678992, current learning rate 1e-05
val loss: 14.194737374782562
accuracy:      0.897
precision:     0.772
recall:        0.600
f1:            0.633
val loss: 10.484975457191467
accuracy:      0.917
precision:     0.718
recall:        0.615
f1:            0.646
===== Start training: epoch 30, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.719444274902344, Epoch: 30, training loss: 1.241069829557091, current learning rate 1e-05
val loss: 13.405533701181412
accuracy:      0.891
precision:     0.729
recall:        0.620
f1:            0.651
val loss: 11.722107745707035
accuracy:      0.905
precision:     0.674
recall:        0.639
f1:            0.653
best result:
0.9018181818181819
0.68076183819675
0.6915344318713992
0.6859269912440252


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.764256238937378, Epoch: 1, training loss: 59.885387539863586, current learning rate 1e-05
val loss: 13.638976216316223
accuracy:      0.294
precision:     0.481
recall:        0.464
f1:            0.280
val loss: 13.593404173851013
accuracy:      0.319
precision:     0.527
recall:        0.569
f1:            0.297
===== Start training: epoch 2, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.769012212753296, Epoch: 2, training loss: 59.28627419471741, current learning rate 1e-05
val loss: 13.179257571697235
accuracy:      0.255
precision:     0.538
recall:        0.549
f1:            0.253
val loss: 13.23134994506836
accuracy:      0.259
precision:     0.544
recall:        0.586
f1:            0.252
===== Start training: epoch 3, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.785459041595459, Epoch: 3, training loss: 54.77906596660614, current learning rate 1e-05
val loss: 10.990331292152405
accuracy:      0.660
precision:     0.569
recall:        0.661
f1:            0.541
val loss: 10.687146544456482
accuracy:      0.697
precision:     0.571
recall:        0.710
f1:            0.546
===== Start training: epoch 4, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.762756824493408, Epoch: 4, training loss: 44.80218905210495, current learning rate 1e-05
val loss: 12.243514478206635
accuracy:      0.622
precision:     0.575
recall:        0.683
f1:            0.527
val loss: 11.89810848236084
accuracy:      0.654
precision:     0.575
recall:        0.736
f1:            0.528
===== Start training: epoch 5, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.774267196655273, Epoch: 5, training loss: 38.41899585723877, current learning rate 1e-05
val loss: 7.918611973524094
accuracy:      0.805
precision:     0.603
recall:        0.655
f1:            0.617
val loss: 7.719325423240662
accuracy:      0.811
precision:     0.599
recall:        0.712
f1:            0.616
===== Start training: epoch 6, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.7670979499816895, Epoch: 6, training loss: 26.836534202098846, current learning rate 1e-05
val loss: 7.2093715369701385
accuracy:      0.839
precision:     0.626
recall:        0.651
f1:            0.637
val loss: 7.323096752166748
accuracy:      0.840
precision:     0.611
recall:        0.703
f1:            0.633
===== Start training: epoch 7, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.741836071014404, Epoch: 7, training loss: 19.503698587417603, current learning rate 1e-05
val loss: 7.964150041341782
accuracy:      0.840
precision:     0.633
recall:        0.665
f1:            0.646
val loss: 8.420934617519379
accuracy:      0.830
precision:     0.610
recall:        0.717
f1:            0.632
===== Start training: epoch 8, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.754521131515503, Epoch: 8, training loss: 14.402792751789093, current learning rate 1e-05
val loss: 7.087577477097511
accuracy:      0.885
precision:     0.703
recall:        0.620
f1:            0.646
val loss: 6.128156125545502
accuracy:      0.895
precision:     0.653
recall:        0.653
f1:            0.653
===== Start training: epoch 9, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.773741245269775, Epoch: 9, training loss: 12.780022509396076, current learning rate 1e-05
val loss: 8.968096226453781
accuracy:      0.855
precision:     0.648
recall:        0.660
f1:            0.654
val loss: 8.906738460063934
accuracy:      0.864
precision:     0.635
recall:        0.721
f1:            0.662
===== Start training: epoch 10, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.732035398483276, Epoch: 10, training loss: 9.447206109762192, current learning rate 1e-05
val loss: 8.700771763920784
accuracy:      0.881
precision:     0.688
recall:        0.624
f1:            0.646
val loss: 7.257208496332169
accuracy:      0.894
precision:     0.656
recall:        0.667
f1:            0.661
===== Start training: epoch 11, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.7374818325042725, Epoch: 11, training loss: 10.509851498529315, current learning rate 1e-05
val loss: 9.315982267260551
accuracy:      0.877
precision:     0.676
recall:        0.625
f1:            0.644
val loss: 8.102752923965454
accuracy:      0.888
precision:     0.647
recall:        0.669
f1:            0.657
===== Start training: epoch 12, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.731402635574341, Epoch: 12, training loss: 7.472968373447657, current learning rate 1e-05
val loss: 10.78529404103756
accuracy:      0.870
precision:     0.667
recall:        0.642
f1:            0.653
val loss: 9.378426879644394
accuracy:      0.887
precision:     0.655
recall:        0.694
f1:            0.671
===== Start training: epoch 13, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.72523045539856, Epoch: 13, training loss: 6.94464928470552, current learning rate 1e-05
val loss: 8.84047432243824
accuracy:      0.887
precision:     0.707
recall:        0.591
f1:            0.617
val loss: 6.580734422430396
accuracy:      0.914
precision:     0.700
recall:        0.623
f1:            0.649
===== Start training: epoch 14, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.753635406494141, Epoch: 14, training loss: 3.7508119605481625, current learning rate 1e-05
val loss: 11.997075982391834
accuracy:      0.890
precision:     0.723
recall:        0.592
f1:            0.620
val loss: 9.124500252306461
accuracy:      0.912
precision:     0.687
recall:        0.607
f1:            0.632
===== Start training: epoch 15, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.738162279129028, Epoch: 15, training loss: 7.072962507605553, current learning rate 1e-05
val loss: 11.795991338789463
accuracy:      0.880
precision:     0.685
recall:        0.621
f1:            0.642
val loss: 10.606243088841438
accuracy:      0.899
precision:     0.670
recall:        0.675
f1:            0.672
===== Start training: epoch 16, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.735608816146851, Epoch: 16, training loss: 2.0462809903547168, current learning rate 1e-05
val loss: 10.783344954252243
accuracy:      0.883
precision:     0.688
recall:        0.599
f1:            0.623
val loss: 8.765832774341106
accuracy:      0.913
precision:     0.702
recall:        0.652
f1:            0.673
===== Start training: epoch 17, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.752053737640381, Epoch: 17, training loss: 5.871404294855893, current learning rate 1e-05
val loss: 12.611715510487556
accuracy:      0.874
precision:     0.669
recall:        0.624
f1:            0.641
val loss: 9.67492775991559
accuracy:      0.904
precision:     0.680
recall:        0.673
f1:            0.676
===== Start training: epoch 18, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.750597238540649, Epoch: 18, training loss: 2.644122362136841, current learning rate 1e-05
val loss: 12.072107546031475
accuracy:      0.881
precision:     0.682
recall:        0.601
f1:            0.624
val loss: 9.856130072847009
accuracy:      0.905
precision:     0.677
recall:        0.649
f1:            0.661
===== Start training: epoch 19, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.719074010848999, Epoch: 19, training loss: 4.606418888317421, current learning rate 1e-05
val loss: 13.604284420609474
accuracy:      0.881
precision:     0.692
recall:        0.641
f1:            0.661
val loss: 14.07308772020042
accuracy:      0.882
precision:     0.638
recall:        0.671
f1:            0.652
===== Start training: epoch 20, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.7238450050354, Epoch: 20, training loss: 2.610179563984275, current learning rate 1e-05
val loss: 11.6947802901268
accuracy:      0.889
precision:     0.717
recall:        0.612
f1:            0.641
val loss: 9.442595601081848
accuracy:      0.906
precision:     0.670
recall:        0.619
f1:            0.638
===== Start training: epoch 21, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.744377374649048, Epoch: 21, training loss: 2.843957256525755, current learning rate 1e-05
val loss: 12.603010296821594
accuracy:      0.878
precision:     0.690
recall:        0.660
f1:            0.673
val loss: 14.638784795999527
accuracy:      0.863
precision:     0.618
recall:        0.675
f1:            0.638
===== Start training: epoch 22, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.722565174102783, Epoch: 22, training loss: 1.9094976554624736, current learning rate 1e-05
val loss: 10.965507373213768
accuracy:      0.884
precision:     0.695
recall:        0.612
f1:            0.637
val loss: 10.027244675904512
accuracy:      0.906
precision:     0.685
recall:        0.664
f1:            0.673
===== Start training: epoch 23, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.734625339508057, Epoch: 23, training loss: 1.9186947080306709, current learning rate 1e-05
val loss: 14.191590517759323
accuracy:      0.874
precision:     0.672
recall:        0.634
f1:            0.649
val loss: 11.965357145294547
accuracy:      0.899
precision:     0.672
recall:        0.680
f1:            0.676
===== Start training: epoch 24, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.738866329193115, Epoch: 24, training loss: 1.9118014462292194, current learning rate 1e-05
val loss: 12.201076209545135
accuracy:      0.883
precision:     0.689
recall:        0.602
f1:            0.626
val loss: 9.663659229874611
accuracy:      0.910
precision:     0.691
recall:        0.646
f1:            0.664
===== Start training: epoch 25, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.718841791152954, Epoch: 25, training loss: 3.415658882353455, current learning rate 1e-05
val loss: 9.483583822846413
accuracy:      0.889
precision:     0.723
recall:        0.558
f1:            0.576
val loss: 6.8143296567723155
accuracy:      0.917
precision:     0.714
recall:        0.570
f1:            0.596
===== Start training: epoch 26, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.717548370361328, Epoch: 26, training loss: 4.293001828715205, current learning rate 1e-05
val loss: 10.553818419575691
accuracy:      0.889
precision:     0.717
recall:        0.599
f1:            0.627
val loss: 7.748755969107151
accuracy:      0.912
precision:     0.695
recall:        0.637
f1:            0.659
===== Start training: epoch 27, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.71935510635376, Epoch: 27, training loss: 1.6412732250755653, current learning rate 1e-05
val loss: 12.355930916965008
accuracy:      0.888
precision:     0.713
recall:        0.615
f1:            0.643
val loss: 11.000373482704163
accuracy:      0.904
precision:     0.672
recall:        0.648
f1:            0.658
===== Start training: epoch 28, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.714265584945679, Epoch: 28, training loss: 4.8877968937158585, current learning rate 1e-05
val loss: 14.133381262421608
accuracy:      0.880
precision:     0.694
recall:        0.657
f1:            0.673
val loss: 13.468455657362938
accuracy:      0.892
precision:     0.658
recall:        0.681
f1:            0.668
===== Start training: epoch 29, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.7300074100494385, Epoch: 29, training loss: 3.8571314576547593, current learning rate 1e-05
val loss: 12.651291228830814
accuracy:      0.884
precision:     0.698
recall:        0.629
f1:            0.653
val loss: 14.510691955685616
accuracy:      0.888
precision:     0.649
recall:        0.674
f1:            0.660
===== Start training: epoch 30, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.736285448074341, Epoch: 30, training loss: 0.9687041104771197, current learning rate 1e-05
val loss: 11.692759402096272
accuracy:      0.885
precision:     0.696
recall:        0.577
f1:            0.599
val loss: 7.439525252208114
accuracy:      0.919
precision:     0.730
recall:        0.601
f1:            0.634
best result:
0.8627272727272727
0.6182701637801759
0.6752251712608501
0.6375823403758583


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.78272008895874, Epoch: 1, training loss: 59.65850508213043, current learning rate 1e-05
val loss: 13.012920498847961
accuracy:      0.211
precision:     0.521
recall:        0.521
f1:            0.211
val loss: 13.0477956533432
accuracy:      0.213
precision:     0.529
recall:        0.546
f1:            0.210
===== Start training: epoch 2, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.78315544128418, Epoch: 2, training loss: 57.470030188560486, current learning rate 1e-05
val loss: 12.92749035358429
accuracy:      0.414
precision:     0.534
recall:        0.576
f1:            0.380
val loss: 12.642467498779297
accuracy:      0.467
precision:     0.544
recall:        0.640
f1:            0.403
===== Start training: epoch 3, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.17it/s]


Timing: 7.79191780090332, Epoch: 3, training loss: 49.5880309343338, current learning rate 1e-05
val loss: 12.013478755950928
accuracy:      0.588
precision:     0.560
recall:        0.647
f1:            0.498
val loss: 11.880937576293945
accuracy:      0.605
precision:     0.560
recall:        0.694
f1:            0.491
===== Start training: epoch 4, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.7681920528411865, Epoch: 4, training loss: 40.091438829898834, current learning rate 1e-05
val loss: 9.634041219949722
accuracy:      0.734
precision:     0.587
recall:        0.679
f1:            0.588
val loss: 8.897434890270233
accuracy:      0.766
precision:     0.591
recall:        0.733
f1:            0.594
===== Start training: epoch 5, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.769619941711426, Epoch: 5, training loss: 32.634342312812805, current learning rate 1e-05
val loss: 8.512747317552567
accuracy:      0.775
precision:     0.595
recall:        0.669
f1:            0.607
val loss: 8.113027274608612
accuracy:      0.805
precision:     0.605
recall:        0.738
f1:            0.621
===== Start training: epoch 6, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.766463994979858, Epoch: 6, training loss: 22.924811899662018, current learning rate 1e-05
val loss: 7.3826640248298645
accuracy:      0.844
precision:     0.635
recall:        0.657
f1:            0.644
val loss: 7.0912184715271
accuracy:      0.863
precision:     0.642
recall:        0.745
f1:            0.672
===== Start training: epoch 7, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.771495580673218, Epoch: 7, training loss: 25.39758613705635, current learning rate 1e-05
val loss: 6.2982573211193085
accuracy:      0.879
precision:     0.682
recall:        0.620
f1:            0.641
val loss: 5.450599044561386
accuracy:      0.886
precision:     0.640
recall:        0.658
f1:            0.648
===== Start training: epoch 8, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.7732672691345215, Epoch: 8, training loss: 19.465786427259445, current learning rate 1e-05
val loss: 6.891715154051781
accuracy:      0.862
precision:     0.645
recall:        0.627
f1:            0.635
val loss: 6.096097946166992
accuracy:      0.887
precision:     0.653
recall:        0.689
f1:            0.668
===== Start training: epoch 9, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.765026807785034, Epoch: 9, training loss: 15.09749236702919, current learning rate 1e-05
val loss: 7.378200531005859
accuracy:      0.860
precision:     0.659
recall:        0.670
f1:            0.664
val loss: 7.179111301898956
accuracy:      0.876
precision:     0.655
recall:        0.743
f1:            0.684
===== Start training: epoch 10, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.766384840011597, Epoch: 10, training loss: 12.803010545670986, current learning rate 1e-05
val loss: 7.649484619498253
accuracy:      0.877
precision:     0.663
recall:        0.586
f1:            0.606
val loss: 5.979686543345451
accuracy:      0.894
precision:     0.642
recall:        0.632
f1:            0.637
===== Start training: epoch 11, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.768515348434448, Epoch: 11, training loss: 13.752151280641556, current learning rate 1e-05
val loss: 7.620531618595123
accuracy:      0.887
precision:     0.710
recall:        0.628
f1:            0.654
val loss: 5.650551795959473
accuracy:      0.908
precision:     0.692
recall:        0.675
f1:            0.683
===== Start training: epoch 12, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.763516902923584, Epoch: 12, training loss: 11.753960758447647, current learning rate 1e-05
val loss: 8.36925184726715
accuracy:      0.860
precision:     0.655
recall:        0.659
f1:            0.657
val loss: 8.010673880577087
accuracy:      0.875
precision:     0.653
recall:        0.742
f1:            0.682
===== Start training: epoch 13, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.754897594451904, Epoch: 13, training loss: 10.89914059638977, current learning rate 1e-05
val loss: 9.681843843311071
accuracy:      0.889
precision:     0.720
recall:        0.572
f1:            0.594
val loss: 6.962791547179222
accuracy:      0.913
precision:     0.687
recall:        0.593
f1:            0.619
===== Start training: epoch 14, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.18it/s]


Timing: 7.772839784622192, Epoch: 14, training loss: 6.481106850784272, current learning rate 1e-05
val loss: 10.784775257110596
accuracy:      0.877
precision:     0.680
recall:        0.632
f1:            0.651
val loss: 8.403300270438194
accuracy:      0.896
precision:     0.660
recall:        0.664
f1:            0.662
===== Start training: epoch 15, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.756230115890503, Epoch: 15, training loss: 15.799574166536331, current learning rate 1e-05
val loss: 8.158976666629314
accuracy:      0.883
precision:     0.687
recall:        0.595
f1:            0.619
val loss: 6.447913832962513
accuracy:      0.906
precision:     0.679
recall:        0.644
f1:            0.659
===== Start training: epoch 16, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.745354175567627, Epoch: 16, training loss: 4.069811642169952, current learning rate 1e-05
val loss: 11.239507928490639
accuracy:      0.893
precision:     0.744
recall:        0.601
f1:            0.632
val loss: 8.245389819145203
accuracy:      0.910
precision:     0.685
recall:        0.626
f1:            0.648
===== Start training: epoch 17, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.758874893188477, Epoch: 17, training loss: 6.391210871748626, current learning rate 1e-05
val loss: 9.49223743379116
accuracy:      0.886
precision:     0.704
recall:        0.601
f1:            0.627
val loss: 7.267166569828987
accuracy:      0.908
precision:     0.689
recall:        0.660
f1:            0.673
===== Start training: epoch 18, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.745042324066162, Epoch: 18, training loss: 4.649138139095157, current learning rate 1e-05
val loss: 11.776720941066742
accuracy:      0.872
precision:     0.667
recall:        0.630
f1:            0.644
val loss: 10.423001572489738
accuracy:      0.889
precision:     0.649
recall:        0.670
f1:            0.658
===== Start training: epoch 19, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.742975234985352, Epoch: 19, training loss: 3.3095650819595903, current learning rate 1e-05
val loss: 11.341169685125351
accuracy:      0.886
precision:     0.702
recall:        0.580
f1:            0.604
val loss: 7.908683195710182
accuracy:      0.908
precision:     0.666
recall:        0.595
f1:            0.617
===== Start training: epoch 20, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.753062725067139, Epoch: 20, training loss: 3.986080849543214, current learning rate 1e-05
val loss: 13.537858307361603
accuracy:      0.867
precision:     0.663
recall:        0.650
f1:            0.656
val loss: 13.883938267827034
accuracy:      0.874
precision:     0.643
recall:        0.711
f1:            0.666
===== Start training: epoch 21, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.763257026672363, Epoch: 21, training loss: 3.7663358286954463, current learning rate 1e-05
val loss: 12.70845414698124
accuracy:      0.882
precision:     0.684
recall:        0.598
f1:            0.622
val loss: 9.683266550302505
accuracy:      0.906
precision:     0.679
recall:        0.644
f1:            0.659
===== Start training: epoch 22, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.758371829986572, Epoch: 22, training loss: 1.1116901354398578, current learning rate 1e-05
val loss: 13.481327190995216
accuracy:      0.884
precision:     0.696
recall:        0.600
f1:            0.625
val loss: 9.817309468984604
accuracy:      0.913
precision:     0.702
recall:        0.652
f1:            0.673
===== Start training: epoch 23, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.7620649337768555, Epoch: 23, training loss: 3.269559223204851, current learning rate 1e-05
val loss: 9.785929158329964
accuracy:      0.894
precision:     0.743
recall:        0.622
f1:            0.654
val loss: 7.743788838386536
accuracy:      0.910
precision:     0.692
recall:        0.651
f1:            0.668
===== Start training: epoch 24, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.7581071853637695, Epoch: 24, training loss: 3.0786057002842426, current learning rate 1e-05
val loss: 11.359957695007324
accuracy:      0.890
precision:     0.721
recall:        0.616
f1:            0.645
val loss: 8.97000078856945
accuracy:      0.909
precision:     0.692
recall:        0.661
f1:            0.674
===== Start training: epoch 25, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.747344017028809, Epoch: 25, training loss: 4.687492637429386, current learning rate 1e-05
val loss: 11.332244456978515
accuracy:      0.892
precision:     0.756
recall:        0.567
f1:            0.589
val loss: 7.372523983474821
accuracy:      0.922
precision:     0.759
recall:        0.587
f1:            0.621
===== Start training: epoch 26, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.758506774902344, Epoch: 26, training loss: 2.4175800882512704, current learning rate 1e-05
val loss: 13.542196094989777
accuracy:      0.889
precision:     0.717
recall:        0.612
f1:            0.641
val loss: 10.677148414310068
accuracy:      0.908
precision:     0.691
recall:        0.670
f1:            0.680
===== Start training: epoch 27, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.762815237045288, Epoch: 27, training loss: 2.6087919435231015, current learning rate 1e-05
val loss: 13.171995863318443
accuracy:      0.887
precision:     0.708
recall:        0.598
f1:            0.624
val loss: 10.122053157072514
accuracy:      0.910
precision:     0.694
recall:        0.656
f1:            0.672
===== Start training: epoch 28, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.753958702087402, Epoch: 28, training loss: 3.9112739388947375, current learning rate 1e-05
val loss: 14.894793063402176
accuracy:      0.884
precision:     0.701
recall:        0.646
f1:            0.667
val loss: 14.055415911599994
accuracy:      0.890
precision:     0.661
recall:        0.700
f1:            0.677
===== Start training: epoch 29, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.745074272155762, Epoch: 29, training loss: 2.824489628896117, current learning rate 1e-05
val loss: 13.79205273091793
accuracy:      0.891
precision:     0.725
recall:        0.626
f1:            0.656
val loss: 11.79666456580162
accuracy:      0.905
precision:     0.682
recall:        0.664
f1:            0.672
===== Start training: epoch 30, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.751009464263916, Epoch: 30, training loss: 2.585636103525758, current learning rate 1e-05
val loss: 15.73978304862976
accuracy:      0.877
precision:     0.674
recall:        0.619
f1:            0.638
val loss: 15.559224050492048
accuracy:      0.886
precision:     0.646
recall:        0.673
f1:            0.658
best result:
0.89
0.6609545836837678
0.7000893061348958
0.6773732217350098


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



avg result:
0.8848484848484849
0.6533288618868979
0.6889496364223818
0.6669608511182977


In [12]:
from google.colab import runtime
runtime.unassign()