<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [2]:
!pip install datasets



In [3]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [4]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [5]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [6]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }

    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }"""

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'neither':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'no_relation':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], [], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(MARGWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'neither':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class NKWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(NKWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      label_distribution=[]
      target = []
      target_sentences = []
      id=[]

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              try:
                ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
                ds_marker = ds_marker.replace("_", " ")
                ds_marker = ds_marker[0].upper() + ds_marker[1:]
                target = target[0].lower() + target[1:]
                target = ds_marker + " " + target
              except:
                pass

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,0,1]
              elif label == 'no_relation':
                l = [0,1,0]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], [], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [7]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      embed_sent1_std = embed_sent1[:batch_size // 2, :, :]
      labels_std = torch.tensor(labels[:batch_size // 2]).to(device)

      if self.num_classes == 2:
        emb_attack = embed_sent1_std[(labels_std == torch.tensor([0,1]).to(device)).all(dim=1)]
        emb_support = embed_sent1_std[(labels_std == torch.tensor([1,0]).to(device)).all(dim=1)]
      else:
        emb_attack = embed_sent1_std[(labels_std == torch.tensor([0,1,0]).to(device)).all(dim=1)]
        emb_support = embed_sent1_std[(labels_std == torch.tensor([1,0,0]).to(device)).all(dim=1) | (labels_std == torch.tensor([0,0,1]).to(device)).all(dim=1)]

      samples_adv = H_sent[batch_size // 2:, :]
      embed_sent1_adv = embed_sent1[batch_size // 2:, :, :]
      labels_adv = torch.tensor(labels[batch_size // 2:]).to(device)

      emb_contr = embed_sent1_adv[(labels_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)]
      emb_other = embed_sent1_adv[(labels_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (labels_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      mean_grl_attack = GRLayer.apply(torch.mean(torch.cat([emb_attack, emb_contr], dim=0), dim=1), .01)
      mean_grl_support = GRLayer.apply(torch.mean(torch.cat([emb_support, emb_other], dim=0), dim=1), .01)

      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl_attack)
      support_prediction = self.support_linear(mean_grl_support)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class BaselineModelWithSentenceComparisonAndCue(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparisonAndCue, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.attention = attention
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*3, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    """self.linear_initial_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.linear_end_sent = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)"""
    self.sigmoid = torch.nn.Sigmoid()

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    """self._init_weights(self.linear_initial_sent)
    self._init_weights(self.linear_end_sent)"""

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)[0]
    else:
      att_output = embed_sent1

    initial_sent1 = att_output[:,0,:]
    initial_sent2 = att_output[torch.arange(att_output.shape[0]), torch.argmax(segs_sent1 * torch.arange(att_output.shape[1], 0, -1).to(device), dim=-1)]
    end_sent2 = att_output[torch.arange(att_output.shape[0]), torch.sum(att_mask_sent1, dim=-1)-1]

    initial_sent1 = self.linear_initial_sent(initial_sent1)
    end_sent2 = self.linear_end_sent(end_sent2)
    gate = self.sigmoid(initial_sent1 + end_sent2)

    final_emb = initial_sent2 * gate

    predictions = self.linear_layer(final_emb)

    return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self, attention):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]
    self.attention = attention

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    if self.attention:
      tar_mask_sent1 = (segs_sent1 == 0).long()
      tar_mask_sent2 = (segs_sent1 == 1).long()

      H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
      H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

      K_sent1 = self.K(H_sent1)
      V_sent1 = self.V(H_sent1)
      Q_sent2 = self.Q(H_sent2)

      att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

      H_sent = torch.mean(att_output[0], dim=1)
    else:
      H_sent = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [8]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [9]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [10]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import LinearLR

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 3, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": False,
    "seed": [0,1,2],
    "batch_size": 64,
    "epochs": 30,
    "class_weight": [9.375, 1, 30], #[9.375, 1, 30], #10 [2.071, 1, 1.933]
    "lr": 1e-5
}

config = {
    "dataset": "m-arg", #"student_essay", #debate, m-arg
    "adversarial": False,
    "double_adversarial": False,
    "dataset_from_saved": False,
    "injection": False,
    "finetuning_discovery": True,
    "grid_search": False,
    "visualize": False,
    "train": True,
    "scheduler": False,
    "attention": False,
    "cue_gating": False
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3, rl=0):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)

            attack_len = torch.sum((targets == torch.tensor([0,1]).to(device)).all(dim=1)).item()
            support_len = torch.sum((targets == torch.tensor([1,0]).to(device)).all(dim=1)).item()
            contr_len = torch.sum((targets_adv == torch.tensor([0,0,1]).to(device)).all(dim=1)).item()
            other_len = torch.sum((targets_adv == torch.tensor([0,1,0]).to(device)).all(dim=1) | (targets_adv == torch.tensor([1,0,0]).to(device)).all(dim=1)).item()

            attack_target = [[0,1]] * attack_len + [[1,0]] * contr_len
            support_target = [[0,1]] * support_len + [[1,0]] * other_len
            attack_target, support_target = torch.tensor(np.array(attack_target)).to(device), torch.tensor(np.array(support_target)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + rl*loss4 + rl*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
          if config["double_adversarial"]:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          else:
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
          loss = loss_fn(out, labels.float())
          val_loss += loss.item()

          labels = labels.cpu().numpy().tolist()

          val_labels.extend(labels)
          if len(labels[0]) != 2:
            for pred in preds:
              if pred == 0:
                val_preds.append([1,0,0])
              elif pred == 1:
                val_preds.append([0,1,0])
              else:
                val_preds.append([0,0,1])
          else:
            val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()





def run(seed):
  set_random_seeds(seed)

  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/train_essay.txt"
    path_dev = "./data/student_essay/dev_essay.txt"
    path_test = "./data/student_essay/test_essay.txt"
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/train_debate_concept.txt"
    path_dev = "./data/debate/dev_debate_concept.txt"
    path_test = "./data/debate/test_debate_concept.txt"
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKWithDiscourseInjectionProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")

  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  if config["adversarial"] or config["double_adversarial"] or config["finetuning_discovery"]:
    df = datasets.load_dataset("discovery","discovery", trust_remote_code=True)
    adv_processor = DiscourseMarkerProcessor()
    if not config["dataset_from_saved"]:
      print("processing discourse marker dataset...")
      train_adv = adv_processor.process_dataset(df["train"])
      with open("./adv_dataset.pkl", "wb") as writer:
        pickle.dump(train_adv, writer)
    else:
      with open("./adv_dataset.pkl", "rb") as reader:
        train_adv = pickle.load(reader)

    data_train_tot = data_train + train_adv
    train_set_adv = dataset(train_adv)
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if config["double_adversarial"]:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = DoubleAdversarialNet()

  elif not config["adversarial"]:
    if config["finetuning_discovery"]:
      sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
      train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    else:
      train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

    if not config["cue_gating"]:
      model = BaselineModelWithSentenceComparison(attention=config["attention"])
    else:
      model = BaselineModelWithSentenceComparisonAndCue(attention=config["attention"])
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

  if config["dataset"] in ["m-arg","nk"]:
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0.8,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)
    range_local = np.arange(0.2, 1, 0.2)

    for rl in reversed(range_local):
      print(f"rl {rl}")
      for discovery_weight in reversed(range_disc):
        for adv_weight in range_adv:
          for epoch in range(args["epochs"]):
            print('===== Start training: epoch {} ====='.format(epoch + 1))
            print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}")
            train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight, rl=rl)
            dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
            test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
            if dev_f1 > best_dev_f1:
              best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
              best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
              #save model

          print('best result:')
          print(best_test_acc)
          print(best_test_pre)
          print(best_test_rec)
          print(best_test_f1)
          result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
          del model
          del optimizer

          set_random_seeds(seed)
          model = DoubleAdversarialNet()
          model = model.to(device)

          optimizer_grouped_parameters = [
            {
              "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
              "weight_decay": 0.01,
            },
            {
              "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
              "weight_decay": 0.0
            },
          ]
          optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"])

          best_acc = -1
          best_pre = -1
          best_rec = -1
          best_f1 = -1
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1
  else:
    if config["scheduler"]:
      scheduler = LinearLR(optimizer, start_factor=1, end_factor=1e-2, total_iters = 30)
    for epoch in range(args["epochs"]):
      if config["train"]:
        print('===== Start training: epoch {} ====='.format(epoch + 1))
        train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
        dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
        test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
        if config["scheduler"]:
          scheduler.step()
        if dev_f1 > best_dev_f1:
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
          best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
          torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

    if config["visualize"] and config["adversarial"]:
      model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
      visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

        #save model

    print('best result:')
    print(best_test_acc)
    print(best_test_pre)
    print(best_test_rec)
    print(best_test_f1)
    result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

  print(result_metrics)
  return result_metrics[0]

if __name__ == "__main__":
  results = []
  for seed in args["seed"]:
    print(f"**** trying with seed {seed} ****")
    result_metrics = run(seed)
    results.append(result_metrics)
  avg = torch.mean(torch.tensor(results), dim=0)
  print(avg)

**** trying with seed 0 ****


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
tokenizing...: 100%|██████████| 3283/3283 [00:01<00:00, 2383.71it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 410/410 [00:00<00:00, 3018.36it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 411/411 [00:00<00:00, 2878.30it/s]


finished preprocessing examples in test
processing discourse marker dataset...


processing labels...: 341840it [00:00, 358640.41it/s]


one hot encoding...


creating results...: 341840it [00:02, 165250.55it/s]
tokenizing...: 100%|██████████| 341840/341840 [02:37<00:00, 2169.80it/s]


finished preprocessing examples in train


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration:   2%|▏         | 102/5393 [00:17<14:45,  5.98it/s]


Timing: 17.068227767944336, Epoch: 1, training loss: 692.843008518219, current learning rate 1e-05
val loss: 7.918373465538025
accuracy:      0.322
precision:     0.353
recall:        0.396
f1:            0.234
val loss: 8.129751801490784
accuracy:      0.336
precision:     0.363
recall:        0.390
f1:            0.245
===== Start training: epoch 2 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:40,  6.45it/s]


Timing: 15.826857805252075, Epoch: 2, training loss: 534.1316337585449, current learning rate 1e-05
val loss: 8.16019070148468
accuracy:      0.459
precision:     0.383
recall:        0.470
f1:            0.314
val loss: 8.265126466751099
accuracy:      0.467
precision:     0.385
recall:        0.469
f1:            0.327
===== Start training: epoch 3 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:42,  6.44it/s]


Timing: 15.856230735778809, Epoch: 3, training loss: 508.6475658416748, current learning rate 1e-05
val loss: 8.75786566734314
accuracy:      0.417
precision:     0.393
recall:        0.505
f1:            0.298
val loss: 8.61320447921753
accuracy:      0.423
precision:     0.395
recall:        0.446
f1:            0.317
===== Start training: epoch 4 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:40,  6.45it/s]


Timing: 15.826275825500488, Epoch: 4, training loss: 472.75782918930054, current learning rate 1e-05
val loss: 5.863338112831116
accuracy:      0.598
precision:     0.377
recall:        0.435
f1:            0.351
val loss: 5.8096203207969666
accuracy:      0.611
precision:     0.411
recall:        0.489
f1:            0.389
===== Start training: epoch 5 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:38,  6.46it/s]


Timing: 15.791413068771362, Epoch: 5, training loss: 468.609717130661, current learning rate 1e-05
val loss: 8.77963638305664
accuracy:      0.402
precision:     0.391
recall:        0.507
f1:            0.292
val loss: 8.787793278694153
accuracy:      0.443
precision:     0.404
recall:        0.530
f1:            0.333
===== Start training: epoch 6 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:38,  6.47it/s]


Timing: 15.78229308128357, Epoch: 6, training loss: 443.05670142173767, current learning rate 1e-05
val loss: 6.8438557386398315
accuracy:      0.459
precision:     0.395
recall:        0.441
f1:            0.308
val loss: 6.888338029384613
accuracy:      0.457
precision:     0.393
recall:        0.450
f1:            0.307
===== Start training: epoch 7 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:38,  6.47it/s]


Timing: 15.779755115509033, Epoch: 7, training loss: 429.534921169281, current learning rate 1e-05
val loss: 7.6915271282196045
accuracy:      0.424
precision:     0.396
recall:        0.494
f1:            0.308
val loss: 7.642354428768158
accuracy:      0.428
precision:     0.385
recall:        0.460
f1:            0.302
===== Start training: epoch 8 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:37,  6.48it/s]


Timing: 15.756280422210693, Epoch: 8, training loss: 430.6092255115509, current learning rate 1e-05
val loss: 5.455394744873047
accuracy:      0.639
precision:     0.405
recall:        0.524
f1:            0.402
val loss: 5.417733192443848
accuracy:      0.672
precision:     0.420
recall:        0.519
f1:            0.421
===== Start training: epoch 9 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:39,  6.46it/s]


Timing: 15.804080724716187, Epoch: 9, training loss: 406.0598666667938, current learning rate 1e-05
val loss: 5.45185261964798
accuracy:      0.629
precision:     0.407
recall:        0.491
f1:            0.391
val loss: 5.379955887794495
accuracy:      0.633
precision:     0.426
recall:        0.497
f1:            0.404
===== Start training: epoch 10 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:38,  6.47it/s]


Timing: 15.781344413757324, Epoch: 10, training loss: 387.678914308548, current learning rate 1e-05
val loss: 5.621835291385651
accuracy:      0.666
precision:     0.418
recall:        0.541
f1:            0.418
val loss: 5.567815840244293
accuracy:      0.669
precision:     0.409
recall:        0.490
f1:            0.410
===== Start training: epoch 11 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:38,  6.47it/s]


Timing: 15.78496265411377, Epoch: 11, training loss: 399.86459159851074, current learning rate 1e-05
val loss: 9.75280249118805
accuracy:      0.432
precision:     0.397
recall:        0.519
f1:            0.310
val loss: 9.241310238838196
accuracy:      0.462
precision:     0.417
recall:        0.573
f1:            0.350
===== Start training: epoch 12 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:37,  6.47it/s]


Timing: 15.772393465042114, Epoch: 12, training loss: 374.1953649520874, current learning rate 1e-05
val loss: 6.143273651599884
accuracy:      0.602
precision:     0.397
recall:        0.488
f1:            0.373
val loss: 6.310262858867645
accuracy:      0.630
precision:     0.422
recall:        0.525
f1:            0.412
===== Start training: epoch 13 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:37,  6.47it/s]


Timing: 15.773054599761963, Epoch: 13, training loss: 361.7161729335785, current learning rate 1e-05
val loss: 5.072112530469894
accuracy:      0.705
precision:     0.432
recall:        0.549
f1:            0.441
val loss: 5.456564664840698
accuracy:      0.718
precision:     0.424
recall:        0.502
f1:            0.435
===== Start training: epoch 14 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:38,  6.46it/s]


Timing: 15.785091638565063, Epoch: 14, training loss: 380.4343800544739, current learning rate 1e-05
val loss: 4.118050217628479
accuracy:      0.783
precision:     0.438
recall:        0.497
f1:            0.456
val loss: 4.602335631847382
accuracy:      0.764
precision:     0.420
recall:        0.443
f1:            0.427
===== Start training: epoch 15 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:39,  6.46it/s]


Timing: 15.796330690383911, Epoch: 15, training loss: 368.6648814678192, current learning rate 1e-05
val loss: 6.377052485942841
accuracy:      0.654
precision:     0.401
recall:        0.485
f1:            0.393
val loss: 6.58464503288269
accuracy:      0.657
precision:     0.396
recall:        0.451
f1:            0.393
===== Start training: epoch 16 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:39,  6.46it/s]


Timing: 15.794679641723633, Epoch: 16, training loss: 364.7504096031189, current learning rate 1e-05
val loss: 5.127331078052521
accuracy:      0.734
precision:     0.438
recall:        0.538
f1:            0.454
val loss: 5.600034296512604
accuracy:      0.713
precision:     0.416
recall:        0.465
f1:            0.420
===== Start training: epoch 17 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:38,  6.46it/s]


Timing: 15.791642665863037, Epoch: 17, training loss: 363.9844436645508, current learning rate 1e-05
val loss: 5.21790736913681
accuracy:      0.754
precision:     0.420
recall:        0.479
f1:            0.433
val loss: 5.277676939964294
accuracy:      0.752
precision:     0.413
recall:        0.445
f1:            0.423
===== Start training: epoch 18 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:36,  6.48it/s]


Timing: 15.751367330551147, Epoch: 18, training loss: 364.0966682434082, current learning rate 1e-05
val loss: 4.7522865533828735
accuracy:      0.773
precision:     0.421
recall:        0.464
f1:            0.434
val loss: 5.383792400360107
accuracy:      0.779
precision:     0.408
recall:        0.421
f1:            0.412
===== Start training: epoch 19 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:37,  6.47it/s]


Timing: 15.759773969650269, Epoch: 19, training loss: 350.77731442451477, current learning rate 1e-05
val loss: 7.618162155151367
accuracy:      0.634
precision:     0.396
recall:        0.478
f1:            0.383
val loss: 7.3758766651153564
accuracy:      0.659
precision:     0.413
recall:        0.501
f1:            0.417
===== Start training: epoch 20 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:37,  6.47it/s]


Timing: 15.766667366027832, Epoch: 20, training loss: 347.82977414131165, current learning rate 1e-05
val loss: 6.492977023124695
accuracy:      0.720
precision:     0.409
recall:        0.473
f1:            0.413
val loss: 5.984618008136749
accuracy:      0.735
precision:     0.424
recall:        0.481
f1:            0.437
===== Start training: epoch 21 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:38,  6.47it/s]


Timing: 15.775810956954956, Epoch: 21, training loss: 333.633696436882, current learning rate 1e-05
val loss: 5.902500599622726
accuracy:      0.749
precision:     0.409
recall:        0.455
f1:            0.417
val loss: 5.3100007474422455
accuracy:      0.764
precision:     0.437
recall:        0.479
f1:            0.449
===== Start training: epoch 22 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:36,  6.48it/s]


Timing: 15.755126476287842, Epoch: 22, training loss: 330.173015832901, current learning rate 1e-05
val loss: 7.11616837978363
accuracy:      0.707
precision:     0.412
recall:        0.491
f1:            0.417
val loss: 6.209243655204773
accuracy:      0.723
precision:     0.423
recall:        0.476
f1:            0.433
===== Start training: epoch 23 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:38,  6.47it/s]


Timing: 15.780696392059326, Epoch: 23, training loss: 340.48397040367126, current learning rate 1e-05
val loss: 6.457884907722473
accuracy:      0.773
precision:     0.432
recall:        0.494
f1:            0.449
val loss: 5.716847330331802
accuracy:      0.749
precision:     0.400
recall:        0.430
f1:            0.407
===== Start training: epoch 24 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:36,  6.48it/s]


Timing: 15.738393545150757, Epoch: 24, training loss: 327.3530340194702, current learning rate 1e-05
val loss: 7.014116585254669
accuracy:      0.722
precision:     0.411
recall:        0.489
f1:            0.422
val loss: 6.834869563579559
accuracy:      0.730
precision:     0.416
recall:        0.472
f1:            0.426
===== Start training: epoch 25 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:36,  6.48it/s]


Timing: 15.745110034942627, Epoch: 25, training loss: 336.1883783340454, current learning rate 1e-05
val loss: 6.8329004645347595
accuracy:      0.727
precision:     0.419
recall:        0.506
f1:            0.432
val loss: 6.483598291873932
accuracy:      0.740
precision:     0.421
recall:        0.476
f1:            0.433
===== Start training: epoch 26 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:36,  6.48it/s]


Timing: 15.755666971206665, Epoch: 26, training loss: 326.8824692964554, current learning rate 1e-05
val loss: 6.522392392158508
accuracy:      0.768
precision:     0.416
recall:        0.462
f1:            0.428
val loss: 6.130404829978943
accuracy:      0.762
precision:     0.388
recall:        0.400
f1:            0.392
===== Start training: epoch 27 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:36,  6.48it/s]


Timing: 15.744643211364746, Epoch: 27, training loss: 325.6490306854248, current learning rate 1e-05
val loss: 7.372624218463898
accuracy:      0.751
precision:     0.420
recall:        0.485
f1:            0.434
val loss: 6.30755889415741
accuracy:      0.766
precision:     0.431
recall:        0.472
f1:            0.444
===== Start training: epoch 28 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:36,  6.48it/s]


Timing: 15.73680830001831, Epoch: 28, training loss: 319.57414531707764, current learning rate 1e-05
val loss: 6.602926790714264
accuracy:      0.778
precision:     0.433
recall:        0.488
f1:            0.450
val loss: 5.824430823326111
accuracy:      0.791
precision:     0.418
recall:        0.425
f1:            0.420
===== Start training: epoch 29 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:36,  6.48it/s]


Timing: 15.753520965576172, Epoch: 29, training loss: 320.83710050582886, current learning rate 1e-05
val loss: 10.028428316116333
accuracy:      0.610
precision:     0.383
recall:        0.454
f1:            0.365
val loss: 9.653366804122925
accuracy:      0.623
precision:     0.399
recall:        0.486
f1:            0.389
===== Start training: epoch 30 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:35,  6.49it/s]


Timing: 15.730623483657837, Epoch: 30, training loss: 336.8716311454773, current learning rate 1e-05
val loss: 6.216626942157745
accuracy:      0.788
precision:     0.401
recall:        0.418
f1:            0.401
val loss: 6.448566198348999
accuracy:      0.813
precision:     0.433
recall:        0.406
f1:            0.415
best result:
0.7639902676399026
0.41960912525428656
0.44315308044121604
0.4273742058289178
[[0.7639902676399026, 0.41960912525428656, 0.44315308044121604, 0.4273742058289178]]
**** trying with seed 1 ****


tokenizing...: 100%|██████████| 3283/3283 [00:01<00:00, 3097.82it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 410/410 [00:00<00:00, 2954.26it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 411/411 [00:00<00:00, 3085.58it/s]


finished preprocessing examples in test
processing discourse marker dataset...


processing labels...: 341840it [00:00, 615403.35it/s] 


one hot encoding...


creating results...: 341840it [00:01, 180398.31it/s]
tokenizing...: 100%|██████████| 341840/341840 [02:33<00:00, 2227.98it/s]


finished preprocessing examples in train


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:43,  6.43it/s]


Timing: 15.876689672470093, Epoch: 1, training loss: 673.9034395217896, current learning rate 1e-05
val loss: 6.807088136672974
accuracy:      0.607
precision:     0.388
recall:        0.452
f1:            0.365
val loss: 7.040621638298035
accuracy:      0.640
precision:     0.412
recall:        0.473
f1:            0.402
===== Start training: epoch 2 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:45,  6.41it/s]


Timing: 15.913248062133789, Epoch: 2, training loss: 548.4848012924194, current learning rate 1e-05
val loss: 6.488198399543762
accuracy:      0.532
precision:     0.390
recall:        0.490
f1:            0.347
val loss: 6.514993190765381
accuracy:      0.557
precision:     0.413
recall:        0.539
f1:            0.383
===== Start training: epoch 3 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:41,  6.44it/s]


Timing: 15.84547233581543, Epoch: 3, training loss: 500.811110496521, current learning rate 1e-05
val loss: 7.7690269947052
accuracy:      0.405
precision:     0.399
recall:        0.516
f1:            0.297
val loss: 7.594627797603607
accuracy:      0.455
precision:     0.413
recall:        0.570
f1:            0.347
===== Start training: epoch 4 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:41,  6.44it/s]


Timing: 15.84860110282898, Epoch: 4, training loss: 465.27703166007996, current learning rate 1e-05
val loss: 5.876450061798096
accuracy:      0.576
precision:     0.409
recall:        0.551
f1:            0.383
val loss: 6.265088140964508
accuracy:      0.589
precision:     0.425
recall:        0.565
f1:            0.408
===== Start training: epoch 5 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:44,  6.42it/s]


Timing: 15.892900943756104, Epoch: 5, training loss: 447.73003673553467, current learning rate 1e-05
val loss: 7.464294672012329
accuracy:      0.512
precision:     0.419
recall:        0.593
f1:            0.362
val loss: 7.588519513607025
accuracy:      0.516
precision:     0.419
recall:        0.587
f1:            0.378
===== Start training: epoch 6 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:44,  6.42it/s]


Timing: 15.892614841461182, Epoch: 6, training loss: 438.2932140827179, current learning rate 1e-05
val loss: 4.716940939426422
accuracy:      0.676
precision:     0.396
recall:        0.464
f1:            0.389
val loss: 4.818850129842758
accuracy:      0.681
precision:     0.428
recall:        0.530
f1:            0.435
===== Start training: epoch 7 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:47,  6.40it/s]


Timing: 15.952069520950317, Epoch: 7, training loss: 415.19565176963806, current learning rate 1e-05
val loss: 5.08200877904892
accuracy:      0.663
precision:     0.411
recall:        0.518
f1:            0.411
val loss: 4.94261759519577
accuracy:      0.669
precision:     0.426
recall:        0.533
f1:            0.433
===== Start training: epoch 8 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:45,  6.41it/s]


Timing: 15.929389476776123, Epoch: 8, training loss: 405.3119456768036, current learning rate 1e-05
val loss: 6.106261134147644
accuracy:      0.578
precision:     0.413
recall:        0.552
f1:            0.385
val loss: 6.227930963039398
accuracy:      0.601
precision:     0.419
recall:        0.542
f1:            0.405
===== Start training: epoch 9 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:46,  6.40it/s]


Timing: 15.93207597732544, Epoch: 9, training loss: 384.3619270324707, current learning rate 1e-05
val loss: 6.038946092128754
accuracy:      0.600
precision:     0.410
recall:        0.538
f1:            0.389
val loss: 6.386695146560669
accuracy:      0.596
precision:     0.418
recall:        0.554
f1:            0.408
===== Start training: epoch 10 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:43,  6.42it/s]


Timing: 15.886991262435913, Epoch: 10, training loss: 386.92861092090607, current learning rate 1e-05
val loss: 4.748854339122772
accuracy:      0.695
precision:     0.417
recall:        0.508
f1:            0.421
val loss: 5.242725133895874
accuracy:      0.679
precision:     0.415
recall:        0.501
f1:            0.418
===== Start training: epoch 11 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:46,  6.40it/s]


Timing: 15.938006162643433, Epoch: 11, training loss: 380.79760777950287, current learning rate 1e-05
val loss: 5.366904258728027
accuracy:      0.661
precision:     0.402
recall:        0.481
f1:            0.391
val loss: 5.86551970243454
accuracy:      0.676
precision:     0.465
recall:        0.542
f1:            0.451
===== Start training: epoch 12 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:46,  6.40it/s]


Timing: 15.947789430618286, Epoch: 12, training loss: 382.03879475593567, current learning rate 1e-05
val loss: 6.547041535377502
accuracy:      0.605
precision:     0.412
recall:        0.519
f1:            0.388
val loss: 6.632205367088318
accuracy:      0.616
precision:     0.449
recall:        0.582
f1:            0.439
===== Start training: epoch 13 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:43,  6.42it/s]


Timing: 15.88282322883606, Epoch: 13, training loss: 375.8718214035034, current learning rate 1e-05
val loss: 5.83468821644783
accuracy:      0.641
precision:     0.408
recall:        0.511
f1:            0.391
val loss: 6.564043998718262
accuracy:      0.645
precision:     0.457
recall:        0.551
f1:            0.446
===== Start training: epoch 14 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:42,  6.43it/s]


Timing: 15.865489482879639, Epoch: 14, training loss: 367.34812235832214, current learning rate 1e-05
val loss: 7.542847752571106
accuracy:      0.624
precision:     0.405
recall:        0.481
f1:            0.388
val loss: 7.551764130592346
accuracy:      0.599
precision:     0.430
recall:        0.584
f1:            0.420
===== Start training: epoch 15 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:43,  6.42it/s]


Timing: 15.880706787109375, Epoch: 15, training loss: 371.16470432281494, current learning rate 1e-05
val loss: 4.914816290140152
accuracy:      0.744
precision:     0.399
recall:        0.438
f1:            0.406
val loss: 5.1162154376506805
accuracy:      0.740
precision:     0.449
recall:        0.518
f1:            0.466
===== Start training: epoch 16 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:43,  6.43it/s]


Timing: 15.875133991241455, Epoch: 16, training loss: 372.1435868740082, current learning rate 1e-05
val loss: 6.4542014598846436
accuracy:      0.685
precision:     0.413
recall:        0.497
f1:            0.414
val loss: 6.637304842472076
accuracy:      0.652
precision:     0.433
recall:        0.519
f1:            0.427
===== Start training: epoch 17 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:43,  6.42it/s]


Timing: 15.884582757949829, Epoch: 17, training loss: 369.70632123947144, current learning rate 1e-05
val loss: 6.285579144954681
accuracy:      0.698
precision:     0.406
recall:        0.480
f1:            0.405
val loss: 6.629573166370392
accuracy:      0.684
precision:     0.474
recall:        0.531
f1:            0.455
===== Start training: epoch 18 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:42,  6.43it/s]


Timing: 15.857235670089722, Epoch: 18, training loss: 344.783390045166, current learning rate 1e-05
val loss: 5.4647433161735535
accuracy:      0.754
precision:     0.415
recall:        0.471
f1:            0.424
val loss: 5.33570921421051
accuracy:      0.762
precision:     0.441
recall:        0.491
f1:            0.457
===== Start training: epoch 19 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:43,  6.42it/s]


Timing: 15.889057159423828, Epoch: 19, training loss: 360.6742845773697, current learning rate 1e-05
val loss: 7.019132494926453
accuracy:      0.668
precision:     0.400
recall:        0.483
f1:            0.397
val loss: 7.379761278629303
accuracy:      0.659
precision:     0.436
recall:        0.565
f1:            0.446
===== Start training: epoch 20 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:44,  6.42it/s]


Timing: 15.894153356552124, Epoch: 20, training loss: 356.9837449789047, current learning rate 1e-05
val loss: 5.461107432842255
accuracy:      0.746
precision:     0.408
recall:        0.461
f1:            0.419
val loss: 5.817087471485138
accuracy:      0.754
precision:     0.467
recall:        0.531
f1:            0.487
===== Start training: epoch 21 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:42,  6.43it/s]


Timing: 15.869943618774414, Epoch: 21, training loss: 351.72743940353394, current learning rate 1e-05
val loss: 5.9528719782829285
accuracy:      0.751
precision:     0.417
recall:        0.471
f1:            0.428
val loss: 5.912242770195007
accuracy:      0.740
precision:     0.445
recall:        0.526
f1:            0.467
===== Start training: epoch 22 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:45,  6.41it/s]


Timing: 15.917362451553345, Epoch: 22, training loss: 343.9210629463196, current learning rate 1e-05
val loss: 6.25072056055069
accuracy:      0.766
precision:     0.404
recall:        0.439
f1:            0.413
val loss: 6.154890060424805
accuracy:      0.774
precision:     0.466
recall:        0.517
f1:            0.484
===== Start training: epoch 23 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:43,  6.43it/s]


Timing: 15.881606578826904, Epoch: 23, training loss: 337.5143804550171, current learning rate 1e-05
val loss: 6.089969456195831
accuracy:      0.744
precision:     0.399
recall:        0.446
f1:            0.408
val loss: 6.180373728275299
accuracy:      0.757
precision:     0.460
recall:        0.511
f1:            0.471
===== Start training: epoch 24 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:41,  6.44it/s]


Timing: 15.842998504638672, Epoch: 24, training loss: 331.6102417707443, current learning rate 1e-05
val loss: 5.678568720817566
accuracy:      0.785
precision:     0.408
recall:        0.425
f1:            0.412
val loss: 5.519955039024353
accuracy:      0.771
precision:     0.432
recall:        0.460
f1:            0.442
===== Start training: epoch 25 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:43,  6.43it/s]


Timing: 15.871371507644653, Epoch: 25, training loss: 322.9136484861374, current learning rate 1e-05
val loss: 6.452966332435608
accuracy:      0.771
precision:     0.413
recall:        0.456
f1:            0.425
val loss: 6.7727261781692505
accuracy:      0.747
precision:     0.447
recall:        0.500
f1:            0.460
===== Start training: epoch 26 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:44,  6.41it/s]


Timing: 15.906308650970459, Epoch: 26, training loss: 330.5826050043106, current learning rate 1e-05
val loss: 6.443395435810089
accuracy:      0.778
precision:     0.406
recall:        0.436
f1:            0.413
val loss: 6.819526851177216
accuracy:      0.754
precision:     0.433
recall:        0.482
f1:            0.448
===== Start training: epoch 27 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:44,  6.42it/s]


Timing: 15.901906251907349, Epoch: 27, training loss: 327.8499573469162, current learning rate 1e-05
val loss: 6.0687435567379
accuracy:      0.807
precision:     0.404
recall:        0.411
f1:            0.398
val loss: 5.938632667064667
accuracy:      0.808
precision:     0.483
recall:        0.489
f1:            0.475
===== Start training: epoch 28 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:46,  6.40it/s]


Timing: 15.947105407714844, Epoch: 28, training loss: 321.83840703964233, current learning rate 1e-05
val loss: 5.689608514308929
accuracy:      0.815
precision:     0.420
recall:        0.421
f1:            0.420
val loss: 6.151991695165634
accuracy:      0.822
precision:     0.504
recall:        0.494
f1:            0.499
===== Start training: epoch 29 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:46,  6.40it/s]


Timing: 15.94349980354309, Epoch: 29, training loss: 330.26648128032684, current learning rate 1e-05
val loss: 7.4855533838272095
accuracy:      0.749
precision:     0.433
recall:        0.507
f1:            0.448
val loss: 7.614287078380585
accuracy:      0.759
precision:     0.472
recall:        0.532
f1:            0.483
===== Start training: epoch 30 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:47,  6.39it/s]


Timing: 15.968406677246094, Epoch: 30, training loss: 321.6515144109726, current learning rate 1e-05
val loss: 6.0678130984306335
accuracy:      0.800
precision:     0.419
recall:        0.437
f1:            0.424
val loss: 6.233673989772797
accuracy:      0.800
precision:     0.467
recall:        0.472
f1:            0.469
best result:
0.7591240875912408
0.4724901581697698
0.5324993274145816
0.4827412839583225
[[0.7591240875912408, 0.4724901581697698, 0.5324993274145816, 0.4827412839583225]]
**** trying with seed 2 ****


tokenizing...: 100%|██████████| 3283/3283 [00:01<00:00, 3116.06it/s]


finished preprocessing examples in train


tokenizing...: 100%|██████████| 410/410 [00:00<00:00, 3080.41it/s]


finished preprocessing examples in dev


tokenizing...: 100%|██████████| 411/411 [00:00<00:00, 3070.50it/s]


finished preprocessing examples in test
processing discourse marker dataset...


processing labels...: 341840it [00:00, 2078810.98it/s]


one hot encoding...


creating results...: 341840it [00:02, 134130.51it/s]
tokenizing...: 100%|██████████| 341840/341840 [02:37<00:00, 2167.83it/s]


finished preprocessing examples in train


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:48,  6.38it/s]


Timing: 15.985784769058228, Epoch: 1, training loss: 637.0174446105957, current learning rate 1e-05
val loss: 10.433056116104126
accuracy:      0.051
precision:     0.314
recall:        0.313
f1:            0.036


  _warn_prf(average, modifier, msg_start, len(result))


val loss: 10.755104184150696
accuracy:      0.068
precision:     0.346
recall:        0.346
f1:            0.048


  _warn_prf(average, modifier, msg_start, len(result))


===== Start training: epoch 2 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:48,  6.38it/s]


Timing: 15.986736059188843, Epoch: 2, training loss: 565.9231834411621, current learning rate 1e-05
val loss: 8.494806170463562
accuracy:      0.276
precision:     0.365
recall:        0.415
f1:            0.215
val loss: 8.725170850753784
accuracy:      0.321
precision:     0.374
recall:        0.434
f1:            0.249
===== Start training: epoch 3 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:49,  6.38it/s]


Timing: 15.99887728691101, Epoch: 3, training loss: 515.2265315055847, current learning rate 1e-05
val loss: 7.515343427658081
accuracy:      0.463
precision:     0.404
recall:        0.457
f1:            0.329
val loss: 7.451305449008942
accuracy:      0.496
precision:     0.411
recall:        0.516
f1:            0.357
===== Start training: epoch 4 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:47,  6.39it/s]


Timing: 15.96092963218689, Epoch: 4, training loss: 481.50446248054504, current learning rate 1e-05
val loss: 6.3478041887283325
accuracy:      0.490
precision:     0.428
recall:        0.548
f1:            0.369
val loss: 6.415797054767609
accuracy:      0.528
precision:     0.421
recall:        0.520
f1:            0.372
===== Start training: epoch 5 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:48,  6.39it/s]


Timing: 15.980791807174683, Epoch: 5, training loss: 472.98338413238525, current learning rate 1e-05
val loss: 6.505429267883301
accuracy:      0.493
precision:     0.410
recall:        0.608
f1:            0.354
val loss: 6.5521674156188965
accuracy:      0.516
precision:     0.419
recall:        0.587
f1:            0.380
===== Start training: epoch 6 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:49,  6.38it/s]


Timing: 16.005492687225342, Epoch: 6, training loss: 451.01759600639343, current learning rate 1e-05
val loss: 9.186402797698975
accuracy:      0.346
precision:     0.410
recall:        0.552
f1:            0.281
val loss: 9.077532172203064
accuracy:      0.372
precision:     0.425
recall:        0.602
f1:            0.320
===== Start training: epoch 7 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:49,  6.38it/s]


Timing: 15.997860193252563, Epoch: 7, training loss: 444.6938843727112, current learning rate 1e-05
val loss: 5.608456909656525
accuracy:      0.602
precision:     0.413
recall:        0.547
f1:            0.391
val loss: 5.45746922492981
accuracy:      0.633
precision:     0.440
recall:        0.611
f1:            0.445
===== Start training: epoch 8 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:49,  6.38it/s]


Timing: 16.00199794769287, Epoch: 8, training loss: 419.5132715702057, current learning rate 1e-05
val loss: 9.24976909160614
accuracy:      0.380
precision:     0.389
recall:        0.506
f1:            0.282
val loss: 8.946590065956116
accuracy:      0.418
precision:     0.416
recall:        0.592
f1:            0.335
===== Start training: epoch 9 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:48,  6.38it/s]


Timing: 15.981544971466064, Epoch: 9, training loss: 402.02240228652954, current learning rate 1e-05
val loss: 6.692895770072937
accuracy:      0.541
precision:     0.405
recall:        0.516
f1:            0.362
val loss: 6.722748816013336
accuracy:      0.547
precision:     0.424
recall:        0.556
f1:            0.394
===== Start training: epoch 10 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:48,  6.39it/s]


Timing: 15.976501703262329, Epoch: 10, training loss: 406.6651232242584, current learning rate 1e-05
val loss: 7.55213463306427
accuracy:      0.517
precision:     0.411
recall:        0.580
f1:            0.363
val loss: 7.244853496551514
accuracy:      0.543
precision:     0.426
recall:        0.590
f1:            0.393
===== Start training: epoch 11 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:48,  6.38it/s]


Timing: 15.983566284179688, Epoch: 11, training loss: 396.553368806839, current learning rate 1e-05
val loss: 6.610183477401733
accuracy:      0.607
precision:     0.417
recall:        0.585
f1:            0.404
val loss: 5.855854094028473
accuracy:      0.650
precision:     0.448
recall:        0.625
f1:            0.461
===== Start training: epoch 12 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:48,  6.38it/s]


Timing: 15.984911918640137, Epoch: 12, training loss: 386.3516447544098, current learning rate 1e-05
val loss: 5.863706588745117
accuracy:      0.632
precision:     0.411
recall:        0.521
f1:            0.399
val loss: 5.550284445285797
accuracy:      0.645
precision:     0.452
recall:        0.573
f1:            0.451
===== Start training: epoch 13 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:47,  6.39it/s]


Timing: 15.962932586669922, Epoch: 13, training loss: 381.72322630882263, current learning rate 1e-05
val loss: 5.604989409446716
accuracy:      0.680
precision:     0.417
recall:        0.525
f1:            0.420
val loss: 5.233764946460724
accuracy:      0.703
precision:     0.461
recall:        0.588
f1:            0.478
===== Start training: epoch 14 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:49,  6.38it/s]


Timing: 16.002930164337158, Epoch: 14, training loss: 369.671138048172, current learning rate 1e-05
val loss: 4.774506509304047
accuracy:      0.759
precision:     0.439
recall:        0.525
f1:            0.457
val loss: 5.1555362939834595
accuracy:      0.730
precision:     0.432
recall:        0.501
f1:            0.449
===== Start training: epoch 15 =====


Iteration:   2%|▏         | 102/5393 [00:16<13:54,  6.34it/s]


Timing: 16.08673620223999, Epoch: 15, training loss: 376.789981842041, current learning rate 1e-05
val loss: 5.590096443891525
accuracy:      0.707
precision:     0.424
recall:        0.528
f1:            0.433
val loss: 5.86376953125
accuracy:      0.691
precision:     0.432
recall:        0.535
f1:            0.446
===== Start training: epoch 16 =====


Iteration:   2%|▏         | 102/5393 [00:16<13:51,  6.36it/s]


Timing: 16.04320764541626, Epoch: 16, training loss: 355.57605147361755, current learning rate 1e-05
val loss: 5.845512330532074
accuracy:      0.688
precision:     0.436
recall:        0.594
f1:            0.450
val loss: 5.9844807386398315
accuracy:      0.696
precision:     0.448
recall:        0.565
f1:            0.464
===== Start training: epoch 17 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:49,  6.38it/s]


Timing: 15.997251272201538, Epoch: 17, training loss: 362.7685101032257, current learning rate 1e-05
val loss: 6.052819013595581
accuracy:      0.700
precision:     0.414
recall:        0.510
f1:            0.423
val loss: 6.585122168064117
accuracy:      0.693
precision:     0.429
recall:        0.521
f1:            0.441
===== Start training: epoch 18 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:49,  6.38it/s]


Timing: 16.002259016036987, Epoch: 18, training loss: 343.63867807388306, current learning rate 1e-05
val loss: 5.393662452697754
accuracy:      0.732
precision:     0.406
recall:        0.471
f1:            0.415
val loss: 5.306012809276581
accuracy:      0.745
precision:     0.469
recall:        0.534
f1:            0.483
===== Start training: epoch 19 =====


Iteration:   2%|▏         | 102/5393 [00:16<13:53,  6.35it/s]


Timing: 16.064141988754272, Epoch: 19, training loss: 369.45400309562683, current learning rate 1e-05
val loss: 5.343775451183319
accuracy:      0.759
precision:     0.423
recall:        0.466
f1:            0.429
val loss: 4.837587922811508
accuracy:      0.749
precision:     0.440
recall:        0.501
f1:            0.456
===== Start training: epoch 20 =====


Iteration:   2%|▏         | 102/5393 [00:16<13:50,  6.37it/s]


Timing: 16.016136407852173, Epoch: 20, training loss: 345.17992103099823, current learning rate 1e-05
val loss: 5.081326276063919
accuracy:      0.795
precision:     0.425
recall:        0.458
f1:            0.436
val loss: 4.926174432039261
accuracy:      0.796
precision:     0.467
recall:        0.491
f1:            0.476
===== Start training: epoch 21 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:48,  6.38it/s]


Timing: 15.985038995742798, Epoch: 21, training loss: 346.2790083885193, current learning rate 1e-05
val loss: 5.143565446138382
accuracy:      0.812
precision:     0.422
recall:        0.427
f1:            0.424
val loss: 4.665368437767029
accuracy:      0.822
precision:     0.469
recall:        0.459
f1:            0.464
===== Start training: epoch 22 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:49,  6.38it/s]


Timing: 15.99645209312439, Epoch: 22, training loss: 333.81808292865753, current learning rate 1e-05
val loss: 5.612482964992523
accuracy:      0.793
precision:     0.426
recall:        0.457
f1:            0.436
val loss: 5.601747542619705
accuracy:      0.791
precision:     0.464
recall:        0.503
f1:            0.479
===== Start training: epoch 23 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:47,  6.39it/s]


Timing: 15.958356618881226, Epoch: 23, training loss: 325.7392210960388, current learning rate 1e-05
val loss: 5.987660318613052
accuracy:      0.739
precision:     0.410
recall:        0.466
f1:            0.415
val loss: 6.357705354690552
accuracy:      0.764
precision:     0.455
recall:        0.521
f1:            0.473
===== Start training: epoch 24 =====


Iteration:   2%|▏         | 102/5393 [00:16<13:51,  6.37it/s]


Timing: 16.029468774795532, Epoch: 24, training loss: 326.2962477207184, current learning rate 1e-05
val loss: 6.274326920509338
accuracy:      0.763
precision:     0.427
recall:        0.490
f1:            0.443
val loss: 6.1867164969444275
accuracy:      0.762
precision:     0.463
recall:        0.548
f1:            0.487
===== Start training: epoch 25 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:48,  6.39it/s]


Timing: 15.973078489303589, Epoch: 25, training loss: 322.81314969062805, current learning rate 1e-05
val loss: 7.0572763085365295
accuracy:      0.744
precision:     0.401
recall:        0.446
f1:            0.408
val loss: 6.26028436422348
accuracy:      0.762
precision:     0.449
recall:        0.513
f1:            0.468
===== Start training: epoch 26 =====


Iteration:   2%|▏         | 102/5393 [00:16<13:53,  6.34it/s]


Timing: 16.08111810684204, Epoch: 26, training loss: 333.636222243309, current learning rate 1e-05
val loss: 5.951820492744446
accuracy:      0.802
precision:     0.434
recall:        0.468
f1:            0.446
val loss: 5.267104685306549
accuracy:      0.808
precision:     0.474
recall:        0.488
f1:            0.477
===== Start training: epoch 27 =====


Iteration:   2%|▏         | 102/5393 [00:16<13:54,  6.34it/s]


Timing: 16.1012601852417, Epoch: 27, training loss: 320.8388315439224, current learning rate 1e-05
val loss: 6.214389383792877
accuracy:      0.771
precision:     0.413
recall:        0.456
f1:            0.425
val loss: 5.997786104679108
accuracy:      0.786
precision:     0.473
recall:        0.522
f1:            0.489
===== Start training: epoch 28 =====


Iteration:   2%|▏         | 102/5393 [00:16<13:54,  6.34it/s]


Timing: 16.091766119003296, Epoch: 28, training loss: 323.3323255777359, current learning rate 1e-05
val loss: 6.803906321525574
accuracy:      0.815
precision:     0.450
recall:        0.480
f1:            0.460
val loss: 5.800000011920929
accuracy:      0.798
precision:     0.474
recall:        0.491
f1:            0.474
===== Start training: epoch 29 =====


Iteration:   2%|▏         | 102/5393 [00:16<13:51,  6.36it/s]


Timing: 16.03339982032776, Epoch: 29, training loss: 321.4562261104584, current learning rate 1e-05
val loss: 5.750877201557159
accuracy:      0.805
precision:     0.422
recall:        0.425
f1:            0.421
val loss: 5.608910381793976
accuracy:      0.820
precision:     0.483
recall:        0.472
f1:            0.476
===== Start training: epoch 30 =====


Iteration:   2%|▏         | 102/5393 [00:15<13:49,  6.38it/s]


Timing: 15.989839315414429, Epoch: 30, training loss: 322.2549877166748, current learning rate 1e-05
val loss: 6.2258501052856445
accuracy:      0.798
precision:     0.436
recall:        0.474
f1:            0.447
val loss: 6.1854923367500305
accuracy:      0.786
precision:     0.459
recall:        0.480
f1:            0.462
best result:
0.7980535279805353
0.47350395647873395
0.49131019639494217
0.4743438094457945
[[0.7980535279805353, 0.47350395647873395, 0.49131019639494217, 0.4743438094457945]]
tensor([0.7737, 0.4552, 0.4890, 0.4615], dtype=torch.float64)


In [11]:
from google.colab import runtime
runtime.unassign()