<a href="https://colab.research.google.com/github/lucacontalbo/ArgumentRelation/blob/main/ArgumentRelation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [2]:
!pip install datasets



In [3]:
%cd /content/drive/MyDrive/AttackSupport/

/content/drive/.shortcut-targets-by-id/1anU-7aAYPfQ-YWIA0AJDKz_AqkD19g_x/AttackSupport


In [4]:
import torch

device = torch.device("cpu")

if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

Training on GPU


In [5]:
import torch
from torch.utils.data import Dataset


class dataset(Dataset):
    """wrap in PyTorch Dataset"""
    def __init__(self, examples):
        super(dataset, self).__init__()
        self.examples = examples

    def __getitem__(self, idx):
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    ids_sent2 = torch.tensor(ids_sent2, dtype=torch.long)
    segs_sent2 = torch.tensor(segs_sent2, dtype=torch.long)
    att_mask_sent2 = torch.tensor(att_mask_sent2, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, labels

def collate_fn_concatenated(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

def collate_fn_concatenated_adv(examples):
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = map(list, zip(*examples))

    ids_sent1 = torch.tensor(ids_sent1, dtype=torch.long)
    segs_sent1 = torch.tensor(segs_sent1, dtype=torch.long)
    att_mask_sent1 = torch.tensor(att_mask_sent1, dtype=torch.long)
    position_sep = torch.tensor(position_sep, dtype=torch.long)
    #labels = torch.tensor(labels, dtype=torch.long)

    return ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels

In [6]:
import torch
import collections
import codecs
from transformers import AutoTokenizer, pipeline
from sklearn.preprocessing import OneHotEncoder
from transformers import pipeline
import pandas as pd

class DataProcessor:

  def __init__(self,):
    self.tokenizer = AutoTokenizer.from_pretrained(args["model_name"])
    self.max_sent_len = 150

  def __str__(self,):
    pattern = """General data processor: \n\n Tokenizer: {}\n\nMax sentence length: {}""".format(args["model_name"], self.max_sent_len)
    return pattern

  def _get_examples(self, dataset, dataset_type="train"):
    examples = []

    for row in dataset:
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      ids_sent1 = self.tokenizer.encode(sentence1)
      segs_sent1 = [0] * len(ids_sent1)
      segs_sent1[1:-1] = [1] * (len(ids_sent1)-2)

      """
      for the second sentence
      """

      ids_sent2 = self.tokenizer.encode(sentence2)
      segs_sent2 = [0] * len(ids_sent2)
      segs_sent2[1:-1] = [1] * (len(ids_sent2)-2)

      assert len(ids_sent1) == len(segs_sent1)
      assert len(ids_sent2) == len(segs_sent2)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len

      if len(ids_sent2) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent2)
        att_mask_sent2 = [1] * len(ids_sent2) + [0] * res
        ids_sent2 += [pad_id] * res
        segs_sent2 += [0] * res
      else:
        ids_sent2 = ids_sent2[:self.max_sent_len]
        segs_sent2 = segs_sent2[:self.max_sent_len]
        att_mask_sent2 = [1] * self.max_sent_len

      example = [ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

  def _get_examples_concatenated(self, dataset, dataset_type="train"):
    examples = []

    for row in tqdm(dataset, desc="tokenizing..."):
      id, sentence1, sentence2, _, _, _, label = row

      """
      for the first sentence
      """

      sentence1_length = len(self.tokenizer.encode(sentence1))
      sentence2_length = len(self.tokenizer.encode(sentence2))
      #sentence1 += " </s> "+sentence2

      ids_sent1 = self.tokenizer.encode(sentence1, sentence2)
      segs_sent1 = [0] * sentence1_length + [1] * (sentence2_length)
      position_sep = [1] * len(ids_sent1)
      position_sep[sentence1_length] = 1
      position_sep[0] = 0
      position_sep[1] = 1

      assert len(ids_sent1) == len(position_sep)
      assert len(ids_sent1) == len(segs_sent1)

      pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0]

      if len(ids_sent1) < self.max_sent_len:
        res = self.max_sent_len - len(ids_sent1)
        att_mask_sent1 = [1] * len(ids_sent1) + [0] * res
        ids_sent1 += [pad_id] * res
        segs_sent1 += [0] * res
        position_sep += [0] * res
      else:
        ids_sent1 = ids_sent1[:self.max_sent_len]
        segs_sent1 = segs_sent1[:self.max_sent_len]
        att_mask_sent1 = [1] * self.max_sent_len
        position_sep = position_sep[:self.max_sent_len]

      example = [ids_sent1, segs_sent1, att_mask_sent1, position_sep, label]

      examples.append(example)

    print(f"finished preprocessing examples in {dataset_type}")

    return examples

class DiscourseMarkerProcessor(DataProcessor):

  def __init__(self):
    super(DiscourseMarkerProcessor, self).__init__()
    #https://pdf.sciencedirectassets.com/271806/1-s2.0-S0378216600X00549/1-s2.0-S0378216698001015/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGgaCXVzLWVhc3QtMSJIMEYCIQCiYMlVmna%2BTaXH5hqdwfhEBWd2VPRNoAHlQLGxzvNEqAIhAO3TVTA51qn13kKQp2bTlzGkaKnf6NhMYtr7laU%2Byy0vKrwFCMH%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1Igzyz%2F2NAMoW0RbAZ%2BMqkAWG017si0y%2FOokz5T44gGpNBL07jup8MAQjv8iwoi4XGALwCP0nf%2FgHD1ZE%2B%2BQGuaLPuShgLg7Y3%2Fcsv2VjkbfrNBSdZPYhqpzpAClSmP2Zs0DszX0zXdmnx4uFyls6d9jCG4TQkhqTsNCGsnKjU89G7z9NMutpaWqEGcUWT6MVMXpxILGQfeu5zLM0ILcft20VXs2dnMMIjWXA5jd0pG8HnAXdils2AmfgTqt%2B9cHn5BXhv%2FaSXX9a7lwR7EbIoUqZVLo%2BDJR2JLtaLYdoZR01FI3FhNAk7Hx1ZLd3RSWWQrRy3ovGKbKnTYC8Jn%2Bs1w1tkF4OJzCy7EZg578HFrPsvxQrUGwtkXfY1BIralzc9JmYZ%2FS1VPIVSvZSM6E3sUUIND14uQDKhQyTQh6WBbG1djkU8M9bW%2ByVDRj8CKEoWdN4ofK3WuRD87QQEAJQ8jwnl0rCtVIYecZyfQzTnpdO0jafDlritW%2BlfSDqyd8ob%2F%2BkljgtN1m8IFKNQ9lopVjvwCzDa5R%2F0WvchF%2BqNMzImVtUHTgXgJOcGC6y9OSVqRGFgQtPhy6W26WodWQxaFsBMTn49dM6rzsyNhd301U4SYL5vTLrLhjmm3%2Ft5JqKHS7JaAbmKYa4DvabWH4Qs2WHsZMxVd8L3KU%2FIeyaQwATOf3TZVCVPWUriUg%2FAKFcuceC1AaUE5MKWB8Qe2Cb5%2FpagPPYTztfNluPar21xLpY7cayKABv%2FkyIa2N9MsaPm8VEvSb90Sl1EkJAxXP3kVU2XTtZqcYPuHgdSyUwh%2FDC%2F0Y1FlLyZW%2BLrnVmL9sqtORiZcZU20jVgXM8HoLIG2vvo0er4qyok9ZxzykuzhClN6ZULz%2FTja1y%2FdhF2UR89jCk%2BuOxBjqwARYSDjyJE7HksxP39FMsgAM0RH2Us2vj22eV6lkbG1n1%2BZm%2B4a4UeUfzibr4B6BdF%2BB3i%2FHsJ3QF1AnxdSS%2F0x5HnBmGCct1etAdyP60bbBH8p1dCgNQL7kb%2BqKINd78nYfrM0D0a4U%2Fxm2FUNln3swIdVpXtLtz0qY2QSaHbc6Ir6BCR8Kqm0FKQyhv1JSMkKfIdFQ9pYCVy8VAr%2BLBA9uSXfFDz6N67ruhh%2BzJWFn1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240506T173714Z&X-Amz-SignedHeaders=host&X-Amz-Expires=299&X-Amz-Credential=ASIAQ3PHCVTYUU54IYV3%2F20240506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=7303563def203481e9802185fa8eab7dd21d58c841cd92621bbbdb18e253f595&hash=62697dff1cc869850b2d38305ff2ad1a35ad33938dbb92466663bdb1bf67d069&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0378216698001015&tid=spdf-c2371d97-52db-456d-a333-31d65dde09f3&sid=c8e761cb3c79f3499a2a90d699ea19871d47gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=091359520307070d01&rr=87fabcafde0b0e1b&cc=it
    # TODO: refactor
    self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 1,
      "although": 2,
      "and": 1,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 2,
      "by_comparison": 2,
      "by_contrast": 2,
      "consequently": 0,
      "conversely": 0,
      "especially": 1,
      "further": 1,
      "furthermore": 1,
      "hence": 0,
      "however": 2,
      "in_contrast": 2,
      "instead": 2,
      "likewise": 1,
      "moreover": 1,
      "namely": 1,
      "nevertheless": 2,
      "nonetheless": 2,
      "on_the_contrary": 2,
      "on_the_other_hand": 2,
      "otherwise": 1,
      "rather": 2,
      "similarly": 1,
      "so": 0,
      "still": 2,
      "then": 0,
      "therefore": 0,
      "though": 2,
      "thus": 0,
      "well": 1,
      "yet": 2
    }

    """self.mapping = elements_dict = {
      "accordingly": 0,
      "also": 0,
      "although": 1,
      "and": 0,
      "as_a_result": 0,
      "because_of_that": 0,
      "because_of_this": 0,
      "besides": 0,
      "but": 1,
      "by_comparison": 1,
      "by_contrast": 1,
      "consequently": 0,
      "conversely": 0,
      "especially": 0,
      "further": 0,
      "furthermore": 0,
      "hence": 0,
      "however": 1,
      "in_contrast": 1,
      "instead": 1,
      "likewise": 0,
      "moreover": 0,
      "namely": 0,
      "nevertheless": 1,
      "nonetheless": 1,
      "on_the_contrary": 1,
      "on_the_other_hand": 1,
      "otherwise": 0,
      "rather": 1,
      "similarly": 0,
      "so": 0,
      "still": 1,
      "then": 0,
      "therefore": 0,
      "though": 1,
      "thus": 0,
      "well": 0,
      "yet": 1
    }"""

    self.id_to_word = {
      0: 'no-conn',
      1: 'absolutely',
      2: 'accordingly',
      3: 'actually',
      4: 'additionally',
      5: 'admittedly',
      6: 'afterward',
      7: 'again',
      8: 'already',
      9: 'also',
      10: 'alternately',
      11: 'alternatively',
      12: 'although',
      13: 'altogether',
      14: 'amazingly',
      15: 'and',
      16: 'anyway',
      17: 'apparently',
      18: 'arguably',
      19: 'as_a_result',
      20: 'basically',
      21: 'because_of_that',
      22: 'because_of_this',
      23: 'besides',
      24: 'but',
      25: 'by_comparison',
      26: 'by_contrast',
      27: 'by_doing_this',
      28: 'by_then',
      29: 'certainly',
      30: 'clearly',
      31: 'coincidentally',
      32: 'collectively',
      33: 'consequently',
      34: 'conversely',
      35: 'curiously',
      36: 'currently',
      37: 'elsewhere',
      38: 'especially',
      39: 'essentially',
      40: 'eventually',
      41: 'evidently',
      42: 'finally',
      43: 'first',
      44: 'firstly',
      45: 'for_example',
      46: 'for_instance',
      47: 'fortunately',
      48: 'frankly',
      49: 'frequently',
      50: 'further',
      51: 'furthermore',
      52: 'generally',
      53: 'gradually',
      54: 'happily',
      55: 'hence',
      56: 'here',
      57: 'historically',
      58: 'honestly',
      59: 'hopefully',
      60: 'however',
      61: 'ideally',
      62: 'immediately',
      63: 'importantly',
      64: 'in_contrast',
      65: 'in_fact',
      66: 'in_other_words',
      67: 'in_particular',
      68: 'in_short',
      69: 'in_sum',
      70: 'in_the_end',
      71: 'in_the_meantime',
      72: 'in_turn',
      73: 'incidentally',
      74: 'increasingly',
      75: 'indeed',
      76: 'inevitably',
      77: 'initially',
      78: 'instead',
      79: 'interestingly',
      80: 'ironically',
      81: 'lastly',
      82: 'lately',
      83: 'later',
      84: 'likewise',
      85: 'locally',
      86: 'luckily',
      87: 'maybe',
      88: 'meaning',
      89: 'meantime',
      90: 'meanwhile',
      91: 'moreover',
      92: 'mostly',
      93: 'namely',
      94: 'nationally',
      95: 'naturally',
      96: 'nevertheless',
      97: 'next',
      98: 'nonetheless',
      99: 'normally',
      100: 'notably',
      101: 'now',
      102: 'obviously',
      103: 'occasionally',
      104: 'oddly',
      105: 'often',
      106: 'on_the_contrary',
      107: 'on_the_other_hand',
      108: 'once',
      109: 'only',
      110: 'optionally',
      111: 'or',
      112: 'originally',
      113: 'otherwise',
      114: 'overall',
      115: 'particularly',
      116: 'perhaps',
      117: 'personally',
      118: 'plus',
      119: 'preferably',
      120: 'presently',
      121: 'presumably',
      122: 'previously',
      123: 'probably',
      124: 'rather',
      125: 'realistically',
      126: 'really',
      127: 'recently',
      128: 'regardless',
      129: 'remarkably',
      130: 'sadly',
      131: 'second',
      132: 'secondly',
      133: 'separately',
      134: 'seriously',
      135: 'significantly',
      136: 'similarly',
      137: 'simultaneously',
      138: 'slowly',
      139: 'so',
      140: 'sometimes',
      141: 'soon',
      142: 'specifically',
      143: 'still',
      144: 'strangely',
      145: 'subsequently',
      146: 'suddenly',
      147: 'supposedly',
      148: 'surely',
      149: 'surprisingly',
      150: 'technically',
      151: 'thankfully',
      152: 'then',
      153: 'theoretically',
      154: 'thereafter',
      155: 'thereby',
      156: 'therefore',
      157: 'third',
      158: 'thirdly',
      159: 'this',
      160: 'though',
      161: 'thus',
      162: 'together',
      163: 'traditionally',
      164: 'truly',
      165: 'truthfully',
      166: 'typically',
      167: 'ultimately',
      168: 'undoubtedly',
      169: 'unfortunately',
      170: 'unsurprisingly',
      171: 'usually',
      172: 'well',
      173: 'yet'
    }


  def process_dataset(self, dataset, name="train"):
    result = []
    new_dataset = []

    for sample in dataset:
      if self.id_to_word[sample["label"]] not in self.mapping.keys():
        continue

      new_dataset.append([sample["sentence1"], sample["sentence2"], self.mapping[self.id_to_word[sample["label"]]]])

    one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    labels = []

    for i, sample in tqdm(enumerate(new_dataset), desc="processing labels..."):
      labels.append([sample[-1]])

    print("one hot encoding...")
    labels = one_hot_encoder.fit_transform(labels)

    for i, (sample, label) in tqdm(enumerate(zip(new_dataset, labels)), desc="creating results..."):
      result.append([f"{name}_{i}", sample[0], sample[1], [], [], [], label])

    examples = self._get_examples_concatenated(result, name)
    return examples


