<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [2]:
!pip install datasets



In [3]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [4]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [5]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [6]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }"""

    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'neither':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'no_relation':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], [], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(MARGWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'neither':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [7]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      embed_sent1_std = embed_sent1[:batch_size // 2, :, :]
      labels_std = torch.tensor(labels[:batch_size // 2]).to(device)

      if self.num_classes == 2:
        emb_attack = embed_sent1_std[(labels_std == torch.tensor([0,1]).to(device)).all(dim=1)]
        emb_support = embed_sent1_std[(labels_std == torch.tensor([1,0]).to(device)).all(dim=1)]
      else:
        emb_attack = embed_sent1_std[(labels_std == torch.tensor([0,1,0]).to(device)).all(dim=1)]
        emb_support = embed_sent1_std[(labels_std == torch.tensor([1,0,0]).to(device)).all(dim=1) | (labels_std == torch.tensor([0,0,1]).to(device)).all(dim=1)]

      samples_adv = H_sent[batch_size // 2:, :]
      embed_sent1_adv = embed_sent1[batch_size // 2:, :, :]
      labels_adv = torch.tensor(labels[batch_size // 2:]).to(device)

      emb_contr = embed_sent1_adv[(labels_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)]
      emb_other = embed_sent1_adv[(labels_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (labels_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      mean_grl_attack = GRLayer.apply(torch.mean(torch.cat([emb_attack, emb_contr], dim=0), dim=1), .01)
      mean_grl_support = GRLayer.apply(torch.mean(torch.cat([emb_support, emb_other], dim=0), dim=1), .01)

      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl_attack)
      support_prediction = self.support_linear(mean_grl_support)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class BaselineModelWithSentenceComparisonAndCue(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparisonAndCue, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.attention = attention
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*3, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    """self.linear_initial_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.linear_end_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)"""
    self.sigmoid = torch.nn.Sigmoid()

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    """self._init_weights(self.linear_initial_sent)
    self._init_weights(self.linear_end_sent)"""

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)[0]
    else:
      att_output = embed_sent1

    initial_sent1 = att_output[:,0,:]
    initial_sent2 = att_output[torch.arange(att_output.shape[0]), torch.argmax(segs_sent1 * torch.arange(att_output.shape[1], 0, -1).to(device), dim=-1)]
    end_sent2 = att_output[torch.arange(att_output.shape[0]), torch.sum(att_mask_sent1, dim=-1)-1]

    initial_sent1 = self.linear_initial_sent(initial_sent1)
    end_sent2 = self.linear_end_sent(end_sent2)
    gate = self.sigmoid(initial_sent1 + end_sent2)

    final_emb = initial_sent2 * gate

    predictions = self.linear_layer(final_emb)

    return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]
    self.attention = attention

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

      H_sent = torch.mean(att_output[0], dim=1)
    else:
      H_sent = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [8]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [9]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [10]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import LinearLR

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 3, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": False,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": [9.375, 1, 30], #[9.375, 1, 30], #10 [2.071, 1, 1.933]
    "lr": 1e-5
}

config = {
    "dataset": "m-arg", #"student_essay", #debate, m-arg
    "adversarial": False,
    "double_adversarial": False,
    "dataset_from_saved": False,
    "injection": True,
    "finetuning_discovery": False,
    "grid_search": False,
    "visualize": False,
    "train": True,
    "scheduler": False,
    "attention": False,
    "cue_gating": False
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3, rl=0):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)

            attack_len = torch.sum((targets == torch.tensor([0,1]).to(device)).all(dim=1)).item()
            support_len = torch.sum((targets == torch.tensor([1,0]).to(device)).all(dim=1)).item()
            contr_len = torch.sum((targets_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)).item()
            other_len = torch.sum((targets_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (targets_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)).item()

            attack_target = [[0,1]] * attack_len + [[1,0]] * contr_len
            support_target = [[0,1]] * support_len + [[1,0]] * other_len
            attack_target, support_target = torch.tensor(np.array(attack_target)).to(device), torch.tensor(np.array(support_target)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + rl*loss4 + rl*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
          if config["double_adversarial"]:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          else:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
          loss = loss_fn(out, labels.float())
          val_loss += loss.item()

          labels = labels.cpu().numpy().tolist()

          val_labels.extend(labels)
          if len(labels[0]) != 2:
            for pred in preds:
              if pred == 0:
                val_preds.append([1,0,0])
              elif pred == 1:
                val_preds.append([0,1,0])
              else:
                val_preds.append([0,0,1])
          else:
            val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()





def run(seed):
  set_random_seeds(seed)

  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/train_essay.txt"
    path_dev = "./data/student_essay/dev_essay.txt"
    path_test = "./data/student_essay/test_essay.txt"
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/train_debate_concept.txt"
    path_dev = "./data/debate/dev_debate_concept.txt"
    path_test = "./data/debate/test_debate_concept.txt"
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKWithDiscourseInjectionProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")

  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  if config["adversarial"] or config["double_adversarial"] or config["finetuning_discovery"]:
    df = datasets.load_dataset("discovery","discovery", trust_remote_code=True)
    adv_processor = DiscourseMarkerProcessor()
    if not config["dataset_from_saved"]:
      print("processing discourse marker dataset...")
      train_adv = adv_processor.process_dataset(df["train"])
      with open("./adv_dataset.pkl", "wb") as writer:
        pickle.dump(train_adv, writer)
    else:
      with open("./adv_dataset.pkl", "rb") as reader:
        train_adv = pickle.load(reader)

    data_train_tot = data_train + train_adv
    train_set_adv = dataset(train_adv)
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if config["double_adversarial"]:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = DoubleAdversarialNet()

  elif not config["adversarial"]:
    if config["finetuning_discovery"]:
      sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
      train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    else:
      train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

    if not config["cue_gating"]:
      model = BaselineModelWithSentenceComparison(attention=config["attention"])
    else:
      model = BaselineModelWithSentenceComparisonAndCue(attention=config["attention"])
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

  if config["dataset"] in ["m-arg","nk"]:
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0.8,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)
    range_local = np.arange(0.2, 1, 0.2)

    for rl in reversed(range_local):
      print(f"rl {rl}")
      for discovery_weight in reversed(range_disc):
        for adv_weight in range_adv:
          for epoch in range(args["epochs"]):
            print('===== Start training: epoch {} ====='.format(epoch + 1))
            print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}")
            train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight, rl=rl)
            dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
            test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
            if dev_f1 > best_dev_f1:
              best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
              best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
              #save model

          print('best result:')
          print(best_test_acc)
          print(best_test_pre)
          print(best_test_rec)
          print(best_test_f1)
          result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
          del model
          del optimizer

          set_random_seeds(seed)
          model = DoubleAdversarialNet()
          model = model.to(device)

          optimizer_grouped_parameters = [
            {
              "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
              "weight_decay": 0.01,
            },
            {
              "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
              "weight_decay": 0.0
            },
          ]
          optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

          best_acc = -1
          best_pre = -1
          best_rec = -1
          best_f1 = -1
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1
  else:
    if config["scheduler"]:
      scheduler = LinearLR(optimizer, start_factor=1, end_factor=1e-2, total_iters = 30)
    for epoch in range(args["epochs"]):
      if config["train"]:
        print('===== Start training: epoch {} ====='.format(epoch + 1))
        train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
        dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
        test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
        if config["scheduler"]:
          scheduler.step()
        if dev_f1 > best_dev_f1:
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
          best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
          torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

    if config["visualize"] and config["adversarial"]:
      model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
      visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

        #save model

    print('best result:')
    print(best_test_acc)
    print(best_test_pre)
    print(best_test_rec)
    print(best_test_f1)
    result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

  print(result_metrics)
  return result_metrics[0]

if __name__ == "__main__":
  results = []
  for seed in args["seed"]:
    print(f"**** trying with seed {seed} ****")
    result_metrics = run(seed)
    results.append(result_metrics)
  avg = torch.mean(torch.tensor(results), dim=0)
  print(avg)

**** trying with seed 0 ****


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/8.79k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizing...: 100%|██████████| 3283/3283 [00:01<00:00, 3006.87it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 410/410 [00:00<00:00, 2995.20it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 411/411 [00:00<00:00, 3001.29it/s]


finished preprocessing examples in test


model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration: 100%|██████████| 52/52 [00:09<00:00,  5.57it/s]


Timing: 9.3379225730896, Epoch: 1, training loss: 147.0894811153412, current learning rate 1e-05
val loss: 7.193751692771912
accuracy:      0.461
precision:     0.376
recall:        0.420
f1:            0.302
val loss: 7.13348925113678
accuracy:      0.453
precision:     0.396
recall:        0.441
f1:            0.304
===== Start training: epoch 2 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.9294445514678955, Epoch: 2, training loss: 130.06633305549622, current learning rate 1e-05
val loss: 8.325748920440674
accuracy:      0.327
precision:     0.387
recall:        0.486
f1:            0.254
val loss: 8.471787691116333
accuracy:      0.336
precision:     0.399
recall:        0.517
f1:            0.275
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.54it/s]


Timing: 7.958324909210205, Epoch: 3, training loss: 106.42850756645203, current learning rate 1e-05
val loss: 6.873387694358826
accuracy:      0.500
precision:     0.399
recall:        0.603
f1:            0.351
val loss: 6.91042685508728
accuracy:      0.489
precision:     0.420
recall:        0.591
f1:            0.368
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.933345317840576, Epoch: 4, training loss: 88.20581924915314, current learning rate 1e-05
val loss: 9.706779479980469
accuracy:      0.368
precision:     0.407
recall:        0.612
f1:            0.293
val loss: 9.883894205093384
accuracy:      0.353
precision:     0.416
recall:        0.566
f1:            0.295
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.933913230895996, Epoch: 5, training loss: 73.16095823049545, current learning rate 1e-05
val loss: 6.904192328453064
accuracy:      0.529
precision:     0.430
recall:        0.637
f1:            0.393
val loss: 7.2453155517578125
accuracy:      0.523
precision:     0.430
recall:        0.603
f1:            0.389
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.9237823486328125, Epoch: 6, training loss: 60.268105924129486, current learning rate 1e-05
val loss: 7.731159806251526
accuracy:      0.517
precision:     0.421
recall:        0.603
f1:            0.380
val loss: 7.6553977727890015
accuracy:      0.504
precision:     0.423
recall:        0.574
f1:            0.372
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.54it/s]


Timing: 7.955434083938599, Epoch: 7, training loss: 48.24394774436951, current learning rate 1e-05
val loss: 5.6827552318573
accuracy:      0.661
precision:     0.449
recall:        0.628
f1:            0.461
val loss: 5.306075036525726
accuracy:      0.645
precision:     0.441
recall:        0.580
f1:            0.444
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.930000305175781, Epoch: 8, training loss: 37.36746799945831, current learning rate 1e-05
val loss: 4.710838496685028
accuracy:      0.722
precision:     0.478
recall:        0.614
f1:            0.505
val loss: 4.489882051944733
accuracy:      0.769
precision:     0.476
recall:        0.565
f1:            0.501
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.909263610839844, Epoch: 9, training loss: 27.686334908008575, current learning rate 1e-05
val loss: 5.122488081455231
accuracy:      0.722
precision:     0.456
recall:        0.548
f1:            0.465
val loss: 4.665467143058777
accuracy:      0.764
precision:     0.495
recall:        0.563
f1:            0.508
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.55it/s]


Timing: 7.938030242919922, Epoch: 10, training loss: 23.40216751396656, current learning rate 1e-05
val loss: 6.258887887001038
accuracy:      0.690
precision:     0.467
recall:        0.624
f1:            0.488
val loss: 6.083345592021942
accuracy:      0.708
precision:     0.455
recall:        0.541
f1:            0.465
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.909978866577148, Epoch: 11, training loss: 24.067036144435406, current learning rate 1e-05
val loss: 7.699003159999847
accuracy:      0.637
precision:     0.463
recall:        0.589
f1:            0.460
val loss: 7.1305689215660095
accuracy:      0.667
precision:     0.460
recall:        0.546
f1:            0.455
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.914757013320923, Epoch: 12, training loss: 17.01983516290784, current learning rate 1e-05
val loss: 4.711904853582382
accuracy:      0.800
precision:     0.460
recall:        0.482
f1:            0.466
val loss: 4.576435476541519
accuracy:      0.842
precision:     0.560
recall:        0.558
f1:            0.557
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.91379714012146, Epoch: 13, training loss: 14.872959055006504, current learning rate 1e-05
val loss: 6.237966597080231
accuracy:      0.776
precision:     0.513
recall:        0.539
f1:            0.515
val loss: 5.47554150223732
accuracy:      0.820
precision:     0.537
recall:        0.570
f1:            0.547
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.90642523765564, Epoch: 14, training loss: 9.874190427362919, current learning rate 1e-05
val loss: 6.0590842962265015
accuracy:      0.829
precision:     0.449
recall:        0.426
f1:            0.436
val loss: 5.116461217403412
accuracy:      0.847
precision:     0.535
recall:        0.461
f1:            0.487
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.920970678329468, Epoch: 15, training loss: 12.203818073496222, current learning rate 1e-05
val loss: 6.82674914598465
accuracy:      0.756
precision:     0.492
recall:        0.546
f1:            0.501
val loss: 6.185649633407593
accuracy:      0.798
precision:     0.506
recall:        0.583
f1:            0.532
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.878312587738037, Epoch: 16, training loss: 13.474635933991522, current learning rate 1e-05
val loss: 6.804081380367279
accuracy:      0.783
precision:     0.470
recall:        0.534
f1:            0.492
val loss: 6.745380997657776
accuracy:      0.800
precision:     0.503
recall:        0.577
f1:            0.529
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.891619682312012, Epoch: 17, training loss: 11.036341436207294, current learning rate 1e-05
val loss: 7.374230086803436
accuracy:      0.795
precision:     0.520
recall:        0.575
f1:            0.538
val loss: 6.759588181972504
accuracy:      0.793
precision:     0.496
recall:        0.532
f1:            0.508
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.916619539260864, Epoch: 18, training loss: 8.620599560905248, current learning rate 1e-05
val loss: 6.3307700753211975
accuracy:      0.810
precision:     0.520
recall:        0.515
f1:            0.510
val loss: 6.90876841545105
accuracy:      0.805
precision:     0.493
recall:        0.501
f1:            0.494
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.922016382217407, Epoch: 19, training loss: 10.20542939985171, current learning rate 1e-05
val loss: 8.01700085401535
accuracy:      0.790
precision:     0.533
recall:        0.574
f1:            0.538
val loss: 7.23639515042305
accuracy:      0.793
precision:     0.499
recall:        0.539
f1:            0.512
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.884253263473511, Epoch: 20, training loss: 9.427664386108518, current learning rate 1e-05
val loss: 6.446832239627838
accuracy:      0.846
precision:     0.514
recall:        0.462
f1:            0.482
val loss: 6.839960157871246
accuracy:      0.837
precision:     0.499
recall:        0.444
f1:            0.461
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.884194850921631, Epoch: 21, training loss: 10.807841839268804, current learning rate 1e-05
val loss: 7.136853456497192
accuracy:      0.817
precision:     0.502
recall:        0.591
f1:            0.532
val loss: 7.722889840602875
accuracy:      0.774
precision:     0.460
recall:        0.517
f1:            0.478
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.886223793029785, Epoch: 22, training loss: 7.001145348884165, current learning rate 1e-05
val loss: 6.975810468196869
accuracy:      0.829
precision:     0.511
recall:        0.500
f1:            0.504
val loss: 7.619342029094696
accuracy:      0.813
precision:     0.473
recall:        0.462
f1:            0.465
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.61it/s]


Timing: 7.871165037155151, Epoch: 23, training loss: 11.194366950774565, current learning rate 1e-05
val loss: 7.611486613750458
accuracy:      0.810
precision:     0.514
recall:        0.463
f1:            0.473
val loss: 8.099049150943756
accuracy:      0.818
precision:     0.524
recall:        0.520
f1:            0.520
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.883070468902588, Epoch: 24, training loss: 7.121472723083571, current learning rate 1e-05
val loss: 8.09598022699356
accuracy:      0.822
precision:     0.457
recall:        0.453
f1:            0.455
val loss: 7.914254605770111
accuracy:      0.810
precision:     0.463
recall:        0.468
f1:            0.465
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.895994186401367, Epoch: 25, training loss: 7.585045918356627, current learning rate 1e-05
val loss: 8.901235699653625
accuracy:      0.790
precision:     0.477
recall:        0.507
f1:            0.486
val loss: 8.630625367164612
accuracy:      0.818
precision:     0.513
recall:        0.541
f1:            0.525
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.890371799468994, Epoch: 26, training loss: 8.227911744848825, current learning rate 1e-05
val loss: 7.950063347816467
accuracy:      0.822
precision:     0.475
recall:        0.519
f1:            0.491
val loss: 8.57843405008316
accuracy:      0.813
precision:     0.483
recall:        0.498
f1:            0.488
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.887990474700928, Epoch: 27, training loss: 4.743834030581638, current learning rate 1e-05
val loss: 10.005316317081451
accuracy:      0.783
precision:     0.488
recall:        0.497
f1:            0.484
val loss: 9.825084567070007
accuracy:      0.803
precision:     0.498
recall:        0.521
f1:            0.503
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.61it/s]


Timing: 7.874048471450806, Epoch: 28, training loss: 7.033442497137003, current learning rate 1e-05
val loss: 8.261012822389603
accuracy:      0.820
precision:     0.503
recall:        0.496
f1:            0.497
val loss: 9.226670861244202
accuracy:      0.815
precision:     0.511
recall:        0.505
f1:            0.507
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.89576268196106, Epoch: 29, training loss: 8.994894422125071, current learning rate 1e-05
val loss: 10.121894896030426
accuracy:      0.780
precision:     0.477
recall:        0.541
f1:            0.497
val loss: 9.852149605751038
accuracy:      0.769
precision:     0.470
recall:        0.544
f1:            0.493
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.62it/s]


Timing: 7.862269639968872, Epoch: 30, training loss: 6.736365987104364, current learning rate 1e-05
val loss: 9.264225125312805
accuracy:      0.812
precision:     0.441
recall:        0.457
f1:            0.448
val loss: 9.186829566955566
accuracy:      0.822
precision:     0.492
recall:        0.487
f1:            0.488
best result:
0.7931873479318735
0.496268406961178
0.531988162496637
0.5081638765119704
[[0.7931873479318735, 0.496268406961178, 0.531988162496637, 0.5081638765119704]]
**** trying with seed 1 ****


tokenizing...: 100%|██████████| 3283/3283 [00:01<00:00, 2964.29it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 410/410 [00:00<00:00, 2617.61it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 411/411 [00:00<00:00, 2845.48it/s]
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


finished preprocessing examples in test




===== Start training: epoch 1 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.936492443084717, Epoch: 1, training loss: 145.08375298976898, current learning rate 1e-05
val loss: 7.312140583992004
accuracy:      0.393
precision:     0.517
recall:        0.431
f1:            0.294
val loss: 7.322839379310608
accuracy:      0.399
precision:     0.698
recall:        0.449
f1:            0.293
===== Start training: epoch 2 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.922915935516357, Epoch: 2, training loss: 128.24933052062988, current learning rate 1e-05
val loss: 6.398347020149231
accuracy:      0.498
precision:     0.398
recall:        0.469
f1:            0.338
val loss: 6.419110715389252
accuracy:      0.484
precision:     0.446
recall:        0.582
f1:            0.385
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.9236297607421875, Epoch: 3, training loss: 109.51276749372482, current learning rate 1e-05
val loss: 6.869356393814087
accuracy:      0.493
precision:     0.411
recall:        0.579
f1:            0.351
val loss: 6.91579806804657
accuracy:      0.467
precision:     0.420
recall:        0.575
f1:            0.353
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.931408643722534, Epoch: 4, training loss: 92.37154018878937, current learning rate 1e-05
val loss: 4.143815279006958
accuracy:      0.707
precision:     0.505
recall:        0.543
f1:            0.482
val loss: 4.034764647483826
accuracy:      0.723
precision:     0.511
recall:        0.567
f1:            0.488
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.928119659423828, Epoch: 5, training loss: 77.58390688896179, current learning rate 1e-05
val loss: 6.092163801193237
accuracy:      0.580
precision:     0.423
recall:        0.568
f1:            0.399
val loss: 6.117356061935425
accuracy:      0.569
precision:     0.433
recall:        0.586
f1:            0.412
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.53it/s]


Timing: 7.9714579582214355, Epoch: 6, training loss: 55.40871560573578, current learning rate 1e-05
val loss: 4.933116972446442
accuracy:      0.717
precision:     0.435
recall:        0.583
f1:            0.456
val loss: 4.641362547874451
accuracy:      0.732
precision:     0.472
recall:        0.600
f1:            0.494
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.55it/s]


Timing: 7.940473318099976, Epoch: 7, training loss: 46.299430310726166, current learning rate 1e-05
val loss: 6.354408264160156
accuracy:      0.610
precision:     0.413
recall:        0.520
f1:            0.400
val loss: 5.865332126617432
accuracy:      0.623
precision:     0.438
recall:        0.564
f1:            0.428
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.9179511070251465, Epoch: 8, training loss: 37.0026752948761, current learning rate 1e-05
val loss: 5.166388988494873
accuracy:      0.722
precision:     0.437
recall:        0.562
f1:            0.461
val loss: 4.7980934381484985
accuracy:      0.742
precision:     0.457
recall:        0.554
f1:            0.480
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.919906139373779, Epoch: 9, training loss: 34.792365193367004, current learning rate 1e-05
val loss: 6.465141654014587
accuracy:      0.661
precision:     0.408
recall:        0.525
f1:            0.413
val loss: 5.48281791806221
accuracy:      0.706
precision:     0.448
recall:        0.568
f1:            0.464
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.916241884231567, Epoch: 10, training loss: 31.37737514078617, current learning rate 1e-05
val loss: 4.206349670886993
accuracy:      0.810
precision:     0.454
recall:        0.471
f1:            0.461
val loss: 3.4891255870461464
accuracy:      0.861
precision:     0.591
recall:        0.551
f1:            0.569
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.887922525405884, Epoch: 11, training loss: 20.77421109378338, current learning rate 1e-05
val loss: 5.410552620887756
accuracy:      0.754
precision:     0.442
recall:        0.457
f1:            0.426
val loss: 4.896623492240906
accuracy:      0.771
precision:     0.480
recall:        0.544
f1:            0.485
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.915792942047119, Epoch: 12, training loss: 19.03856361657381, current learning rate 1e-05
val loss: 5.569486498832703
accuracy:      0.805
precision:     0.457
recall:        0.484
f1:            0.467
val loss: 4.377059310674667
accuracy:      0.827
precision:     0.503
recall:        0.503
f1:            0.502
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.901846647262573, Epoch: 13, training loss: 18.34792796894908, current learning rate 1e-05
val loss: 6.350742965936661
accuracy:      0.788
precision:     0.435
recall:        0.477
f1:            0.450
val loss: 5.456903696060181
accuracy:      0.796
precision:     0.490
recall:        0.561
f1:            0.514
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.890896558761597, Epoch: 14, training loss: 15.079479359090328, current learning rate 1e-05
val loss: 6.041100174188614
accuracy:      0.783
precision:     0.442
recall:        0.483
f1:            0.455
val loss: 5.878285586833954
accuracy:      0.783
precision:     0.482
recall:        0.535
f1:            0.500
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.898331165313721, Epoch: 15, training loss: 12.313224511221051, current learning rate 1e-05
val loss: 6.473724782466888
accuracy:      0.800
precision:     0.429
recall:        0.452
f1:            0.438
val loss: 5.701869606971741
accuracy:      0.825
precision:     0.516
recall:        0.545
f1:            0.525
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.908026695251465, Epoch: 16, training loss: 11.092411693185568, current learning rate 1e-05
val loss: 7.0492249727249146
accuracy:      0.807
precision:     0.464
recall:        0.433
f1:            0.434
val loss: 6.287605702877045
accuracy:      0.839
precision:     0.530
recall:        0.486
f1:            0.499
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.906275987625122, Epoch: 17, training loss: 13.977989174425602, current learning rate 1e-05
val loss: 8.404581665992737
accuracy:      0.741
precision:     0.422
recall:        0.467
f1:            0.430
val loss: 7.688180327415466
accuracy:      0.759
precision:     0.467
recall:        0.547
f1:            0.489
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.891111850738525, Epoch: 18, training loss: 9.495662368368357, current learning rate 1e-05
val loss: 6.6700257658958435
accuracy:      0.837
precision:     0.450
recall:        0.429
f1:            0.433
val loss: 7.181581556797028
accuracy:      0.803
precision:     0.427
recall:        0.431
f1:            0.412
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.895415306091309, Epoch: 19, training loss: 17.28765122126788, current learning rate 1e-05
val loss: 7.419873774051666
accuracy:      0.817
precision:     0.414
recall:        0.407
f1:            0.410
val loss: 6.367416143417358
accuracy:      0.830
precision:     0.502
recall:        0.490
f1:            0.493
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.8947179317474365, Epoch: 20, training loss: 6.674899747013114, current learning rate 1e-05
val loss: 8.352261662483215
accuracy:      0.798
precision:     0.433
recall:        0.451
f1:            0.437
val loss: 7.217733204364777
accuracy:      0.805
precision:     0.470
recall:        0.459
f1:            0.460
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.884488105773926, Epoch: 21, training loss: 10.242583242885303, current learning rate 1e-05
val loss: 8.465691447257996
accuracy:      0.793
precision:     0.449
recall:        0.486
f1:            0.463
val loss: 8.33640593290329
accuracy:      0.798
precision:     0.491
recall:        0.534
f1:            0.507
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.885831594467163, Epoch: 22, training loss: 12.277486198116094, current learning rate 1e-05
val loss: 8.908531248569489
accuracy:      0.805
precision:     0.452
recall:        0.484
f1:            0.465
val loss: 8.270353257656097
accuracy:      0.786
precision:     0.461
recall:        0.494
f1:            0.474
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.906910419464111, Epoch: 23, training loss: 6.456515496363863, current learning rate 1e-05
val loss: 8.852204740047455
accuracy:      0.812
precision:     0.443
recall:        0.457
f1:            0.447
val loss: 8.145296931266785
accuracy:      0.822
precision:     0.488
recall:        0.480
f1:            0.481
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.61it/s]


Timing: 7.8755998611450195, Epoch: 24, training loss: 7.984134722268209, current learning rate 1e-05
val loss: 7.567614942789078
accuracy:      0.820
precision:     0.438
recall:        0.430
f1:            0.432
val loss: 8.270854353904724
accuracy:      0.835
precision:     0.522
recall:        0.484
f1:            0.494
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.896192312240601, Epoch: 25, training loss: 7.035715082893148, current learning rate 1e-05
val loss: 8.30190885066986
accuracy:      0.820
precision:     0.428
recall:        0.423
f1:            0.425
val loss: 8.260942041873932
accuracy:      0.818
precision:     0.463
recall:        0.457
f1:            0.459
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.898777484893799, Epoch: 26, training loss: 5.871300583006814, current learning rate 1e-05
val loss: 9.632805109024048
accuracy:      0.827
precision:     0.440
recall:        0.447
f1:            0.439
val loss: 8.865723729133606
accuracy:      0.810
precision:     0.448
recall:        0.462
f1:            0.437
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.890581846237183, Epoch: 27, training loss: 8.789423133130185, current learning rate 1e-05
val loss: 9.572088897228241
accuracy:      0.817
precision:     0.445
recall:        0.459
f1:            0.451
val loss: 8.462374985218048
accuracy:      0.813
precision:     0.469
recall:        0.477
f1:            0.468
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.61it/s]


Timing: 7.872272253036499, Epoch: 28, training loss: 4.98331963759847, current learning rate 1e-05
val loss: 8.185907900333405
accuracy:      0.827
precision:     0.491
recall:        0.418
f1:            0.432
val loss: 8.534550905227661
accuracy:      0.818
precision:     0.470
recall:        0.450
f1:            0.457
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.894149303436279, Epoch: 29, training loss: 4.060729298449587, current learning rate 1e-05
val loss: 9.135165691375732
accuracy:      0.827
precision:     0.431
recall:        0.418
f1:            0.424
val loss: 9.299947023391724
accuracy:      0.815
precision:     0.471
recall:        0.463
f1:            0.466
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.60it/s]


Timing: 7.88293194770813, Epoch: 30, training loss: 9.317947376403026, current learning rate 1e-05
val loss: 8.635886192321777
accuracy:      0.810
precision:     0.417
recall:        0.412
f1:            0.413
val loss: 8.306211173534393
accuracy:      0.818
precision:     0.496
recall:        0.478
f1:            0.485
best result:
0.7226277372262774
0.5112812975426111
0.5673392520850148
0.4882186975751106
[[0.7226277372262774, 0.5112812975426111, 0.5673392520850148, 0.4882186975751106]]
**** trying with seed 2 ****


tokenizing...: 100%|██████████| 3283/3283 [00:01<00:00, 2995.80it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 410/410 [00:00<00:00, 2855.82it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 411/411 [00:00<00:00, 2849.74it/s]
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


finished preprocessing examples in test




===== Start training: epoch 1 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.9277873039245605, Epoch: 1, training loss: 144.06191903352737, current learning rate 1e-05
val loss: 6.949038028717041
accuracy:      0.500
precision:     0.370
recall:        0.434
f1:            0.325
val loss: 6.934732496738434
accuracy:      0.509
precision:     0.397
recall:        0.500
f1:            0.358
===== Start training: epoch 2 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.934577941894531, Epoch: 2, training loss: 125.113556265831, current learning rate 1e-05
val loss: 6.035383343696594
accuracy:      0.490
precision:     0.401
recall:        0.468
f1:            0.339
val loss: 6.0463507771492
accuracy:      0.513
precision:     0.417
recall:        0.493
f1:            0.354
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.54it/s]


Timing: 7.9562952518463135, Epoch: 3, training loss: 103.24235117435455, current learning rate 1e-05
val loss: 4.95250928401947
accuracy:      0.612
precision:     0.414
recall:        0.514
f1:            0.393
val loss: 4.92655485868454
accuracy:      0.625
precision:     0.422
recall:        0.515
f1:            0.403
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.55it/s]


Timing: 7.941531419754028, Epoch: 4, training loss: 81.82595109939575, current learning rate 1e-05
val loss: 6.09727269411087
accuracy:      0.527
precision:     0.415
recall:        0.533
f1:            0.368
val loss: 6.189595580101013
accuracy:      0.535
precision:     0.413
recall:        0.523
f1:            0.369
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.53it/s]


Timing: 7.970653772354126, Epoch: 5, training loss: 64.80517649650574, current learning rate 1e-05
val loss: 5.574227511882782
accuracy:      0.600
precision:     0.429
recall:        0.553
f1:            0.409
val loss: 5.269316554069519
accuracy:      0.620
precision:     0.435
recall:        0.556
f1:            0.422
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.55it/s]


Timing: 7.9440202713012695, Epoch: 6, training loss: 51.246189177036285, current learning rate 1e-05
val loss: 5.847450017929077
accuracy:      0.617
precision:     0.438
recall:        0.560
f1:            0.421
val loss: 5.395315945148468
accuracy:      0.669
precision:     0.466
recall:        0.582
f1:            0.463
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.55it/s]


Timing: 7.945197343826294, Epoch: 7, training loss: 39.64119607210159, current learning rate 1e-05
val loss: 7.377878785133362
accuracy:      0.573
precision:     0.441
recall:        0.565
f1:            0.407
val loss: 6.581591606140137
accuracy:      0.616
precision:     0.477
recall:        0.575
f1:            0.447
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.55it/s]


Timing: 7.943528890609741, Epoch: 8, training loss: 32.08879496157169, current learning rate 1e-05
val loss: 5.176030337810516
accuracy:      0.720
precision:     0.428
recall:        0.510
f1:            0.436
val loss: 4.169579237699509
accuracy:      0.754
precision:     0.482
recall:        0.538
f1:            0.492
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.930354118347168, Epoch: 9, training loss: 28.965826213359833, current learning rate 1e-05
val loss: 4.195457220077515
accuracy:      0.802
precision:     0.465
recall:        0.512
f1:            0.483
val loss: 3.603279799222946
accuracy:      0.827
precision:     0.515
recall:        0.510
f1:            0.511
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.925757169723511, Epoch: 10, training loss: 21.99700754880905, current learning rate 1e-05
val loss: 5.264905154705048
accuracy:      0.746
precision:     0.459
recall:        0.565
f1:            0.482
val loss: 4.485401391983032
accuracy:      0.781
precision:     0.495
recall:        0.570
f1:            0.519
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.54it/s]


Timing: 7.951111793518066, Epoch: 11, training loss: 19.033302821218967, current learning rate 1e-05
val loss: 5.422292947769165
accuracy:      0.798
precision:     0.473
recall:        0.517
f1:            0.491
val loss: 4.476852387189865
accuracy:      0.832
precision:     0.553
recall:        0.561
f1:            0.556
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.91624116897583, Epoch: 12, training loss: 14.914632190018892, current learning rate 1e-05
val loss: 6.2876346707344055
accuracy:      0.749
precision:     0.470
recall:        0.536
f1:            0.480
val loss: 5.2059133648872375
accuracy:      0.783
precision:     0.525
recall:        0.577
f1:            0.530
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.55it/s]


Timing: 7.941681623458862, Epoch: 13, training loss: 14.1896847859025, current learning rate 1e-05
val loss: 8.077635884284973
accuracy:      0.741
precision:     0.447
recall:        0.548
f1:            0.469
val loss: 6.605331718921661
accuracy:      0.749
precision:     0.471
recall:        0.557
f1:            0.493
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.912684917449951, Epoch: 14, training loss: 9.549607599154115, current learning rate 1e-05
val loss: 6.453756093978882
accuracy:      0.783
precision:     0.437
recall:        0.438
f1:            0.434
val loss: 5.4273693561553955
accuracy:      0.825
precision:     0.554
recall:        0.579
f1:            0.563
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.921332359313965, Epoch: 15, training loss: 12.183186128735542, current learning rate 1e-05
val loss: 6.720861971378326
accuracy:      0.771
precision:     0.453
recall:        0.522
f1:            0.475
val loss: 6.15987765789032
accuracy:      0.788
precision:     0.517
recall:        0.608
f1:            0.548
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.9017393589019775, Epoch: 16, training loss: 11.664992850273848, current learning rate 1e-05
val loss: 7.421781241893768
accuracy:      0.763
precision:     0.431
recall:        0.490
f1:            0.448
val loss: 6.772437870502472
accuracy:      0.788
precision:     0.498
recall:        0.579
f1:            0.524
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.9161670207977295, Epoch: 17, training loss: 9.142265952192247, current learning rate 1e-05
val loss: 7.97393262386322
accuracy:      0.763
precision:     0.461
recall:        0.527
f1:            0.481
val loss: 6.834494352340698
accuracy:      0.810
precision:     0.561
recall:        0.595
f1:            0.571
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.9053404331207275, Epoch: 18, training loss: 8.39993215445429, current learning rate 1e-05
val loss: 8.582420349121094
accuracy:      0.785
precision:     0.475
recall:        0.506
f1:            0.483
val loss: 7.348546206951141
accuracy:      0.800
precision:     0.510
recall:        0.542
f1:            0.520
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.894291639328003, Epoch: 19, training loss: 11.948474438861012, current learning rate 1e-05
val loss: 8.065061628818512
accuracy:      0.817
precision:     0.441
recall:        0.473
f1:            0.452
val loss: 6.7119638323783875
accuracy:      0.830
precision:     0.533
recall:        0.540
f1:            0.526
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.901568174362183, Epoch: 20, training loss: 9.35629534558393, current learning rate 1e-05
val loss: 8.91020393371582
accuracy:      0.795
precision:     0.428
recall:        0.428
f1:            0.427
val loss: 7.328109323978424
accuracy:      0.825
precision:     0.547
recall:        0.566
f1:            0.555
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.908454656600952, Epoch: 21, training loss: 7.744744511786848, current learning rate 1e-05
val loss: 9.560671508312225
accuracy:      0.785
precision:     0.507
recall:        0.476
f1:            0.482
val loss: 8.681246221065521
accuracy:      0.815
precision:     0.566
recall:        0.590
f1:            0.569
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.905789136886597, Epoch: 22, training loss: 9.17769454116933, current learning rate 1e-05
val loss: 7.902364820241928
accuracy:      0.824
precision:     0.476
recall:        0.476
f1:            0.476
val loss: 6.836038887500763
accuracy:      0.825
precision:     0.534
recall:        0.552
f1:            0.542
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.893726587295532, Epoch: 23, training loss: 12.29180621379055, current learning rate 1e-05
val loss: 7.459656655788422
accuracy:      0.793
precision:     0.435
recall:        0.449
f1:            0.440
val loss: 7.042297720909119
accuracy:      0.815
precision:     0.536
recall:        0.569
f1:            0.551
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.893242597579956, Epoch: 24, training loss: 10.74173534833244, current learning rate 1e-05
val loss: 9.287586688995361
accuracy:      0.817
precision:     0.518
recall:        0.503
f1:            0.506
val loss: 7.995272994041443
accuracy:      0.825
precision:     0.559
recall:        0.580
f1:            0.568
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.900292158126831, Epoch: 25, training loss: 11.772826529107988, current learning rate 1e-05
val loss: 8.362858712673187
accuracy:      0.805
precision:     0.515
recall:        0.520
f1:            0.515
val loss: 8.034910202026367
accuracy:      0.813
precision:     0.528
recall:        0.568
f1:            0.545
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.59it/s]


Timing: 7.8963165283203125, Epoch: 26, training loss: 6.679236807860434, current learning rate 1e-05
val loss: 9.853505790233612
accuracy:      0.790
precision:     0.475
recall:        0.471
f1:            0.463
val loss: 8.781692028045654
accuracy:      0.800
precision:     0.552
recall:        0.584
f1:            0.561
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.56it/s]


Timing: 7.9327311515808105, Epoch: 27, training loss: 7.255691765341908, current learning rate 1e-05
val loss: 10.737398266792297
accuracy:      0.756
precision:     0.456
recall:        0.472
f1:            0.448
val loss: 10.29874873161316
accuracy:      0.749
precision:     0.489
recall:        0.529
f1:            0.491
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.921704053878784, Epoch: 28, training loss: 3.9631901604589075, current learning rate 1e-05
val loss: 9.742090940475464
accuracy:      0.783
precision:     0.470
recall:        0.490
f1:            0.477
val loss: 8.755429148674011
accuracy:      0.798
precision:     0.515
recall:        0.569
f1:            0.536
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.57it/s]


Timing: 7.917595863342285, Epoch: 29, training loss: 11.63121859554667, current learning rate 1e-05
val loss: 9.743416547775269
accuracy:      0.805
precision:     0.491
recall:        0.513
f1:            0.500
val loss: 8.709289371967316
accuracy:      0.818
precision:     0.530
recall:        0.549
f1:            0.538
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 52/52 [00:07<00:00,  6.58it/s]


Timing: 7.9084014892578125, Epoch: 30, training loss: 3.9298977106809616, current learning rate 1e-05
val loss: 9.08470892906189
accuracy:      0.800
precision:     0.515
recall:        0.482
f1:            0.490
val loss: 7.779045939445496
accuracy:      0.825
precision:     0.552
recall:        0.551
f1:            0.546
best result:
0.8126520681265207
0.5282402269083734
0.5680925477535648
0.5453184966608294
[[0.8126520681265207, 0.5282402269083734, 0.5680925477535648, 0.5453184966608294]]
tensor([0.7762, 0.5119, 0.5558, 0.5139], dtype=torch.float64)


In [11]:
from google.colab import runtime
runtime.unassign()