<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [2]:
!pip install datasets



In [3]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [4]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [5]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [6]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }

    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }"""

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-5] != name:
                continue

              story_id = row[0]
              sent = row[2].strip()
              target = row[3].strip()

              label = row[-6]

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 1:
                l = [1,0]
              elif label == 0:
                l = [0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue

              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              label = row[3].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue

              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              label = row[3].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'neither':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'no_relation':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
        if row[-1] != name:
          continue

        story_id = row[0]
        sent = row[1].strip()
        target = row[2].strip()

        ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
        ds_marker = ds_marker.replace("_", " ")
        ds_marker = ds_marker[0].upper() + ds_marker[1:]
        target = target[0].lower() + target[1:]
        target = ds_marker + " " + target

        label = row[3].strip()

        sentences.append(sent)
        target_sentences.append(target)
        id.append(story_id)

        l=[0,0,0]
        if label == 'supports' or label == 'support' or label == 'because':
          l = [1,0,0]
        elif label == 'attacks' or label == 'attack' or label == 'but':
          l = [0,1,0]
        elif label == 'neither':
          l = [0,0,1]

        label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [7]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      labels_std = labels[:batch_size // 2, :]

      emb_attack = samples[labels_std == [0,1]]
      emb_support = samples[labels_std == [1,0]]

      samples_adv = H_sent[batch_size // 2:, ]
      labels_adv = labels[batch_size // 2:, :]

      emb_caus = samples_adv[labels_adv == [1,0,0]]
      emb_other = samples_adv[labels_adv == [0,1,0] or labels_adv == [0,0,1]]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl)
      support_prediction = self.support_linear(mean_grl)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)
    #H_sent = att_output[0][:,0,:]

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [8]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [9]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [10]:
import datetime
import json
from pathlib import Path

def save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed=None, discovery_weight=None, adv_weight=None):
  json_dict = config | args | {
      "discovery_weight": discovery_weight,
      "adv_weight": adv_weight,
      "accuracy": best_test_acc,
      "precision": best_test_pre,
      "recall": best_test_rec,
      "f1": best_test_f1,
      "time": datetime.datetime.now()
  }

  file_path = f"results/{config['dataset']}"
  if config["grid_search"]:
    file_path += "/grid_search"
  else:
    file_path += "/standard"

  if config["double_adversarial"]:
    file_path += "/double_adversarial"
  elif config["adversarial"]:
    file_path += "/adversarial"
  else:
    file_path += "/standard"

  if config["injection"]:
    file_path += "/injection"
  else:
    file_path += "/standard"

  if discovery_weight is not None and adv_weight is not None:
    file_path += f"/discovery_{discovery_weight}/adv_{adv_weight}"

  if seed is not None:
    file_path += f"/{seed}"

  Path(file_path).mkdir(parents=True, exist_ok=True)

  file_path += "/result.json"

  with open(file_path, "w") as outfile:
    json.dump(json_dict, outfile, default=str)

In [11]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 2, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": True,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": 10, #[9.375, 30, 1], #10, [2.071, 1.933, 1]
    "lr": 1e-5
}

config = {
    "dataset": "student_essay", #"student_essay", #debate, m-arg, nk
    "adversarial": False,
    "double_adversarial": False,
    "dataset_from_saved": True,
    "injection": False,
    "grid_search": False,
    "visualize": True,
    "train": True,
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
            attack_target = targets[targets == [0,1]]
            support_target = targets[targets == [1,0]]

            cause_target = targets_adv[targets_adv == [1,0,0]]
            other_target = targets_adv[targets_adv == [0,1,0] or targets_adv == [0,0,1]]
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + .2*loss4 + .2*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
            preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
            loss = loss_fn(out, labels.float())
            val_loss += loss.item()

            labels = labels.cpu().numpy().tolist()

            val_labels.extend(labels)
            if len(labels[0]) != 2:
              for pred in preds:
                if pred == 0:
                  val_preds.append([1,0,0])
                elif pred == 1:
                  val_preds.append([0,1,0])
                else:
                  val_preds.append([0,0,1])
            else:
              val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1, val_loss

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

def run():
  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/student_essay.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/debate.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")
  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  df = datasets.load_dataset("discovery","discovery")
  adv_processor = DiscourseMarkerProcessor()
  if not config["dataset_from_saved"]:
    print("processing discourse marker dataset...")
    train_adv = adv_processor.process_dataset(df["train"])
    with open("./adv_dataset.pkl", "wb") as writer:
      pickle.dump(train_adv, writer)
  else:
    with open("./adv_dataset.pkl", "rb") as reader:
      train_adv = pickle.load(reader)
  train_set_adv = dataset(train_adv)

  if config["adversarial"]:
    data_train_tot = data_train + train_adv
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if not config["adversarial"]:
    train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv) #loading adv for visualization
    model = BaselineModelWithSentenceComparison()
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

  if config["dataset"] == "m-arg" or config["dataset"] == "nk":
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)

    for discovery_weight in range_disc:
      for adv_weight in range_adv:
        avg_acc, avg_pre, avg_rec, avg_f1 = [], [], [], []
        for seed in args["seed"]:
          set_random_seeds(seed)
          for epoch in range(args["epochs"]):
            print('===== Start training: epoch {} ====='.format(epoch + 1))
            print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}, seed = {seed}")
            train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight)
            dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
            test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
            if dev_f1 > best_dev_f1:
              best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
              best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
              #save model

          print('best result:')
          print(best_test_acc)
          print(best_test_pre)
          print(best_test_rec)
          print(best_test_f1)
          avg_acc.append(best_test_acc)
          avg_pre.append(best_test_pre)
          avg_rec.append(best_test_rec)
          avg_f1.append(best_test_f1)

          save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed, discovery_weight, adv_weight)
          result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
          del model
          del optimizer

          model = AdversarialNet()
          model = model.to(device)

          optimizer_grouped_parameters = [
            {
              "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
              "weight_decay": 0.01,
            },
            {
              "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
              "weight_decay": 0.0
            },
          ]
          optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

          best_acc = -1
          best_pre = -1
          best_rec = -1
          best_f1 = -1
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

        avg_acc_score = sum(avg_acc) / len(avg_acc)
        avg_pre_score = sum(avg_pre) / len(avg_pre)
        avg_rec_score = sum(avg_rec) / len(avg_rec)
        avg_f1_score = sum(avg_f1) / len(avg_f1)
        print()
        print('avg result:')
        print(avg_acc_score)
        print(avg_pre_score)
        print(avg_rec_score)
        print(avg_f1_score)
        save_result(config, args, avg_acc_score, avg_pre_score, avg_rec_score, avg_f1_score, discovery_weight=discovery_weight, adv_weight=adv_weight)

  else:
    avg_acc, avg_pre, avg_rec, avg_f1 = [], [], [], []

    for seed in args["seed"]:
      set_random_seeds(seed)

      #scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, min_lr=1e-7)
      for epoch in range(args["epochs"]):
        if config["train"]:
          print('===== Start training: epoch {}, seed = {} ====='.format(epoch + 1, seed))
          train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
          dev_a, dev_p, dev_r, dev_f1, dev_loss = val(model, dev_dataloader)
          test_a, test_p, test_r, test_f1, _ = val(model, test_dataloader)
          #scheduler.step(dev_loss)
          if dev_f1 > best_dev_f1:
            best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
            best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
            torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

      if config["visualize"] and config["adversarial"]:
        model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
        visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

          #save model

      print('best result:')
      print(best_test_acc)
      print(best_test_pre)
      print(best_test_rec)
      print(best_test_f1)

      avg_acc.append(best_test_acc)
      avg_pre.append(best_test_pre)
      avg_rec.append(best_test_rec)
      avg_f1.append(best_test_f1)

      if not config["adversarial"] and not config["double_adversarial"]:
        save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed)
      else:
        save_result(config, args, best_test_acc, best_test_pre, best_test_rec, best_test_f1, seed, discovery_weight, adv_weight)

      result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

      del model
      del optimizer

      model = BaselineModelWithSentenceComparison()
      model = model.to(device)

      optimizer_grouped_parameters = [
        {
          "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
          "weight_decay": 0.01,
        },
        {
          "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
          "weight_decay": 0.0
        },
      ]
      optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

      best_acc = -1
      best_pre = -1
      best_rec = -1
      best_f1 = -1
      best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  avg_acc_score = sum(avg_acc) / len(avg_acc)
  avg_pre_score = sum(avg_pre) / len(avg_pre)
  avg_rec_score = sum(avg_rec) / len(avg_rec)
  avg_f1_score = sum(avg_f1) / len(avg_f1)

  print()
  print('avg result:')
  print(avg_acc_score)
  print(avg_pre_score)
  print(avg_rec_score)
  print(avg_f1_score)
  save_result(config, args, avg_acc_score, avg_pre_score, avg_rec_score, avg_f1_score)

if __name__ == "__main__":
  run()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
tokenizing...: 100%|██████████| 3070/3070 [00:01<00:00, 2484.39it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 1142/1142 [00:00<00:00, 2465.84it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 1100/1100 [00:00<00:00, 2972.63it/s]


finished preprocessing examples in test


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:08<00:00,  5.61it/s]


Timing: 8.556735515594482, Epoch: 1, training loss: 59.95030331611633, current learning rate 1e-05
val loss: 11.744002997875214
accuracy:      0.895
precision:     0.947
recall:        0.538
f1:            0.543
val loss: 11.659599304199219
accuracy:      0.917
precision:     0.710
recall:        0.520
f1:            0.519
===== Start training: epoch 2, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.727200508117676, Epoch: 2, training loss: 57.683807015419006, current learning rate 1e-05
val loss: 12.326111972332
accuracy:      0.543
precision:     0.559
recall:        0.645
f1:            0.472
val loss: 12.033414244651794
accuracy:      0.580
precision:     0.557
recall:        0.686
f1:            0.476
===== Start training: epoch 3, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.742282390594482, Epoch: 3, training loss: 48.357213109731674, current learning rate 1e-05
val loss: 7.3124812841415405
accuracy:      0.843
precision:     0.632
recall:        0.653
f1:            0.641
val loss: 7.281821578741074
accuracy:      0.839
precision:     0.587
recall:        0.642
f1:            0.602
===== Start training: epoch 4, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.7208027839660645, Epoch: 4, training loss: 38.52971750497818, current learning rate 1e-05
val loss: 11.740790963172913
accuracy:      0.683
precision:     0.596
recall:        0.724
f1:            0.574
val loss: 11.439718246459961
accuracy:      0.673
precision:     0.576
recall:        0.737
f1:            0.539
===== Start training: epoch 5, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.73745322227478, Epoch: 5, training loss: 31.06693947315216, current learning rate 1e-05
val loss: 7.062482863664627
accuracy:      0.838
precision:     0.642
recall:        0.691
f1:            0.659
val loss: 7.806798458099365
accuracy:      0.817
precision:     0.605
recall:        0.720
f1:            0.624
===== Start training: epoch 6, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.723234415054321, Epoch: 6, training loss: 26.377491652965546, current learning rate 1e-05
val loss: 7.22581684589386
accuracy:      0.849
precision:     0.660
recall:        0.711
f1:            0.679
val loss: 7.827081859111786
accuracy:      0.838
precision:     0.618
recall:        0.727
f1:            0.642
===== Start training: epoch 7, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.723796606063843, Epoch: 7, training loss: 18.255187302827835, current learning rate 1e-05
val loss: 8.092838644981384
accuracy:      0.832
precision:     0.640
recall:        0.701
f1:            0.660
val loss: 9.088633716106415
accuracy:      0.825
precision:     0.613
recall:        0.734
f1:            0.635
===== Start training: epoch 8, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.734256744384766, Epoch: 8, training loss: 18.095314145088196, current learning rate 1e-05
val loss: 7.0003355741500854
accuracy:      0.884
precision:     0.697
recall:        0.606
f1:            0.632
val loss: 5.893680542707443
accuracy:      0.901
precision:     0.658
recall:        0.631
f1:            0.643
===== Start training: epoch 9, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.685365676879883, Epoch: 9, training loss: 11.810851141810417, current learning rate 1e-05
val loss: 7.948335826396942
accuracy:      0.888
precision:     0.715
recall:        0.648
f1:            0.673
val loss: 7.790300220251083
accuracy:      0.895
precision:     0.657
recall:        0.658
f1:            0.657
===== Start training: epoch 10, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.6974968910217285, Epoch: 10, training loss: 12.425726853311062, current learning rate 1e-05
val loss: 9.884997606277466
accuracy:      0.845
precision:     0.656
recall:        0.711
f1:            0.676
val loss: 10.47197937965393
accuracy:      0.842
precision:     0.609
recall:        0.694
f1:            0.630
===== Start training: epoch 11, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.697664260864258, Epoch: 11, training loss: 12.85847869515419, current learning rate 1e-05
val loss: 9.056925624608994
accuracy:      0.886
precision:     0.709
recall:        0.644
f1:            0.668
val loss: 8.923342660069466
accuracy:      0.885
precision:     0.634
recall:        0.648
f1:            0.641
===== Start training: epoch 12, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.693908452987671, Epoch: 12, training loss: 8.609022948890924, current learning rate 1e-05
val loss: 8.590531587600708
accuracy:      0.877
precision:     0.687
recall:        0.656
f1:            0.669
val loss: 7.697080602869391
accuracy:      0.888
precision:     0.645
recall:        0.664
f1:            0.654
===== Start training: epoch 13, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.703744888305664, Epoch: 13, training loss: 9.209410443902016, current learning rate 1e-05
val loss: 10.153603509068489
accuracy:      0.883
precision:     0.696
recall:        0.635
f1:            0.657
val loss: 8.811857752501965
accuracy:      0.895
precision:     0.651
recall:        0.648
f1:            0.649
===== Start training: epoch 14, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.684369325637817, Epoch: 14, training loss: 5.282062360085547, current learning rate 1e-05
val loss: 8.775612279772758
accuracy:      0.896
precision:     0.755
recall:        0.616
f1:            0.650
val loss: 8.03864273428917
accuracy:      0.905
precision:     0.667
recall:        0.619
f1:            0.637
===== Start training: epoch 15, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.689729928970337, Epoch: 15, training loss: 6.187498518731445, current learning rate 1e-05
val loss: 11.222885340452194
accuracy:      0.876
precision:     0.683
recall:        0.655
f1:            0.667
val loss: 10.934126734733582
accuracy:      0.881
precision:     0.633
recall:        0.660
f1:            0.645
===== Start training: epoch 16, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.698151350021362, Epoch: 16, training loss: 8.489319146610796, current learning rate 1e-05
val loss: 10.099225893616676
accuracy:      0.889
precision:     0.718
recall:        0.649
f1:            0.674
val loss: 9.403144290205091
accuracy:      0.888
precision:     0.633
recall:        0.634
f1:            0.633
===== Start training: epoch 17, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.70399022102356, Epoch: 17, training loss: 3.387490284163505, current learning rate 1e-05
val loss: 12.026533484458923
accuracy:      0.891
precision:     0.730
recall:        0.586
f1:            0.613
val loss: 9.488055985420942
accuracy:      0.912
precision:     0.692
recall:        0.622
f1:            0.646
===== Start training: epoch 18, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.694307804107666, Epoch: 18, training loss: 6.0368375547695905, current learning rate 1e-05
val loss: 10.991535663604736
accuracy:      0.896
precision:     0.760
recall:        0.606
f1:            0.639
val loss: 10.201605916023254
accuracy:      0.906
precision:     0.668
recall:        0.614
f1:            0.634
===== Start training: epoch 19, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.6893861293792725, Epoch: 19, training loss: 2.6920968850608915, current learning rate 1e-05
val loss: 12.012793064117432
accuracy:      0.892
precision:     0.736
recall:        0.607
f1:            0.638
val loss: 10.054275046102703
accuracy:      0.904
precision:     0.667
recall:        0.633
f1:            0.647
===== Start training: epoch 20, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.70074200630188, Epoch: 20, training loss: 3.5189525093883276, current learning rate 1e-05
val loss: 12.713888198137283
accuracy:      0.898
precision:     0.791
recall:        0.597
f1:            0.631
val loss: 9.238514361903071
accuracy:      0.918
precision:     0.723
recall:        0.610
f1:            0.642
===== Start training: epoch 21, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.6812052726745605, Epoch: 21, training loss: 2.1873852912103757, current learning rate 1e-05
val loss: 12.370889544487
accuracy:      0.897
precision:     0.785
recall:        0.586
f1:            0.617
val loss: 9.611821383237839
accuracy:      0.916
precision:     0.712
recall:        0.609
f1:            0.639
===== Start training: epoch 22, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.6942925453186035, Epoch: 22, training loss: 1.7599547858117148, current learning rate 1e-05
val loss: 12.629266172647476
accuracy:      0.893
precision:     0.752
recall:        0.584
f1:            0.612
val loss: 9.658394530415535
accuracy:      0.911
precision:     0.679
recall:        0.597
f1:            0.621
===== Start training: epoch 23, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.697084665298462, Epoch: 23, training loss: 2.3760770589578897, current learning rate 1e-05
val loss: 14.77583372592926
accuracy:      0.876
precision:     0.679
recall:        0.642
f1:            0.657
val loss: 13.564703702926636
accuracy:      0.884
precision:     0.633
recall:        0.652
f1:            0.642
===== Start training: epoch 24, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.684121370315552, Epoch: 24, training loss: 1.0189657504670322, current learning rate 1e-05
val loss: 13.253511041402817
accuracy:      0.896
precision:     0.770
recall:        0.593
f1:            0.624
val loss: 10.972075372934341
accuracy:      0.910
precision:     0.679
recall:        0.606
f1:            0.630
===== Start training: epoch 25, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.703039884567261, Epoch: 25, training loss: 0.39214353740680963, current learning rate 1e-05
val loss: 14.457119762897491
accuracy:      0.883
precision:     0.697
recall:        0.639
f1:            0.660
val loss: 13.094673991203308
accuracy:      0.895
precision:     0.656
recall:        0.663
f1:            0.659
===== Start training: epoch 26, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.708649158477783, Epoch: 26, training loss: 2.570742965035606, current learning rate 1e-05
val loss: 14.036019504070282
accuracy:      0.893
precision:     0.745
recall:        0.598
f1:            0.628
val loss: 11.715161204338074
accuracy:      0.908
precision:     0.678
recall:        0.625
f1:            0.645
===== Start training: epoch 27, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.710158348083496, Epoch: 27, training loss: 2.3150955359451473, current learning rate 1e-05
val loss: 13.637508660554886
accuracy:      0.887
precision:     0.713
recall:        0.655
f1:            0.677
val loss: 13.143441826105118
accuracy:      0.892
precision:     0.647
recall:        0.651
f1:            0.649
===== Start training: epoch 28, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.680978536605835, Epoch: 28, training loss: 1.2498081609955989, current learning rate 1e-05
val loss: 15.866443663835526
accuracy:      0.877
precision:     0.683
recall:        0.643
f1:            0.659
val loss: 13.666047394275665
accuracy:      0.890
precision:     0.651
recall:        0.670
f1:            0.660
===== Start training: epoch 29, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.692036151885986, Epoch: 29, training loss: 3.0258051897981204, current learning rate 1e-05
val loss: 14.390502452850342
accuracy:      0.890
precision:     0.722
recall:        0.606
f1:            0.635
val loss: 10.97569364309311
accuracy:      0.908
precision:     0.677
recall:        0.620
f1:            0.641
===== Start training: epoch 30, seed = 0 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.6962244510650635, Epoch: 30, training loss: 0.7562798744766042, current learning rate 1e-05
val loss: 15.124243408441544
accuracy:      0.895
precision:     0.760
recall:        0.595
f1:            0.627
val loss: 12.113340586423874
accuracy:      0.913
precision:     0.693
recall:        0.612
f1:            0.639
best result:
0.8381818181818181
0.6180195495112621
0.726832137139372
0.6423117112772285


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.716917514801025, Epoch: 1, training loss: 59.961461901664734, current learning rate 1e-05
val loss: 12.841822564601898
accuracy:      0.186
precision:     0.532
recall:        0.524
f1:            0.185
val loss: 12.878092467784882
accuracy:      0.155
precision:     0.532
recall:        0.530
f1:            0.155
===== Start training: epoch 2, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.74356746673584, Epoch: 2, training loss: 58.41093587875366, current learning rate 1e-05
val loss: 11.169476687908173
accuracy:      0.768
precision:     0.611
recall:        0.715
f1:            0.622
val loss: 11.052759826183319
accuracy:      0.775
precision:     0.589
recall:        0.722
f1:            0.595
===== Start training: epoch 3, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.20it/s]


Timing: 7.744747877120972, Epoch: 3, training loss: 49.38457828760147, current learning rate 1e-05
val loss: 10.500047504901886
accuracy:      0.731
precision:     0.606
recall:        0.728
f1:            0.604
val loss: 10.516607284545898
accuracy:      0.721
precision:     0.584
recall:        0.743
f1:            0.569
===== Start training: epoch 4, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.735445499420166, Epoch: 4, training loss: 38.060938239097595, current learning rate 1e-05
val loss: 6.308591455221176
accuracy:      0.874
precision:     0.691
recall:        0.701
f1:            0.696
val loss: 6.255006790161133
accuracy:      0.875
precision:     0.650
recall:        0.727
f1:            0.676
===== Start training: epoch 5, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.72409725189209, Epoch: 5, training loss: 28.75890803337097, current learning rate 1e-05
val loss: 6.310093104839325
accuracy:      0.872
precision:     0.690
recall:        0.710
f1:            0.699
val loss: 6.0371406972408295
accuracy:      0.870
precision:     0.651
recall:        0.749
f1:            0.681
===== Start training: epoch 6, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.722593784332275, Epoch: 6, training loss: 20.91024887561798, current learning rate 1e-05
val loss: 6.176548033952713
accuracy:      0.881
precision:     0.688
recall:        0.624
f1:            0.646
val loss: 4.922819934785366
accuracy:      0.901
precision:     0.674
recall:        0.676
f1:            0.675
===== Start training: epoch 7, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.756627321243286, Epoch: 7, training loss: 16.877649396657944, current learning rate 1e-05
val loss: 6.440134793519974
accuracy:      0.884
precision:     0.704
recall:        0.647
f1:            0.668
val loss: 5.581154108047485
accuracy:      0.896
precision:     0.660
recall:        0.664
f1:            0.662
===== Start training: epoch 8, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.725560426712036, Epoch: 8, training loss: 14.852174557745457, current learning rate 1e-05
val loss: 6.863581031560898
accuracy:      0.891
precision:     0.728
recall:        0.671
f1:            0.693
val loss: 5.936223708093166
accuracy:      0.900
precision:     0.678
recall:        0.696
f1:            0.686
===== Start training: epoch 9, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.693654537200928, Epoch: 9, training loss: 11.219647321850061, current learning rate 1e-05
val loss: 8.531626984477043
accuracy:      0.884
precision:     0.704
recall:        0.663
f1:            0.680
val loss: 7.817716628313065
accuracy:      0.893
precision:     0.656
recall:        0.672
f1:            0.663
===== Start training: epoch 10, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.688486337661743, Epoch: 10, training loss: 11.036381177604198, current learning rate 1e-05
val loss: 9.33897790312767
accuracy:      0.882
precision:     0.704
recall:        0.692
f1:            0.698
val loss: 9.042755782604218
accuracy:      0.891
precision:     0.673
recall:        0.736
f1:            0.697
===== Start training: epoch 11, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.700124740600586, Epoch: 11, training loss: 15.120513156056404, current learning rate 1e-05
val loss: 7.548760533332825
accuracy:      0.888
precision:     0.718
recall:        0.679
f1:            0.695
val loss: 7.459205806255341
accuracy:      0.904
precision:     0.696
recall:        0.738
f1:            0.714
===== Start training: epoch 12, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.702867031097412, Epoch: 12, training loss: 7.282527816481888, current learning rate 1e-05
val loss: 8.91434109210968
accuracy:      0.891
precision:     0.728
recall:        0.644
f1:            0.672
val loss: 7.945309271570295
accuracy:      0.893
precision:     0.652
recall:        0.662
f1:            0.657
===== Start training: epoch 13, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.701987981796265, Epoch: 13, training loss: 7.8712184354662895, current learning rate 1e-05
val loss: 9.16835467517376
accuracy:      0.895
precision:     0.749
recall:        0.619
f1:            0.652
val loss: 7.411096359603107
accuracy:      0.911
precision:     0.692
recall:        0.637
f1:            0.658
===== Start training: epoch 14, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.696494817733765, Epoch: 14, training loss: 5.341078754980117, current learning rate 1e-05
val loss: 11.006024971604347
accuracy:      0.891
precision:     0.730
recall:        0.617
f1:            0.647
val loss: 9.068214426515624
accuracy:      0.908
precision:     0.682
recall:        0.635
f1:            0.654
===== Start training: epoch 15, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.697000026702881, Epoch: 15, training loss: 4.076659634942189, current learning rate 1e-05
val loss: 11.681326180696487
accuracy:      0.892
precision:     0.735
recall:        0.614
f1:            0.645
val loss: 9.200408294796944
accuracy:      0.914
precision:     0.703
recall:        0.638
f1:            0.662
===== Start training: epoch 16, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.71556282043457, Epoch: 16, training loss: 4.653936922550201, current learning rate 1e-05
val loss: 11.884214170277119
accuracy:      0.897
precision:     0.767
recall:        0.606
f1:            0.641
val loss: 8.180310621857643
accuracy:      0.923
precision:     0.753
recall:        0.628
f1:            0.665
===== Start training: epoch 17, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.695068120956421, Epoch: 17, training loss: 4.300648250267841, current learning rate 1e-05
val loss: 13.063272595405579
accuracy:      0.891
precision:     0.735
recall:        0.590
f1:            0.618
val loss: 9.470454253256321
accuracy:      0.922
precision:     0.752
recall:        0.602
f1:            0.638
===== Start training: epoch 18, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.68514609336853, Epoch: 18, training loss: 3.3393042515963316, current learning rate 1e-05
val loss: 12.877902418375015
accuracy:      0.886
precision:     0.705
recall:        0.614
f1:            0.641
val loss: 11.357068300247192
accuracy:      0.905
precision:     0.669
recall:        0.624
f1:            0.641
===== Start training: epoch 19, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.686536073684692, Epoch: 19, training loss: 4.714541700202972, current learning rate 1e-05
val loss: 13.679459273815155
accuracy:      0.879
precision:     0.688
recall:        0.644
f1:            0.661
val loss: 11.017396837472916
accuracy:      0.898
precision:     0.675
recall:        0.700
f1:            0.686
===== Start training: epoch 20, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.699193954467773, Epoch: 20, training loss: 3.2921473383903503, current learning rate 1e-05
val loss: 13.890815913677216
accuracy:      0.875
precision:     0.679
recall:        0.648
f1:            0.661
val loss: 12.970828473567963
accuracy:      0.891
precision:     0.666
recall:        0.711
f1:            0.684
===== Start training: epoch 21, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.711379051208496, Epoch: 21, training loss: 4.117350492626429, current learning rate 1e-05
val loss: 8.685220777988434
accuracy:      0.895
precision:     0.758
recall:        0.599
f1:            0.631
val loss: 7.927597358822823
accuracy:      0.915
precision:     0.708
recall:        0.643
f1:            0.668
===== Start training: epoch 22, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.722419261932373, Epoch: 22, training loss: 4.300974436104298, current learning rate 1e-05
val loss: 11.979470431804657
accuracy:      0.876
precision:     0.684
recall:        0.662
f1:            0.672
val loss: 10.43505810201168
accuracy:      0.894
precision:     0.673
recall:        0.722
f1:            0.693
===== Start training: epoch 23, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.695085048675537, Epoch: 23, training loss: 1.8241410015616566, current learning rate 1e-05
val loss: 12.225078910589218
accuracy:      0.896
precision:     0.753
recall:        0.623
f1:            0.657
val loss: 9.31223051249981
accuracy:      0.917
precision:     0.722
recall:        0.670
f1:            0.692
===== Start training: epoch 24, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.684108018875122, Epoch: 24, training loss: 1.2807340420549735, current learning rate 1e-05
val loss: 11.991771847009659
accuracy:      0.896
precision:     0.755
recall:        0.616
f1:            0.650
val loss: 9.04475075751543
accuracy:      0.917
precision:     0.721
recall:        0.650
f1:            0.676
===== Start training: epoch 25, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.689240217208862, Epoch: 25, training loss: 0.6732206288725138, current learning rate 1e-05
val loss: 16.19272205233574
accuracy:      0.872
precision:     0.685
recall:        0.690
f1:            0.687
val loss: 16.41990077495575
accuracy:      0.883
precision:     0.658
recall:        0.726
f1:            0.683
===== Start training: epoch 26, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.6945178508758545, Epoch: 26, training loss: 2.593610944226384, current learning rate 1e-05
val loss: 12.816258579492569
accuracy:      0.898
precision:     0.762
recall:        0.627
f1:            0.663
val loss: 10.437245647131931
accuracy:      0.910
precision:     0.694
recall:        0.656
f1:            0.672
===== Start training: epoch 27, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.706537961959839, Epoch: 27, training loss: 0.5965137557941489, current learning rate 1e-05
val loss: 13.114595264196396
accuracy:      0.894
precision:     0.753
recall:        0.595
f1:            0.626
val loss: 10.876455157995224
accuracy:      0.916
precision:     0.714
recall:        0.624
f1:            0.654
===== Start training: epoch 28, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.720470190048218, Epoch: 28, training loss: 2.0011768016265705, current learning rate 1e-05
val loss: 14.203746110200882
accuracy:      0.893
precision:     0.750
recall:        0.588
f1:            0.616
val loss: 10.550014987587929
accuracy:      0.918
precision:     0.723
recall:        0.610
f1:            0.642
===== Start training: epoch 29, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.692503929138184, Epoch: 29, training loss: 4.8592114094644785, current learning rate 1e-05
val loss: 9.828946113586426
accuracy:      0.892
precision:     0.731
recall:        0.668
f1:            0.692
val loss: 10.16977545619011
accuracy:      0.900
precision:     0.678
recall:        0.696
f1:            0.686
===== Start training: epoch 30, seed = 1 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.689003229141235, Epoch: 30, training loss: 0.9079455728642642, current learning rate 1e-05
val loss: 11.602097749710083
accuracy:      0.895
precision:     0.743
recall:        0.646
f1:            0.677
val loss: 9.24197646486573
accuracy:      0.914
precision:     0.707
recall:        0.658
f1:            0.678
best result:
0.87
0.650563320961331
0.7491750073514196
0.6806681405060182


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.730283737182617, Epoch: 1, training loss: 59.71612370014191, current learning rate 1e-05
val loss: 13.370318949222565
accuracy:      0.132
precision:     0.558
recall:        0.510
f1:            0.124
val loss: 13.427289366722107
accuracy:      0.105
precision:     0.523
recall:        0.507
f1:            0.101
===== Start training: epoch 2, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.7326836585998535, Epoch: 2, training loss: 57.930721282958984, current learning rate 1e-05
val loss: 10.547777116298676
accuracy:      0.755
precision:     0.565
recall:        0.614
f1:            0.569
val loss: 10.215123534202576
accuracy:      0.785
precision:     0.562
recall:        0.633
f1:            0.566
===== Start training: epoch 3, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.732553243637085, Epoch: 3, training loss: 50.318631529808044, current learning rate 1e-05
val loss: 11.75025600194931
accuracy:      0.616
precision:     0.566
recall:        0.660
f1:            0.517
val loss: 10.983502089977264
accuracy:      0.668
precision:     0.575
recall:        0.734
f1:            0.536
===== Start training: epoch 4, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.19it/s]


Timing: 7.7544286251068115, Epoch: 4, training loss: 40.944580078125, current learning rate 1e-05
val loss: 9.220919400453568
accuracy:      0.749
precision:     0.586
recall:        0.667
f1:            0.591
val loss: 8.651955306529999
accuracy:      0.775
precision:     0.590
recall:        0.723
f1:            0.596
===== Start training: epoch 5, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.7186973094940186, Epoch: 5, training loss: 33.73008757829666, current learning rate 1e-05
val loss: 8.734386533498764
accuracy:      0.790
precision:     0.591
recall:        0.643
f1:            0.602
val loss: 8.019364207983017
accuracy:      0.814
precision:     0.601
recall:        0.713
f1:            0.618
===== Start training: epoch 6, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.72138786315918, Epoch: 6, training loss: 26.364069283008575, current learning rate 1e-05
val loss: 8.3420649766922
accuracy:      0.812
precision:     0.598
recall:        0.636
f1:            0.610
val loss: 7.766218513250351
accuracy:      0.833
precision:     0.603
recall:        0.694
f1:            0.624
===== Start training: epoch 7, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.72388482093811, Epoch: 7, training loss: 25.736235737800598, current learning rate 1e-05
val loss: 10.373031795024872
accuracy:      0.779
precision:     0.598
recall:        0.671
f1:            0.610
val loss: 10.146125555038452
accuracy:      0.786
precision:     0.590
recall:        0.714
f1:            0.600
===== Start training: epoch 8, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.723655700683594, Epoch: 8, training loss: 18.418657064437866, current learning rate 1e-05
val loss: 7.8890784084796906
accuracy:      0.855
precision:     0.644
recall:        0.650
f1:            0.647
val loss: 7.94120454788208
accuracy:      0.845
precision:     0.603
recall:        0.671
f1:            0.622
===== Start training: epoch 9, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.718005657196045, Epoch: 9, training loss: 16.53694073855877, current learning rate 1e-05
val loss: 7.274102509021759
accuracy:      0.872
precision:     0.670
recall:        0.640
f1:            0.652
val loss: 6.6642268896102905
accuracy:      0.880
precision:     0.637
recall:        0.675
f1:            0.653
===== Start training: epoch 10, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.717062711715698, Epoch: 10, training loss: 13.735440015792847, current learning rate 1e-05
val loss: 8.147696658968925
accuracy:      0.880
precision:     0.675
recall:        0.590
f1:            0.612
val loss: 6.726620118133724
accuracy:      0.896
precision:     0.639
recall:        0.614
f1:            0.624
===== Start training: epoch 11, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.73175048828125, Epoch: 11, training loss: 19.290636718273163, current learning rate 1e-05
val loss: 9.091435432434082
accuracy:      0.871
precision:     0.655
recall:        0.606
f1:            0.623
val loss: 8.601748526096344
accuracy:      0.879
precision:     0.614
recall:        0.624
f1:            0.619
===== Start training: epoch 12, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.689990043640137, Epoch: 12, training loss: 10.97086375951767, current learning rate 1e-05
val loss: 8.326085895299911
accuracy:      0.865
precision:     0.662
recall:        0.656
f1:            0.659
val loss: 7.976104959845543
accuracy:      0.850
precision:     0.599
recall:        0.653
f1:            0.616
===== Start training: epoch 13, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.706490516662598, Epoch: 13, training loss: 11.371284311637282, current learning rate 1e-05
val loss: 8.5391815751791
accuracy:      0.892
precision:     0.751
recall:        0.574
f1:            0.598
val loss: 7.241278976202011
accuracy:      0.904
precision:     0.638
recall:        0.578
f1:            0.596
===== Start training: epoch 14, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.739982843399048, Epoch: 14, training loss: 7.754710346460342, current learning rate 1e-05
val loss: 8.648883730173111
accuracy:      0.890
precision:     0.726
recall:        0.576
f1:            0.600
val loss: 7.592859506607056
accuracy:      0.905
precision:     0.645
recall:        0.583
f1:            0.602
===== Start training: epoch 15, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.698061466217041, Epoch: 15, training loss: 9.067884683609009, current learning rate 1e-05
val loss: 11.044493824243546
accuracy:      0.878
precision:     0.681
recall:        0.626
f1:            0.646
val loss: 11.734688222408295
accuracy:      0.882
precision:     0.623
recall:        0.636
f1:            0.629
===== Start training: epoch 16, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.699013948440552, Epoch: 16, training loss: 7.369189955294132, current learning rate 1e-05
val loss: 9.36564975976944
accuracy:      0.887
precision:     0.708
recall:        0.598
f1:            0.624
val loss: 8.066081255674362
accuracy:      0.897
precision:     0.646
recall:        0.624
f1:            0.634
===== Start training: epoch 17, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.697543382644653, Epoch: 17, training loss: 6.694837514311075, current learning rate 1e-05
val loss: 12.090335965156555
accuracy:      0.882
precision:     0.685
recall:        0.601
f1:            0.625
val loss: 10.973941504955292
accuracy:      0.898
precision:     0.653
recall:        0.635
f1:            0.643
===== Start training: epoch 18, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.687117338180542, Epoch: 18, training loss: 7.435824032872915, current learning rate 1e-05
val loss: 11.518149733543396
accuracy:      0.885
precision:     0.699
recall:        0.593
f1:            0.618
val loss: 9.71060162782669
accuracy:      0.910
precision:     0.685
recall:        0.626
f1:            0.648
===== Start training: epoch 19, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.69270133972168, Epoch: 19, training loss: 5.689043193589896, current learning rate 1e-05
val loss: 11.725553154945374
accuracy:      0.891
precision:     0.733
recall:        0.576
f1:            0.601
val loss: 9.18349552154541
accuracy:      0.915
precision:     0.705
recall:        0.599
f1:            0.628
===== Start training: epoch 20, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.702974319458008, Epoch: 20, training loss: 3.90428898204118, current learning rate 1e-05
val loss: 11.371925681829453
accuracy:      0.889
precision:     0.718
recall:        0.585
f1:            0.611
val loss: 9.66515551507473
accuracy:      0.908
precision:     0.673
recall:        0.610
f1:            0.632
===== Start training: epoch 21, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.692950248718262, Epoch: 21, training loss: 3.612947096582502, current learning rate 1e-05
val loss: 12.09056881070137
accuracy:      0.874
precision:     0.669
recall:        0.624
f1:            0.641
val loss: 11.472577795386314
accuracy:      0.896
precision:     0.657
recall:        0.654
f1:            0.655
===== Start training: epoch 22, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.703023433685303, Epoch: 22, training loss: 3.671443661674857, current learning rate 1e-05
val loss: 11.945662274956703
accuracy:      0.885
precision:     0.705
recall:        0.640
f1:            0.664
val loss: 11.891250118613243
accuracy:      0.893
precision:     0.656
recall:        0.672
f1:            0.663
===== Start training: epoch 23, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.23it/s]


Timing: 7.712263822555542, Epoch: 23, training loss: 5.041971133323386, current learning rate 1e-05
val loss: 12.067609563469887
accuracy:      0.890
precision:     0.724
recall:        0.586
f1:            0.612
val loss: 9.183409377932549
accuracy:      0.913
precision:     0.690
recall:        0.603
f1:            0.629
===== Start training: epoch 24, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.22it/s]


Timing: 7.721780300140381, Epoch: 24, training loss: 5.216594937257469, current learning rate 1e-05
val loss: 11.627293318510056
accuracy:      0.877
precision:     0.685
recall:        0.649
f1:            0.664
val loss: 10.047703370451927
accuracy:      0.895
precision:     0.662
recall:        0.673
f1:            0.667
===== Start training: epoch 25, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.698939323425293, Epoch: 25, training loss: 1.6746757961809635, current learning rate 1e-05
val loss: 15.463958203792572
accuracy:      0.857
precision:     0.652
recall:        0.661
f1:            0.657
val loss: 16.02627682685852
accuracy:      0.867
precision:     0.629
recall:        0.693
f1:            0.651
===== Start training: epoch 26, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.21it/s]


Timing: 7.730634450912476, Epoch: 26, training loss: 3.5152393023017794, current learning rate 1e-05
val loss: 12.726582288742065
accuracy:      0.884
precision:     0.701
recall:        0.630
f1:            0.654
val loss: 11.619423672556877
accuracy:      0.899
precision:     0.661
recall:        0.650
f1:            0.655
===== Start training: epoch 27, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.685948848724365, Epoch: 27, training loss: 5.42972413636744, current learning rate 1e-05
val loss: 12.106289148330688
accuracy:      0.890
precision:     0.722
recall:        0.596
f1:            0.624
val loss: 9.476920480607077
accuracy:      0.911
precision:     0.688
recall:        0.622
f1:            0.645
===== Start training: epoch 28, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.26it/s]


Timing: 7.671627759933472, Epoch: 28, training loss: 1.989691661670804, current learning rate 1e-05
val loss: 12.212948709726334
accuracy:      0.884
precision:     0.694
recall:        0.593
f1:            0.617
val loss: 9.042428597807884
accuracy:      0.912
precision:     0.692
recall:        0.622
f1:            0.646
===== Start training: epoch 29, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.25it/s]


Timing: 7.6886069774627686, Epoch: 29, training loss: 1.6201764446450397, current learning rate 1e-05
val loss: 12.023453325033188
accuracy:      0.895
precision:     0.762
recall:        0.592
f1:            0.623
val loss: 9.345368231413886
accuracy:      0.915
precision:     0.708
recall:        0.619
f1:            0.648
===== Start training: epoch 30, seed = 2 =====


Iteration: 100%|██████████| 48/48 [00:07<00:00,  6.24it/s]


Timing: 7.69218897819519, Epoch: 30, training loss: 3.513718207599595, current learning rate 1e-05
val loss: 12.029789954423904
accuracy:      0.891
precision:     0.733
recall:        0.600
f1:            0.630
val loss: 10.999764189124107
accuracy:      0.907
precision:     0.675
recall:        0.625
f1:            0.644
best result:
0.8954545454545455
0.6618294024196505
0.6730687548328778
0.6671744558368129


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



avg result:
0.8678787878787878
0.6434707576307478
0.7163586331078898
0.6633847692066865


In [12]:
from google.colab import runtime
runtime.unassign()