class StudentEssayProcessor(DataProcessor):

  def __init__(self,):
    super(StudentEssayProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

class DebateProcessor(DataProcessor):

  def __init__(self,):
    super(DebateProcessor,self).__init__()

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              count +=1

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'neither':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class NKProcessor(DataProcessor):

  def __init__(self):
    super(NKProcessor, self).__init__()

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path, sep="\t")
      for i,row in df.iterrows():
              id_sample = row[0]
              label = row[2]

              sent = row[3].strip()
              target = row[4].strip()

              """ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target"""

              sentences.append(sent)
              target_sentences.append(target)
              id.append(id_sample)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'no_relation':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], [], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class StudentEssayWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(StudentEssayWithDiscourseInjectionProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class DebateWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self,):
    super(DebateWithDiscourseInjectionProcessor,self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def padding(self, input, maxlen):
      """
      Padding the input sequence.....
      """

      id, sentences, target, source_sentiment, target_sentiment, knowledge, label_distribution = zip(*input)

      sentences = torch.nn.utils.rnn.pad_sequence([torch.tensor(s) for s in sentences], batch_first=True, padding_value=0)
      knowledge = torch.nn.utils.rnn.pad_sequence([torch.tensor(k) for k in knowledge], batch_first=True, padding_value=0)
      target = torch.nn.utils.rnn.pad_sequence([torch.tensor(t) for t in target], batch_first=True, padding_value=0)

      return list(zip(sentences, knowledge, target, label_distribution))


  def create_batches_of_sentence_ids(self, sentences, batch_equal_size, max_batch_size):
      """
      Groups together sentences into batches
      If max_batch_size is positive, this value determines the maximum number of sentences in each batch.
      If max_batch_size has a negative value, the function dynamically creates the batches such that each batch contains abs(max_batch_size) words.
      Returns a list of lists with sentences ids.
      """
      batches_of_sentence_ids = []
      if batch_equal_size == True:
          sentence_ids_by_length = collections.OrderedDict()
          sentence_length_sum = 0.0
          for i in range(len(sentences)):
              length = len(sentences[i])
              if length not in sentence_ids_by_length:
                  sentence_ids_by_length[length] = []
              sentence_ids_by_length[length].append(i)

          for sentence_length in sentence_ids_by_length:
              if max_batch_size > 0:
                  batch_size = max_batch_size
              else:
                  batch_size = int((-1 * max_batch_size) / sentence_length)

              for i in range(0, len(sentence_ids_by_length[sentence_length]), batch_size):
                  batches_of_sentence_ids.append(sentence_ids_by_length[sentence_length][i:i + batch_size])
      else:
          current_batch = []
          max_sentence_length = 0
          for i in range(len(sentences)):
              current_batch.append(i)
              if len(sentences[i]) > max_sentence_length:
                  max_sentence_length = len(sentences[i])
              if (max_batch_size > 0 and len(current_batch) >= max_batch_size) \
                or (max_batch_size <= 0 and len(current_batch)*max_sentence_length >= (-1 * max_batch_size)):
                  batches_of_sentence_ids.append(current_batch)
                  current_batch = []
                  max_sentence_length = 0
          if len(current_batch) > 0:
              batches_of_sentence_ids.append(current_batch)
      return batches_of_sentence_ids


  def read_input_files(self, file_path, max_sentence_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """
      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      with codecs.open(file_path, encoding="ISO-8859-1", mode="r") as f:
        for line in f:
              know =[]
              #print(line)
              count +=1
              #print(count)
              line = line.replace("\n","")
              line = line.split("\t")

              if line == ['\r']:
                      continue
              count +=1
              story_id = line[0]
              sent = line[1].strip()
              target = line[3].strip()
              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target
              #print(target)
              label = line[-1].strip()
              facts = line[-2].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                    l=[1,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                    l=[0,1]
              label_distribution.append(l)
              #print(label_distribution)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples


class MARGWithDiscourseInjectionProcessor(DataProcessor):

  def __init__(self):
    super(MARGProcessor, self).__init__()
    self.pipe = pipeline("text-classification", model="sileod/roberta-base-discourse-marker-prediction")

  def read_input_files(self, file_path, max_sent_length=-1, name="train"):
      """
      Reads input files in tab-separated format.
      Will split file_paths on comma, reading from multiple files.
      """

      # Code copied from https://aclanthology.org/2023.eacl-main.182.pdf
      # TODO: refactor, several objects are not needed

      sentences = []
      labels = []
      label_distribution=[]
      target = []
      knowledge = []
      story_id_know=[]
      lst2=[]
      target_sentences = []
      source_senti = []
      target_senti = []
      id=[]
      count = 0

      df = pd.read_csv(file_path)
      for i,row in df.iterrows():
              if row[-1] != name:
                continue
              know =[]
              #print(line)
              count +=1
              #print(count)

              count +=1
              story_id = row[0]
              sent = row[1].strip()
              target = row[2].strip()

              ds_marker = self.pipe(f"{sent}</s></s>{target}")[0]["label"]
              ds_marker = ds_marker.replace("_", " ")
              ds_marker = ds_marker[0].upper() + ds_marker[1:]
              target = target[0].lower() + target[1:]
              target = ds_marker + " " + target

              #print(target)
              label = row[3].strip()
              facts = row[-3].strip()

              facts = facts.replace("_", "").replace("[", "").replace("]", "").replace("(", "").replace("(", "")

              lst2.append(facts)

              sentences.append(sent)
              target_sentences.append(target)
              id.append(story_id)

              l=[0,0,0]
              if label == 'supports' or label == 'support' or label == 'because':
                l = [1,0,0]
              elif label == 'attacks' or label == 'attack' or label == 'but':
                l = [0,1,0]
              elif label == 'neither':
                l = [0,0,1]

              label_distribution.append(l)

      result = []
      for i in range(len(label_distribution)):
        result.append([id[i],sentences[i],target_sentences[i], [], [], lst2[i], label_distribution[i]])

      examples = self._get_examples_concatenated(result, name)

      return examples

In [7]:
from transformers import AutoModel
from torch import nn

class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None

class DoubleAdversarialNet(torch.nn.Module):
  def __init__(self):
    super(DoubleAdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.attack_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.support_linear = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)
    self._init_weights(self.attack_linear)
    self._init_weights(self.support_linear)


  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      labels_std = labels[:batch_size // 2, :]

      emb_attack = samples[labels_std == [0,1]]
      emb_support = samples[labels_std == [1,0]]

      samples_adv = H_sent[batch_size // 2:, ]
      labels_adv = labels[batch_size // 2:, :]

      emb_caus = samples_adv[labels_adv == [1,0,0]]
      emb_other = samples_adv[labels_adv == [0,1,0] or labels_adv == [0,0,1]]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)
      attack_prediction = self.attack_linear(mean_grl)
      support_prediction = self.support_linear(mean_grl)

      return predictions, predictions_adv, task_prediction, attack_prediction, support_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions

class AdversarialNet(torch.nn.Module):
  def __init__(self):
    super(AdversarialNet, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.num_classes_adv = args["num_classes_adv"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes)
    self.linear_layer_adv = torch.nn.Linear(in_features=self.embed_size, out_features=self.num_classes_adv)
    self.task_linear = torch.nn.Linear(in_features=self.embed_size, out_features=2)

    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.linear_layer_adv)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)
    self._init_weights(self.task_linear)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=False):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    if visualize:
      return H_sent
    if self.training:
      batch_size = H_sent.shape[0]
      samples = H_sent[:batch_size // 2, :]
      samples_adv = H_sent[batch_size // 2:, ]

      predictions = self.linear_layer(samples)
      predictions_adv = self.linear_layer_adv(samples_adv)

      mean_grl = GRLayer.apply(torch.mean(embed_sent1, dim=1), .01)
      task_prediction = self.task_linear(mean_grl)

      return predictions, predictions_adv, task_prediction
    else:
      predictions = self.linear_layer(H_sent)

      return predictions


class BaselineModelWithSentenceComparison(torch.nn.Module):
  def __init__(self):
    super(BaselineModelWithSentenceComparison, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])
    self.multi_head_att = torch.nn.MultiheadAttention(self.embed_size, 8, batch_first=True)
    self.Q = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.K = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)
    self.V = torch.nn.Linear(in_features=self.embed_size, out_features=self.embed_size)

    self._init_weights(self.linear_layer)
    self._init_weights(self.Q)
    self._init_weights(self.K)
    self._init_weights(self.V)
    self._init_weights(self.multi_head_att)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (segs_sent1 == 0).long()
    tar_mask_sent2 = (segs_sent1 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent1)

    K_sent1 = self.K(H_sent1)
    V_sent1 = self.V(H_sent1)
    Q_sent2 = self.Q(H_sent2)

    att_output = self.multi_head_att(Q_sent2, K_sent1, V_sent1)

    H_sent = torch.mean(att_output[0], dim=1)

    predictions = self.linear_layer(H_sent)

    return predictions


class BaselineModel(torch.nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, position_sep):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
    else:
      embed_sent1 = last_sent1

    tar_mask_sent1 = (position_sep == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)

    H_mean1 = torch.mean(embed_sent1, dim=1)

    predictions = self.linear_layer(H_mean1)

    return predictions


class SiameseBaselineModel(torch.nn.Module):
  def __init__(self):
    super(SiameseBaselineModel, self).__init__()

    self.plm = AutoModel.from_pretrained(args["model_name"])
    config = self.plm.config
    config.type_vocab_size = 4
    self.plm.embeddings.token_type_embeddings = nn.Embedding(
      config.type_vocab_size, config.hidden_size
    )
    self.plm._init_weights(self.plm.embeddings.token_type_embeddings)

    self.num_classes = args["num_classes"]
    self.embed_size = args["embed_size"]

    self.first_last_avg = args["first_last_avg"]

    for param in self.plm.parameters():
      param.requires_grad = True

    self.linear_layer = torch.nn.Linear(in_features=self.embed_size*2, out_features=args["num_classes"])

    self._init_weights(self.linear_layer)

  def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, (nn.Linear, nn.Embedding)):
      module.weight.data.normal_(mean=0.0, std=self.plm.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

  @torch.autocast(device_type="cuda")
  def forward(self, ids_sent1, segs_sent1, att_mask_sent1, ids_sent2, segs_sent2, att_mask_sent2):
    out_sent1 = self.plm(ids_sent1, token_type_ids=segs_sent1, attention_mask=att_mask_sent1, output_hidden_states=True)
    out_sent2 = self.plm(ids_sent2, token_type_ids=segs_sent2, attention_mask=att_mask_sent2, output_hidden_states=True)

    last_sent1, first_sent1 = out_sent1.hidden_states[-1], out_sent1.hidden_states[1]
    last_sent2, first_sent2 = out_sent2.hidden_states[-1], out_sent2.hidden_states[1]

    if self.first_last_avg:
      embed_sent1 = torch.div((last_sent1 + first_sent1), 2)
      embed_sent2 = torch.div((last_sent2 + first_sent2), 2)
    else:
      embed_sent1 = last_sent1
      embed_sent2 = last_sent2

    tar_mask_sent1 = (segs_sent1 == 1).long()
    tar_mask_sent2 = (segs_sent2 == 1).long()

    H_sent1 = torch.mul(tar_mask_sent1.unsqueeze(2), embed_sent1)
    H_sent2 = torch.mul(tar_mask_sent2.unsqueeze(2), embed_sent2)

    H_mean1 = torch.mean(embed_sent1, dim=1)
    H_mean2 = torch.mean(embed_sent2, dim=1)

    H_cat = torch.cat((H_mean1, H_mean2), dim=-1)

    predictions = self.linear_layer(H_cat)

    return predictions

In [8]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def set_random_seeds(seed):
    """
    set random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def output_metrics(labels, preds):
    """

    :param labels: ground truth labels
    :param preds: prediction labels
    :return: accuracy, precision, recall, f1
    """
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")

    print("{:15}{:<.3f}".format('accuracy:', accuracy))
    print("{:15}{:<.3f}".format('precision:', precision))
    print("{:15}{:<.3f}".format('recall:', recall))
    print("{:15}{:<.3f}".format('f1:', f1))

    return accuracy, precision, recall, f1

In [9]:
from torch.utils.data import Sampler

class BalancedSampler(Sampler):
    def __init__(self, dataset1, dataset2, batch_size):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.batch_size = batch_size
        self.total_size = len(dataset1) + len(dataset2)
        self.indices1 = list(range(len(dataset1)))
        self.indices2 = list(range(len(dataset2)))
        self.epoch = 0

    def __iter__(self):
        self.epoch += 1
        batch = []
        indices1 = self.indices1.copy()
        indices2 = self.indices2.copy()

        indices1 = torch.randperm(len(self.dataset1)).cpu().tolist()
        indices2 = torch.randperm(len(self.dataset2)) +len(indices1)
        indices2 = indices2.cpu().tolist()

        for i in range(self.total_size // self.batch_size):
            batch_size1 = min(self.batch_size // 2, len(indices1))
            if batch_size1 < (self.batch_size // 2):
              break
            batch_size2 = self.batch_size - batch_size1
            batch.extend([indices1.pop() for _ in range(batch_size1)])
            batch.extend([indices2.pop() for _ in range(batch_size2)])

            yield batch
            batch = []
            if len(indices1) == 0:
              break

    def __len__(self):
        return (self.total_size + self.batch_size - 1) // self.batch_size

In [10]:
import random
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
import time
import datasets
import pickle

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm

args = {
    "model_name": "roberta-base",
    "num_classes": 3, #3, #2,
    "num_classes_adv": 3, #174,
    "embed_size": 768,
    "first_last_avg": True,
    "seed": 1,
    "batch_size": 64,
    "epochs": 30,
    "class_weight": [2.071, 1.933, 1], #[9.375, 30, 1], #10, [2.071, 1.933, 1]
    "lr": 1e-5
}

config = {
    "dataset": "nk", #"student_essay", #debate, m-arg, nk
    "adversarial": False,
    "double_adversarial": False,
    "dataset_from_saved": True,
    "injection": False,
    "grid_search": False,
    "visualize": True,
    "train": True,
}

def train(epoch, model, loss_fn, optimizer, train_loader, scheduler=None, discovery_weight=0.3, adv_weight=0.3):
    epoch_start_time = time.time()
    model.train()
    tr_loss = 0
    loss_fn2 = nn.CrossEntropyLoss()

    for step, batch in enumerate(tqdm(train_loader, desc='Iteration')):
        if config["adversarial"]:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
        else:
          batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch) #tuple(t.to(device) for t in batch)

        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        if config["adversarial"]:
          pred, pred_adv, task_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3
        elif config["double_adversarial"]:
          pred, pred_adv, task_pred, attack_pred, support_pred = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels)
          try:
            half_batch_size = len(labels) // 2
            targets, targets_adv, targets_task = labels[:half_batch_size], labels[half_batch_size:], [[0, 1]] * half_batch_size + [[1, 0]] * half_batch_size
            targets, targets_adv, targets_task = torch.tensor(np.array(targets)).to(device), \
                                                 torch.tensor(np.array(targets_adv)).to(device), \
                                                 torch.tensor(np.array(targets_task)).to(device)
            attack_target = targets[targets == [0,1]]
            support_target = targets[targets == [1,0]]

            cause_target = targets_adv[targets_adv == [1,0,0]]
            other_target = targets_adv[targets_adv == [0,1,0] or targets_adv == [0,0,1]]
          except:
            print("error")

          loss1 = loss_fn(pred, targets.float())
          loss2 = loss_fn2(pred_adv, targets_adv.float())
          loss3 = loss_fn2(task_pred, targets_task.float())
          loss4 = loss_fn2(attack_pred, attack_target.float())
          loss5 = loss_fn2(support_pred, support_target.float())
          loss = loss1 + discovery_weight*loss2 + adv_weight*loss3 + .2*loss4 + .2*loss5
        else:
          out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
          if isinstance(labels, list):
            labels = torch.tensor(np.array(labels)).to(device)
          loss = loss_fn(out, labels.float())

        tr_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    timing = time.time() - epoch_start_time
    cur_lr = optimizer.param_groups[0]["lr"]
    print(f"Timing: {timing}, Epoch: {epoch + 1}, training loss: {tr_loss}, current learning rate {cur_lr}")

def val(model, val_loader):
    model.eval()

    loss_fn = nn.CrossEntropyLoss()

    val_loss = 0
    val_preds = []
    val_labels = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch

        with torch.no_grad():
            out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep)
            preds = torch.max(out.data, 1)[1].cpu().numpy().tolist()
            loss = loss_fn(out, labels.float())
            val_loss += loss.item()

            labels = labels.cpu().numpy().tolist()

            val_labels.extend(labels)
            if len(labels[0]) != 2:
              for pred in preds:
                if pred == 0:
                  val_preds.append([1,0,0])
                elif pred == 1:
                  val_preds.append([0,1,0])
                else:
                  val_preds.append([0,0,1])
            else:
              val_preds.extend([[1,0] if pred == 0 else [0,1] for pred in preds])

    print(f"val loss: {val_loss}")

    val_acc, val_prec, val_recall, val_f1 = output_metrics(val_labels, val_preds)
    return val_acc, val_prec, val_recall, val_f1

def visualize(epoch, model, test_dataloader, train_adv_dataloader, discovery_weight = 0.2, adv_weight = 0.2):
  if not config["adversarial"]:
    return

  model.eval()

  loss_fn = nn.CrossEntropyLoss()

  tot_labels = None
  embeddings = None

  tot_labels_adv = None
  embeddings_adv = None

  print("Visualizing...")
  for batch in tqdm(test_dataloader):
    batch = tuple(t.to(device) for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.argmax(labels, dim=-1)
    if tot_labels is None:
      tot_labels = labels
    else:
      tot_labels = torch.cat([tot_labels, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings is None:
        embeddings = out
      else:
        embeddings = torch.cat([embeddings, out], dim=0)

  for i, batch in tqdm(enumerate(train_adv_dataloader)):
    if i == 20: break
    batch = tuple(t.to(device) if not isinstance(t, list) else t for t in batch)
    ids_sent1, segs_sent1, att_mask_sent1, position_sep, labels = batch
    labels = torch.tensor(np.array(labels)).to(device)
    labels = torch.argmax(labels, dim=-1)+2

    if tot_labels_adv is None:
      tot_labels_adv = labels
    else:
      tot_labels_adv = torch.cat([tot_labels_adv, labels], dim=0)

    with torch.no_grad():
      out = model(ids_sent1, segs_sent1, att_mask_sent1, position_sep, visualize=True)
      if embeddings_adv is None:
        embeddings_adv = out
      else:
        embeddings_adv = torch.cat([embeddings_adv, out], dim=0)

  tsne = TSNE(random_state=0)
  tsne_results = tsne.fit_transform(embeddings.cpu().numpy())
  tsne_results_adv = tsne.fit_transform(embeddings_adv.cpu().numpy())

  df_tsne = pd.DataFrame(tsne_results, columns=["x","y"])
  df_tsne_adv = pd.DataFrame(tsne_results_adv, columns=["x","y"])

  df_tsne["label"] = tot_labels.cpu().numpy()
  df_tsne_adv["label"] = tot_labels_adv.cpu().numpy()

  print(df_tsne_adv["label"].unique())

  fig1, ax1 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

  fig2, ax2 = plt.subplots(figsize=(8,6))
  sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
  sns.scatterplot(data=df_tsne_adv, x='x', y='y', hue='label', palette='deep')
  sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 1))
  plt.title(f'Scatter plot of embeddings trained with α = {discovery_weight} and μ = {adv_weight}');
  plt.xlabel('x');
  plt.ylabel('y');
  plt.axis('equal')
  plt.show()

def run():
  set_random_seeds(args["seed"])

  if config["dataset"] == "student_essay":
    if config["injection"]:
      processor = StudentEssayWithDiscourseInjectionProcessor()
    else:
      processor = StudentEssayProcessor()

    path_train = "./data/student_essay/train_essay_wordnet_concept.txt"
    path_dev = "./data/student_essay/dev_essay_wordnet_concept.txt"
    path_test = "./data/student_essay/test_essay_wordnet_concept.txt"
  elif config["dataset"] == "debate":
    if config["injection"]:
      processor = DebateWithDiscourseInjectionProcessor()
    else:
      processor = DebateProcessor()

    path_train = "./data/debate/train_debate_concept.txt"
    path_dev = "./data/debate/dev_debate_concept.txt"
    path_test = "./data/debate/test_debate_concept.txt"
  elif config["dataset"] == "m-arg":
    if config["injection"]:
      processor = MARGWithDiscourseInjectionProcessor()
    else:
      processor = MARGProcessor()

    path_train = "./data/m-arg/presidential_final.csv"
    path_dev = path_train
    path_test = path_train
  elif config["dataset"] == "nk":
    if config["injection"]:
      processor = NKProcessor()
    else:
      processor = NKProcessor()

    path_train = "./data/nk/balanced_dataset.tsv"
  else:
    raise ValueError(f"{config['dataset']} is not a valid database name (choose from 'student_essay' and 'debate')")

  max_sent_length = -1

  data_train = processor.read_input_files(path_train, max_sent_length, name="train")
  if config["dataset"] == "nk":
    data_dev = data_train[:len(data_train) // 10]
    data_test = data_train[-(len(data_train) // 10):]
    data_train = data_train[(len(data_train) // 10) : -(len(data_train) // 10)]
  else:
    data_dev = processor.read_input_files(path_dev, max_sent_length, name="dev")
    data_test = processor.read_input_files(path_test, max_sent_length, name="test")

  if config["adversarial"]:
    df = datasets.load_dataset("discovery","discovery")
    adv_processor = DiscourseMarkerProcessor()
    if not config["dataset_from_saved"]:
      print("processing discourse marker dataset...")
      train_adv = adv_processor.process_dataset(df["train"])
      with open("./adv_dataset.pkl", "wb") as writer:
        pickle.dump(train_adv, writer)
    else:
      with open("./adv_dataset.pkl", "rb") as reader:
        train_adv = pickle.load(reader)

    data_train_tot = data_train + train_adv
    train_set_adv = dataset(train_adv)
  else:
    data_train_tot = data_train

  train_set = dataset(data_train_tot)
  dev_set = dataset(data_dev)
  test_set = dataset(data_test)

  if not config["adversarial"]:
    train_dataloader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
    model = BaselineModelWithSentenceComparison()
  else:
    sampler_train = BalancedSampler(data_train, train_adv, args["batch_size"])
    train_dataloader = DataLoader(train_set, batch_sampler=sampler_train, collate_fn=collate_fn_concatenated_adv)
    train_adv_dataloader = DataLoader(train_set_adv, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated_adv)

    model = AdversarialNet()

  model.to(device)

  dev_dataloader = DataLoader(dev_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)
  test_dataloader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_fn_concatenated)

  no_decay = ["bias", "LayerNorm.weight"]
  optimizer_grouped_parameters = [
    {
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": 0.01,
    },
    {
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0
    },
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

  if config["dataset"] == "m-arg" or config["dataset"] == "nk":
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(args["class_weight"]).to(device))
  else:
    loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1, args["class_weight"]]).to(device))

  best_acc = -1
  best_pre = -1
  best_rec = -1
  best_f1 = -1
  best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1

  result_metrics = []

  if config["grid_search"]:
    range_disc = np.arange(0,1.2,0.2)
    range_adv = np.arange(0,1.2,0.2)

    for discovery_weight in range_disc:
      for adv_weight in range_adv:
        for epoch in range(args["epochs"]):
          print('===== Start training: epoch {} ====='.format(epoch + 1))
          print(f"*** trying with discovery_weight = {discovery_weight}, adv_weight = {adv_weight}")
          train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=discovery_weight, adv_weight=adv_weight)
          dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
          test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
          if dev_f1 > best_dev_f1:
            best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
            best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
            #save model

        print('best result:')
        print(best_test_acc)
        print(best_test_pre)
        print(best_test_rec)
        print(best_test_f1)
        result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])
        del model
        del optimizer

        set_random_seeds(args["seed"])
        model = AdversarialNet()
        model = model.to(device)

        optimizer_grouped_parameters = [
          {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
          },
          {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0
          },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=args["lr"], weight_decay=1e-2)

        best_acc = -1
        best_pre = -1
        best_rec = -1
        best_f1 = -1
        best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = -1, -1, -1, -1
  else:
    for epoch in range(args["epochs"]):
      if config["train"]:
        print('===== Start training: epoch {} ====='.format(epoch + 1))
        train(epoch, model, loss_fn, optimizer, train_dataloader, discovery_weight=0.6, adv_weight=0.6)
        dev_a, dev_p, dev_r, dev_f1 = val(model, dev_dataloader)
        test_a, test_p, test_r, test_f1 = val(model, test_dataloader)
        if dev_f1 > best_dev_f1:
          best_dev_acc, best_dev_pre, best_dev_rec, best_dev_f1 = dev_a, dev_p, dev_r, dev_f1
          best_test_acc, best_test_pre, best_test_rec, best_test_f1 = test_a, test_p, test_r, test_f1
          torch.save(model.state_dict(), f"./{config['dataset']}_model.pt")

    if config["visualize"] and config["adversarial"]:
      model.load_state_dict(torch.load(f"./{config['dataset']}_model.pt"))
      visualize(epoch, model, test_dataloader, train_adv_dataloader, 0.6, 0.6)

        #save model

    print('best result:')
    print(best_test_acc)
    print(best_test_pre)
    print(best_test_rec)
    print(best_test_f1)
    result_metrics.append([best_test_acc, best_test_pre, best_test_rec, best_test_f1])

  print(result_metrics)

if __name__ == "__main__":
  run()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
tokenizing...:  26%|██▋       | 387/1462 [00:00<00:01, 766.85it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
tokenizing...: 100%|██████████| 1462/1462 [00:01<00:00, 796.83it/s]


finished preprocessing examples in train


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


===== Start training: epoch 1 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  4.97it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Timing: 3.8255715370178223, Epoch: 1, training loss: 31.32738745212555, current learning rate 1e-05
val loss: 3.482656717300415
accuracy:      0.075
precision:     0.025
recall:        0.333
f1:            0.047


  _warn_prf(average, modifier, msg_start, len(result))


val loss: 3.4462162256240845
accuracy:      0.144
precision:     0.048
recall:        0.333
f1:            0.084
===== Start training: epoch 2 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.30it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Timing: 3.0212106704711914, Epoch: 2, training loss: 30.940184712409973, current learning rate 1e-05
val loss: 3.441274881362915
accuracy:      0.075
precision:     0.025
recall:        0.333
f1:            0.047


  _warn_prf(average, modifier, msg_start, len(result))


val loss: 3.418251872062683
accuracy:      0.144
precision:     0.048
recall:        0.333
f1:            0.084
===== Start training: epoch 3 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.28it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Timing: 3.0323874950408936, Epoch: 3, training loss: 31.05383849143982, current learning rate 1e-05
val loss: 3.4736878871917725
accuracy:      0.075
precision:     0.025
recall:        0.333
f1:            0.047


  _warn_prf(average, modifier, msg_start, len(result))


val loss: 3.470908284187317
accuracy:      0.144
precision:     0.048
recall:        0.333
f1:            0.084
===== Start training: epoch 4 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.23it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Timing: 3.059558629989624, Epoch: 4, training loss: 30.91460657119751, current learning rate 1e-05
val loss: 3.4555732011795044
accuracy:      0.075
precision:     0.025
recall:        0.333
f1:            0.047


  _warn_prf(average, modifier, msg_start, len(result))


val loss: 3.4423582553863525
accuracy:      0.158
precision:     0.382
recall:        0.343
f1:            0.104
===== Start training: epoch 5 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.24it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Timing: 3.0530998706817627, Epoch: 5, training loss: 30.103022694587708, current learning rate 1e-05
val loss: 3.418285846710205
accuracy:      0.205
precision:     0.250
recall:        0.424
f1:            0.185
val loss: 3.55084228515625
accuracy:      0.233
precision:     0.443
recall:        0.386
f1:            0.220
===== Start training: epoch 6 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.31it/s]


Timing: 3.0168001651763916, Epoch: 6, training loss: 29.435927987098694, current learning rate 1e-05
val loss: 3.4099375009536743
accuracy:      0.349
precision:     0.478
recall:        0.427
f1:            0.335
val loss: 3.320844054222107
accuracy:      0.315
precision:     0.416
recall:        0.405
f1:            0.318
===== Start training: epoch 7 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.30it/s]


