<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [None]:
!pip install datasets



In [None]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [None]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [None]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [None]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }

    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }"""

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'neither':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'no_relation':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], [], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [None]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      embed_sent1_std = embed_sent1[:batch_size // 2, :, :]
      labels_std = torch.tensor(labels[:batch_size // 2]).to(device)

      emb_attack = embed_sent1_std[(labels_std == torch.tensor([0,1]).to(device)).all(dim=1)]
      emb_support = embed_sent1_std[(labels_std == torch.tensor([1,0]).to(device)).all(dim=1)]

      samples_adv = H_sent[batch_size // 2:, :]
      embed_sent1_adv = embed_sent1[batch_size // 2:, :, :]
      labels_adv = torch.tensor(labels[batch_size // 2:]).to(device)

      emb_contr = embed_sent1_adv[(labels_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)]
      emb_other = embed_sent1_adv[(labels_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (labels_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      mean_grl_attack = GRLayer.apply(torch.mean(torch.cat([emb_attack, emb_contr], dim=0), dim=1), .01)
      mean_grl_support = GRLayer.apply(torch.mean(torch.cat([emb_support, emb_other], dim=0), dim=1), .01)

      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl_attack)
      support_prediction = self.support_linear(mean_grl_support)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class BaselineModelWithSentenceComparisonAndCue(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparisonAndCue, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.attention = attention
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*3, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    """self.linear_initial_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.linear_end_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)"""
    self.sigmoid = torch.nn.Sigmoid()

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    """self._init_weights(self.linear_initial_sent)
    self._init_weights(self.linear_end_sent)"""

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)[0]
    else:
      att_output = embed_sent1

    initial_sent1 = att_output[:,0,:]
    initial_sent2 = att_output[torch.arange(att_output.shape[0]), torch.argmax(segs_sent1 * torch.arange(att_output.shape[1], 0, -1).to(device), dim=-1)]
    end_sent2 = att_output[torch.arange(att_output.shape[0]), torch.sum(att_mask_sent1, dim=-1)-1]

    initial_sent1 = self.linear_initial_sent(initial_sent1)
    end_sent2 = self.linear_end_sent(end_sent2)
    gate = self.sigmoid(initial_sent1 + end_sent2)

    final_emb = initial_sent2 * gate

    predictions = self.linear_layer(final_emb)

    return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]
    self.attention = attention

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

      H_sent = torch.mean(att_output[0], dim=1)
    else:
      H_sent = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [None]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [None]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import LinearLR

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 2, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": True,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": 1, #[9.375, 30, 1], #10 [2.071, 1.933, 1]
    "lr": 1e-5
}

config = {
    "dataset": "debate", #"student_essay", #debate, m-arg
    "adversarial": False,
    "double_adversarial": True,
    "dataset_from_saved": True,
    "injection": False,
    "grid_search": True,
    "visualize": True,
    "train": True,
    "scheduler": False,
    "attention": True,
    "cue_gating": False
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
            """print(targets)
            print((targets == torch.tensor([0,1]).to(device)).all(dim=1))
            print(torch.sum((targets == torch.tensor([0,1]).to(device)).all(dim=1)).item())"""

            attack_len = torch.sum((targets == torch.tensor([0,1]).to(device)).all(dim=1)).item()
            support_len = torch.sum((targets == torch.tensor([1,0]).to(device)).all(dim=1)).item()
            contr_len = torch.sum((targets_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)).item()
            other_len = torch.sum((targets_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (targets_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)).item()

            """print(attack_len)
            print(contr_len)
            print(support_len)
            print(other_len)"""

            attack_target = [[0,1]] * attack_len + [[1,0]] * contr_len
            support_target = [[0,1]] * support_len + [[1,0]] * other_len
            attack_target, support_target = torch.tensor(np.array(attack_target)).to(device), torch.tensor(np.array(support_target)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + loss4 + loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
          if config["double_adversarial"]:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          else:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
          loss = loss_fn(out, labels.float())
          val_loss += loss.item()

          labels = labels.cpu().numpy().tolist()

          val_labels.extend(labels)
          if len(labels[0]) != 2:
            for pred in preds:
              if pred == 0:
                val_preds.append([1,0,0])
              elif pred == 1:
                val_preds.append([0,1,0])
              else:
                val_preds.append([0,0,1])
          else:
            val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

def run(seed):
  set_random_seeds(seed)

  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/train_essay.txt"
    path_dev = "./data/student_essay/dev_essay.txt"
    path_test = "./data/student_essay/test_essay.txt"
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/train_debate_concept.txt"
    path_dev = "./data/debate/dev_debate_concept.txt"
    path_test = "./data/debate/test_debate_concept.txt"
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKWithDiscourseInjectionProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")

  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  if config["adversarial"] or config["double_adversarial"]:
    df = datasets.load_dataset("discovery","discovery")
    adv_processor = DiscourseMarkerProcessor()
    if not config["dataset_from_saved"]:
      print("processing discourse marker dataset...")
      train_adv = adv_processor.process_dataset(df["train"])
      with open("./adv_dataset.pkl", "wb") as writer:
        pickle.dump(train_adv, writer)
    else:
      with open("./adv_dataset.pkl", "rb") as reader:
        train_adv = pickle.load(reader)

    data_train_tot = data_train + train_adv
    train_set_adv = dataset(train_adv)
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if config["double_adversarial"]:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = DoubleAdversarialNet()

  elif not config["adversarial"]:
    #sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    #train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
    if not config["cue_gating"]:
      model = BaselineModelWithSentenceComparison(attention=config["attention"])
    else:
      model = BaselineModelWithSentenceComparisonAndCue(attention=config["attention"])
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

  if config["dataset"] in ["m-arg","nk"]:
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0.4,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)

    for discovery_weight in range_disc:
      for adv_weight in [1]: #range_adv:
        for epoch in range(args["epochs"]):
          print('===== Start training: epoch {} ====='.format(epoch + 1))
          print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}")
          train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight)
          dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
          test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
          if dev_f1 > best_dev_f1:
            best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
            best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
            #save model

        print('best result:')
        print(best_test_acc)
        print(best_test_pre)
        print(best_test_rec)
        print(best_test_f1)
        result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
        del model
        del optimizer

        set_random_seeds(seed)
        model = DoubleAdversarialNet()
        model = model.to(device)

        optimizer_grouped_parameters = [
          {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
          },
          {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0
          },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

        best_acc = -1
        best_pre = -1
        best_rec = -1
        best_f1 = -1
        best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1
  else:
    if config["scheduler"]:
      scheduler = LinearLR(optimizer, start_factor=1, end_factor=1e-1, total_iters = 30)
    for epoch in range(args["epochs"]):
      if config["train"]:
        print('===== Start training: epoch {} ====='.format(epoch + 1))
        train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
        dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
        test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
        if config["scheduler"]:
          scheduler.step()
        if dev_f1 > best_dev_f1:
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
          best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
          torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

    if config["visualize"] and config["adversarial"]:
      model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
      visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

        #save model

    print('best result:')
    print(best_test_acc)
    print(best_test_pre)
    print(best_test_rec)
    print(best_test_f1)
    result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

  print(result_metrics)
  return result_metrics[0]

if __name__ == "__main__":
  results = []
  for seed in args["seed"]:
    if seed == 0: continue
    print(f"**** trying with seed {seed} ****")
    result_metrics = run(seed)
    results.append(result_metrics)
  avg = torch.mean(torch.tensor(results), dim=0)
  print(avg)

**** trying with seed 1 ****


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
tokenizing...: 100%|██████████| 6486/6486 [00:02<00:00, 2904.36it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 2164/2164 [00:00<00:00, 2676.66it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 2162/2162 [00:00<00:00, 2839.80it/s]


finished preprocessing examples in test


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


  labels_adv = torch.tensor(labels[batch_size // 2:]).to(device)
Iteration:   4%|▎         | 202/5443 [00:35<15:19,  5.70it/s]


Timing: 35.4311318397522, Epoch: 1, training loss: 672.6889321804047, current learning rate 1e-05
val loss: 23.51508218050003
accuracy:      0.544
precision:     0.558
recall:        0.543
f1:            0.511
val loss: 23.55283659696579
accuracy:      0.532
precision:     0.568
recall:        0.545
f1:            0.496
===== Start training: epoch 2 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.57761740684509, Epoch: 2, training loss: 982.3146138191223, current learning rate 1e-05
val loss: 20.68849152326584
accuracy:      0.678
precision:     0.704
recall:        0.677
f1:            0.666
val loss: 20.99238896369934
accuracy:      0.670
precision:     0.703
recall:        0.678
f1:            0.661
===== Start training: epoch 3 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.59173560142517, Epoch: 3, training loss: 1100.918616771698, current learning rate 1e-05
val loss: 18.50782334804535
accuracy:      0.743
precision:     0.752
recall:        0.743
f1:            0.741
val loss: 18.974529027938843
accuracy:      0.718
precision:     0.732
recall:        0.723
f1:            0.716
===== Start training: epoch 4 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.64956521987915, Epoch: 4, training loss: 1121.9929451942444, current learning rate 1e-05
val loss: 18.352797627449036
accuracy:      0.754
precision:     0.755
recall:        0.754
f1:            0.754
val loss: 18.061638951301575
accuracy:      0.751
precision:     0.751
recall:        0.749
f1:            0.749
===== Start training: epoch 5 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.84it/s]


Timing: 34.61758542060852, Epoch: 5, training loss: 1095.024887561798, current learning rate 1e-05
val loss: 20.466246634721756
accuracy:      0.745
precision:     0.746
recall:        0.745
f1:            0.745
val loss: 19.703871190547943
accuracy:      0.747
precision:     0.748
recall:        0.746
f1:            0.746
===== Start training: epoch 6 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.72006678581238, Epoch: 6, training loss: 1011.033688545227, current learning rate 1e-05
val loss: 21.96000063419342
accuracy:      0.751
precision:     0.752
recall:        0.752
f1:            0.752
val loss: 22.160606563091278
accuracy:      0.747
precision:     0.747
recall:        0.746
f1:            0.746
===== Start training: epoch 7 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:01,  5.82it/s]


Timing: 34.739649295806885, Epoch: 7, training loss: 877.6682722568512, current learning rate 1e-05
val loss: 23.866338044404984
accuracy:      0.754
precision:     0.756
recall:        0.754
f1:            0.754
val loss: 23.922076731920242
accuracy:      0.752
precision:     0.755
recall:        0.754
f1:            0.752
===== Start training: epoch 8 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.61404895782471, Epoch: 8, training loss: 766.6637165546417, current learning rate 1e-05
val loss: 22.264346450567245
accuracy:      0.759
precision:     0.761
recall:        0.759
f1:            0.759
val loss: 22.349476873874664
accuracy:      0.766
precision:     0.768
recall:        0.768
f1:            0.766
===== Start training: epoch 9 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.569072008132935, Epoch: 9, training loss: 654.9602193832397, current learning rate 1e-05
val loss: 26.423242688179016
accuracy:      0.744
precision:     0.745
recall:        0.745
f1:            0.744
val loss: 26.577793806791306
accuracy:      0.744
precision:     0.745
recall:        0.742
f1:            0.743
===== Start training: epoch 10 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54888319969177, Epoch: 10, training loss: 564.7366554737091, current learning rate 1e-05
val loss: 31.007622331380844
accuracy:      0.755
precision:     0.759
recall:        0.755
f1:            0.754
val loss: 31.121280014514923
accuracy:      0.751
precision:     0.757
recall:        0.754
f1:            0.751
===== Start training: epoch 11 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.59877300262451, Epoch: 11, training loss: 495.88457584381104, current learning rate 1e-05
val loss: 36.70407724380493
accuracy:      0.748
precision:     0.750
recall:        0.748
f1:            0.747
val loss: 36.75815498828888
accuracy:      0.743
precision:     0.744
recall:        0.741
f1:            0.741
===== Start training: epoch 12 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.56117630004883, Epoch: 12, training loss: 428.68018186092377, current learning rate 1e-05
val loss: 36.872279822826385
accuracy:      0.750
precision:     0.751
recall:        0.751
f1:            0.751
val loss: 36.41526639461517
accuracy:      0.759
precision:     0.760
recall:        0.760
f1:            0.759
===== Start training: epoch 13 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.583630084991455, Epoch: 13, training loss: 385.3619467020035, current learning rate 1e-05
val loss: 34.48030507564545
accuracy:      0.752
precision:     0.754
recall:        0.752
f1:            0.752
val loss: 34.88252955675125
accuracy:      0.751
precision:     0.754
recall:        0.753
f1:            0.751
===== Start training: epoch 14 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.48661732673645, Epoch: 14, training loss: 355.3525849580765, current learning rate 1e-05
val loss: 44.13727152347565
accuracy:      0.745
precision:     0.746
recall:        0.745
f1:            0.745
val loss: 45.34925550222397
accuracy:      0.743
precision:     0.746
recall:        0.745
f1:            0.743
===== Start training: epoch 15 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.86it/s]


Timing: 34.46142888069153, Epoch: 15, training loss: 334.74223470687866, current learning rate 1e-05
val loss: 45.63799250125885
accuracy:      0.748
precision:     0.748
recall:        0.748
f1:            0.748
val loss: 48.04848635196686
accuracy:      0.742
precision:     0.744
recall:        0.743
f1:            0.742
===== Start training: epoch 16 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.47974371910095, Epoch: 16, training loss: 322.48500871658325, current learning rate 1e-05
val loss: 45.65902614593506
accuracy:      0.744
precision:     0.745
recall:        0.744
f1:            0.744
val loss: 46.375114262104034
accuracy:      0.750
precision:     0.752
recall:        0.751
f1:            0.750
===== Start training: epoch 17 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58361625671387, Epoch: 17, training loss: 340.20488250255585, current learning rate 1e-05
val loss: 49.28996330499649
accuracy:      0.746
precision:     0.746
recall:        0.747
f1:            0.746
val loss: 48.404049932956696
accuracy:      0.749
precision:     0.749
recall:        0.749
f1:            0.749
===== Start training: epoch 18 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.476654052734375, Epoch: 18, training loss: 364.84057080745697, current learning rate 1e-05
val loss: 49.93456047773361
accuracy:      0.752
precision:     0.752
recall:        0.752
f1:            0.752
val loss: 50.1800891160965
accuracy:      0.753
precision:     0.752
recall:        0.752
f1:            0.752
===== Start training: epoch 19 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.47538614273071, Epoch: 19, training loss: 374.5475867986679, current learning rate 1e-05
val loss: 45.9968495965004
accuracy:      0.755
precision:     0.755
recall:        0.755
f1:            0.755
val loss: 45.51885658502579
accuracy:      0.758
precision:     0.757
recall:        0.758
f1:            0.757
===== Start training: epoch 20 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.49738001823425, Epoch: 20, training loss: 375.7521468400955, current learning rate 1e-05
val loss: 43.010222136974335
accuracy:      0.748
precision:     0.753
recall:        0.748
f1:            0.746
val loss: 43.02838450670242
accuracy:      0.746
precision:     0.752
recall:        0.749
f1:            0.745
===== Start training: epoch 21 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.4958381652832, Epoch: 21, training loss: 371.3248063325882, current learning rate 1e-05
val loss: 42.63000249862671
accuracy:      0.762
precision:     0.764
recall:        0.763
f1:            0.762
val loss: 42.35555648803711
accuracy:      0.763
precision:     0.764
recall:        0.764
f1:            0.763
===== Start training: epoch 22 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.46528601646423, Epoch: 22, training loss: 369.55856013298035, current learning rate 1e-05
val loss: 50.03130483627319
accuracy:      0.762
precision:     0.765
recall:        0.762
f1:            0.761
val loss: 49.256685972213745
accuracy:      0.759
precision:     0.763
recall:        0.761
f1:            0.759
===== Start training: epoch 23 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.5636088848114, Epoch: 23, training loss: 365.93786013126373, current learning rate 1e-05
val loss: 51.584200382232666
accuracy:      0.759
precision:     0.760
recall:        0.759
f1:            0.759
val loss: 51.701335430145264
accuracy:      0.757
precision:     0.757
recall:        0.758
f1:            0.757
===== Start training: epoch 24 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54482102394104, Epoch: 24, training loss: 356.4238328933716, current learning rate 1e-05
val loss: 63.10585165023804
accuracy:      0.754
precision:     0.755
recall:        0.754
f1:            0.754
val loss: 61.15479397773743
accuracy:      0.757
precision:     0.759
recall:        0.758
f1:            0.757
===== Start training: epoch 25 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.86it/s]


Timing: 34.50326371192932, Epoch: 25, training loss: 331.5799437761307, current learning rate 1e-05
val loss: 64.25281369686127
accuracy:      0.757
precision:     0.757
recall:        0.758
f1:            0.758
val loss: 63.599738121032715
accuracy:      0.761
precision:     0.760
recall:        0.761
f1:            0.761
===== Start training: epoch 26 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.47772669792175, Epoch: 26, training loss: 313.77783596515656, current learning rate 1e-05
val loss: 53.95468312501907
accuracy:      0.760
precision:     0.763
recall:        0.760
f1:            0.759
val loss: 53.26941227912903
accuracy:      0.753
precision:     0.757
recall:        0.755
f1:            0.753
===== Start training: epoch 27 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.57344365119934, Epoch: 27, training loss: 318.087229013443, current learning rate 1e-05
val loss: 55.18224322795868
accuracy:      0.758
precision:     0.758
recall:        0.759
f1:            0.758
val loss: 54.569834649562836
accuracy:      0.759
precision:     0.759
recall:        0.759
f1:            0.758
===== Start training: epoch 28 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54012417793274, Epoch: 28, training loss: 304.65256452560425, current learning rate 1e-05
val loss: 62.304373264312744
accuracy:      0.760
precision:     0.760
recall:        0.761
f1:            0.760
val loss: 61.65096306800842
accuracy:      0.766
precision:     0.766
recall:        0.766
f1:            0.766
===== Start training: epoch 29 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52547740936279, Epoch: 29, training loss: 302.49224627017975, current learning rate 1e-05
val loss: 49.65616816282272
accuracy:      0.756
precision:     0.757
recall:        0.756
f1:            0.755
val loss: 50.4651620388031
accuracy:      0.749
precision:     0.752
recall:        0.751
f1:            0.749
===== Start training: epoch 30 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.496989011764526, Epoch: 30, training loss: 308.8113968372345, current learning rate 1e-05
val loss: 51.2683539390564
accuracy:      0.765
precision:     0.765
recall:        0.765
f1:            0.765
val loss: 50.76712042093277
accuracy:      0.761
precision:     0.760
recall:        0.760
f1:            0.760
best result:
0.7608695652173914
0.7604402483364203
0.7603766144138815
0.7604069605306847


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.526283979415894, Epoch: 1, training loss: 732.4561681747437, current learning rate 1e-05
val loss: 23.460149347782135
accuracy:      0.559
precision:     0.576
recall:        0.558
f1:            0.530
val loss: 23.503870248794556
accuracy:      0.543
precision:     0.577
recall:        0.555
f1:            0.513
===== Start training: epoch 2 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.62705755233765, Epoch: 2, training loss: 1085.897569179535, current learning rate 1e-05
val loss: 20.774281442165375
accuracy:      0.675
precision:     0.707
recall:        0.674
f1:            0.661
val loss: 21.00206595659256
accuracy:      0.669
precision:     0.709
recall:        0.677
f1:            0.659
===== Start training: epoch 3 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.82it/s]


Timing: 34.68555927276611, Epoch: 3, training loss: 1206.4622478485107, current learning rate 1e-05
val loss: 18.338292062282562
accuracy:      0.725
precision:     0.731
recall:        0.725
f1:            0.723
val loss: 18.561100363731384
accuracy:      0.713
precision:     0.722
recall:        0.717
f1:            0.712
===== Start training: epoch 4 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.82it/s]


Timing: 34.68954634666443, Epoch: 4, training loss: 1218.7142930030823, current learning rate 1e-05
val loss: 18.281585097312927
accuracy:      0.750
precision:     0.752
recall:        0.751
f1:            0.750
val loss: 17.830446183681488
accuracy:      0.748
precision:     0.749
recall:        0.750
f1:            0.748
===== Start training: epoch 5 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.61598300933838, Epoch: 5, training loss: 1142.3589525222778, current learning rate 1e-05
val loss: 19.208124190568924
accuracy:      0.750
precision:     0.751
recall:        0.751
f1:            0.750
val loss: 18.478179544210434
accuracy:      0.749
precision:     0.750
recall:        0.748
f1:            0.748
===== Start training: epoch 6 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.61184215545654, Epoch: 6, training loss: 1048.4127535820007, current learning rate 1e-05
val loss: 23.781272411346436
accuracy:      0.753
precision:     0.753
recall:        0.753
f1:            0.753
val loss: 22.153336495161057
accuracy:      0.753
precision:     0.753
recall:        0.753
f1:            0.753
===== Start training: epoch 7 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.84it/s]


Timing: 34.61789345741272, Epoch: 7, training loss: 911.7971301078796, current learning rate 1e-05
val loss: 25.335809379816055
accuracy:      0.754
precision:     0.756
recall:        0.754
f1:            0.753
val loss: 24.291603833436966
accuracy:      0.751
precision:     0.754
recall:        0.753
f1:            0.751
===== Start training: epoch 8 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.61613583564758, Epoch: 8, training loss: 788.6530814170837, current learning rate 1e-05
val loss: 25.309872657060623
accuracy:      0.749
precision:     0.752
recall:        0.749
f1:            0.748
val loss: 24.36504316329956
accuracy:      0.750
precision:     0.754
recall:        0.753
f1:            0.750
===== Start training: epoch 9 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.84it/s]


Timing: 34.61848592758179, Epoch: 9, training loss: 672.277321100235, current learning rate 1e-05
val loss: 30.570971071720123
accuracy:      0.748
precision:     0.749
recall:        0.749
f1:            0.748
val loss: 28.837051272392273
accuracy:      0.755
precision:     0.755
recall:        0.755
f1:            0.755
===== Start training: epoch 10 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.57712125778198, Epoch: 10, training loss: 587.9115326404572, current learning rate 1e-05
val loss: 35.56189000606537
accuracy:      0.748
precision:     0.759
recall:        0.747
f1:            0.745
val loss: 35.61653959751129
accuracy:      0.742
precision:     0.757
recall:        0.747
f1:            0.741
===== Start training: epoch 11 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58150291442871, Epoch: 11, training loss: 513.3814961910248, current learning rate 1e-05
val loss: 30.61510530114174
accuracy:      0.752
precision:     0.753
recall:        0.752
f1:            0.752
val loss: 29.706574648618698
accuracy:      0.763
precision:     0.763
recall:        0.764
f1:            0.763
===== Start training: epoch 12 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.640477657318115, Epoch: 12, training loss: 456.94525718688965, current learning rate 1e-05
val loss: 34.3598957657814
accuracy:      0.750
precision:     0.752
recall:        0.751
f1:            0.750
val loss: 34.86931562423706
accuracy:      0.746
precision:     0.747
recall:        0.747
f1:            0.746
===== Start training: epoch 13 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.60959076881409, Epoch: 13, training loss: 411.2253565788269, current learning rate 1e-05
val loss: 41.44793528318405
accuracy:      0.753
precision:     0.758
recall:        0.753
f1:            0.752
val loss: 40.58051782846451
accuracy:      0.754
precision:     0.763
recall:        0.758
f1:            0.754
===== Start training: epoch 14 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.59309363365173, Epoch: 14, training loss: 377.0542186498642, current learning rate 1e-05
val loss: 40.88938879966736
accuracy:      0.755
precision:     0.758
recall:        0.755
f1:            0.754
val loss: 39.765770852565765
accuracy:      0.758
precision:     0.762
recall:        0.760
f1:            0.757
===== Start training: epoch 15 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.53240752220154, Epoch: 15, training loss: 354.1879014968872, current learning rate 1e-05
val loss: 37.44071412086487
accuracy:      0.750
precision:     0.752
recall:        0.751
f1:            0.750
val loss: 37.550378143787384
accuracy:      0.746
precision:     0.749
recall:        0.748
f1:            0.746
===== Start training: epoch 16 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52478647232056, Epoch: 16, training loss: 341.34894037246704, current learning rate 1e-05
val loss: 37.905123591423035
accuracy:      0.754
precision:     0.754
recall:        0.754
f1:            0.754
val loss: 37.55784910917282
accuracy:      0.749
precision:     0.749
recall:        0.748
f1:            0.749
===== Start training: epoch 17 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.5142776966095, Epoch: 17, training loss: 332.8878153562546, current learning rate 1e-05
val loss: 42.80442136526108
accuracy:      0.746
precision:     0.747
recall:        0.746
f1:            0.746
val loss: 41.814244985580444
accuracy:      0.750
precision:     0.752
recall:        0.752
f1:            0.750
===== Start training: epoch 18 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.573376417160034, Epoch: 18, training loss: 326.821212887764, current learning rate 1e-05
val loss: 43.45241552591324
accuracy:      0.747
precision:     0.747
recall:        0.747
f1:            0.747
val loss: 43.64768436551094
accuracy:      0.738
precision:     0.737
recall:        0.737
f1:            0.737
===== Start training: epoch 19 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.55420470237732, Epoch: 19, training loss: 330.5501048564911, current learning rate 1e-05
val loss: 48.13708174228668
accuracy:      0.756
precision:     0.758
recall:        0.756
f1:            0.755
val loss: 49.72427171468735
accuracy:      0.747
precision:     0.751
recall:        0.749
f1:            0.747
===== Start training: epoch 20 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.4693865776062, Epoch: 20, training loss: 326.54101836681366, current learning rate 1e-05
val loss: 47.12648344039917
accuracy:      0.745
precision:     0.750
recall:        0.745
f1:            0.744
val loss: 47.26224994659424
accuracy:      0.737
precision:     0.746
recall:        0.741
f1:            0.736
===== Start training: epoch 21 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.474848985672, Epoch: 21, training loss: 317.9687613248825, current learning rate 1e-05
val loss: 46.982950925827026
accuracy:      0.762
precision:     0.762
recall:        0.762
f1:            0.762
val loss: 48.49547106027603
accuracy:      0.751
precision:     0.751
recall:        0.751
f1:            0.751
===== Start training: epoch 22 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52881121635437, Epoch: 22, training loss: 316.19618713855743, current learning rate 1e-05
val loss: 44.20498675107956
accuracy:      0.758
precision:     0.759
recall:        0.759
f1:            0.758
val loss: 46.79210251569748
accuracy:      0.758
precision:     0.759
recall:        0.759
f1:            0.758
===== Start training: epoch 23 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.48931169509888, Epoch: 23, training loss: 316.9275028705597, current learning rate 1e-05
val loss: 49.52585142850876
accuracy:      0.750
precision:     0.753
recall:        0.750
f1:            0.749
val loss: 49.37129884958267
accuracy:      0.750
precision:     0.754
recall:        0.752
f1:            0.750
===== Start training: epoch 24 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.51950407028198, Epoch: 24, training loss: 320.89205181598663, current learning rate 1e-05
val loss: 46.60332763195038
accuracy:      0.753
precision:     0.753
recall:        0.754
f1:            0.753
val loss: 47.22244119644165
accuracy:      0.747
precision:     0.746
recall:        0.746
f1:            0.746
===== Start training: epoch 25 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.57902455329895, Epoch: 25, training loss: 314.952379822731, current learning rate 1e-05
val loss: 47.814370691776276
accuracy:      0.759
precision:     0.759
recall:        0.760
f1:            0.759
val loss: 48.17445206642151
accuracy:      0.753
precision:     0.754
recall:        0.754
f1:            0.753
===== Start training: epoch 26 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58941030502319, Epoch: 26, training loss: 320.1766735315323, current learning rate 1e-05
val loss: 44.97420734167099
accuracy:      0.754
precision:     0.758
recall:        0.754
f1:            0.753
val loss: 45.22207826375961
accuracy:      0.755
precision:     0.761
recall:        0.758
f1:            0.755
===== Start training: epoch 27 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.6443612575531, Epoch: 27, training loss: 318.1576546430588, current learning rate 1e-05
val loss: 49.38844442367554
accuracy:      0.752
precision:     0.752
recall:        0.752
f1:            0.752
val loss: 49.04218250513077
accuracy:      0.753
precision:     0.753
recall:        0.753
f1:            0.753
===== Start training: epoch 28 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.56680393218994, Epoch: 28, training loss: 311.8295202255249, current learning rate 1e-05
val loss: 43.45180368423462
accuracy:      0.757
precision:     0.758
recall:        0.758
f1:            0.758
val loss: 42.295942813158035
accuracy:      0.764
precision:     0.763
recall:        0.763
f1:            0.763
===== Start training: epoch 29 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.578532218933105, Epoch: 29, training loss: 312.7532684803009, current learning rate 1e-05
val loss: 38.11527842283249
accuracy:      0.748
precision:     0.749
recall:        0.748
f1:            0.748
val loss: 36.9763218164444
accuracy:      0.761
precision:     0.763
recall:        0.763
f1:            0.761
===== Start training: epoch 30 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.84it/s]


Timing: 34.618794441223145, Epoch: 30, training loss: 313.5445946455002, current learning rate 1e-05
val loss: 38.244065165519714
accuracy:      0.768
precision:     0.768
recall:        0.768
f1:            0.767
val loss: 37.43093055486679
accuracy:      0.759
precision:     0.761
recall:        0.761
f1:            0.759
best result:
0.759481961147086
0.7612853679987988
0.761056886522725
0.7594768154340132


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.57866454124451, Epoch: 1, training loss: 772.203812122345, current learning rate 1e-05
val loss: 22.91184765100479
accuracy:      0.621
precision:     0.622
recall:        0.621
f1:            0.619
val loss: 22.896551489830017
accuracy:      0.631
precision:     0.636
recall:        0.634
f1:            0.631
===== Start training: epoch 2 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.640846490859985, Epoch: 2, training loss: 989.2513203620911, current learning rate 1e-05
val loss: 20.101846039295197
accuracy:      0.687
precision:     0.722
recall:        0.686
f1:            0.674
val loss: 20.2287078499794
accuracy:      0.678
precision:     0.722
recall:        0.687
f1:            0.668
===== Start training: epoch 3 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.64835500717163, Epoch: 3, training loss: 1076.518002986908, current learning rate 1e-05
val loss: 17.70043158531189
accuracy:      0.742
precision:     0.747
recall:        0.742
f1:            0.741
val loss: 17.90121364593506
accuracy:      0.730
precision:     0.737
recall:        0.733
f1:            0.730
===== Start training: epoch 4 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.71414923667908, Epoch: 4, training loss: 1106.012276172638, current learning rate 1e-05
val loss: 18.40463352203369
accuracy:      0.752
precision:     0.752
recall:        0.752
f1:            0.752
val loss: 17.478313267230988
accuracy:      0.753
precision:     0.753
recall:        0.752
f1:            0.752
===== Start training: epoch 5 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.82it/s]


Timing: 34.68737030029297, Epoch: 5, training loss: 1082.2291860580444, current learning rate 1e-05
val loss: 19.399372160434723
accuracy:      0.751
precision:     0.752
recall:        0.751
f1:            0.751
val loss: 18.451939672231674
accuracy:      0.753
precision:     0.753
recall:        0.752
f1:            0.752
===== Start training: epoch 6 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.65886688232422, Epoch: 6, training loss: 1003.9428586959839, current learning rate 1e-05
val loss: 22.056057423353195
accuracy:      0.759
precision:     0.759
recall:        0.759
f1:            0.759
val loss: 20.227288961410522
accuracy:      0.761
precision:     0.761
recall:        0.761
f1:            0.761
===== Start training: epoch 7 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.649025678634644, Epoch: 7, training loss: 894.9763128757477, current learning rate 1e-05
val loss: 24.07169735431671
accuracy:      0.757
precision:     0.760
recall:        0.757
f1:            0.756
val loss: 22.931541591882706
accuracy:      0.750
precision:     0.755
recall:        0.752
f1:            0.750
===== Start training: epoch 8 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.669694662094116, Epoch: 8, training loss: 800.3005976676941, current learning rate 1e-05
val loss: 26.072477102279663
accuracy:      0.748
precision:     0.755
recall:        0.748
f1:            0.746
val loss: 24.44672918319702
accuracy:      0.760
precision:     0.770
recall:        0.764
f1:            0.760
===== Start training: epoch 9 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.60068726539612, Epoch: 9, training loss: 700.1094121932983, current learning rate 1e-05
val loss: 23.85662803053856
accuracy:      0.756
precision:     0.756
recall:        0.756
f1:            0.756
val loss: 22.98112717270851
accuracy:      0.755
precision:     0.755
recall:        0.755
f1:            0.755
===== Start training: epoch 10 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.72255992889404, Epoch: 10, training loss: 620.8901209831238, current learning rate 1e-05
val loss: 32.450878620147705
accuracy:      0.755
precision:     0.765
recall:        0.755
f1:            0.753
val loss: 32.29084002971649
accuracy:      0.753
precision:     0.764
recall:        0.757
f1:            0.752
===== Start training: epoch 11 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.82it/s]


Timing: 34.69342064857483, Epoch: 11, training loss: 546.0693151950836, current learning rate 1e-05
val loss: 30.62426632642746
accuracy:      0.766
precision:     0.767
recall:        0.766
f1:            0.766
val loss: 29.14056959748268
accuracy:      0.771
precision:     0.771
recall:        0.771
f1:            0.771
===== Start training: epoch 12 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.71356773376465, Epoch: 12, training loss: 491.5405652523041, current learning rate 1e-05
val loss: 34.36124235391617
accuracy:      0.755
precision:     0.756
recall:        0.755
f1:            0.754
val loss: 34.525192856788635
accuracy:      0.757
precision:     0.759
recall:        0.758
f1:            0.757
===== Start training: epoch 13 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.670390605926514, Epoch: 13, training loss: 456.98223650455475, current learning rate 1e-05
val loss: 34.9645978808403
accuracy:      0.761
precision:     0.763
recall:        0.761
f1:            0.761
val loss: 32.92070156335831
accuracy:      0.774
precision:     0.776
recall:        0.776
f1:            0.774
===== Start training: epoch 14 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:01,  5.82it/s]


Timing: 34.73521852493286, Epoch: 14, training loss: 428.32722651958466, current learning rate 1e-05
val loss: 38.92536002397537
accuracy:      0.762
precision:     0.767
recall:        0.762
f1:            0.761
val loss: 37.12834721803665
accuracy:      0.772
precision:     0.777
recall:        0.775
f1:            0.772
===== Start training: epoch 15 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.67143654823303, Epoch: 15, training loss: 400.6337547302246, current learning rate 1e-05
val loss: 40.464176177978516
accuracy:      0.766
precision:     0.766
recall:        0.767
f1:            0.766
val loss: 38.54677152633667
accuracy:      0.775
precision:     0.775
recall:        0.775
f1:            0.775
===== Start training: epoch 16 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.65561366081238, Epoch: 16, training loss: 383.23870968818665, current learning rate 1e-05
val loss: 36.59586971998215
accuracy:      0.765
precision:     0.766
recall:        0.765
f1:            0.765
val loss: 35.184418857097626
accuracy:      0.762
precision:     0.763
recall:        0.761
f1:            0.761
===== Start training: epoch 17 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.59074139595032, Epoch: 17, training loss: 373.42096734046936, current learning rate 1e-05
val loss: 38.37720489501953
accuracy:      0.756
precision:     0.757
recall:        0.757
f1:            0.756
val loss: 36.63338214159012
accuracy:      0.757
precision:     0.757
recall:        0.758
f1:            0.757
===== Start training: epoch 18 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.597347259521484, Epoch: 18, training loss: 364.9994820356369, current learning rate 1e-05
val loss: 40.35134291648865
accuracy:      0.766
precision:     0.767
recall:        0.766
f1:            0.766
val loss: 40.636245906353
accuracy:      0.756
precision:     0.759
recall:        0.758
f1:            0.756
===== Start training: epoch 19 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.60066890716553, Epoch: 19, training loss: 356.5056085586548, current learning rate 1e-05
val loss: 39.25174903869629
accuracy:      0.763
precision:     0.764
recall:        0.764
f1:            0.763
val loss: 38.69210457801819
accuracy:      0.755
precision:     0.756
recall:        0.756
f1:            0.755
===== Start training: epoch 20 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.5972101688385, Epoch: 20, training loss: 345.39518773555756, current learning rate 1e-05
val loss: 43.37619423866272
accuracy:      0.743
precision:     0.754
recall:        0.742
f1:            0.739
val loss: 44.146648645401
accuracy:      0.737
precision:     0.754
recall:        0.742
f1:            0.735
===== Start training: epoch 21 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.84it/s]


Timing: 34.61787819862366, Epoch: 21, training loss: 335.91384506225586, current learning rate 1e-05
val loss: 38.445125848054886
accuracy:      0.763
precision:     0.763
recall:        0.764
f1:            0.764
val loss: 38.214216470718384
accuracy:      0.757
precision:     0.757
recall:        0.757
f1:            0.757
===== Start training: epoch 22 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.572211027145386, Epoch: 22, training loss: 321.77089607715607, current learning rate 1e-05
val loss: 43.24205666780472
accuracy:      0.756
precision:     0.758
recall:        0.757
f1:            0.756
val loss: 41.88534218072891
accuracy:      0.759
precision:     0.762
recall:        0.761
f1:            0.759
===== Start training: epoch 23 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.625913858413696, Epoch: 23, training loss: 319.68824446201324, current learning rate 1e-05
val loss: 42.200169026851654
accuracy:      0.760
precision:     0.760
recall:        0.760
f1:            0.760
val loss: 41.66834253072739
accuracy:      0.760
precision:     0.760
recall:        0.761
f1:            0.760
===== Start training: epoch 24 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.654130935668945, Epoch: 24, training loss: 319.0410907268524, current learning rate 1e-05
val loss: 44.88030084967613
accuracy:      0.764
precision:     0.765
recall:        0.765
f1:            0.764
val loss: 43.78637009859085
accuracy:      0.765
precision:     0.765
recall:        0.765
f1:            0.765
===== Start training: epoch 25 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.63444423675537, Epoch: 25, training loss: 313.3033918142319, current learning rate 1e-05
val loss: 45.88456970453262
accuracy:      0.767
precision:     0.767
recall:        0.767
f1:            0.767
val loss: 45.8519486784935
accuracy:      0.765
precision:     0.765
recall:        0.765
f1:            0.765
===== Start training: epoch 26 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.585463762283325, Epoch: 26, training loss: 316.18401491642, current learning rate 1e-05
val loss: 46.79458677768707
accuracy:      0.763
precision:     0.771
recall:        0.763
f1:            0.761
val loss: 46.16529858112335
accuracy:      0.757
precision:     0.765
recall:        0.761
f1:            0.757
===== Start training: epoch 27 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.580267667770386, Epoch: 27, training loss: 310.2706184387207, current learning rate 1e-05
val loss: 43.55328053236008
accuracy:      0.766
precision:     0.766
recall:        0.766
f1:            0.766
val loss: 45.109029829502106
accuracy:      0.750
precision:     0.750
recall:        0.749
f1:            0.749
===== Start training: epoch 28 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.587430238723755, Epoch: 28, training loss: 307.51822674274445, current learning rate 1e-05
val loss: 41.37980890274048
accuracy:      0.768
precision:     0.768
recall:        0.768
f1:            0.768
val loss: 40.59495687484741
accuracy:      0.765
precision:     0.764
recall:        0.763
f1:            0.764
===== Start training: epoch 29 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.50858998298645, Epoch: 29, training loss: 305.7791658639908, current learning rate 1e-05
val loss: 40.97696566581726
accuracy:      0.759
precision:     0.759
recall:        0.760
f1:            0.759
val loss: 40.310474932193756
accuracy:      0.753
precision:     0.752
recall:        0.753
f1:            0.752
===== Start training: epoch 30 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.62999749183655, Epoch: 30, training loss: 306.76674485206604, current learning rate 1e-05
val loss: 44.52905374765396
accuracy:      0.761
precision:     0.762
recall:        0.761
f1:            0.761
val loss: 43.034017622470856
accuracy:      0.765
precision:     0.766
recall:        0.766
f1:            0.765
best result:
0.7645698427382054
0.7644954858169948
0.7634131913635019
0.7637286818992517


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.84it/s]


Timing: 34.62360715866089, Epoch: 1, training loss: 811.1482427120209, current learning rate 1e-05
val loss: 22.85732215642929
accuracy:      0.620
precision:     0.620
recall:        0.620
f1:            0.620
val loss: 22.812467336654663
accuracy:      0.642
precision:     0.644
recall:        0.643
f1:            0.641
===== Start training: epoch 2 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.66127252578735, Epoch: 2, training loss: 951.4759593009949, current learning rate 1e-05
val loss: 19.867391526699066
accuracy:      0.686
precision:     0.716
recall:        0.685
f1:            0.674
val loss: 19.959649205207825
accuracy:      0.683
precision:     0.718
recall:        0.691
f1:            0.675
===== Start training: epoch 3 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.70248079299927, Epoch: 3, training loss: 1007.9672727584839, current learning rate 1e-05
val loss: 17.655198454856873
accuracy:      0.749
precision:     0.752
recall:        0.749
f1:            0.748
val loss: 17.84778791666031
accuracy:      0.737
precision:     0.741
recall:        0.739
f1:            0.737
===== Start training: epoch 4 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:01,  5.81it/s]


Timing: 34.74469709396362, Epoch: 4, training loss: 1032.1879215240479, current learning rate 1e-05
val loss: 17.403364598751068
accuracy:      0.756
precision:     0.756
recall:        0.756
f1:            0.756
val loss: 16.889259696006775
accuracy:      0.759
precision:     0.759
recall:        0.759
f1:            0.759
===== Start training: epoch 5 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:02,  5.81it/s]


Timing: 34.78382086753845, Epoch: 5, training loss: 1019.5258502960205, current learning rate 1e-05
val loss: 20.19579803943634
accuracy:      0.756
precision:     0.756
recall:        0.756
f1:            0.756
val loss: 19.17756888270378
accuracy:      0.759
precision:     0.759
recall:        0.758
f1:            0.759
===== Start training: epoch 6 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.82it/s]


Timing: 34.684914350509644, Epoch: 6, training loss: 972.3054351806641, current learning rate 1e-05
val loss: 21.917697995901108
accuracy:      0.762
precision:     0.762
recall:        0.762
f1:            0.762
val loss: 20.505915015935898
accuracy:      0.764
precision:     0.764
recall:        0.764
f1:            0.764
===== Start training: epoch 7 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.68102765083313, Epoch: 7, training loss: 888.1503219604492, current learning rate 1e-05
val loss: 21.714040845632553
accuracy:      0.753
precision:     0.753
recall:        0.753
f1:            0.753
val loss: 20.579463124275208
accuracy:      0.759
precision:     0.758
recall:        0.759
f1:            0.758
===== Start training: epoch 8 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.70303392410278, Epoch: 8, training loss: 812.1069021224976, current learning rate 1e-05
val loss: 24.112975597381592
accuracy:      0.756
precision:     0.765
recall:        0.756
f1:            0.754
val loss: 22.590749353170395
accuracy:      0.750
precision:     0.761
recall:        0.754
f1:            0.749
===== Start training: epoch 9 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.695316791534424, Epoch: 9, training loss: 731.164303779602, current learning rate 1e-05
val loss: 26.472876101732254
accuracy:      0.748
precision:     0.749
recall:        0.748
f1:            0.748
val loss: 24.678420692682266
accuracy:      0.772
precision:     0.772
recall:        0.770
f1:            0.770
===== Start training: epoch 10 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.631792068481445, Epoch: 10, training loss: 659.6049907207489, current learning rate 1e-05
val loss: 34.16384148597717
accuracy:      0.748
precision:     0.760
recall:        0.748
f1:            0.745
val loss: 33.31488049030304
accuracy:      0.740
precision:     0.753
recall:        0.744
f1:            0.739
===== Start training: epoch 11 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:01,  5.82it/s]


Timing: 34.741564989089966, Epoch: 11, training loss: 589.9639053344727, current learning rate 1e-05
val loss: 31.79163607954979
accuracy:      0.745
precision:     0.748
recall:        0.745
f1:            0.745
val loss: 30.468541234731674
accuracy:      0.762
precision:     0.765
recall:        0.764
f1:            0.762
===== Start training: epoch 12 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58071947097778, Epoch: 12, training loss: 541.2521417140961, current learning rate 1e-05
val loss: 33.748093605041504
accuracy:      0.762
precision:     0.763
recall:        0.762
f1:            0.762
val loss: 33.04828733205795
accuracy:      0.762
precision:     0.764
recall:        0.764
f1:            0.762
===== Start training: epoch 13 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54750204086304, Epoch: 13, training loss: 493.9668390750885, current learning rate 1e-05
val loss: 36.86989298462868
accuracy:      0.761
precision:     0.762
recall:        0.761
f1:            0.761
val loss: 35.466664493083954
accuracy:      0.765
precision:     0.765
recall:        0.766
f1:            0.765
===== Start training: epoch 14 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54572057723999, Epoch: 14, training loss: 465.50605821609497, current learning rate 1e-05
val loss: 39.71781814098358
accuracy:      0.755
precision:     0.761
recall:        0.754
f1:            0.753
val loss: 39.43659728765488
accuracy:      0.751
precision:     0.759
recall:        0.754
f1:            0.750
===== Start training: epoch 15 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.59292984008789, Epoch: 15, training loss: 436.2304470539093, current learning rate 1e-05
val loss: 40.31965547800064
accuracy:      0.763
precision:     0.763
recall:        0.763
f1:            0.763
val loss: 39.186802446842194
accuracy:      0.765
precision:     0.764
recall:        0.765
f1:            0.764
===== Start training: epoch 16 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.59267067909241, Epoch: 16, training loss: 417.58328092098236, current learning rate 1e-05
val loss: 35.72589027881622
accuracy:      0.756
precision:     0.758
recall:        0.757
f1:            0.756
val loss: 34.19767755270004
accuracy:      0.763
precision:     0.764
recall:        0.764
f1:            0.763
===== Start training: epoch 17 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52698278427124, Epoch: 17, training loss: 406.86275708675385, current learning rate 1e-05
val loss: 42.1263707280159
accuracy:      0.759
precision:     0.760
recall:        0.759
f1:            0.759
val loss: 40.34744572639465
accuracy:      0.762
precision:     0.763
recall:        0.763
f1:            0.762
===== Start training: epoch 18 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.51439356803894, Epoch: 18, training loss: 396.91197526454926, current learning rate 1e-05
val loss: 43.892225563526154
accuracy:      0.757
precision:     0.758
recall:        0.758
f1:            0.757
val loss: 42.1363088786602
accuracy:      0.763
precision:     0.763
recall:        0.763
f1:            0.763
===== Start training: epoch 19 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.608644008636475, Epoch: 19, training loss: 390.1326709985733, current learning rate 1e-05
val loss: 46.62251687049866
accuracy:      0.762
precision:     0.762
recall:        0.762
f1:            0.761
val loss: 45.28393203020096
accuracy:      0.761
precision:     0.762
recall:        0.762
f1:            0.761
===== Start training: epoch 20 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.584322929382324, Epoch: 20, training loss: 386.56783521175385, current learning rate 1e-05
val loss: 44.749148935079575
accuracy:      0.756
precision:     0.759
recall:        0.757
f1:            0.756
val loss: 45.24710848927498
accuracy:      0.752
precision:     0.754
recall:        0.754
f1:            0.752
===== Start training: epoch 21 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.49628281593323, Epoch: 21, training loss: 377.90308606624603, current learning rate 1e-05
val loss: 46.392323076725006
accuracy:      0.755
precision:     0.756
recall:        0.755
f1:            0.754
val loss: 45.92004728317261
accuracy:      0.755
precision:     0.757
recall:        0.757
f1:            0.755
===== Start training: epoch 22 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.87it/s]


Timing: 34.42891001701355, Epoch: 22, training loss: 368.62976944446564, current learning rate 1e-05
val loss: 46.78062093257904
accuracy:      0.753
precision:     0.757
recall:        0.753
f1:            0.752
val loss: 45.67578774690628
accuracy:      0.754
precision:     0.757
recall:        0.756
f1:            0.754
===== Start training: epoch 23 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:52,  5.87it/s]


Timing: 34.40089201927185, Epoch: 23, training loss: 359.63607013225555, current learning rate 1e-05
val loss: 38.629136085510254
accuracy:      0.754
precision:     0.755
recall:        0.754
f1:            0.754
val loss: 37.70195943117142
accuracy:      0.755
precision:     0.756
recall:        0.756
f1:            0.755
===== Start training: epoch 24 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.47203350067139, Epoch: 24, training loss: 355.053062081337, current learning rate 1e-05
val loss: 41.20994907617569
accuracy:      0.757
precision:     0.758
recall:        0.757
f1:            0.757
val loss: 40.722437500953674
accuracy:      0.760
precision:     0.762
recall:        0.762
f1:            0.760
===== Start training: epoch 25 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.55545425415039, Epoch: 25, training loss: 351.1856609582901, current learning rate 1e-05
val loss: 34.50590378046036
accuracy:      0.759
precision:     0.760
recall:        0.759
f1:            0.759
val loss: 34.1767058968544
accuracy:      0.758
precision:     0.759
recall:        0.759
f1:            0.758
===== Start training: epoch 26 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.5746066570282, Epoch: 26, training loss: 359.9869477748871, current learning rate 1e-05
val loss: 42.2105268239975
accuracy:      0.759
precision:     0.761
recall:        0.759
f1:            0.759
val loss: 42.059221267700195
accuracy:      0.756
precision:     0.759
recall:        0.758
f1:            0.756
===== Start training: epoch 27 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52546453475952, Epoch: 27, training loss: 360.94418263435364, current learning rate 1e-05
val loss: 40.40218299627304
accuracy:      0.770
precision:     0.770
recall:        0.771
f1:            0.770
val loss: 41.139342963695526
accuracy:      0.756
precision:     0.756
recall:        0.756
f1:            0.756
===== Start training: epoch 28 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.86it/s]


Timing: 34.46143102645874, Epoch: 28, training loss: 358.7315813302994, current learning rate 1e-05
val loss: 42.12876099348068
accuracy:      0.774
precision:     0.774
recall:        0.775
f1:            0.775
val loss: 40.20694172382355
accuracy:      0.772
precision:     0.772
recall:        0.772
f1:            0.772
===== Start training: epoch 29 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.514307260513306, Epoch: 29, training loss: 366.57688653469086, current learning rate 1e-05
val loss: 35.99532163143158
accuracy:      0.758
precision:     0.758
recall:        0.759
f1:            0.758
val loss: 33.97372841835022
accuracy:      0.769
precision:     0.768
recall:        0.768
f1:            0.768
===== Start training: epoch 30 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:52,  5.87it/s]


Timing: 34.40933871269226, Epoch: 30, training loss: 364.7486324310303, current learning rate 1e-05
val loss: 37.93311661481857
accuracy:      0.762
precision:     0.763
recall:        0.763
f1:            0.763
val loss: 36.832941472530365
accuracy:      0.767
precision:     0.766
recall:        0.767
f1:            0.767
best result:
0.7719703977798335
0.7716084162790537
0.7719313812481514
0.771710129655391


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[[0.7608695652173914, 0.7604402483364203, 0.7603766144138815, 0.7604069605306847], [0.759481961147086, 0.7612853679987988, 0.761056886522725, 0.7594768154340132], [0.7645698427382054, 0.7644954858169948, 0.7634131913635019, 0.7637286818992517], [0.7719703977798335, 0.7716084162790537, 0.7719313812481514, 0.771710129655391]]
**** trying with seed 2 ****


tokenizing...: 100%|██████████| 6486/6486 [00:02<00:00, 3124.87it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 2164/2164 [00:00<00:00, 3082.15it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 2162/2162 [00:00<00:00, 2957.16it/s]


finished preprocessing examples in test


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.86it/s]


Timing: 34.45226693153381, Epoch: 1, training loss: 723.1144616603851, current learning rate 1e-05
val loss: 23.43153029680252
accuracy:      0.517
precision:     0.540
recall:        0.519
f1:            0.446
val loss: 23.426265835762024
accuracy:      0.533
precision:     0.536
recall:        0.518
f1:            0.458
===== Start training: epoch 2 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.86it/s]


Timing: 34.45590925216675, Epoch: 2, training loss: 980.6032948493958, current learning rate 1e-05
val loss: 19.68999344110489
accuracy:      0.695
precision:     0.704
recall:        0.695
f1:            0.692
val loss: 19.962743371725082
accuracy:      0.676
precision:     0.690
recall:        0.681
f1:            0.673
===== Start training: epoch 3 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.4922981262207, Epoch: 3, training loss: 1016.289936542511, current learning rate 1e-05
val loss: 19.49106377363205
accuracy:      0.729
precision:     0.745
recall:        0.728
f1:            0.724
val loss: 19.90785801410675
accuracy:      0.712
precision:     0.733
recall:        0.718
f1:            0.708
===== Start training: epoch 4 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.582791805267334, Epoch: 4, training loss: 960.7883770465851, current learning rate 1e-05
val loss: 19.425093173980713
accuracy:      0.741
precision:     0.742
recall:        0.741
f1:            0.741
val loss: 19.531935393810272
accuracy:      0.724
precision:     0.726
recall:        0.726
f1:            0.724
===== Start training: epoch 5 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54076170921326, Epoch: 5, training loss: 908.2974536418915, current learning rate 1e-05
val loss: 20.62224844098091
accuracy:      0.729
precision:     0.733
recall:        0.729
f1:            0.728
val loss: 20.327835500240326
accuracy:      0.731
precision:     0.732
recall:        0.729
f1:            0.729
===== Start training: epoch 6 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.87it/s]


Timing: 34.447041034698486, Epoch: 6, training loss: 845.6714596748352, current learning rate 1e-05
val loss: 21.66748684644699
accuracy:      0.749
precision:     0.751
recall:        0.749
f1:            0.749
val loss: 20.85141521692276
accuracy:      0.747
precision:     0.750
recall:        0.749
f1:            0.747
===== Start training: epoch 7 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.87it/s]


Timing: 34.43399262428284, Epoch: 7, training loss: 768.6342272758484, current learning rate 1e-05
val loss: 23.380544662475586
accuracy:      0.745
precision:     0.745
recall:        0.745
f1:            0.745
val loss: 23.30519476532936
accuracy:      0.742
precision:     0.742
recall:        0.743
f1:            0.742
===== Start training: epoch 8 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.84it/s]


Timing: 34.62172341346741, Epoch: 8, training loss: 704.3740491867065, current learning rate 1e-05
val loss: 30.512280464172363
accuracy:      0.752
precision:     0.752
recall:        0.752
f1:            0.752
val loss: 30.310295462608337
accuracy:      0.755
precision:     0.755
recall:        0.756
f1:            0.755
===== Start training: epoch 9 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.56570911407471, Epoch: 9, training loss: 638.2518496513367, current learning rate 1e-05
val loss: 32.648853063583374
accuracy:      0.748
precision:     0.748
recall:        0.749
f1:            0.748
val loss: 32.61694806814194
accuracy:      0.739
precision:     0.739
recall:        0.739
f1:            0.739
===== Start training: epoch 10 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.64691781997681, Epoch: 10, training loss: 575.6952698230743, current learning rate 1e-05
val loss: 35.119977831840515
accuracy:      0.760
precision:     0.760
recall:        0.761
f1:            0.760
val loss: 34.4806404709816
accuracy:      0.758
precision:     0.758
recall:        0.758
f1:            0.758
===== Start training: epoch 11 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.62559151649475, Epoch: 11, training loss: 508.7538652420044, current learning rate 1e-05
val loss: 34.26965069770813
accuracy:      0.751
precision:     0.756
recall:        0.751
f1:            0.750
val loss: 34.43992477655411
accuracy:      0.747
precision:     0.756
recall:        0.751
f1:            0.747
===== Start training: epoch 12 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.57737898826599, Epoch: 12, training loss: 445.08671247959137, current learning rate 1e-05
val loss: 35.0702810883522
accuracy:      0.755
precision:     0.756
recall:        0.755
f1:            0.754
val loss: 34.533514857292175
accuracy:      0.753
precision:     0.758
recall:        0.756
f1:            0.753
===== Start training: epoch 13 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.57058024406433, Epoch: 13, training loss: 393.90957856178284, current learning rate 1e-05
val loss: 29.954997897148132
accuracy:      0.751
precision:     0.754
recall:        0.751
f1:            0.751
val loss: 30.03352963924408
accuracy:      0.748
precision:     0.753
recall:        0.751
f1:            0.748
===== Start training: epoch 14 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.473817586898804, Epoch: 14, training loss: 367.03764164447784, current learning rate 1e-05
val loss: 38.595010459423065
accuracy:      0.760
precision:     0.760
recall:        0.760
f1:            0.760
val loss: 38.18238294124603
accuracy:      0.757
precision:     0.757
recall:        0.757
f1:            0.757
===== Start training: epoch 15 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.55721426010132, Epoch: 15, training loss: 360.7500454187393, current learning rate 1e-05
val loss: 41.617550015449524
accuracy:      0.753
precision:     0.754
recall:        0.754
f1:            0.753
val loss: 41.34032815694809
accuracy:      0.752
precision:     0.751
recall:        0.751
f1:            0.751
===== Start training: epoch 16 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52171874046326, Epoch: 16, training loss: 351.3586235046387, current learning rate 1e-05
val loss: 50.650167405605316
accuracy:      0.762
precision:     0.762
recall:        0.762
f1:            0.762
val loss: 50.5934756398201
accuracy:      0.758
precision:     0.759
recall:        0.759
f1:            0.758
===== Start training: epoch 17 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.47733449935913, Epoch: 17, training loss: 345.0229104757309, current learning rate 1e-05
val loss: 48.04312074184418
accuracy:      0.765
precision:     0.765
recall:        0.765
f1:            0.765
val loss: 49.0815372467041
accuracy:      0.757
precision:     0.759
recall:        0.759
f1:            0.757
===== Start training: epoch 18 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.477282762527466, Epoch: 18, training loss: 362.92742908000946, current learning rate 1e-05
val loss: 52.688533902168274
accuracy:      0.757
precision:     0.757
recall:        0.757
f1:            0.757
val loss: 50.77815538644791
accuracy:      0.752
precision:     0.752
recall:        0.752
f1:            0.752
===== Start training: epoch 19 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52867031097412, Epoch: 19, training loss: 367.243399977684, current learning rate 1e-05
val loss: 61.44211196899414
accuracy:      0.755
precision:     0.757
recall:        0.755
f1:            0.754
val loss: 61.92078483104706
accuracy:      0.745
precision:     0.751
recall:        0.748
f1:            0.745
===== Start training: epoch 20 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.53798246383667, Epoch: 20, training loss: 390.48461830616, current learning rate 1e-05
val loss: 60.569376051425934
accuracy:      0.750
precision:     0.750
recall:        0.751
f1:            0.751
val loss: 57.42693477869034
accuracy:      0.753
precision:     0.753
recall:        0.754
f1:            0.753
===== Start training: epoch 21 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54728436470032, Epoch: 21, training loss: 415.2327892780304, current learning rate 1e-05
val loss: 65.35932910442352
accuracy:      0.758
precision:     0.764
recall:        0.758
f1:            0.757
val loss: 69.15567195415497
accuracy:      0.751
precision:     0.763
recall:        0.755
f1:            0.750
===== Start training: epoch 22 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.566463232040405, Epoch: 22, training loss: 401.2509434223175, current learning rate 1e-05
val loss: 60.837488770484924
accuracy:      0.750
precision:     0.750
recall:        0.750
f1:            0.750
val loss: 61.03949296474457
accuracy:      0.749
precision:     0.749
recall:        0.749
f1:            0.749
===== Start training: epoch 23 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.63506317138672, Epoch: 23, training loss: 364.1362683773041, current learning rate 1e-05
val loss: 62.33855438232422
accuracy:      0.750
precision:     0.752
recall:        0.750
f1:            0.749
val loss: 65.07115435600281
accuracy:      0.740
precision:     0.743
recall:        0.742
f1:            0.740
===== Start training: epoch 24 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.490641355514526, Epoch: 24, training loss: 346.7735129594803, current learning rate 1e-05
val loss: 63.071951508522034
accuracy:      0.750
precision:     0.750
recall:        0.750
f1:            0.749
val loss: 62.64676773548126
accuracy:      0.744
precision:     0.747
recall:        0.746
f1:            0.744
===== Start training: epoch 25 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.51219701766968, Epoch: 25, training loss: 336.52222216129303, current learning rate 1e-05
val loss: 70.7518869638443
accuracy:      0.750
precision:     0.751
recall:        0.751
f1:            0.750
val loss: 71.82979321479797
accuracy:      0.749
precision:     0.751
recall:        0.751
f1:            0.749
===== Start training: epoch 26 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.49013876914978, Epoch: 26, training loss: 323.4825506210327, current learning rate 1e-05
val loss: 66.03355073928833
accuracy:      0.755
precision:     0.755
recall:        0.756
f1:            0.755
val loss: 62.145106077194214
accuracy:      0.745
precision:     0.744
recall:        0.744
f1:            0.744
===== Start training: epoch 27 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.49454975128174, Epoch: 27, training loss: 320.59063494205475, current learning rate 1e-05
val loss: 80.49789929389954
accuracy:      0.750
precision:     0.751
recall:        0.750
f1:            0.749
val loss: 79.70471000671387
accuracy:      0.749
precision:     0.751
recall:        0.751
f1:            0.749
===== Start training: epoch 28 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.50009083747864, Epoch: 28, training loss: 318.0156384706497, current learning rate 1e-05
val loss: 77.02517688274384
accuracy:      0.746
precision:     0.746
recall:        0.747
f1:            0.746
val loss: 71.8615357875824
accuracy:      0.739
precision:     0.739
recall:        0.739
f1:            0.739
===== Start training: epoch 29 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.535706996917725, Epoch: 29, training loss: 317.09832322597504, current learning rate 1e-05
val loss: 61.580339550971985
accuracy:      0.752
precision:     0.752
recall:        0.752
f1:            0.752
val loss: 59.21432065963745
accuracy:      0.751
precision:     0.751
recall:        0.752
f1:            0.751
===== Start training: epoch 30 =====
*** trying with discovery_weight = 0.4, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.53923940658569, Epoch: 30, training loss: 315.1682758331299, current learning rate 1e-05
val loss: 69.18211448192596
accuracy:      0.752
precision:     0.756
recall:        0.752
f1:            0.751
val loss: 72.67264592647552
accuracy:      0.739
precision:     0.749
recall:        0.743
f1:            0.738
best result:
0.7571692876965772
0.7587490136882913
0.7586414275855269
0.7571679889196463


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.5846049785614, Epoch: 1, training loss: 745.6758553981781, current learning rate 1e-05
val loss: 23.39602839946747
accuracy:      0.604
precision:     0.606
recall:        0.603
f1:            0.601
val loss: 23.4333193898201
accuracy:      0.562
precision:     0.570
recall:        0.567
f1:            0.560
===== Start training: epoch 2 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.62767434120178, Epoch: 2, training loss: 985.0405230522156, current learning rate 1e-05
val loss: 19.117981255054474
accuracy:      0.713
precision:     0.720
recall:        0.712
f1:            0.710
val loss: 19.61071425676346
accuracy:      0.692
precision:     0.699
recall:        0.696
f1:            0.692
===== Start training: epoch 3 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.704472064971924, Epoch: 3, training loss: 1017.4648056030273, current learning rate 1e-05
val loss: 19.171209752559662
accuracy:      0.739
precision:     0.750
recall:        0.739
f1:            0.736
val loss: 19.61482399702072
accuracy:      0.718
precision:     0.729
recall:        0.722
f1:            0.717
===== Start training: epoch 4 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.66987109184265, Epoch: 4, training loss: 983.6247086524963, current learning rate 1e-05
val loss: 20.296704351902008
accuracy:      0.755
precision:     0.755
recall:        0.755
f1:            0.755
val loss: 20.79804766178131
accuracy:      0.733
precision:     0.733
recall:        0.733
f1:            0.733
===== Start training: epoch 5 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.682297229766846, Epoch: 5, training loss: 938.0372519493103, current learning rate 1e-05
val loss: 20.243574619293213
accuracy:      0.756
precision:     0.757
recall:        0.757
f1:            0.757
val loss: 20.665845811367035
accuracy:      0.745
precision:     0.744
recall:        0.744
f1:            0.744
===== Start training: epoch 6 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.64846420288086, Epoch: 6, training loss: 876.5712282657623, current learning rate 1e-05
val loss: 21.312838450074196
accuracy:      0.755
precision:     0.758
recall:        0.755
f1:            0.754
val loss: 21.713946133852005
accuracy:      0.741
precision:     0.743
recall:        0.742
f1:            0.740
===== Start training: epoch 7 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.65790915489197, Epoch: 7, training loss: 804.2242949008942, current learning rate 1e-05
val loss: 23.647664487361908
accuracy:      0.763
precision:     0.763
recall:        0.763
f1:            0.763
val loss: 23.858400106430054
accuracy:      0.747
precision:     0.747
recall:        0.747
f1:            0.747
===== Start training: epoch 8 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.60695433616638, Epoch: 8, training loss: 721.084876537323, current learning rate 1e-05
val loss: 30.583832383155823
accuracy:      0.758
precision:     0.758
recall:        0.759
f1:            0.758
val loss: 31.288708716630936
accuracy:      0.764
precision:     0.764
recall:        0.764
f1:            0.764
===== Start training: epoch 9 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58468151092529, Epoch: 9, training loss: 650.5548386573792, current learning rate 1e-05
val loss: 31.899199783802032
accuracy:      0.756
precision:     0.757
recall:        0.757
f1:            0.756
val loss: 33.152389883995056
accuracy:      0.751
precision:     0.752
recall:        0.752
f1:            0.751
===== Start training: epoch 10 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.68052959442139, Epoch: 10, training loss: 588.0866384506226, current learning rate 1e-05
val loss: 32.839112401008606
accuracy:      0.757
precision:     0.758
recall:        0.758
f1:            0.757
val loss: 33.900289714336395
accuracy:      0.750
precision:     0.750
recall:        0.750
f1:            0.750
===== Start training: epoch 11 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.664698123931885, Epoch: 11, training loss: 521.5217685699463, current learning rate 1e-05
val loss: 30.856267541646957
accuracy:      0.752
precision:     0.754
recall:        0.752
f1:            0.752
val loss: 31.12409943342209
accuracy:      0.746
precision:     0.749
recall:        0.748
f1:            0.746
===== Start training: epoch 12 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52035117149353, Epoch: 12, training loss: 465.64895951747894, current learning rate 1e-05
val loss: 37.30590897798538
accuracy:      0.758
precision:     0.759
recall:        0.759
f1:            0.758
val loss: 38.173861145973206
accuracy:      0.747
precision:     0.747
recall:        0.747
f1:            0.747
===== Start training: epoch 13 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.55959677696228, Epoch: 13, training loss: 415.83609199523926, current learning rate 1e-05
val loss: 36.159988939762115
accuracy:      0.760
precision:     0.764
recall:        0.760
f1:            0.759
val loss: 39.690974205732346
accuracy:      0.742
precision:     0.747
recall:        0.745
f1:            0.742
===== Start training: epoch 14 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.552778482437134, Epoch: 14, training loss: 381.202236533165, current learning rate 1e-05
val loss: 39.15243625640869
accuracy:      0.752
precision:     0.758
recall:        0.752
f1:            0.750
val loss: 41.5619056224823
accuracy:      0.739
precision:     0.747
recall:        0.743
f1:            0.739
===== Start training: epoch 15 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54235243797302, Epoch: 15, training loss: 378.2116963863373, current learning rate 1e-05
val loss: 40.5145189166069
accuracy:      0.751
precision:     0.758
recall:        0.752
f1:            0.750
val loss: 40.6754384636879
accuracy:      0.753
precision:     0.758
recall:        0.750
f1:            0.750
===== Start training: epoch 16 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54120063781738, Epoch: 16, training loss: 359.0336960554123, current learning rate 1e-05
val loss: 37.955035507678986
accuracy:      0.768
precision:     0.769
recall:        0.768
f1:            0.767
val loss: 40.54343235492706
accuracy:      0.747
precision:     0.750
recall:        0.749
f1:            0.747
===== Start training: epoch 17 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54241991043091, Epoch: 17, training loss: 344.66686618328094, current learning rate 1e-05
val loss: 38.969040274620056
accuracy:      0.765
precision:     0.765
recall:        0.765
f1:            0.765
val loss: 40.51101487874985
accuracy:      0.753
precision:     0.752
recall:        0.753
f1:            0.752
===== Start training: epoch 18 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.53993272781372, Epoch: 18, training loss: 338.5660058259964, current learning rate 1e-05
val loss: 44.3079269528389
accuracy:      0.769
precision:     0.774
recall:        0.769
f1:            0.768
val loss: 48.19252812862396
accuracy:      0.744
precision:     0.751
recall:        0.747
f1:            0.744
===== Start training: epoch 19 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58129024505615, Epoch: 19, training loss: 343.4204765558243, current learning rate 1e-05
val loss: 43.2140976190567
accuracy:      0.760
precision:     0.760
recall:        0.760
f1:            0.760
val loss: 44.50768458843231
accuracy:      0.745
precision:     0.745
recall:        0.744
f1:            0.744
===== Start training: epoch 20 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.5781466960907, Epoch: 20, training loss: 342.3836201429367, current learning rate 1e-05
val loss: 58.62011784315109
accuracy:      0.756
precision:     0.756
recall:        0.757
f1:            0.757
val loss: 59.82228875160217
accuracy:      0.744
precision:     0.744
recall:        0.744
f1:            0.744
===== Start training: epoch 21 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.86it/s]


Timing: 34.50552010536194, Epoch: 21, training loss: 345.23149251937866, current learning rate 1e-05
val loss: 55.78522968292236
accuracy:      0.774
precision:     0.775
recall:        0.774
f1:            0.774
val loss: 58.82238733768463
accuracy:      0.746
precision:     0.746
recall:        0.747
f1:            0.746
===== Start training: epoch 22 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.495704650878906, Epoch: 22, training loss: 359.29779171943665, current learning rate 1e-05
val loss: 55.01399791240692
accuracy:      0.760
precision:     0.760
recall:        0.760
f1:            0.760
val loss: 56.268001556396484
accuracy:      0.754
precision:     0.754
recall:        0.752
f1:            0.753
===== Start training: epoch 23 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.59161639213562, Epoch: 23, training loss: 376.4508892297745, current learning rate 1e-05
val loss: 52.2637796998024
accuracy:      0.767
precision:     0.769
recall:        0.767
f1:            0.766
val loss: 54.476704359054565
accuracy:      0.748
precision:     0.750
recall:        0.750
f1:            0.748
===== Start training: epoch 24 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58967447280884, Epoch: 24, training loss: 373.6571537256241, current learning rate 1e-05
val loss: 56.71202754974365
accuracy:      0.767
precision:     0.769
recall:        0.767
f1:            0.766
val loss: 58.203116059303284
accuracy:      0.756
precision:     0.758
recall:        0.758
f1:            0.756
===== Start training: epoch 25 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.582351207733154, Epoch: 25, training loss: 381.0236862897873, current learning rate 1e-05
val loss: 59.80287754535675
accuracy:      0.771
precision:     0.772
recall:        0.772
f1:            0.771
val loss: 60.913376450538635
accuracy:      0.757
precision:     0.757
recall:        0.756
f1:            0.756
===== Start training: epoch 26 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.541598320007324, Epoch: 26, training loss: 402.4168554544449, current learning rate 1e-05
val loss: 48.80085128545761
accuracy:      0.757
precision:     0.758
recall:        0.758
f1:            0.757
val loss: 50.693869948387146
accuracy:      0.750
precision:     0.750
recall:        0.751
f1:            0.750
===== Start training: epoch 27 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58156967163086, Epoch: 27, training loss: 414.2177355289459, current learning rate 1e-05
val loss: 60.77896320819855
accuracy:      0.760
precision:     0.762
recall:        0.760
f1:            0.760
val loss: 61.96638739109039
accuracy:      0.746
precision:     0.747
recall:        0.747
f1:            0.746
===== Start training: epoch 28 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.542906045913696, Epoch: 28, training loss: 392.68401062488556, current learning rate 1e-05
val loss: 64.7345050573349
accuracy:      0.756
precision:     0.759
recall:        0.756
f1:            0.755
val loss: 67.86134707927704
accuracy:      0.748
precision:     0.751
recall:        0.750
f1:            0.748
===== Start training: epoch 29 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52761483192444, Epoch: 29, training loss: 381.28628838062286, current learning rate 1e-05
val loss: 54.628896594047546
accuracy:      0.759
precision:     0.759
recall:        0.759
f1:            0.759
val loss: 56.52535080909729
accuracy:      0.746
precision:     0.746
recall:        0.745
f1:            0.745
===== Start training: epoch 30 =====
*** trying with discovery_weight = 0.6000000000000001, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58952569961548, Epoch: 30, training loss: 360.41795325279236, current learning rate 1e-05
val loss: 82.63597333431244
accuracy:      0.762
precision:     0.767
recall:        0.762
f1:            0.762
val loss: 84.5760428905487
accuracy:      0.748
precision:     0.752
recall:        0.750
f1:            0.748
best result:
0.7460684551341351
0.7464640083816098
0.7468500443655723
0.7460227590045199


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58169484138489, Epoch: 1, training loss: 776.9551615715027, current learning rate 1e-05
val loss: 23.27283775806427
accuracy:      0.609
precision:     0.613
recall:        0.609
f1:            0.605
val loss: 23.29669064283371
accuracy:      0.608
precision:     0.621
recall:        0.613
f1:            0.603
===== Start training: epoch 2 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.63593816757202, Epoch: 2, training loss: 981.4992480278015, current learning rate 1e-05
val loss: 19.548210561275482
accuracy:      0.695
precision:     0.700
recall:        0.695
f1:            0.693
val loss: 19.67784383893013
accuracy:      0.690
precision:     0.696
recall:        0.693
f1:            0.689
===== Start training: epoch 3 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.72159123420715, Epoch: 3, training loss: 1010.9114928245544, current learning rate 1e-05
val loss: 19.789527356624603
accuracy:      0.729
precision:     0.737
recall:        0.729
f1:            0.727
val loss: 19.532495617866516
accuracy:      0.724
precision:     0.734
recall:        0.728
f1:            0.723
===== Start training: epoch 4 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:01,  5.82it/s]


Timing: 34.73681139945984, Epoch: 4, training loss: 989.4476599693298, current learning rate 1e-05
val loss: 20.183599889278412
accuracy:      0.747
precision:     0.747
recall:        0.747
f1:            0.747
val loss: 20.078083336353302
accuracy:      0.737
precision:     0.737
recall:        0.737
f1:            0.737
===== Start training: epoch 5 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<15:00,  5.82it/s]


Timing: 34.69969081878662, Epoch: 5, training loss: 964.7888898849487, current learning rate 1e-05
val loss: 19.030315458774567
accuracy:      0.755
precision:     0.759
recall:        0.755
f1:            0.754
val loss: 19.40459167957306
accuracy:      0.735
precision:     0.741
recall:        0.738
f1:            0.735
===== Start training: epoch 6 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.82it/s]


Timing: 34.69162034988403, Epoch: 6, training loss: 913.4136323928833, current learning rate 1e-05
val loss: 20.643157571554184
accuracy:      0.760
precision:     0.762
recall:        0.760
f1:            0.759
val loss: 20.536016702651978
accuracy:      0.741
precision:     0.745
recall:        0.743
f1:            0.740
===== Start training: epoch 7 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.59757852554321, Epoch: 7, training loss: 847.3762743473053, current learning rate 1e-05
val loss: 22.779261469841003
accuracy:      0.753
precision:     0.754
recall:        0.753
f1:            0.753
val loss: 22.25959664583206
accuracy:      0.746
precision:     0.747
recall:        0.747
f1:            0.746
===== Start training: epoch 8 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.611034631729126, Epoch: 8, training loss: 762.1324162483215, current learning rate 1e-05
val loss: 28.55449253320694
accuracy:      0.767
precision:     0.767
recall:        0.767
f1:            0.767
val loss: 28.146027147769928
accuracy:      0.755
precision:     0.756
recall:        0.756
f1:            0.755
===== Start training: epoch 9 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.60307312011719, Epoch: 9, training loss: 685.8607938289642, current learning rate 1e-05
val loss: 29.965814739465714
accuracy:      0.751
precision:     0.753
recall:        0.752
f1:            0.751
val loss: 29.269800543785095
accuracy:      0.752
precision:     0.752
recall:        0.750
f1:            0.751
===== Start training: epoch 10 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.6269314289093, Epoch: 10, training loss: 617.3899350166321, current learning rate 1e-05
val loss: 29.09792497754097
accuracy:      0.764
precision:     0.765
recall:        0.765
f1:            0.764
val loss: 29.830738455057144
accuracy:      0.750
precision:     0.751
recall:        0.751
f1:            0.750
===== Start training: epoch 11 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58712029457092, Epoch: 11, training loss: 564.1342163085938, current learning rate 1e-05
val loss: 30.99210387468338
accuracy:      0.762
precision:     0.762
recall:        0.762
f1:            0.762
val loss: 31.22626042366028
accuracy:      0.755
precision:     0.755
recall:        0.756
f1:            0.755
===== Start training: epoch 12 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.83it/s]


Timing: 34.6434121131897, Epoch: 12, training loss: 508.6970615386963, current learning rate 1e-05
val loss: 31.83509075641632
accuracy:      0.757
precision:     0.758
recall:        0.757
f1:            0.757
val loss: 32.98519027233124
accuracy:      0.747
precision:     0.749
recall:        0.749
f1:            0.747
===== Start training: epoch 13 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.600069522857666, Epoch: 13, training loss: 470.1034218072891, current learning rate 1e-05
val loss: 31.732024013996124
accuracy:      0.763
precision:     0.766
recall:        0.763
f1:            0.762
val loss: 34.225687474012375
accuracy:      0.748
precision:     0.753
recall:        0.751
f1:            0.748
===== Start training: epoch 14 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.57576036453247, Epoch: 14, training loss: 436.12249505519867, current learning rate 1e-05
val loss: 39.909330666065216
accuracy:      0.756
precision:     0.758
recall:        0.756
f1:            0.756
val loss: 39.96632182598114
accuracy:      0.755
precision:     0.758
recall:        0.757
f1:            0.755
===== Start training: epoch 15 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.57497429847717, Epoch: 15, training loss: 415.3737713098526, current learning rate 1e-05
val loss: 40.354135036468506
accuracy:      0.753
precision:     0.754
recall:        0.754
f1:            0.753
val loss: 40.1488521695137
accuracy:      0.761
precision:     0.761
recall:        0.760
f1:            0.760
===== Start training: epoch 16 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.50029635429382, Epoch: 16, training loss: 393.7576028108597, current learning rate 1e-05
val loss: 39.664475321769714
accuracy:      0.772
precision:     0.773
recall:        0.772
f1:            0.772
val loss: 40.964826822280884
accuracy:      0.763
precision:     0.764
recall:        0.764
f1:            0.763
===== Start training: epoch 17 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:52,  5.87it/s]


Timing: 34.4033887386322, Epoch: 17, training loss: 377.3322925567627, current learning rate 1e-05
val loss: 41.80260193347931
accuracy:      0.781
precision:     0.782
recall:        0.782
f1:            0.781
val loss: 44.65494638681412
accuracy:      0.756
precision:     0.758
recall:        0.758
f1:            0.756
===== Start training: epoch 18 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:52,  5.87it/s]


Timing: 34.422611474990845, Epoch: 18, training loss: 357.48032307624817, current learning rate 1e-05
val loss: 41.54589915275574
accuracy:      0.761
precision:     0.765
recall:        0.761
f1:            0.760
val loss: 42.241420209407806
accuracy:      0.751
precision:     0.755
recall:        0.753
f1:            0.751
===== Start training: epoch 19 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.87it/s]


Timing: 34.42708897590637, Epoch: 19, training loss: 354.85506641864777, current learning rate 1e-05
val loss: 37.29274374246597
accuracy:      0.761
precision:     0.761
recall:        0.761
f1:            0.761
val loss: 37.16743981838226
accuracy:      0.753
precision:     0.752
recall:        0.753
f1:            0.752
===== Start training: epoch 20 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.87it/s]


Timing: 34.42570781707764, Epoch: 20, training loss: 346.72759318351746, current learning rate 1e-05
val loss: 43.395164132118225
accuracy:      0.766
precision:     0.766
recall:        0.767
f1:            0.766
val loss: 44.212180614471436
accuracy:      0.753
precision:     0.752
recall:        0.753
f1:            0.752
===== Start training: epoch 21 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:52,  5.87it/s]


Timing: 34.400700092315674, Epoch: 21, training loss: 334.65778827667236, current learning rate 1e-05
val loss: 53.66778516769409
accuracy:      0.760
precision:     0.772
recall:        0.759
f1:            0.757
val loss: 57.35310173034668
accuracy:      0.744
precision:     0.759
recall:        0.749
f1:            0.743
===== Start training: epoch 22 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:51,  5.88it/s]


Timing: 34.36580491065979, Epoch: 22, training loss: 334.57115495204926, current learning rate 1e-05
val loss: 45.61753636598587
accuracy:      0.772
precision:     0.772
recall:        0.773
f1:            0.772
val loss: 47.277304887771606
accuracy:      0.757
precision:     0.757
recall:        0.757
f1:            0.757
===== Start training: epoch 23 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:52,  5.87it/s]


Timing: 34.393845558166504, Epoch: 23, training loss: 333.21394753456116, current learning rate 1e-05
val loss: 50.298639476299286
accuracy:      0.763
precision:     0.766
recall:        0.763
f1:            0.763
val loss: 51.40104657411575
accuracy:      0.751
precision:     0.754
recall:        0.753
f1:            0.751
===== Start training: epoch 24 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:52,  5.87it/s]


Timing: 34.41515016555786, Epoch: 24, training loss: 325.81412649154663, current learning rate 1e-05
val loss: 50.23270732164383
accuracy:      0.761
precision:     0.766
recall:        0.761
f1:            0.760
val loss: 51.73282665014267
accuracy:      0.751
precision:     0.758
recall:        0.754
f1:            0.751
===== Start training: epoch 25 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.46918535232544, Epoch: 25, training loss: 323.4536807537079, current learning rate 1e-05
val loss: 51.77542692422867
accuracy:      0.772
precision:     0.773
recall:        0.772
f1:            0.772
val loss: 53.767977356910706
accuracy:      0.754
precision:     0.756
recall:        0.756
f1:            0.754
===== Start training: epoch 26 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.513314962387085, Epoch: 26, training loss: 317.16034412384033, current learning rate 1e-05
val loss: 50.107664465904236
accuracy:      0.745
precision:     0.745
recall:        0.746
f1:            0.746
val loss: 50.49868094921112
accuracy:      0.755
precision:     0.755
recall:        0.755
f1:            0.755
===== Start training: epoch 27 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54192519187927, Epoch: 27, training loss: 312.75253760814667, current learning rate 1e-05
val loss: 54.057975232601166
accuracy:      0.766
precision:     0.766
recall:        0.766
f1:            0.766
val loss: 54.615280747413635
accuracy:      0.760
precision:     0.760
recall:        0.760
f1:            0.760
===== Start training: epoch 28 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.87it/s]


Timing: 34.44005608558655, Epoch: 28, training loss: 309.3173403739929, current learning rate 1e-05
val loss: 55.42140477895737
accuracy:      0.766
precision:     0.770
recall:        0.766
f1:            0.765
val loss: 56.883325815200806
accuracy:      0.753
precision:     0.756
recall:        0.755
f1:            0.752
===== Start training: epoch 29 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.48869180679321, Epoch: 29, training loss: 309.7609544992447, current learning rate 1e-05
val loss: 51.87077462673187
accuracy:      0.771
precision:     0.772
recall:        0.771
f1:            0.771
val loss: 54.74948972463608
accuracy:      0.751
precision:     0.752
recall:        0.752
f1:            0.751
===== Start training: epoch 30 =====
*** trying with discovery_weight = 0.8000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:52,  5.87it/s]


Timing: 34.42312955856323, Epoch: 30, training loss: 306.4924006462097, current learning rate 1e-05
val loss: 54.7752240896225
accuracy:      0.765
precision:     0.772
recall:        0.765
f1:            0.764
val loss: 59.41555881500244
accuracy:      0.749
precision:     0.757
recall:        0.753
f1:            0.749
best result:
0.7562442183163737
0.7578218806820132
0.7577146800749285
0.7562429145917211


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52583289146423, Epoch: 1, training loss: 806.9848856925964, current learning rate 1e-05
val loss: 23.363463819026947
accuracy:      0.581
precision:     0.582
recall:        0.581
f1:            0.580
val loss: 23.413636445999146
accuracy:      0.553
precision:     0.557
recall:        0.556
f1:            0.552
===== Start training: epoch 2 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.551971197128296, Epoch: 2, training loss: 984.9677476882935, current learning rate 1e-05
val loss: 19.70379811525345
accuracy:      0.695
precision:     0.711
recall:        0.694
f1:            0.688
val loss: 20.208464920520782
accuracy:      0.679
precision:     0.700
recall:        0.685
f1:            0.675
===== Start training: epoch 3 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.58060646057129, Epoch: 3, training loss: 1024.1378173828125, current learning rate 1e-05
val loss: 18.917897582054138
accuracy:      0.731
precision:     0.746
recall:        0.731
f1:            0.727
val loss: 19.295509040355682
accuracy:      0.721
precision:     0.739
recall:        0.726
f1:            0.719
===== Start training: epoch 4 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.59348106384277, Epoch: 4, training loss: 1011.7829337120056, current learning rate 1e-05
val loss: 19.363670885562897
accuracy:      0.749
precision:     0.751
recall:        0.749
f1:            0.748
val loss: 19.734048187732697
accuracy:      0.733
precision:     0.736
recall:        0.735
f1:            0.733
===== Start training: epoch 5 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.5975444316864, Epoch: 5, training loss: 976.0372748374939, current learning rate 1e-05
val loss: 18.875591158866882
accuracy:      0.744
precision:     0.744
recall:        0.745
f1:            0.745
val loss: 19.218562841415405
accuracy:      0.738
precision:     0.737
recall:        0.737
f1:            0.737
===== Start training: epoch 6 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:59,  5.83it/s]


Timing: 34.67339515686035, Epoch: 6, training loss: 924.4159553050995, current learning rate 1e-05
val loss: 20.523599684238434
accuracy:      0.758
precision:     0.761
recall:        0.758
f1:            0.757
val loss: 21.441590309143066
accuracy:      0.743
precision:     0.747
recall:        0.745
f1:            0.743
===== Start training: epoch 7 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:57,  5.84it/s]


Timing: 34.60766530036926, Epoch: 7, training loss: 856.9075198173523, current learning rate 1e-05
val loss: 20.111783266067505
accuracy:      0.751
precision:     0.751
recall:        0.751
f1:            0.751
val loss: 20.468111366033554
accuracy:      0.745
precision:     0.745
recall:        0.746
f1:            0.745
===== Start training: epoch 8 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.56163454055786, Epoch: 8, training loss: 786.2003753185272, current learning rate 1e-05
val loss: 24.722757399082184
accuracy:      0.767
precision:     0.767
recall:        0.767
f1:            0.767
val loss: 24.689374685287476
accuracy:      0.760
precision:     0.760
recall:        0.761
f1:            0.760
===== Start training: epoch 9 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.56877541542053, Epoch: 9, training loss: 711.1799986362457, current learning rate 1e-05
val loss: 26.955925434827805
accuracy:      0.750
precision:     0.751
recall:        0.750
f1:            0.750
val loss: 26.859581351280212
accuracy:      0.755
precision:     0.756
recall:        0.754
f1:            0.754
===== Start training: epoch 10 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.570103883743286, Epoch: 10, training loss: 641.4690961837769, current learning rate 1e-05
val loss: 28.953776955604553
accuracy:      0.763
precision:     0.763
recall:        0.764
f1:            0.764
val loss: 29.166455179452896
accuracy:      0.750
precision:     0.750
recall:        0.750
f1:            0.750
===== Start training: epoch 11 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.51372814178467, Epoch: 11, training loss: 589.4551401138306, current learning rate 1e-05
val loss: 29.516445070505142
accuracy:      0.753
precision:     0.756
recall:        0.753
f1:            0.753
val loss: 29.256931722164154
accuracy:      0.754
precision:     0.758
recall:        0.757
f1:            0.754
===== Start training: epoch 12 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.55319118499756, Epoch: 12, training loss: 533.6349632740021, current learning rate 1e-05
val loss: 33.46657782793045
accuracy:      0.762
precision:     0.763
recall:        0.762
f1:            0.761
val loss: 33.78527891635895
accuracy:      0.754
precision:     0.755
recall:        0.755
f1:            0.754
===== Start training: epoch 13 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.55090069770813, Epoch: 13, training loss: 502.6188452243805, current learning rate 1e-05
val loss: 34.785595297813416
accuracy:      0.762
precision:     0.774
recall:        0.761
f1:            0.759
val loss: 37.449611842632294
accuracy:      0.736
precision:     0.750
recall:        0.741
f1:            0.735
===== Start training: epoch 14 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.87it/s]


Timing: 34.43342137336731, Epoch: 14, training loss: 473.3060041666031, current learning rate 1e-05
val loss: 36.73263669013977
accuracy:      0.764
precision:     0.765
recall:        0.765
f1:            0.764
val loss: 37.3768065571785
accuracy:      0.759
precision:     0.759
recall:        0.760
f1:            0.759
===== Start training: epoch 15 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:58,  5.84it/s]


Timing: 34.625982999801636, Epoch: 15, training loss: 451.92617678642273, current learning rate 1e-05
val loss: 39.20532160997391
accuracy:      0.762
precision:     0.762
recall:        0.762
f1:            0.762
val loss: 38.945604741573334
accuracy:      0.758
precision:     0.758
recall:        0.758
f1:            0.758
===== Start training: epoch 16 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.57334518432617, Epoch: 16, training loss: 430.03714776039124, current learning rate 1e-05
val loss: 37.523957550525665
accuracy:      0.771
precision:     0.772
recall:        0.771
f1:            0.771
val loss: 38.48769557476044
accuracy:      0.756
precision:     0.757
recall:        0.758
f1:            0.756
===== Start training: epoch 17 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.56658983230591, Epoch: 17, training loss: 425.41096663475037, current learning rate 1e-05
val loss: 40.81888508796692
accuracy:      0.762
precision:     0.762
recall:        0.763
f1:            0.763
val loss: 40.9998260140419
accuracy:      0.761
precision:     0.761
recall:        0.761
f1:            0.761
===== Start training: epoch 18 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.495816230773926, Epoch: 18, training loss: 405.98467242717743, current learning rate 1e-05
val loss: 42.83191579580307
accuracy:      0.757
precision:     0.757
recall:        0.757
f1:            0.757
val loss: 43.07012528181076
accuracy:      0.746
precision:     0.746
recall:        0.746
f1:            0.746
===== Start training: epoch 19 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.467386960983276, Epoch: 19, training loss: 398.9103444814682, current learning rate 1e-05
val loss: 37.92085200548172
accuracy:      0.766
precision:     0.766
recall:        0.766
f1:            0.766
val loss: 38.376400113105774
accuracy:      0.760
precision:     0.760
recall:        0.759
f1:            0.759
===== Start training: epoch 20 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.84it/s]


Timing: 34.577805042266846, Epoch: 20, training loss: 389.75825130939484, current learning rate 1e-05
val loss: 42.240751564502716
accuracy:      0.764
precision:     0.764
recall:        0.765
f1:            0.764
val loss: 44.13083863258362
accuracy:      0.747
precision:     0.747
recall:        0.747
f1:            0.747
===== Start training: epoch 21 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.46473932266235, Epoch: 21, training loss: 383.72288036346436, current learning rate 1e-05
val loss: 48.76641380786896
accuracy:      0.773
precision:     0.777
recall:        0.773
f1:            0.772
val loss: 52.71045571565628
accuracy:      0.752
precision:     0.759
recall:        0.755
f1:            0.752
===== Start training: epoch 22 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.5427622795105, Epoch: 22, training loss: 384.9201135635376, current learning rate 1e-05
val loss: 43.73174077272415
accuracy:      0.773
precision:     0.775
recall:        0.773
f1:            0.772
val loss: 45.35509133338928
accuracy:      0.757
precision:     0.758
recall:        0.756
f1:            0.756
===== Start training: epoch 23 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.52764821052551, Epoch: 23, training loss: 384.32727801799774, current learning rate 1e-05
val loss: 40.28625202178955
accuracy:      0.769
precision:     0.770
recall:        0.770
f1:            0.769
val loss: 41.83689108490944
accuracy:      0.757
precision:     0.758
recall:        0.758
f1:            0.757
===== Start training: epoch 24 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:55,  5.85it/s]


Timing: 34.50801420211792, Epoch: 24, training loss: 378.70354425907135, current learning rate 1e-05
val loss: 38.315489679574966
accuracy:      0.770
precision:     0.773
recall:        0.770
f1:            0.769
val loss: 39.42047417163849
accuracy:      0.754
precision:     0.757
recall:        0.756
f1:            0.754
===== Start training: epoch 25 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:54,  5.86it/s]


Timing: 34.474698305130005, Epoch: 25, training loss: 363.5864671468735, current learning rate 1e-05
val loss: 42.65453678369522
accuracy:      0.768
precision:     0.769
recall:        0.768
f1:            0.768
val loss: 44.783642411231995
accuracy:      0.752
precision:     0.754
recall:        0.754
f1:            0.752
===== Start training: epoch 26 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.54640293121338, Epoch: 26, training loss: 350.4692507982254, current learning rate 1e-05
val loss: 40.161914706230164
accuracy:      0.761
precision:     0.761
recall:        0.761
f1:            0.761
val loss: 41.678820848464966
accuracy:      0.758
precision:     0.758
recall:        0.758
f1:            0.758
===== Start training: epoch 27 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:56,  5.85it/s]


Timing: 34.56043529510498, Epoch: 27, training loss: 354.1883648633957, current learning rate 1e-05
val loss: 43.42564219236374
accuracy:      0.768
precision:     0.768
recall:        0.768
f1:            0.768
val loss: 45.65847051143646
accuracy:      0.755
precision:     0.755
recall:        0.756
f1:            0.755
===== Start training: epoch 28 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.86it/s]


Timing: 34.451722621917725, Epoch: 28, training loss: 357.8649271726608, current learning rate 1e-05
val loss: 45.54308992624283
accuracy:      0.766
precision:     0.766
recall:        0.767
f1:            0.766
val loss: 48.61475372314453
accuracy:      0.750
precision:     0.750
recall:        0.751
f1:            0.750
===== Start training: epoch 29 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.87it/s]


Timing: 34.4326491355896, Epoch: 29, training loss: 367.76402175426483, current learning rate 1e-05
val loss: 45.41219866275787
accuracy:      0.765
precision:     0.766
recall:        0.765
f1:            0.765
val loss: 48.90185070037842
accuracy:      0.753
precision:     0.756
recall:        0.755
f1:            0.753
===== Start training: epoch 30 =====
*** trying with discovery_weight = 1.0000000000000002, adv_weight = 1


Iteration:   4%|▎         | 202/5443 [00:34<14:53,  5.87it/s]


Timing: 34.44231986999512, Epoch: 30, training loss: 368.35818445682526, current learning rate 1e-05
val loss: 46.75881069898605
accuracy:      0.765
precision:     0.765
recall:        0.765
f1:            0.765
val loss: 50.52084118127823
accuracy:      0.752
precision:     0.753
recall:        0.753
f1:            0.752
best result:
0.7571692876965772
0.7575869924432552
0.7555259785073449
0.7559470503683747


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[[0.7571692876965772, 0.7587490136882913, 0.7586414275855269, 0.7571679889196463], [0.7460684551341351, 0.7464640083816098, 0.7468500443655723, 0.7460227590045199], [0.7562442183163737, 0.7578218806820132, 0.7577146800749285, 0.7562429145917211], [0.7571692876965772, 0.7575869924432552, 0.7555259785073449, 0.7559470503683747]]
tensor([0.7590, 0.7596, 0.7595, 0.7588], dtype=torch.float64)


In [None]:
from google.colab import runtime
runtime.unassign()