<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [None]:
!pip install datasets



In [None]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [None]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [None]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [None]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }"""

    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'neither':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'no_relation':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], [], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(MARGWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'neither':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class NKWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(NKWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              try:
                ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
                ds_marker = ds_marker.replace("_", " ")
                ds_marker = ds_marker[0].upper() + ds_marker[1:]
                target = target[0].lower() + target[1:]
                target = ds_marker + " " + target
              except:
                pass

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'no_relation':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], [], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [None]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      embed_sent1_std = embed_sent1[:batch_size // 2, :, :]
      labels_std = torch.tensor(labels[:batch_size // 2]).to(device)

      if self.num_classes == 2:
        emb_attack = embed_sent1_std[(labels_std == torch.tensor([0,1]).to(device)).all(dim=1)]
        emb_support = embed_sent1_std[(labels_std == torch.tensor([1,0]).to(device)).all(dim=1)]
      else:
        emb_attack = embed_sent1_std[(labels_std == torch.tensor([0,1,0]).to(device)).all(dim=1)]
        emb_support = embed_sent1_std[(labels_std == torch.tensor([1,0,0]).to(device)).all(dim=1) | (labels_std == torch.tensor([0,0,1]).to(device)).all(dim=1)]

      samples_adv = H_sent[batch_size // 2:, :]
      embed_sent1_adv = embed_sent1[batch_size // 2:, :, :]
      labels_adv = torch.tensor(labels[batch_size // 2:]).to(device)

      emb_contr = embed_sent1_adv[(labels_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)]
      emb_other = embed_sent1_adv[(labels_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (labels_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      mean_grl_attack = GRLayer.apply(torch.mean(torch.cat([emb_attack, emb_contr], dim=0), dim=1), .01)
      mean_grl_support = GRLayer.apply(torch.mean(torch.cat([emb_support, emb_other], dim=0), dim=1), .01)

      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl_attack)
      support_prediction = self.support_linear(mean_grl_support)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class BaselineModelWithSentenceComparisonAndCue(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparisonAndCue, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.attention = attention
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*3, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    """self.linear_initial_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.linear_end_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)"""
    self.sigmoid = torch.nn.Sigmoid()

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    """self._init_weights(self.linear_initial_sent)
    self._init_weights(self.linear_end_sent)"""

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)[0]
    else:
      att_output = embed_sent1

    initial_sent1 = att_output[:,0,:]
    initial_sent2 = att_output[torch.arange(att_output.shape[0]), torch.argmax(segs_sent1 * torch.arange(att_output.shape[1], 0, -1).to(device), dim=-1)]
    end_sent2 = att_output[torch.arange(att_output.shape[0]), torch.sum(att_mask_sent1, dim=-1)-1]

    initial_sent1 = self.linear_initial_sent(initial_sent1)
    end_sent2 = self.linear_end_sent(end_sent2)
    gate = self.sigmoid(initial_sent1 + end_sent2)

    final_emb = initial_sent2 * gate

    predictions = self.linear_layer(final_emb)

    return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]
    self.attention = attention

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

      H_sent = torch.mean(att_output[0], dim=1)
    else:
      H_sent = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [None]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [10]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import LinearLR

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 3, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": False,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": [2.071, 1, 1.933], #[9.375, 1, 30], #10 [2.071, 1, 1.933]
    "lr": 1e-5
}

config = {
    "dataset": "nk", #"student_essay", #debate, m-arg
    "adversarial": False,
    "double_adversarial": False,
    "dataset_from_saved": False,
    "injection": True,
    "finetuning_discovery": False,
    "grid_search": False,
    "visualize": False,
    "train": True,
    "scheduler": False,
    "attention": False,
    "cue_gating": False
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3, rl=0):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)

            attack_len = torch.sum((targets == torch.tensor([0,1]).to(device)).all(dim=1)).item()
            support_len = torch.sum((targets == torch.tensor([1,0]).to(device)).all(dim=1)).item()
            contr_len = torch.sum((targets_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)).item()
            other_len = torch.sum((targets_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (targets_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)).item()

            attack_target = [[0,1]] * attack_len + [[1,0]] * contr_len
            support_target = [[0,1]] * support_len + [[1,0]] * other_len
            attack_target, support_target = torch.tensor(np.array(attack_target)).to(device), torch.tensor(np.array(support_target)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + rl*loss4 + rl*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
          if config["double_adversarial"]:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          else:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
          loss = loss_fn(out, labels.float())
          val_loss += loss.item()

          labels = labels.cpu().numpy().tolist()

          val_labels.extend(labels)
          if len(labels[0]) != 2:
            for pred in preds:
              if pred == 0:
                val_preds.append([1,0,0])
              elif pred == 1:
                val_preds.append([0,1,0])
              else:
                val_preds.append([0,0,1])
          else:
            val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()





def run(seed):
  set_random_seeds(seed)

  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/train_essay.txt"
    path_dev = "./data/student_essay/dev_essay.txt"
    path_test = "./data/student_essay/test_essay.txt"
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/train_debate_concept.txt"
    path_dev = "./data/debate/dev_debate_concept.txt"
    path_test = "./data/debate/test_debate_concept.txt"
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKWithDiscourseInjectionProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")

  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  if config["adversarial"] or config["double_adversarial"] or config["finetuning_discovery"]:
    df = datasets.load_dataset("discovery","discovery", trust_remote_code=True)
    adv_processor = DiscourseMarkerProcessor()
    if not config["dataset_from_saved"]:
      print("processing discourse marker dataset...")
      train_adv = adv_processor.process_dataset(df["train"])
      with open("./adv_dataset.pkl", "wb") as writer:
        pickle.dump(train_adv, writer)
    else:
      with open("./adv_dataset.pkl", "rb") as reader:
        train_adv = pickle.load(reader)

    data_train_tot = data_train + train_adv
    train_set_adv = dataset(train_adv)
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if config["double_adversarial"]:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = DoubleAdversarialNet()

  elif not config["adversarial"]:
    if config["finetuning_discovery"]:
      sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
      train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    else:
      train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

    if not config["cue_gating"]:
      model = BaselineModelWithSentenceComparison(attention=config["attention"])
    else:
      model = BaselineModelWithSentenceComparisonAndCue(attention=config["attention"])
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

  if config["dataset"] in ["m-arg","nk"]:
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0.8,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)
    range_local = np.arange(0.2, 1, 0.2)

    for rl in reversed(range_local):
      print(f"rl {rl}")
      for discovery_weight in reversed(range_disc):
        for adv_weight in range_adv:
          for epoch in range(args["epochs"]):
            print('===== Start training: epoch {} ====='.format(epoch + 1))
            print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}")
            train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight, rl=rl)
            dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
            test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
            if dev_f1 > best_dev_f1:
              best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
              best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
              #save model

          print('best result:')
          print(best_test_acc)
          print(best_test_pre)
          print(best_test_rec)
          print(best_test_f1)
          result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
          del model
          del optimizer

          set_random_seeds(seed)
          model = DoubleAdversarialNet()
          model = model.to(device)

          optimizer_grouped_parameters = [
            {
              "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
              "weight_decay": 0.01,
            },
            {
              "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
              "weight_decay": 0.0
            },
          ]
          optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

          best_acc = -1
          best_pre = -1
          best_rec = -1
          best_f1 = -1
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1
  else:
    if config["scheduler"]:
      scheduler = LinearLR(optimizer, start_factor=1, end_factor=1e-2, total_iters = 30)
    for epoch in range(args["epochs"]):
      if config["train"]:
        print('===== Start training: epoch {} ====='.format(epoch + 1))
        train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
        dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
        test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
        if config["scheduler"]:
          scheduler.step()
        if dev_f1 > best_dev_f1:
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
          best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
          torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

    if config["visualize"] and config["adversarial"]:
      model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
      visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

        #save model

    print('best result:')
    print(best_test_acc)
    print(best_test_pre)
    print(best_test_rec)
    print(best_test_f1)
    result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

  print(result_metrics)
  return result_metrics[0]

if __name__ == "__main__":
  results = []
  for seed in args["seed"]:
    print(f"**** trying with seed {seed} ****")
    result_metrics = run(seed)
    results.append(result_metrics)
  avg = torch.mean(torch.tensor(results), dim=0)
  print(avg)

**** trying with seed 0 ****


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
tokenizing...:  30%|███       | 442/1462 [00:00<00:01, 673.14it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
tokenizing...: 100%|██████████| 1462/1462 [00:02<00:00, 702.80it/s]


finished preprocessing examples in train


model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration: 100%|██████████| 19/19 [00:04<00:00,  4.24it/s]


Timing: 4.493077278137207, Epoch: 1, training loss: 31.204197883605957, current learning rate 1e-05
val loss: 3.4139972925186157
accuracy:      0.089
precision:     0.359
recall:        0.344
f1:            0.068
val loss: 3.3999083042144775
accuracy:      0.158
precision:     0.549
recall:        0.344
f1:            0.106
===== Start training: epoch 2 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.865694761276245, Epoch: 2, training loss: 30.483917236328125, current learning rate 1e-05
val loss: 3.2962663173675537
accuracy:      0.260
precision:     0.454
recall:        0.363
f1:            0.260
val loss: 3.3227115869522095
accuracy:      0.247
precision:     0.417
recall:        0.385
f1:            0.244
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.863015651702881, Epoch: 3, training loss: 29.66497027873993, current learning rate 1e-05
val loss: 3.089264750480652
accuracy:      0.500
precision:     0.437
recall:        0.442
f1:            0.410
val loss: 3.0555203557014465
accuracy:      0.514
precision:     0.482
recall:        0.446
f1:            0.430
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8555736541748047, Epoch: 4, training loss: 27.75483238697052, current learning rate 1e-05
val loss: 3.5622901916503906
accuracy:      0.363
precision:     0.471
recall:        0.492
f1:            0.354
val loss: 3.660959482192993
accuracy:      0.370
precision:     0.431
recall:        0.436
f1:            0.342
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8506925106048584, Epoch: 5, training loss: 26.24552023410797, current learning rate 1e-05
val loss: 3.7545593976974487
accuracy:      0.384
precision:     0.463
recall:        0.507
f1:            0.369
val loss: 3.407336115837097
accuracy:      0.411
precision:     0.459
recall:        0.455
f1:            0.372
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.853180408477783, Epoch: 6, training loss: 24.887107133865356, current learning rate 1e-05
val loss: 2.9812681674957275
accuracy:      0.555
precision:     0.502
recall:        0.551
f1:            0.493
val loss: 3.21249258518219
accuracy:      0.514
precision:     0.477
recall:        0.490
f1:            0.474
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.8614933490753174, Epoch: 7, training loss: 23.398633182048798, current learning rate 1e-05
val loss: 3.4661415815353394
accuracy:      0.438
precision:     0.493
recall:        0.519
f1:            0.414
val loss: 3.4384571313858032
accuracy:      0.404
precision:     0.472
recall:        0.448
f1:            0.396
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.855450391769409, Epoch: 8, training loss: 20.568612813949585, current learning rate 1e-05
val loss: 3.962704658508301
accuracy:      0.411
precision:     0.494
recall:        0.500
f1:            0.395
val loss: 4.008760452270508
accuracy:      0.397
precision:     0.461
recall:        0.450
f1:            0.382
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.8645567893981934, Epoch: 9, training loss: 18.98569357395172, current learning rate 1e-05
val loss: 3.749747633934021
accuracy:      0.493
precision:     0.476
recall:        0.532
f1:            0.449
val loss: 4.008379817008972
accuracy:      0.493
precision:     0.507
recall:        0.504
f1:            0.467
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.8711330890655518, Epoch: 10, training loss: 16.848671197891235, current learning rate 1e-05
val loss: 4.50736665725708
accuracy:      0.404
precision:     0.454
recall:        0.493
f1:            0.382
val loss: 4.786748290061951
accuracy:      0.397
precision:     0.472
recall:        0.432
f1:            0.390
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.64it/s]


Timing: 2.867959976196289, Epoch: 11, training loss: 15.882253348827362, current learning rate 1e-05
val loss: 4.082983374595642
accuracy:      0.452
precision:     0.472
recall:        0.528
f1:            0.421
val loss: 3.719148099422455
accuracy:      0.479
precision:     0.489
recall:        0.493
f1:            0.452
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.8602287769317627, Epoch: 12, training loss: 14.866149306297302, current learning rate 1e-05
val loss: 4.195038080215454
accuracy:      0.527
precision:     0.521
recall:        0.580
f1:            0.476
val loss: 4.135080099105835
accuracy:      0.479
precision:     0.481
recall:        0.486
f1:            0.455
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.68it/s]


Timing: 2.8509817123413086, Epoch: 13, training loss: 13.64595079421997, current learning rate 1e-05
val loss: 4.3637189865112305
accuracy:      0.521
precision:     0.491
recall:        0.528
f1:            0.466
val loss: 3.9105217456817627
accuracy:      0.555
precision:     0.526
recall:        0.529
f1:            0.511
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.86100697517395, Epoch: 14, training loss: 11.668573141098022, current learning rate 1e-05
val loss: 4.197120785713196
accuracy:      0.507
precision:     0.489
recall:        0.544
f1:            0.462
val loss: 3.790217638015747
accuracy:      0.527
precision:     0.520
recall:        0.516
f1:            0.483
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.856924533843994, Epoch: 15, training loss: 10.00363776087761, current learning rate 1e-05
val loss: 4.600186109542847
accuracy:      0.479
precision:     0.452
recall:        0.449
f1:            0.415
val loss: 4.703103065490723
accuracy:      0.534
precision:     0.492
recall:        0.494
f1:            0.461
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.857687473297119, Epoch: 16, training loss: 9.522124499082565, current learning rate 1e-05
val loss: 5.797942638397217
accuracy:      0.459
precision:     0.487
recall:        0.506
f1:            0.421
val loss: 5.6668864488601685
accuracy:      0.493
precision:     0.521
recall:        0.506
f1:            0.472
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.857222557067871, Epoch: 17, training loss: 8.882134050130844, current learning rate 1e-05
val loss: 6.507970213890076
accuracy:      0.404
precision:     0.440
recall:        0.472
f1:            0.372
val loss: 6.296023368835449
accuracy:      0.418
precision:     0.498
recall:        0.441
f1:            0.384
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.860247850418091, Epoch: 18, training loss: 8.424689710140228, current learning rate 1e-05
val loss: 5.552797198295593
accuracy:      0.486
precision:     0.473
recall:        0.503
f1:            0.438
val loss: 4.986376523971558
accuracy:      0.514
precision:     0.510
recall:        0.507
f1:            0.475
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8564693927764893, Epoch: 19, training loss: 7.38496907055378, current learning rate 1e-05
val loss: 4.703657746315002
accuracy:      0.507
precision:     0.479
recall:        0.517
f1:            0.454
val loss: 5.082331895828247
accuracy:      0.521
precision:     0.491
recall:        0.504
f1:            0.481
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.8602843284606934, Epoch: 20, training loss: 6.563810333609581, current learning rate 1e-05
val loss: 5.977556943893433
accuracy:      0.514
precision:     0.490
recall:        0.519
f1:            0.451
val loss: 5.515820860862732
accuracy:      0.514
precision:     0.503
recall:        0.501
f1:            0.480
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.864121675491333, Epoch: 21, training loss: 7.238286130130291, current learning rate 1e-05
val loss: 7.861491918563843
accuracy:      0.397
precision:     0.460
recall:        0.465
f1:            0.376
val loss: 6.911310434341431
accuracy:      0.459
precision:     0.521
recall:        0.485
f1:            0.432
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.62it/s]


Timing: 2.8773128986358643, Epoch: 22, training loss: 6.272215694189072, current learning rate 1e-05
val loss: 7.293327808380127
accuracy:      0.425
precision:     0.456
recall:        0.482
f1:            0.393
val loss: 6.207888722419739
accuracy:      0.459
precision:     0.493
recall:        0.490
f1:            0.443
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8554861545562744, Epoch: 23, training loss: 5.499895885586739, current learning rate 1e-05
val loss: 6.363942742347717
accuracy:      0.500
precision:     0.504
recall:        0.513
f1:            0.452
val loss: 5.164745450019836
accuracy:      0.507
precision:     0.509
recall:        0.501
f1:            0.468
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.8563425540924072, Epoch: 24, training loss: 5.414123676717281, current learning rate 1e-05
val loss: 6.0269235372543335
accuracy:      0.486
precision:     0.477
recall:        0.502
f1:            0.439
val loss: 5.370041370391846
accuracy:      0.521
precision:     0.511
recall:        0.509
f1:            0.472
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.856452465057373, Epoch: 25, training loss: 4.4660076424479485, current learning rate 1e-05
val loss: 6.357742667198181
accuracy:      0.486
precision:     0.460
recall:        0.476
f1:            0.429
val loss: 5.522213459014893
accuracy:      0.575
precision:     0.539
recall:        0.555
f1:            0.531
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.862321138381958, Epoch: 26, training loss: 4.563352487981319, current learning rate 1e-05
val loss: 8.618504285812378
accuracy:      0.370
precision:     0.435
recall:        0.421
f1:            0.344
val loss: 7.725826740264893
accuracy:      0.432
precision:     0.504
recall:        0.463
f1:            0.401
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.68it/s]


Timing: 2.850426197052002, Epoch: 27, training loss: 5.373469531536102, current learning rate 1e-05
val loss: 6.456541061401367
accuracy:      0.493
precision:     0.464
recall:        0.505
f1:            0.437
val loss: 6.204436540603638
accuracy:      0.541
precision:     0.515
recall:        0.533
f1:            0.509
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.68it/s]


Timing: 2.8499302864074707, Epoch: 28, training loss: 4.67329365760088, current learning rate 1e-05
val loss: 6.699783444404602
accuracy:      0.466
precision:     0.489
recall:        0.489
f1:            0.425
val loss: 5.662330269813538
accuracy:      0.507
precision:     0.532
recall:        0.521
f1:            0.468
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8527631759643555, Epoch: 29, training loss: 4.234860330820084, current learning rate 1e-05
val loss: 7.160452127456665
accuracy:      0.507
precision:     0.487
recall:        0.517
f1:            0.455
val loss: 6.345511198043823
accuracy:      0.527
precision:     0.513
recall:        0.515
f1:            0.480
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.862112045288086, Epoch: 30, training loss: 4.5181670263409615, current learning rate 1e-05
val loss: 7.4642064571380615
accuracy:      0.466
precision:     0.456
recall:        0.485
f1:            0.415
val loss: 6.822123050689697
accuracy:      0.493
precision:     0.507
recall:        0.496
f1:            0.468
best result:
0.5136986301369864
0.47656091691995944
0.4900363654233623
0.47446997446997435
[[0.5136986301369864, 0.47656091691995944, 0.4900363654233623, 0.47446997446997435]]
**** trying with seed 1 ****


Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
tokenizing...:  29%|██▉       | 423/1462 [00:00<00:01, 697.59it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
tokenizing...: 100%|██████████| 1462/1462 [00:02<00:00, 718.90it/s]


finished preprocessing examples in train


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.60it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Timing: 2.8846399784088135, Epoch: 1, training loss: 31.19434130191803, current learning rate 1e-05
val loss: 3.5073734521865845
accuracy:      0.075
precision:     0.025
recall:        0.333
f1:            0.047
val loss: 3.4650505781173706
accuracy:      0.144
precision:     0.048
recall:        0.333
f1:            0.084


  _warn_prf(average, modifier, msg_start, len(result))


===== Start training: epoch 2 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.854595899581909, Epoch: 2, training loss: 30.714690685272217, current learning rate 1e-05
val loss: 3.4281336069107056
accuracy:      0.075
precision:     0.026
recall:        0.333
f1:            0.048
val loss: 3.400838255882263
accuracy:      0.158
precision:     0.549
recall:        0.344
f1:            0.106
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.853154182434082, Epoch: 3, training loss: 30.291821479797363, current learning rate 1e-05
val loss: 3.410481810569763
accuracy:      0.171
precision:     0.369
recall:        0.375
f1:            0.165
val loss: 3.38311767578125
accuracy:      0.253
precision:     0.564
recall:        0.385
f1:            0.247
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.8705861568450928, Epoch: 4, training loss: 29.02133560180664, current learning rate 1e-05
val loss: 3.3657480478286743
accuracy:      0.397
precision:     0.466
recall:        0.466
f1:            0.376
val loss: 3.3381508588790894
accuracy:      0.411
precision:     0.447
recall:        0.452
f1:            0.397
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.61it/s]


Timing: 2.879105806350708, Epoch: 5, training loss: 26.791606664657593, current learning rate 1e-05
val loss: 3.3705999851226807
accuracy:      0.466
precision:     0.485
recall:        0.516
f1:            0.428
val loss: 3.3862005472183228
accuracy:      0.459
precision:     0.457
recall:        0.484
f1:            0.427
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.68it/s]


Timing: 2.8517515659332275, Epoch: 6, training loss: 25.065077781677246, current learning rate 1e-05
val loss: 3.425723910331726
accuracy:      0.479
precision:     0.500
recall:        0.524
f1:            0.442
val loss: 3.0756271481513977
accuracy:      0.514
precision:     0.522
recall:        0.539
f1:            0.486
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.8618032932281494, Epoch: 7, training loss: 23.40069568157196, current learning rate 1e-05
val loss: 4.0068124532699585
accuracy:      0.425
precision:     0.477
recall:        0.459
f1:            0.396
val loss: 3.5462833642959595
accuracy:      0.466
precision:     0.510
recall:        0.494
f1:            0.447
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.8715600967407227, Epoch: 8, training loss: 21.837164103984833, current learning rate 1e-05
val loss: 3.4674789905548096
accuracy:      0.473
precision:     0.476
recall:        0.465
f1:            0.416
val loss: 3.4705435037612915
accuracy:      0.486
precision:     0.472
recall:        0.484
f1:            0.460
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.60it/s]


Timing: 2.8847196102142334, Epoch: 9, training loss: 18.991918325424194, current learning rate 1e-05
val loss: 3.1695239543914795
accuracy:      0.541
precision:     0.480
recall:        0.518
f1:            0.476
val loss: 3.432681918144226
accuracy:      0.527
precision:     0.480
recall:        0.484
f1:            0.470
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.872606039047241, Epoch: 10, training loss: 17.198066890239716, current learning rate 1e-05
val loss: 4.193138003349304
accuracy:      0.452
precision:     0.477
recall:        0.473
f1:            0.386
val loss: 4.082877993583679
accuracy:      0.473
precision:     0.470
recall:        0.475
f1:            0.449
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.861100912094116, Epoch: 11, training loss: 16.141872704029083, current learning rate 1e-05
val loss: 3.8933013677597046
accuracy:      0.452
precision:     0.470
recall:        0.451
f1:            0.405
val loss: 3.951214909553528
accuracy:      0.507
precision:     0.504
recall:        0.529
f1:            0.487
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.8602023124694824, Epoch: 12, training loss: 14.266010403633118, current learning rate 1e-05
val loss: 3.865340232849121
accuracy:      0.514
precision:     0.446
recall:        0.471
f1:            0.443
val loss: 3.8521193265914917
accuracy:      0.527
precision:     0.480
recall:        0.487
f1:            0.475
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.8660361766815186, Epoch: 13, training loss: 13.543948352336884, current learning rate 1e-05
val loss: 3.489720344543457
accuracy:      0.562
precision:     0.484
recall:        0.504
f1:            0.476
val loss: 3.71255886554718
accuracy:      0.568
precision:     0.525
recall:        0.537
f1:            0.527
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.62it/s]


Timing: 2.8752200603485107, Epoch: 14, training loss: 12.076824575662613, current learning rate 1e-05
val loss: 4.6901843547821045
accuracy:      0.486
precision:     0.491
recall:        0.475
f1:            0.431
val loss: 4.52799379825592
accuracy:      0.521
precision:     0.504
recall:        0.528
f1:            0.494
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.61it/s]


Timing: 2.8799500465393066, Epoch: 15, training loss: 10.017003357410431, current learning rate 1e-05
val loss: 4.573856830596924
accuracy:      0.486
precision:     0.438
recall:        0.430
f1:            0.406
val loss: 5.141450047492981
accuracy:      0.527
precision:     0.508
recall:        0.478
f1:            0.452
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.8683834075927734, Epoch: 16, training loss: 10.537838906049728, current learning rate 1e-05
val loss: 5.869100570678711
accuracy:      0.445
precision:     0.491
recall:        0.496
f1:            0.407
val loss: 5.477409482002258
accuracy:      0.466
precision:     0.495
recall:        0.497
f1:            0.451
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.62it/s]


Timing: 2.8753929138183594, Epoch: 17, training loss: 9.224888637661934, current learning rate 1e-05
val loss: 5.122044086456299
accuracy:      0.507
precision:     0.467
recall:        0.466
f1:            0.438
val loss: 4.418789207935333
accuracy:      0.527
precision:     0.492
recall:        0.507
f1:            0.485
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.8626675605773926, Epoch: 18, training loss: 7.976049870252609, current learning rate 1e-05
val loss: 6.340992569923401
accuracy:      0.500
precision:     0.478
recall:        0.486
f1:            0.441
val loss: 5.7108234167099
accuracy:      0.521
precision:     0.505
recall:        0.516
f1:            0.488
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.62it/s]


Timing: 2.8734664916992188, Epoch: 19, training loss: 7.299920558929443, current learning rate 1e-05
val loss: 4.990623712539673
accuracy:      0.514
precision:     0.457
recall:        0.445
f1:            0.432
val loss: 5.551807284355164
accuracy:      0.548
precision:     0.503
recall:        0.522
f1:            0.499
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.870440721511841, Epoch: 20, training loss: 7.217814236879349, current learning rate 1e-05
val loss: 6.370915651321411
accuracy:      0.500
precision:     0.478
recall:        0.486
f1:            0.441
val loss: 5.963855743408203
accuracy:      0.493
precision:     0.484
recall:        0.493
f1:            0.462
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8558197021484375, Epoch: 21, training loss: 5.727564975619316, current learning rate 1e-05
val loss: 5.403678059577942
accuracy:      0.534
precision:     0.449
recall:        0.460
f1:            0.444
val loss: 5.767418622970581
accuracy:      0.555
precision:     0.501
recall:        0.508
f1:            0.500
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.854736089706421, Epoch: 22, training loss: 5.961778439581394, current learning rate 1e-05
val loss: 6.41021990776062
accuracy:      0.493
precision:     0.467
recall:        0.456
f1:            0.429
val loss: 6.333541035652161
accuracy:      0.514
precision:     0.490
recall:        0.505
f1:            0.471
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.859945058822632, Epoch: 23, training loss: 6.289074406027794, current learning rate 1e-05
val loss: 5.86121129989624
accuracy:      0.527
precision:     0.476
recall:        0.481
f1:            0.453
val loss: 5.58498227596283
accuracy:      0.555
precision:     0.506
recall:        0.519
f1:            0.504
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.8632588386535645, Epoch: 24, training loss: 4.883528679609299, current learning rate 1e-05
val loss: 5.791061043739319
accuracy:      0.507
precision:     0.459
recall:        0.466
f1:            0.436
val loss: 5.454177975654602
accuracy:      0.534
precision:     0.499
recall:        0.523
f1:            0.495
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.8591294288635254, Epoch: 25, training loss: 4.9701531156897545, current learning rate 1e-05
val loss: 5.170906186103821
accuracy:      0.555
precision:     0.465
recall:        0.476
f1:            0.463
val loss: 5.539413571357727
accuracy:      0.562
precision:     0.507
recall:        0.501
f1:            0.498
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.8708343505859375, Epoch: 26, training loss: 5.172488480806351, current learning rate 1e-05
val loss: 6.624522566795349
accuracy:      0.473
precision:     0.464
recall:        0.442
f1:            0.416
val loss: 6.633464097976685
accuracy:      0.493
precision:     0.501
recall:        0.500
f1:            0.454
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.86195969581604, Epoch: 27, training loss: 4.3296588733792305, current learning rate 1e-05
val loss: 6.4394543170928955
accuracy:      0.527
precision:     0.462
recall:        0.480
f1:            0.451
val loss: 5.243935585021973
accuracy:      0.541
precision:     0.487
recall:        0.497
f1:            0.488
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.8719189167022705, Epoch: 28, training loss: 4.542445212602615, current learning rate 1e-05
val loss: 7.97172737121582
accuracy:      0.459
precision:     0.430
recall:        0.435
f1:            0.397
val loss: 6.854867219924927
accuracy:      0.521
precision:     0.531
recall:        0.496
f1:            0.461
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.86370849609375, Epoch: 29, training loss: 5.2145635187625885, current learning rate 1e-05
val loss: 6.569410085678101
accuracy:      0.534
precision:     0.464
recall:        0.483
f1:            0.449
val loss: 5.8240262269973755
accuracy:      0.575
precision:     0.525
recall:        0.532
f1:            0.525
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.8595004081726074, Epoch: 30, training loss: 4.2445923909544945, current learning rate 1e-05
val loss: 6.366873860359192
accuracy:      0.527
precision:     0.450
recall:        0.456
f1:            0.441
val loss: 5.79164445400238
accuracy:      0.534
precision:     0.476
recall:        0.479
f1:            0.470
best result:
0.5273972602739726
0.47961027229319914
0.48413926974298493
0.4699904908199839
[[0.5273972602739726, 0.47961027229319914, 0.48413926974298493, 0.4699904908199839]]
**** trying with seed 2 ****


Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
tokenizing...:  29%|██▉       | 431/1462 [00:00<00:01, 719.78it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
tokenizing...: 100%|██████████| 1462/1462 [00:02<00:00, 730.54it/s]


finished preprocessing examples in train


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Timing: 2.8804070949554443, Epoch: 1, training loss: 31.2089684009552, current learning rate 1e-05
val loss: 3.4913746118545532
accuracy:      0.075
precision:     0.025
recall:        0.333
f1:            0.047
val loss: 3.433014750480652
accuracy:      0.144
precision:     0.048
recall:        0.333
f1:            0.084


  _warn_prf(average, modifier, msg_start, len(result))


===== Start training: epoch 2 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8529164791107178, Epoch: 2, training loss: 30.69534742832184, current learning rate 1e-05
val loss: 3.377427816390991
accuracy:      0.260
precision:     0.511
recall:        0.370
f1:            0.225
val loss: 3.27402663230896
accuracy:      0.295
precision:     0.532
recall:        0.310
f1:            0.234
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.64it/s]


Timing: 2.866852045059204, Epoch: 3, training loss: 29.950788140296936, current learning rate 1e-05
val loss: 3.569762349128723
accuracy:      0.144
precision:     0.250
recall:        0.381
f1:            0.130
val loss: 3.5061988830566406
accuracy:      0.171
precision:     0.234
recall:        0.347
f1:            0.135
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.60it/s]


Timing: 2.8814101219177246, Epoch: 4, training loss: 28.924482107162476, current learning rate 1e-05
val loss: 3.3367576599121094
accuracy:      0.370
precision:     0.451
recall:        0.468
f1:            0.355
val loss: 3.141453206539154
accuracy:      0.432
precision:     0.456
recall:        0.480
f1:            0.421
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.8654255867004395, Epoch: 5, training loss: 26.929117441177368, current learning rate 1e-05
val loss: 3.0497796535491943
accuracy:      0.555
precision:     0.486
recall:        0.527
f1:            0.487
val loss: 3.037655472755432
accuracy:      0.521
precision:     0.482
recall:        0.486
f1:            0.483
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.64it/s]


Timing: 2.8676414489746094, Epoch: 6, training loss: 24.85631275177002, current learning rate 1e-05
val loss: 2.9531906247138977
accuracy:      0.527
precision:     0.469
recall:        0.507
f1:            0.464
val loss: 3.006523370742798
accuracy:      0.527
precision:     0.489
recall:        0.502
f1:            0.492
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.8716700077056885, Epoch: 7, training loss: 22.20030528306961, current learning rate 1e-05
val loss: 3.532490611076355
accuracy:      0.466
precision:     0.468
recall:        0.485
f1:            0.416
val loss: 3.3589879274368286
accuracy:      0.514
precision:     0.504
recall:        0.522
f1:            0.487
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.60it/s]


Timing: 2.8839643001556396, Epoch: 8, training loss: 20.158899128437042, current learning rate 1e-05
val loss: 3.230968713760376
accuracy:      0.555
precision:     0.460
recall:        0.445
f1:            0.426
val loss: 3.1475512981414795
accuracy:      0.521
precision:     0.499
recall:        0.454
f1:            0.453
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.59it/s]


Timing: 2.889244318008423, Epoch: 9, training loss: 18.844865322113037, current learning rate 1e-05
val loss: 4.056120157241821
accuracy:      0.493
precision:     0.484
recall:        0.529
f1:            0.436
val loss: 3.7888846397399902
accuracy:      0.514
precision:     0.488
recall:        0.506
f1:            0.485
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.60it/s]


Timing: 2.8843607902526855, Epoch: 10, training loss: 17.03865772485733, current learning rate 1e-05
val loss: 3.2381234765052795
accuracy:      0.548
precision:     0.478
recall:        0.518
f1:            0.474
val loss: 3.044162631034851
accuracy:      0.596
precision:     0.545
recall:        0.536
f1:            0.540
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.59it/s]


Timing: 2.8910789489746094, Epoch: 11, training loss: 15.252784371376038, current learning rate 1e-05
val loss: 3.3165851831436157
accuracy:      0.562
precision:     0.467
recall:        0.477
f1:            0.457
val loss: 3.9167991876602173
accuracy:      0.575
precision:     0.540
recall:        0.505
f1:            0.508
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.860077142715454, Epoch: 12, training loss: 13.616906613111496, current learning rate 1e-05
val loss: 3.916303038597107
accuracy:      0.541
precision:     0.471
recall:        0.488
f1:            0.452
val loss: 3.4926775693893433
accuracy:      0.568
precision:     0.525
recall:        0.535
f1:            0.529
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.862355947494507, Epoch: 13, training loss: 12.896424114704132, current learning rate 1e-05
val loss: 4.224985837936401
accuracy:      0.514
precision:     0.457
recall:        0.469
f1:            0.436
val loss: 3.7973233461380005
accuracy:      0.562
precision:     0.514
recall:        0.527
f1:            0.515
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.8585493564605713, Epoch: 14, training loss: 11.077171206474304, current learning rate 1e-05
val loss: 5.699563503265381
accuracy:      0.486
precision:     0.473
recall:        0.475
f1:            0.427
val loss: 4.915755033493042
accuracy:      0.486
precision:     0.466
recall:        0.477
f1:            0.450
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.64it/s]


Timing: 2.8674468994140625, Epoch: 15, training loss: 9.897199034690857, current learning rate 1e-05
val loss: 4.610163569450378
accuracy:      0.548
precision:     0.479
recall:        0.468
f1:            0.450
val loss: 4.348693490028381
accuracy:      0.575
precision:     0.529
recall:        0.540
f1:            0.533
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8533554077148438, Epoch: 16, training loss: 8.904334396123886, current learning rate 1e-05
val loss: 5.245084047317505
accuracy:      0.534
precision:     0.483
recall:        0.484
f1:            0.453
val loss: 4.526586890220642
accuracy:      0.541
precision:     0.493
recall:        0.500
f1:            0.490
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.8620643615722656, Epoch: 17, training loss: 8.728659957647324, current learning rate 1e-05
val loss: 5.3413180112838745
accuracy:      0.507
precision:     0.478
recall:        0.462
f1:            0.423
val loss: 4.619990229606628
accuracy:      0.541
precision:     0.505
recall:        0.515
f1:            0.502
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.62it/s]


Timing: 2.8735928535461426, Epoch: 18, training loss: 7.428610093891621, current learning rate 1e-05
val loss: 4.691231191158295
accuracy:      0.534
precision:     0.472
recall:        0.457
f1:            0.434
val loss: 4.960529565811157
accuracy:      0.562
precision:     0.516
recall:        0.522
f1:            0.514
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.62it/s]


Timing: 2.8734281063079834, Epoch: 19, training loss: 7.509811267256737, current learning rate 1e-05
val loss: 5.442756414413452
accuracy:      0.521
precision:     0.459
recall:        0.449
f1:            0.434
val loss: 4.613005995750427
accuracy:      0.589
precision:     0.544
recall:        0.556
f1:            0.538
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.64it/s]


Timing: 2.8654470443725586, Epoch: 20, training loss: 6.246814429759979, current learning rate 1e-05
val loss: 5.649431228637695
accuracy:      0.514
precision:     0.465
recall:        0.469
f1:            0.438
val loss: 4.858753323554993
accuracy:      0.575
precision:     0.536
recall:        0.556
f1:            0.534
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.8592445850372314, Epoch: 21, training loss: 5.473095878958702, current learning rate 1e-05
val loss: 5.549524307250977
accuracy:      0.562
precision:     0.465
recall:        0.479
f1:            0.466
val loss: 4.660139381885529
accuracy:      0.589
precision:     0.527
recall:        0.503
f1:            0.506
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.857072353363037, Epoch: 22, training loss: 5.5331635400652885, current learning rate 1e-05
val loss: 4.977753162384033
accuracy:      0.521
precision:     0.435
recall:        0.423
f1:            0.414
val loss: 5.3882986307144165
accuracy:      0.589
precision:     0.525
recall:        0.528
f1:            0.526
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Timing: 2.860971212387085, Epoch: 23, training loss: 5.170472212135792, current learning rate 1e-05
val loss: 4.7204365730285645
accuracy:      0.541
precision:     0.451
recall:        0.463
f1:            0.446
val loss: 4.831705689430237
accuracy:      0.603
precision:     0.537
recall:        0.527
f1:            0.529
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8539586067199707, Epoch: 24, training loss: 4.995389312505722, current learning rate 1e-05
val loss: 5.6801438331604
accuracy:      0.562
precision:     0.467
recall:        0.478
f1:            0.470
val loss: 5.959479451179504
accuracy:      0.575
precision:     0.523
recall:        0.486
f1:            0.491
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.857699155807495, Epoch: 25, training loss: 5.287065714597702, current learning rate 1e-05
val loss: 6.321275472640991
accuracy:      0.527
precision:     0.495
recall:        0.502
f1:            0.447
val loss: 5.526328444480896
accuracy:      0.575
precision:     0.533
recall:        0.544
f1:            0.532
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Timing: 2.8694908618927, Epoch: 26, training loss: 4.489842489361763, current learning rate 1e-05
val loss: 6.200607776641846
accuracy:      0.548
precision:     0.453
recall:        0.469
f1:            0.457
val loss: 5.261836886405945
accuracy:      0.555
precision:     0.499
recall:        0.480
f1:            0.485
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.67it/s]


Timing: 2.8533124923706055, Epoch: 27, training loss: 4.28056214004755, current learning rate 1e-05
val loss: 5.4391350746154785
accuracy:      0.534
precision:     0.442
recall:        0.458
f1:            0.444
val loss: 5.412554740905762
accuracy:      0.589
precision:     0.531
recall:        0.506
f1:            0.511
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.69it/s]


Timing: 2.8436081409454346, Epoch: 28, training loss: 4.045589316636324, current learning rate 1e-05
val loss: 6.94526481628418
accuracy:      0.500
precision:     0.475
recall:        0.460
f1:            0.431
val loss: 5.399985074996948
accuracy:      0.541
precision:     0.506
recall:        0.517
f1:            0.492
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Timing: 2.862004280090332, Epoch: 29, training loss: 4.274469114840031, current learning rate 1e-05
val loss: 6.812833666801453
accuracy:      0.521
precision:     0.462
recall:        0.474
f1:            0.443
val loss: 5.858015418052673
accuracy:      0.575
precision:     0.521
recall:        0.527
f1:            0.521
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.68it/s]


Timing: 2.8484904766082764, Epoch: 30, training loss: 4.247377675026655, current learning rate 1e-05
val loss: 6.225400447845459
accuracy:      0.555
precision:     0.463
recall:        0.472
f1:            0.457
val loss: 5.854234933853149
accuracy:      0.541
precision:     0.516
recall:        0.473
f1:            0.483
best result:
0.5205479452054794
0.48222743203561613
0.485859255983095
0.4831168831168831
[[0.5205479452054794, 0.48222743203561613, 0.485859255983095, 0.4831168831168831]]
tensor([0.5205, 0.4795, 0.4867, 0.4759], dtype=torch.float64)


In [None]:
from google.colab import runtime
runtime.unassign()