Timing: 3.0240118503570557, Epoch: 7, training loss: 28.36738419532776, current learning rate 1e-05
val loss: 3.3038114309310913
accuracy:      0.452
precision:     0.492
recall:        0.475
f1:            0.405
val loss: 3.272351622581482
accuracy:      0.459
precision:     0.463
recall:        0.475
f1:            0.441
===== Start training: epoch 8 =====


Iteration: 100%|██████████| 19/19 [00:02<00:00,  6.34it/s]


Timing: 3.0051982402801514, Epoch: 8, training loss: 27.116580486297607, current learning rate 1e-05
val loss: 3.173014283180237
accuracy:      0.445
precision:     0.519
recall:        0.492
f1:            0.367
val loss: 3.2756237983703613
accuracy:      0.404
precision:     0.477
recall:        0.453
f1:            0.381
===== Start training: epoch 9 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.32it/s]


Timing: 3.010497808456421, Epoch: 9, training loss: 25.76430857181549, current learning rate 1e-05
val loss: 2.972774922847748
accuracy:      0.500
precision:     0.485
recall:        0.483
f1:            0.427
val loss: 3.3347584009170532
accuracy:      0.527
precision:     0.500
recall:        0.507
f1:            0.491
===== Start training: epoch 10 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.28it/s]


Timing: 3.0330522060394287, Epoch: 10, training loss: 25.98775827884674, current learning rate 1e-05
val loss: 3.538685441017151
accuracy:      0.438
precision:     0.470
recall:        0.466
f1:            0.397
val loss: 3.5306856632232666
accuracy:      0.459
precision:     0.491
recall:        0.485
f1:            0.446
===== Start training: epoch 11 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.33it/s]


Timing: 3.0052804946899414, Epoch: 11, training loss: 24.837236046791077, current learning rate 1e-05
val loss: 3.625059723854065
accuracy:      0.370
precision:     0.493
recall:        0.490
f1:            0.334
val loss: 3.7067705392837524
accuracy:      0.356
precision:     0.492
recall:        0.435
f1:            0.349
===== Start training: epoch 12 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.29it/s]


Timing: 3.0273795127868652, Epoch: 12, training loss: 22.99711847305298, current learning rate 1e-05
val loss: 3.370778799057007
accuracy:      0.445
precision:     0.501
recall:        0.494
f1:            0.393
val loss: 3.4598267674446106
accuracy:      0.438
precision:     0.476
recall:        0.466
f1:            0.419
===== Start training: epoch 13 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.22it/s]


Timing: 3.0605199337005615, Epoch: 13, training loss: 22.32637047767639, current learning rate 1e-05
val loss: 3.1089184880256653
accuracy:      0.459
precision:     0.506
recall:        0.479
f1:            0.402
val loss: 3.5393375158309937
accuracy:      0.473
precision:     0.493
recall:        0.489
f1:            0.452
===== Start training: epoch 14 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.31it/s]


Timing: 3.017638921737671, Epoch: 14, training loss: 21.416475415229797, current learning rate 1e-05
val loss: 3.574809432029724
accuracy:      0.452
precision:     0.494
recall:        0.474
f1:            0.401
val loss: 3.6791552305221558
accuracy:      0.445
precision:     0.473
recall:        0.480
f1:            0.430
===== Start training: epoch 15 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.32it/s]


Timing: 3.0132830142974854, Epoch: 15, training loss: 20.11404538154602, current learning rate 1e-05
val loss: 3.3177731037139893
accuracy:      0.493
precision:     0.437
recall:        0.433
f1:            0.414
val loss: 3.5844134092330933
accuracy:      0.548
precision:     0.491
recall:        0.477
f1:            0.470
===== Start training: epoch 16 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.31it/s]


Timing: 3.0173683166503906, Epoch: 16, training loss: 20.89114022254944, current learning rate 1e-05
val loss: 4.214674472808838
accuracy:      0.411
precision:     0.471
recall:        0.473
f1:            0.387
val loss: 4.603760242462158
accuracy:      0.411
precision:     0.463
recall:        0.442
f1:            0.397
===== Start training: epoch 17 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.32it/s]


Timing: 3.0097830295562744, Epoch: 17, training loss: 19.726694345474243, current learning rate 1e-05
val loss: 3.6170283555984497
accuracy:      0.507
precision:     0.470
recall:        0.492
f1:            0.446
val loss: 3.6320589184761047
accuracy:      0.527
precision:     0.485
recall:        0.498
f1:            0.484
===== Start training: epoch 18 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.32it/s]


Timing: 3.013113021850586, Epoch: 18, training loss: 19.422439694404602, current learning rate 1e-05
val loss: 3.5730401277542114
accuracy:      0.575
precision:     0.501
recall:        0.514
f1:            0.485
val loss: 3.883696675300598
accuracy:      0.548
precision:     0.493
recall:        0.498
f1:            0.495
===== Start training: epoch 19 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.31it/s]


Timing: 3.0189719200134277, Epoch: 19, training loss: 18.42746078968048, current learning rate 1e-05
val loss: 3.3717787861824036
accuracy:      0.500
precision:     0.460
recall:        0.486
f1:            0.440
val loss: 3.820618748664856
accuracy:      0.562
precision:     0.518
recall:        0.530
f1:            0.520
===== Start training: epoch 20 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.31it/s]


Timing: 3.016740083694458, Epoch: 20, training loss: 17.710244238376617, current learning rate 1e-05
val loss: 3.8579992055892944
accuracy:      0.541
precision:     0.495
recall:        0.513
f1:            0.465
val loss: 4.016133189201355
accuracy:      0.507
precision:     0.470
recall:        0.481
f1:            0.469
===== Start training: epoch 21 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.31it/s]


Timing: 3.0210089683532715, Epoch: 21, training loss: 15.957435309886932, current learning rate 1e-05
val loss: 3.878405213356018
accuracy:      0.479
precision:     0.467
recall:        0.470
f1:            0.420
val loss: 4.177996397018433
accuracy:      0.527
precision:     0.497
recall:        0.506
f1:            0.491
===== Start training: epoch 22 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.29it/s]


Timing: 3.0248098373413086, Epoch: 22, training loss: 15.905399739742279, current learning rate 1e-05
val loss: 4.467102289199829
accuracy:      0.445
precision:     0.466
recall:        0.470
f1:            0.398
val loss: 4.645256519317627
accuracy:      0.438
precision:     0.432
recall:        0.448
f1:            0.420
===== Start training: epoch 23 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.31it/s]


Timing: 3.0144741535186768, Epoch: 23, training loss: 15.361664295196533, current learning rate 1e-05
val loss: 3.9182225465774536
accuracy:      0.568
precision:     0.492
recall:        0.484
f1:            0.469
val loss: 3.665548801422119
accuracy:      0.568
precision:     0.516
recall:        0.514
f1:            0.515
===== Start training: epoch 24 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.26it/s]


Timing: 3.0386645793914795, Epoch: 24, training loss: 14.93480521440506, current learning rate 1e-05
val loss: 4.639623522758484
accuracy:      0.452
precision:     0.463
recall:        0.451
f1:            0.404
val loss: 4.313200116157532
accuracy:      0.438
precision:     0.428
recall:        0.435
f1:            0.415
===== Start training: epoch 25 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.32it/s]


Timing: 3.0098681449890137, Epoch: 25, training loss: 13.565258085727692, current learning rate 1e-05
val loss: 3.7090474367141724
accuracy:      0.521
precision:     0.448
recall:        0.451
f1:            0.435
val loss: 3.9936596155166626
accuracy:      0.534
precision:     0.486
recall:        0.493
f1:            0.487
===== Start training: epoch 26 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.33it/s]


Timing: 3.006742000579834, Epoch: 26, training loss: 12.410157144069672, current learning rate 1e-05
val loss: 4.461904287338257
accuracy:      0.507
precision:     0.459
recall:        0.468
f1:            0.437
val loss: 4.588837623596191
accuracy:      0.500
precision:     0.474
recall:        0.475
f1:            0.454
===== Start training: epoch 27 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.33it/s]


Timing: 3.007519006729126, Epoch: 27, training loss: 12.401977062225342, current learning rate 1e-05
val loss: 4.060092806816101
accuracy:      0.568
precision:     0.479
recall:        0.483
f1:            0.466
val loss: 4.036096811294556
accuracy:      0.575
precision:     0.523
recall:        0.499
f1:            0.504
===== Start training: epoch 28 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.30it/s]


Timing: 3.020162582397461, Epoch: 28, training loss: 12.90627259016037, current learning rate 1e-05
val loss: 5.191008448600769
accuracy:      0.473
precision:     0.460
recall:        0.443
f1:            0.415
val loss: 4.783647418022156
accuracy:      0.466
precision:     0.476
recall:        0.483
f1:            0.443
===== Start training: epoch 29 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.31it/s]


Timing: 3.0167863368988037, Epoch: 29, training loss: 12.120983064174652, current learning rate 1e-05
val loss: 4.728011727333069
accuracy:      0.514
precision:     0.497
recall:        0.521
f1:            0.458
val loss: 4.544405579566956
accuracy:      0.466
precision:     0.451
recall:        0.465
f1:            0.440
===== Start training: epoch 30 =====


Iteration: 100%|██████████| 19/19 [00:03<00:00,  6.32it/s]


Timing: 3.010354995727539, Epoch: 30, training loss: 12.120865643024445, current learning rate 1e-05
val loss: 4.283243775367737
accuracy:      0.541
precision:     0.455
recall:        0.467
f1:            0.452
val loss: 4.499384880065918
accuracy:      0.562
precision:     0.517
recall:        0.489
f1:            0.489
best result:
0.547945205479452
0.4933698830409357
0.49828001375988995
0.49485703719113333
[[0.547945205479452, 0.4933698830409357, 0.49828001375988995, 0.49485703719113333]]


In [11]:
from google.colab import runtime
runtime.unassign()