<a href="https://colab.research.google.com/github/demelin/ai_reimplementations/blob/main/PET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install libraries

!pip install datasets

In [None]:
# Import libraries

import os
import json
import torch
import pickle
import random
import numpy as np

from tqdm import trange, tqdm
from torch.utils.data import (
    Dataset, DataLoader, RandomSampler, SequentialSampler)
from transformers import (
    AutoConfig, AutoTokenizer, AutoModelForPreTraining,
    AutoModelForSequenceClassification, get_linear_schedule_with_warmup)
from datasets import load_dataset
import matplotlib.pyplot as plt

In [None]:
# Load and sample data

class DataInitializer(object):

  """ Downloads the Yelp data from the HUggingFace repository and prepares
  it for model use. """

  def __init__(
      self,
      dataset_name
      ):

    self.dataset_name = dataset_name
    (self.train_text, self.train_labels,
     self.eval_text, self.eval_labels) = self._fetch_data()

  def _fetch_data(
      self
      ):

    """ Downloads datasets from HuggingFace. """

    raw_train_data = load_dataset(self.dataset_name, split="train")
    raw_test_data = load_dataset(self.dataset_name, split="test")

    return [raw_train_data["text"][:], raw_train_data["label"][:],
            raw_test_data["text"][:], raw_test_data["label"][:]]


In [None]:
# Define PVP for Yelp data

class YelpPVPVerbalizer(object):

  def __init__(
      self,
      tokenizer
      ):

    self.tokenizer = tokenizer

    # Declare PVPs to be used with the data
    # (last element denotes the segment and position containing the text token)
    self.pattern_dict = {
        0: [["It was ", "<|MASK|>", ". ", "<|TEXT|>"], [], (0, 3)],
        1: [["Just ", "<|MASK|>", "! "], ["<|TEXT|>"], (1, 0)],
        2: [["<|TEXT|>", "All in all, it was ", "<|MASK|>", "."], [], (0, 0)],
        3: [["<|TEXT|>"], ["In summary, the restaurant is ", "<|MASK|>", "."],
         (0, 0)]
        }

    # Pre-tokenize pattern segments and compute pattern lengths for truncation
    self.pattern_segment_dict = dict()
    self.pattern_text_length = dict()

    for key in self.pattern_dict.keys():
      self.pattern_segment_dict[key] = list()
      self.pattern_text_length[key] = 0

      for lst in self.pattern_dict[key][:-1]:
        if len(lst) == 0:
          self.pattern_segment_dict[key].append(lst)
        else:
          token_list = list()
          for seg in lst:
            if seg == "<|TEXT|>":
              token_list.append(seg)
            elif seg == "<|MASK|>":
              token_list.append([self.tokenizer.mask_token_id])
              self.pattern_text_length[key] += 1
            else:
              phrase_tokens = self.tokenizer.encode(
                  seg, add_special_tokens=False)
              token_list.append(phrase_tokens)
              self.pattern_text_length[key] += len(phrase_tokens)

          self.pattern_segment_dict[key].append(token_list)

    # Declare labels and their corresponding verbalizations
    # NOTE: "terrible" and "okay", the original verbalizations of the "1" and
    # "3" labels, are split in two tokens making them unsuitable as PET labels;
    # using "grim" and "ok" instead
    self.label_dict = {
        0: "grim",
        1: "bad",
        2: "ok",
        3: "good",
        4: "great"
        }

    # Sanity check
    for lbl in self.label_dict.values():
      lbl_id = self.tokenizer.convert_tokens_to_ids(lbl)
      assert lbl_id != self.tokenizer.unk_token_id, "{:s} does not map to a single token!".format(lbl)

    # Create a tensor of token_ids corresponding to label verbalization
    # for evaluating model predictions
    self.label_token_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(
        ["grim", "bad", "okay", "good", "great"]))

    # If a pattern has two segments, tokenizer will add special tokens
    self.num_special_tokens_added_single = self.tokenizer.num_special_tokens_to_add(pair=False)
    self.num_special_tokens_added_pair = self.tokenizer.num_special_tokens_to_add(pair=True)


  def verbalize_and_encode_sample(
      self,
      sample,
      pattern_id,
      max_seq_len
      ):

    """ Transforms a Yelp sample by filling the specified pattern. """

    # Select pattern
    pattern = self.pattern_dict[pattern_id][:-1]
    text_segment_id, text_position = self.pattern_dict[pattern_id][-1]
    # Tokenize and truncate sample
    token_ids = self.tokenizer.encode(sample["text"], add_special_tokens=False)
    # Adjust maximum length for special tokens
    if min([len(s) for s in pattern]) > 0:
      max_seq_len -= (self.num_special_tokens_added_pair +
                      self.pattern_text_length[pattern_id])
    else:
      max_seq_len -= (self.num_special_tokens_added_single +
                      self.pattern_text_length[pattern_id])

    # Truncate by substracting the length of special tokens and pattern parts
    token_ids = token_ids[:max_seq_len]
    # Add to tokenized pattern
    pattern_tokens = self.pattern_segment_dict[pattern_id]
    pattern_tokens[text_segment_id][text_position] = token_ids
    # Merge
    merged_pattern_tokens = [[], []]
    for lst_id, lst in enumerate(pattern_tokens):
      for part in pattern_tokens[lst_id]:
        merged_pattern_tokens[lst_id] += part

    # Encode verbalized sample
    encoded_filled_pattern_segments = self.tokenizer.build_inputs_with_special_tokens(
        merged_pattern_tokens)
    encoded_filled_pattern = []
    for seg in encoded_filled_pattern_segments:
      if type(seg) == list:
        encoded_filled_pattern += seg
        # Add separator token after each list if pattern has two segments
        if min([len(s) for s in pattern]) > 0:
          encoded_filled_pattern += [self.tokenizer.sep_token_id]
      else:
        encoded_filled_pattern += [seg]

    return {"token_ids": encoded_filled_pattern, "label": sample["label"]}


In [None]:
class NewDataset(Dataset):

    """ Dataset object for model training. """

    def __init__(
        self,
        input_ids,
        attention_masks,
        labels,
        labels_arrays
        ):

        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.labels = labels
        self.labels_arrays = labels_arrays

    def __len__(
        self
        ):

        return len(self.labels)

    def __getitem__(
        self,
        item_id
        ):

        input_ids = self.input_ids[item_id]
        attention_mask = self.attention_masks[item_id]
        label = self.labels[item_id]
        labels_array = self.labels_arrays[item_id] if self.labels_arrays[
            item_id] is not None else [-1]
        sample = {
            "input_ids": input_ids,
            "label": label,
            "attention_mask": attention_mask,
            "labels_array": labels_array
            }

        return sample


class DataProcessor(object):

  """ Prepares data for use with the MLM. """

  def __init__(
      self,
      config,
      data_initializer,
      tokenizer,
      verbalizer
      ):

    self.config = config
    self.data_initializer = data_initializer
    self.tokenizer = tokenizer
    self.verbalizer = verbalizer
    self.labels_list = list(self.verbalizer.label_dict.keys())
    self.num_labels = len(self.labels_list)

    # Prepare data
    self.train_data_per_label = self._transform_data(
        self.data_initializer.train_text, self.data_initializer.train_labels)
    self.eval_data_per_label = self._transform_data(
        self.data_initializer.eval_text, self.data_initializer.eval_labels)


  def _transform_data(
      self,
      inputs,
      labels
      ):

      """ Performs pre-processing on the fetched data. """

      def _preprocess_text(text):

        """ Helper applying minimal pre-processing to the sample text by
        removing newlines. """

        return text.replace('\\', ' ')

      # Convert training data into a {label: [samples]} dict
      data_per_label = dict()
      print("Transforming samples ...")

      for lbl_id, lbl in enumerate(labels):
        # Apply preprocessing
        preprocessed_text = _preprocess_text(inputs[lbl_id])
        sample_dict = {
            "sample_id": "{:d}_{:d}".format(lbl, lbl_id),
            "text": preprocessed_text,
            "label": lbl
            }
        if lbl not in data_per_label.keys():
          data_per_label[lbl] = [sample_dict]
        else:
          data_per_label[lbl].append(sample_dict)

      print("Done!")

      # Free up RAM
      del inputs
      del labels

      return data_per_label


  def _get_samples(
      self,
      data_per_label,
      num_samples_per_label,
      sample_ids_to_exclude=[]
      ):

      """ Samples the specified number of samples from the training data. """

      collected_samples = {l: list() for l in self.labels_list}

      for lbl in self.labels_list:
        # Shuffle data
        random.shuffle(data_per_label[lbl])

        # Sample samples, excluding datapoints based on ID (if specified) to
        # prevent overlap between labeled and unlabelled data
        if len(sample_ids_to_exclude) == 0:
          collected_samples[lbl] = data_per_label[lbl][:num_samples_per_label]
        else:
          collected_samples[lbl] = list()
          for i in range(num_samples_per_label):
            if data_per_label[lbl][i]["sample_id"] not in sample_ids_to_exclude:
              collected_samples[lbl].append(data_per_label[lbl][i])
            else:
              continue

      return collected_samples


  def sample_data(
      self,
      data_per_label,
      num_samples_per_label,
      is_eval=False
      ):

    """ Samples data to be used with models. """

    # Sample labeled data for model training
    labeled_samples = self._get_samples(data_per_label, num_samples_per_label)
    unlabeled_samples = list()

    if not is_eval:
      # Prevent overlap between the two sets of samples
      labeled_sample_ids = list()
      for lbl in labeled_samples.keys():
        for ps in labeled_samples[lbl]:
          labeled_sample_ids.append(ps["sample_id"])
      unlabeled_samples = self._get_samples(
          data_per_label, self.config.num_unlabeled_sampes_per_label,
          labeled_sample_ids)

    return labeled_samples, unlabeled_samples


  def _verbalize_and_encode_data(
      self,
      pattern_id,
      labeled_samples,
      unlabeled_samples
      ):

    """ Generates data to be used for model training and evaluation. """

    # Verbalize and encode all samples
    encoded_labeled_samples = list()
    encoded_unlabeled_samples = list()

    if len(unlabeled_samples) > 0:
      all_samples = [labeled_samples, unlabeled_samples]
    else:
      all_samples = [labeled_samples]

    for samples_id, samples in enumerate(all_samples):
      if samples_id == 0:
        print("Verbalizing PET samples ...")
        encoded_samples = encoded_labeled_samples
      else:
        print('Verbalizing unlabeled samples ...')
        encoded_samples = encoded_unlabeled_samples

      # Ensure sequential dataloaders always have the same data sequence
      samples_keys = sorted(list(samples.keys()))
      samples_seen = 0
      for label in samples.keys():
        for sample_id, sample in enumerate(samples[label]):
          es = self.verbalizer.verbalize_and_encode_sample(
              sample, pattern_id, self.config.max_seq_len)
          encoded_samples.append(es)
          samples_seen += 1

        if (samples_seen) % 1000 == 0:
          print("\tProcessed {:d} samples".format(samples_seen))

    return encoded_labeled_samples, encoded_unlabeled_samples


  def _encode_data(
      self,
      samples_per_label
      ):

    """ Encodes non-verbalized data. """

    encoded_labeled_samples = list()

    for pattern_id in self.verbalizer.pattern_dict.keys():
       for sample_id, sample in enumerate(samples_per_label[pattern_id]):
        # Tokenize and truncate sample
        token_ids = self.tokenizer.encode(
            sample["text"], add_special_tokens=False)
        token_ids = token_ids[:self.config.max_seq_len -
                              self.verbalizer.num_special_tokens_added_single]
        token_ids = self.tokenizer.build_inputs_with_special_tokens(token_ids)
        encoded_labeled_samples.append({
            "token_ids": token_ids,
            "label": sample["label"]
            })

    return encoded_labeled_samples


  def _add_masks_and_padding(
      self,
      samples,
      is_pet_sample=True
      ):

    """ Prepares data for use with the trained / evaluated model. """

    for smp in samples:
      seq_len = len(smp["token_ids"])
      # 1. Construct attention mask
      smp["attention_mask"] = [1] * seq_len

      # 2. Pad sequences to max_seq_len if necessary
      pad_size = self.config.max_seq_len - len(smp["token_ids"])
      if pad_size > 0:
          smp["token_ids"] += [self.tokenizer.pad_token_id] * pad_size
          smp["attention_mask"] += [0] * pad_size

      # 3. Create training labels
      if is_pet_sample:
        # Label for the MLM objective
        pet_labels = [0] * self.num_labels
        pet_labels[smp["label"]] = 1
        smp["labels_array"] = pet_labels
      else:
        smp["labels_array"] = []

    return samples


  def _convert_sample_to_model_inputs(
      self,
      samples
      ):

    """ Converts an encoded sample into a format usable by the LM. """

    # sample is a dictionary with keys:
    # ["token_ids", "attention_mask", "label"(is int), "pet_label_tensor"]
    # - unlabeled_data does not require labels
    feature_dict = {
        "input_ids": torch.tensor(
            [s["token_ids"] for s in samples], dtype=torch.long),
        "attention_masks": torch.tensor(
            [s["attention_mask"] for s in samples], dtype=torch.long),
        "labels": torch.tensor(
            [s["label"] for s in samples], dtype=torch.long),
        "labels_arrays": torch.tensor(
            [s["labels_array"] for s in samples], dtype=torch.long)}

    return NewDataset(feature_dict["input_ids"],
                      feature_dict["attention_masks"],
                      feature_dict["labels"],
                      feature_dict["labels_arrays"])


  def get_dataloader(
      self,
      pattern_id,
      labeled_samples,
      unlabeled_samples,
      train_batch_size,
      eval_batch_size,
      use_mlm_training=False,
      is_eval=False,
      get_sequential_unlabeled_data=False,
      no_verbalization=False
      ):

    """ Creates a dataloader for labeled PET samples. """

    # Random dataloader for training
    unlabeled_samples = [] if unlabeled_samples is None else unlabeled_samples
    if not is_eval:
      if not no_verbalization:
        labeled_samples, unlabeled_samples = self._verbalize_and_encode_data(
            pattern_id, labeled_samples, unlabeled_samples)
      else:
        # Required for training supervised models
        labeled_samples = self._encode_data(labeled_samples)

      (labeled_dataloader, unlabeled_dataset, unlabeled_dataloader,
       seq_unlabeled_dataloader) = None, None, None, None

      if len(labeled_samples) > 0:
        padded_labeled_samples = self._add_masks_and_padding(
            labeled_samples, is_pet_sample=(not no_verbalization))
        labeled_dataset = self._convert_sample_to_model_inputs(
            padded_labeled_samples)
        labeled_sampler = RandomSampler(labeled_dataset)
        labeled_dataloader = DataLoader(labeled_dataset,
                                        sampler=labeled_sampler,
                                        batch_size=train_batch_size)

      if use_mlm_training or get_sequential_unlabeled_data:
        assert unlabeled_samples is not None, "Unlabeled samples required for auxiliary MLM training!"
        padded_unlabeled_samples = self._add_masks_and_padding(
            unlabeled_samples, is_pet_sample=False)
        unlabeled_dataset = self._convert_sample_to_model_inputs(
            padded_unlabeled_samples)

        if use_mlm_training:
          unlabeled_sampler = RandomSampler(unlabeled_dataset)
          unlabeled_dataloader = DataLoader(unlabeled_dataset,
                                            sampler=unlabeled_sampler,
                                            batch_size=self.config.mlm_batch_size)
        if get_sequential_unlabeled_data:
          seq_unlabeled_sampler = SequentialSampler(unlabeled_dataset)
          seq_unlabeled_dataloader = DataLoader(unlabeled_dataset,
                                                sampler=seq_unlabeled_sampler,
                                                batch_size=eval_batch_size)

    else:
      # Sequential dataloader for evaluation
      if not no_verbalization:
        labeled_samples, _ = self._verbalize_and_encode_data(
            pattern_id, labeled_samples, [])
      else:
        labeled_samples = self._encode_data(labeled_samples)

      padded_labeled_samples = self._add_masks_and_padding(
          labeled_samples, is_pet_sample=(not no_verbalization))
      labeled_dataset = self._convert_sample_to_model_inputs(
          padded_labeled_samples)
      labeled_sampler = SequentialSampler(labeled_dataset)
      labeled_dataloader = DataLoader(labeled_dataset,
                                      sampler=labeled_sampler,
                                      batch_size=eval_batch_size)
      unlabeled_dataloader, seq_unlabeled_dataloader = None, None

    return labeled_dataloader, unlabeled_dataloader, seq_unlabeled_dataloader


In [None]:
# Define experiment configuration


class Config(object):

  """ Contains default settings for model training and evaluation. """

  def __init__(
      self,
      max_seq_len=256,
      train_batch_size=1,
      eval_batch_size=8,
      mlm_batch_size=3,
      gradient_accumulation_steps=4,
      num_train_steps=250,
      learning_rate=1e-5,
      weight_decay=0.01,
      adam_epsilon=1e-8,
      max_grad_norm=1.0,
      num_unlabeled_sampes_per_label=10000,
      cls_max_seq_len=256,
      cls_train_batch_size=64,
      cls_eval_batch_size=128,
      cls_gradient_accumulation_steps=4,
      cls_num_train_steps=1000,
      cls_learning_rate=1e-5,
      cls_weight_decay=0.01,
      cls_adam_epsilon=1e-8,
      cls_max_grad_norm=1.0,
      pet_cls_train_batch_size=32,
      pet_cls_eval_batch_size=64,
      pet_models_per_pattern=3,
      pet_alpha=10**-4,
      pet_temperature=2,
      ipet_num_generations=3,
      ipet_models_fraction=0.25,
      ipet_train_scale_factor=5,
      ipet_num_first_gen_samples=10,
      no_cuda=False,
      random_seed=42,
      logging_steps=50,
      checkpoint="FacebookAI/roberta-large",
      dataset="yelp_review_full"
      ):

    # MLM
    self.max_seq_len = max_seq_len
    self.train_batch_size = train_batch_size
    self.eval_batch_size = eval_batch_size
    self.mlm_batch_size = mlm_batch_size
    self.gradient_accumulation_steps = gradient_accumulation_steps
    self.num_train_steps = num_train_steps
    self.learning_rate = learning_rate
    self.weight_decay = weight_decay
    self.adam_epsilon = adam_epsilon
    self.max_grad_norm = max_grad_norm

    # Sequence classifier
    self.cls_max_seq_len = cls_max_seq_len
    self.cls_train_batch_size = cls_train_batch_size
    self.cls_eval_batch_size = cls_eval_batch_size
    self.cls_gradient_accumulation_steps = cls_gradient_accumulation_steps
    self.cls_num_train_steps = cls_num_train_steps
    self.cls_learning_rate = cls_learning_rate
    self.cls_weight_decay = cls_weight_decay
    self.cls_adam_epsilon = cls_adam_epsilon
    self.cls_max_grad_norm = cls_max_grad_norm

    # PET
    self.pet_cls_train_batch_size = pet_cls_train_batch_size
    self.pet_cls_eval_batch_size = pet_cls_eval_batch_size
    self.pet_models_per_pattern = pet_models_per_pattern
    self.pet_alpha = pet_alpha
    self.pet_temperature = pet_temperature

    # iPET
    self.ipet_num_generations = ipet_num_generations
    self.ipet_models_fraction = ipet_models_fraction
    self.ipet_train_scale_factor = ipet_train_scale_factor
    self.ipet_num_first_gen_samples = ipet_num_first_gen_samples

    # Other
    self.num_unlabeled_sampes_per_label = num_unlabeled_sampes_per_label
    self.checkpoint = checkpoint
    self.dataset = dataset
    self.random_seed = random_seed
    self.no_cuda = no_cuda
    self.logging_steps = logging_steps


In [None]:
# Shared utility function
def compute_accuracy(
    preds,
    labels
    ):

  """ Helper for computing evaluation accuracy """

  return (torch.argmax(preds, -1) == labels).float().mean().item()


In [None]:
# Define model trainers and evaluators

class SingleModelTrainer(object):

  """ Trains and evaluates a single PET model. """

  def __init__(
      self,
      config,
      tokenizer,
      verbalizer,
      pattern_id,
      model_id,
      model_path=None,
      train_pet_classifier=False,
      train_sup_classifier=False
      ):

    self.config = config

    # Instantiate tokenizer and verbalizer
    self.tokenizer = tokenizer
    self.verbalizer = verbalizer

    self.train_pet_classifier = train_pet_classifier
    self.train_sup_classifier = train_sup_classifier

    # Initialize model to fine-tune
    if not (self.train_pet_classifier or self.train_sup_classifier):
      if model_path is None:
        self.model = AutoModelForPreTraining.from_pretrained(
            self.config.checkpoint)
      else:
        self.model = AutoModelForPreTraining.from_pretrained(model_path)
    else:
      self.model = AutoModelForSequenceClassification.from_pretrained(
          self.config.checkpoint,
          num_labels=len(self.verbalizer.label_dict.keys()))

    # Check if GPU can be used
    self.device = "cuda" if torch.cuda.is_available() and not self.config.no_cuda else "cpu"
    if self.device == "cuda":
      print("CUDA is available, using GPU for model training and evaluation.")
    else:
      print("CUDA is unavailable, using CPU for model training and evaluation.")
    self.model.to(self.device)


    # Decide on hyperparameters to use based on type of trained model
    if self.train_pet_classifier or self.train_sup_classifier:
      self.max_seq_len = self.config.cls_max_seq_len
      self.gradient_accumulation_steps = self.config.cls_gradient_accumulation_steps
      self.num_train_steps = self.config.cls_num_train_steps
      self.learning_rate = self.config.cls_learning_rate
      self.weight_decay = self.config.cls_weight_decay
      self.adam_epsilon = self.config.cls_adam_epsilon
      self.max_grad_norm = self.config.cls_max_grad_norm
      if self.train_pet_classifier:
        self.train_batch_size = self.config.pet_cls_train_batch_size
        self.eval_batch_size = self.config.pet_cls_eval_batch_size
      else:
        self.train_batch_size = self.config.cls_train_batch_size
        self.eval_batch_size = self.config.cls_eval_batch_size

    else:
      self.max_seq_len = self.config.max_seq_len
      self.gradient_accumulation_steps = self.config.gradient_accumulation_steps
      self.num_train_steps = self.config.num_train_steps
      self.learning_rate = self.config.learning_rate
      self.weight_decay = self.config.weight_decay
      self.adam_epsilon = self.config.adam_epsilon
      self.max_grad_norm = self.config.max_grad_norm
      self.train_batch_size = self.config.train_batch_size
      self.eval_batch_size = self.config.eval_batch_size

    # Define model name
    if self.train_pet_classifier:
      self.model_name = "PET_CLASSIFIER"
    elif self.train_sup_classifier:
      self.model_name = "SUPERVISED_CLASSIFIER"
    else:
      if pattern_id is not None and model_id is not None:
        self.model_name = "PET_LM_pattern_{:d}_model_{:d}".format(
            pattern_id, model_id)
      else:
        self.model_name = "PET_LM"


  def train(
      self,
      labeled_dataloader,
      unlabeled_dataloader,
      use_mlm_training=False
      ):

    """ Trains the specified model. """

    # Compute number of training epochs from total training steps
    num_train_epochs = int(np.round(self.num_train_steps // max(
        1, (len(labeled_dataloader)) // self.gradient_accumulation_steps)))

    # !! NOTE: This code below is *partially* copied from the paper's codebase,
    # since it's boilerplate model training code !!

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in self.model.named_parameters() if not any(
            nd in n for nd in no_decay)],
          'weight_decay': self.config.weight_decay},
        {'params': [p for n, p in self.model.named_parameters() if any(
            nd in n for nd in no_decay)],
          'weight_decay': 0.0}]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters,
                                  lr=self.learning_rate,
                                  eps=self.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=self.num_train_steps)

    # Initialize traning variables
    step, global_step = 0, 0
    tr_loss, last_reported_loss = 0.0, 0.0
    mean_training_loss = 0.0
    # Reset model gradient
    self.model.zero_grad()

    # Enable iteration over unlabeled data for auxiliary training
    # Unlabeled dataloader has to be None for classifier training
    if self.train_pet_classifier or self.train_sup_classifier:
      assert unlabeled_dataloader is None, "Unlabeled data is not compatible with classifier training!"
    if unlabeled_dataloader is None:
      assert use_mlm_training is False, "Can't execute auxiliary MLM training if no unlabeled data is provided!"
    else:
      unlabeled_iter = unlabeled_dataloader.__iter__()

    print("Starting training model: {:s}".format(self.model_name))

    # Report hyper-parameters used during training
    print('Hyperparameters:')
    print("max_seq_len: {}".format(self.config.cls_max_seq_len))
    print("gradient_accumulation_steps: {}".format(
        self.config.cls_gradient_accumulation_steps))
    print("num_train_steps: {}".format(self.config.cls_num_train_steps))
    print("learning_rate: {}".format(self.config.cls_learning_rate))
    print("weight_decay: {}".format(self.config.cls_weight_decay))
    print("adam_epsilon: {}".format(self.config.cls_adam_epsilon))
    print("max_grad_norm: {}".format(self.config.cls_max_grad_norm))
    print("train_batch_size: {}".format(self.config.pet_cls_train_batch_size))
    print("eval_batch_size: {}".format(self.config.pet_cls_eval_batch_size))

    train_iterator = trange(num_train_epochs, desc="Epoch")
    for epoch_id, _ in enumerate(train_iterator):
      epoch_iterator = tqdm(labeled_dataloader, desc="Iteration")
      for batch_id, batch in enumerate(epoch_iterator):

        # Sanity check of batch contents
        # if epoch_id == 0 and batch_id == 0:
        #   print("\nCONTENTS OF FIRST TRAINING BATCH:")
        #   for k, v in batch.items():
        #     print("\t {}: {}".format(k, v))
        #     if type(v) == torch.Tensor:
        #       print("Size {}: {}".format(k, v.size()))

        # Training mode
        self.model.train()
        unlabeled_batch = None
        # Prepare labeled inpts
        labeled_batch_inputs = {
            "input_ids": batch["input_ids"],
            "attention_mask": batch["attention_mask"]
            }
        if self.train_sup_classifier:
          labeled_batch_inputs["labels"] = batch["label"]
        labeled_batch_inputs = {k: t.to(self.device) for k, t in batch.items()}
        outputs = self.model(**labeled_batch_inputs)
        model_logits = outputs[0]

        # Compute model loss
        if not (self.train_pet_classifier or self.train_sup_classifier):
          # Compute PET training loss
          cls_logits = self._get_logits_at_pet_labels(
              batch["input_ids"], model_logits)
          pet_labels = batch["labels_array"].float()
          loss = torch.nn.CrossEntropyLoss()(
              cls_logits.view(-1), pet_labels.view(-1))
        else:
          if self.train_pet_classifier:
            # Compute PET classifier loss
            model_probabilities = torch.nn.functional.log_softmax(
                model_logits / self.config.pet_temperature, dim=-1)
            loss = torch.nn.functional.kl_div(
                model_probabilities, batch["labels"], reduction="sum") * (
                    self.config.pet_temperature ** 2) / model_probabilities.shape[0]
          else:
            # Obtain supervised classifier loss
            loss = outputs.loss

        if use_mlm_training:
          # Prepare a batch of unlabeled data
          while unlabeled_batch is None:
            try:
                unlabeled_batch = unlabeled_iter.__next__()
            except StopIteration:
                unlabeled_iter = unlabeled_dataloader.__iter__()
            mlm_input_ids = unlabeled_batch["input_ids"]
            (unlabeled_batch["input_ids"],
             unlabeled_batch["mlm_labels"]) = self._prepare_mlm_data(
                 mlm_input_ids)

          # Compute auxiliary loss
          unlabeled_batch_inputs = {
              "input_ids": unlabeled_batch["input_ids"],
              "attention_mask": batch["attention_mask"]
              }
          unlabeled_batch_inputs["labels"] = unlabeled_batch["mlm_labels"]
          unlabeled_batch_inputs = {
                k: t.to(self.device) for k, t in unlabeled_batch_inputs.items()}

          lm_loss = self.model(**unlabeled_batch_inputs)[0]
          # Combine losses
          loss = self.config.pet_alpha * loss + (
              1 - self.config.pet_alpha) * lm_loss

        if self.gradient_accumulation_steps > 1:
          loss = loss / self.gradient_accumulation_steps

        # Track training loss
        tr_loss += loss.item()

        # Get gradients
        loss.backward()
        step += 1

        # Update model after N gradient accumulation steps
        if (step + 1) % self.gradient_accumulation_steps == 0:
          # Backpropagate, reglarize, optimize
          torch.nn.utils.clip_grad_norm_(
              self.model.parameters(), self.max_grad_norm)
          optimizer.step()
          scheduler.step()
          # Reset gradient
          self.model.zero_grad()
          global_step += 1

          if (self.config.logging_steps > 0 and
              global_step % self.config.logging_steps == 0):
            curr_mean_loss = (
                tr_loss - last_reported_loss) / self.config.logging_steps
            curr_learning_rate = scheduler.get_lr()[0]
            past_reported_loss = tr_loss
            print("\tGlobal step: {:d} | LR: {:.4f} | Avg. loss: {:.3f}".format(
                global_step, curr_learning_rate, self.config.logging_steps,
                curr_mean_loss))

        # Stop iteration after maximum number of training steps
        if self.num_train_steps < global_step:
          epoch_iterator.close()
          break

      # Stop iteration after maximum number of training steps
      if 0 < self.config.num_train_steps < global_step:
        train_iterator.close()
        mean_training_loss = tr_loss / global_step
        print("== Finished training model: {:s} ==".format(self.model_name))
        print("\tMean training loss: {:.3f}".format(mean_training_loss))
        break

    # Report mean training loss for weighted logit combination
    return global_step, mean_training_loss


  def eval(
      self,
      eval_dataloader,
      is_unlabeled,
      return_logits,
      pet_samples=False
      ):

    """ Evaluates the specified model / annotates data with logits. """

    # !! NOTE: This code below is *partially* copied from the paper's codebase,
    # since it's boilerplate model evaluation code !!

    # Collect for evaluation
    all_logits = list()
    all_labels = list()

    eval_step = 0
    eval_acc = 0.

    # Iterate over evaluation data
    for batch_id, batch in enumerate(
        tqdm(eval_dataloader, desc="Evaluating / Annotating")):

      # # TODO: DEBUGGING
      # if batch_id == 20:
      #   break

      # Sanity check of batch contents
      # if batch_id == 0:
      #   print("\nCONTENTS OF FIRST EVALUATION BATCH:")
      #   for k, v in batch.items():
      #     print("\t {}: {}".format(k, v))
      #     if type(v) == torch.Tensor:
      #       print("Size {}: {}".format(k, v.size()))

      # Evaluation mode
      self.model.eval()

      # Iterate over evaluation data
      with torch.no_grad():
        # Prepare inputs
        batch_inputs = {
            "input_ids": batch["input_ids"],
            "attention_mask": batch["attention_mask"]
            }
        #if self.train_pet_classifier or self.train_sup_classifier:
        labels = batch["label"].to(self.device)

        # else:
        #   labels = batch["labels_array"]

        batch_inputs = {k: t.to(self.device) for k, t in batch_inputs.items()}
        # Get outputs
        outputs = self.model(**batch_inputs)
        # Get model logits
        logits = outputs[0]
        if pet_samples:
          logits = self._get_logits_at_pet_labels(
              batch_inputs["input_ids"], logits)

      # Compute model accuracy (unless used for annotation of unlabeled data)
      if not is_unlabeled:
        eval_acc += compute_accuracy(logits, labels)

      if return_logits:
        # Collect logits for the entire evaluation set
        all_logits.append(logits.detach().cpu().numpy())
        all_labels.append(labels.detach().cpu().numpy())
      eval_step += 1

    if return_logits:
      # Concatenate logits / labels
      all_logits = np.concatenate(all_logits, axis=0)
      all_labels = np.concatenate(all_labels, axis=0)

    # Report model accuracy on the evaluation set
    print("== Finished evaluating model: {:s} ==".format(self.model_name))
    mean_eval_acc = 0.
    if not is_unlabeled:
      mean_eval_acc = eval_acc / eval_step
      print("\tMean evaluation acc: {:.3f}".format(mean_eval_acc))

    return {
        'logits': all_logits,
        'labels': all_labels,
        'eval_acc': mean_eval_acc
        }


  def _get_logits_at_pet_labels(
      self,
      input_ids,
      model_logits
      ):

    """ Gathers model logits corresponding to verbalizations of task labels. """

    # Assumes input of shape [batch_size, time_steps, vocab_size]
    # Gather model predictions at the masked positions -> [batch x vocab_size]
    logits_at_mask = model_logits[input_ids == self.tokenizer.mask_token_id]
    logits_at_labels_list = list()
    for batch_item_logits in logits_at_mask:
      logits_at_labels_list.append(
          batch_item_logits[self.verbalizer.label_token_ids])

    # Stack gathered logits to be same shape as label -> [batch x num_labels]
    return torch.stack(logits_at_labels_list, dim=0)


  def _prepare_mlm_data(
      self,
      input_ids,
      ):

    """ Transforms unlabeled data for use with the MLM training objective;
    Transformations are done per sample, not batch. """

    # Create a label tensor for the MLM objective
    mlm_labels = torch.ones_like(input_ids) * -100

    # Select positions to be transformed
    input_shape = torch.tensor(input_ids.shape)

    # 15% of tokens per line are modified
    num_tokens_per_line = input_shape[1]
    num_samples = int(np.round(num_tokens_per_line * 0.15))
    for b in range(input_shape[0]):
      # Exclude masked positions from transformations
      valid_positions = list()
      for t in range(num_tokens_per_line):
          if input_ids[b, t] != self.tokenizer.mask_token_id:
            valid_positions.append(t)

      # Sample positions to be transformed
      sampled_token_positions = random.sample(valid_positions, num_samples)

      # Modify sampled token
      for t in sampled_token_positions:
        # Designate as a target within the label
        mlm_labels[b, t] = input_ids[b, t]

        # Decide on transformation
        flip = random.uniform(0., 1.)
        if flip < 0.8:
          # Replace by a MASK token in 80% of cases
          input_ids[b, t] = self.tokenizer.mask_token_id
        else:
          if flip < 0.9:
            # Replace with a random token in 10% of cases
            input_ids[b, t] = random.randint(0, len(self.tokenizer) - 1)
          else:
            # Leave unchanged in 10% of cases
            continue

    return input_ids, mlm_labels

In [None]:
class MultiModelTrainer(object):

  """ Trains and evaluates a PET ensemble (+ optional classifier), and unsupervised baselines. """

  def __init__(
      self,
      config,
      data_initializer,
      data_processor,
      tokenizer,
      verbalizer,
      num_training_samples,
      eval_dataloaders,
      use_mlm_training,
      weighting_strategy="mean"
      ):

    self.root_dir = "/content/"
    self.config = config

    self.data_initializer = data_initializer
    self.data_processor = data_processor
    self.tokenizer = tokenizer
    self.verbalizer = verbalizer

    self.use_mlm_training = use_mlm_training
    self.num_training_samples = num_training_samples

    # Prepare training data for all PET patterns
    self.training_data_per_label = self.data_processor.train_data_per_label

    # Sample PET and unlabeled data to be used across ALL models in the ensemble
    self.num_training_samples = num_training_samples
    self.num_samples_per_label = self.num_training_samples // self.data_processor.num_labels
    self.eval_dataloaders = eval_dataloaders
    # Sample intial data
    self.pet_train_samples, self.unlabeled_train_samples = self.data_processor.sample_data(
        self.training_data_per_label, self.num_samples_per_label, is_eval=False)

    # Keep track of individual models
    self.model_records = dict()


  @staticmethod
  def _save_logits(
      logits_dict,
      out_path
      ):

    """ Saves logits and labels of the soft-labeled data to disc. """

    with open(out_path, 'wb') as pf:
      pickle.dump(logits_dict, pf, protocol=pickle.HIGHEST_PROTOCOL)
    print("Saved logits to {:s}".format(out_path))


  @staticmethod
  def _load_logits(
      logits_path
      ):

    """ Loads logits and labels of soft-labeled data from disc. """

    print("Loading logits from {:s}".format(logits_path))
    with open(logits_path, 'rb') as pf:
      return pickle.load(pf)


  @staticmethod
  def _save_model(
      model,
      out_path
      ):

    """ Saves a model checkpoint to disk. """

    # Save model
    model = model.module if hasattr(model, 'module') else model
    model.save_pretrained(out_path)


  def _merge_logits(
      self,
      weighting_strategy,
      model_logits=None,
      for_classifier=False
      ):

    """ Merges logits for classifier training or PET ensamble evaluation. """

    merged_logits = None

    if not for_classifier:
        model_logits = self.model_records if model_logits is None else model_logits
        # Merge logits for the PET ensemble
        if weighting_strategy == "weighted":
          logits = [
              model_logits[k]["logits"] *
              model_logits[k]["model_init_train_acc"] for k in model_logits.keys()]
        else:
          logits = [model_logits[k]["logits"] for k in model_logits.keys()]
        merged_logits = torch.tensor(np.mean(logits, axis=0))
        labels = model_logits[list(model_logits.keys())[0]]["eval_labels"]

    else:
      print("Merging unlabelled logits ...")
      # Logits are stored on the drive; load and keep a their running sum
      numerator = 0

      for pattern_id in self.verbalizer.pattern_dict.keys():
        for model_id in range(self.config.pet_models_per_pattern):
          # Specify logits path
          pwd = os.path.join(self.root_dir, 'pet_model_{}_pattern_{}/'.format(
              model_id, pattern_id))
          logits_path = os.path.join(pwd, "unlabeled_logits.pickle")
          try:
            unlabeled_logits_and_labels = self._load_logits(logits_path)
          except FileNotFoundError:
            continue
          model_weight = 1.0
          if weighting_strategy != "uniform":
            model_weight = unlabeled_logits_and_labels["model_init_train_acc"]
          if merged_logits is None:
            merged_logits = unlabeled_logits_and_labels["logits"] * model_weight
          else:
            merged_logits += (unlabeled_logits_and_labels["logits"] *
                              model_weight)
          numerator += 1
      labels = unlabeled_logits_and_labels["labels"]

      # Compute mean
      merged_logits /= numerator
      merged_logits = torch.tensor(merged_logits)

    # Apply softmax
    merged_softmax = torch.nn.functional.log_softmax(
        merged_logits / self.config.pet_temperature, dim=-1)

    return merged_softmax, labels


  def train_pet_ensemble(
      self,
      weighting_strategy="uniform",
      get_unlabeled_logits=False
      ):

    """ Trains a PET model ensemble. """

    # Train N models for each pattern and annotate unlabeled data (for PET)
    out = dict()
    pattern_eval_weights = dict()
    dataloaders = dict()

    for p_id, pattern_id in enumerate(self.verbalizer.pattern_dict.keys()):
      out[pattern_id] = dict()
      pattern_eval_weights[pattern_id] = None
      dataloaders[pattern_id] = None

      for model_id in range(self.config.pet_models_per_pattern):
        # Define model key
        model_key = "pet_pattern_{:d}_model_{:d}/".format(pattern_id, model_id)
        # Make directory to save everything for this ensemble to
        pwd = os.path.join(self.root_dir, "num_samples_{:d}_{:s}/".format(
            self.num_training_samples, model_key))
        if not os.path.isdir(pwd):
          os.mkdir(pwd)

        # Initialize model
        model_trainer = SingleModelTrainer(
            self.config, self.tokenizer, self.verbalizer, pattern_id, model_id)

        # Generate training data for each PET pattern
        if dataloaders[pattern_id] is None:
          (train_pet_dataloader, train_unlabeled_dataloader,
          sequential_unlabeled_dataloader) = self.data_processor.get_dataloader(
              pattern_id, self.pet_train_samples, self.unlabeled_train_samples,
              self.config.train_batch_size, self.config.eval_batch_size,
              use_mlm_training=self.use_mlm_training, is_eval=False,
              get_sequential_unlabeled_data=True, no_verbalization=False)
          dataloaders[pattern_id] = [
              train_pet_dataloader, train_unlabeled_dataloader,
              sequential_unlabeled_dataloader]
        else:
          (train_pet_dataloader, train_unlabeled_dataloader,
           sequential_unlabeled_dataloader) = dataloaders[pattern_id]

        # Obtain "pre-training" accuracy to be used for model weighting
        if pattern_eval_weights[pattern_id] is None:
          if weighting_strategy == "weighted":
            logits_and_labels = model_trainer.eval(
                train_pet_dataloader, is_unlabeled=False,
                return_logits=False, pet_samples=True)
            pattern_eval_weights[pattern_id] = logits_and_labels["eval_acc"]
        else:
          pattern_eval_weights[pattern_id] = 1.

        # Train a single model
        if self.num_training_samples > 0:
          num_raining_steps, train_loss = model_trainer.train(
              train_pet_dataloader, train_unlabeled_dataloader,
              self.use_mlm_training)

        # Evaluate a single model (on the eval set)
        logits_and_labels = model_trainer.eval(
            self.eval_dataloaders[pattern_id], is_unlabeled=False,
            return_logits=True, pet_samples=True)
        # Update model records (used in model ensembling)
        self.model_records[model_key] = dict()
        self.model_records[model_key][
            "model_init_train_acc"] = pattern_eval_weights[pattern_id]
        self.model_records[model_key][
            "logits"] = logits_and_labels["logits"]
        self.model_records[model_key][
            "eval_labels"] = logits_and_labels["labels"]

        out[pattern_id][model_id] = logits_and_labels["eval_acc"]

        if get_unlabeled_logits:
          # Obtain logits for unlabeled data
          print("Labeling unlabeled data!")
          unlabeled_logits_and_labels = model_trainer.eval(
              sequential_unlabeled_dataloader, is_unlabeled=True,
              return_logits=True, pet_samples=True)
          unlabeled_logits_and_labels[
              "model_init_train_acc"] = pattern_eval_weights[pattern_id]
          # Save logits to disc
          print("Saving labeled unlabeled data!")
          out_dir = os.path.join(
              self.root_dir,
              "pet_model_{}_pattern_{}/".format(model_id, pattern_id))
          if not os.path.isdir(out_dir):
            os.mkdir(out_dir)
          out_path = os.path.join(out_dir, 'unlabeled_logits.pickle')
          self._save_logits(unlabeled_logits_and_labels, out_path)

    print('--- Finished training the PET ensemble! --- ')

    return out


  def _get_next_ipet_data(
      self,
      merged_logits,
      num_pet_samples_per_label,
      already_sampled,
      zero_shot=False
      ):

    """ Generates iPET data for training the next model generation. """

    # Draw samples with highest annotation probabilities to be used in
    # the training of the next iPET generation

    new_samples = dict()
    new_unlabeled_samples = dict()
    # Map sample ids to samples for easier retrieval
    candidate_id = 0
    sample_labels = sorted(list(self.unlabeled_train_samples.keys()))

    for label in sample_labels:
      candidate_map = dict()

      # # TODO: DEBUGGING
      # if candidate_id == merged_logits.size()[0]:
      #   break

      for sample in self.unlabeled_train_samples[label]:
        if candidate_id not in already_sampled:
          candidate_map[candidate_id] = sample
        candidate_id += 1

        # # TODO: DEBUGGING
        # if candidate_id == merged_logits.size()[0]:
        #   break

      all_candidate_ids = list(candidate_map.keys())

      # Get probability distributions and sample
      label_candidates, probabilities = list(), list()
      new_samples[label] = list()
      new_unlabeled_samples[label] = list()
      # if zero-shot, find 100 samples with the highest label probability,
      # even if the current label is not the top label
      if zero_shot:
        label_candidates_and_probs = [(
            c_id, merged_logits[c_id][label].item()) for c_id in
                                      all_candidate_ids]
        sorted_label_candidates_and_probs = sorted(
            label_candidates_and_probs, reverse=True, key=lambda x: x[1])
        # Pick the top 100
        for c_id, prob in sorted_label_candidates_and_probs[:100]:
            label_candidates.append(c_id)
            probabilities.append(prob)

      else:
        # Select samples with highest probability assigned to the current label
        for c_id in all_candidate_ids:
          if torch.argmax(merged_logits[c_id]) == label:
            label_candidates.append(c_id)
            probabilities.append(merged_logits[c_id][label].item())

      # Normalize label probabilities for weighted sampling
      normalized_probabilities = np.exp(probabilities) / np.sum(
          np.exp(probabilities), axis=0).tolist()
      sampled_candidate_ids = np.random.choice(
          label_candidates, size=num_pet_samples_per_label,
          replace=False, p=normalized_probabilities).tolist()
      already_sampled += sampled_candidate_ids

      for sc_id in candidate_map.keys():
        if sc_id in sampled_candidate_ids:
          new_samples[label].append({"text": candidate_map[sc_id]["text"],
                                     "label": label})
        else:
          new_unlabeled_samples[label].append(
              {"text": candidate_map[sc_id]["text"],
               "label": label})

    return new_samples, new_unlabeled_samples, already_sampled


  def train_ipet_ensemble(
      self,
      weighting_strategy="uniform",
      get_unlabeled_logits=False
      ):

    """ Trains a iPET model ensemble. """

    # Avoid sampling the same datapoint for more than one label
    already_sampled = list()
    # Keep track of samples per label in the growing iPET training data
    num_pet_samples_per_label = self.num_samples_per_label
    # Compute number of models to be sampled per iPET generation
    num_sampled_models = int(np.round(len(self.verbalizer.pattern_dict.keys()) *
                                      self.config.pet_models_per_pattern *
                                      self.config.ipet_models_fraction))
    # Keep track of pattern accuracies on the training set
    pattern_eval_weights = dict()

    init_ipet_gen_for_loop = 0

    # Specify initial iPET data
    if self.num_training_samples == 0:
      # -- ZERO-SHOT iPET --
      # Sample annotator models
      pattern_candidates = list(
          self.verbalizer.pattern_dict.keys()) * self.config.pet_models_per_pattern
      sampled_pattern_ids = random.sample(pattern_candidates, num_sampled_models)
      pattern_ids = list(set(sampled_pattern_ids))
      model_ids = [sampled_pattern_ids.count(pi) for pi in pattern_ids]
      dataloaders = dict()

      # Initialize zero-shot models (no training)
      logits_from_sampled_models = dict()
      for pattern_id in pattern_ids:
        if pattern_id not in dataloaders.keys():
          dataloaders[pattern_id] = None
        for model_id in model_ids:
          model_trainer = SingleModelTrainer(
              self.config, self.tokenizer, self.verbalizer,
              pattern_id=pattern_id, model_id=model_id)
          model_key = "ipet_gen_{:d}_pattern_{:d}_model_{:d}/".format(0, pattern_id, model_id)
          logits_from_sampled_models[model_key] = dict()

          # Get unlabelled dataloader
          if dataloaders[pattern_id] is None:
            _, _, sequential_unlabeled_dataloader = self.data_processor.get_dataloader(
                pattern_id, self.pet_train_samples, self.unlabeled_train_samples,
                self.config.train_batch_size, self.config.eval_batch_size,
                use_mlm_training=self.use_mlm_training, is_eval=False,
                get_sequential_unlabeled_data=True, no_verbalization=False)
            dataloaders[pattern_id] = sequential_unlabeled_dataloader
          else:
              sequential_unlabeled_dataloader = dataloaders[pattern_id]

          # Annotate unlabeled data with each selected model
          print("Labeling unlabeled data with zero-shot model {:s}!".format(model_key))
          unlabeled_logits_and_labels = model_trainer.eval(
              sequential_unlabeled_dataloader,
              is_unlabeled=True, return_logits=True, pet_samples=True)
          logits_from_sampled_models[model_key][
              "logits"] = unlabeled_logits_and_labels["logits"]
          logits_from_sampled_models[model_key][
              "model_init_train_acc"] = 1.  # since this is zero-shot
          logits_from_sampled_models[model_key][
              "eval_labels"] = unlabeled_logits_and_labels["labels"]

      # Merge logits
      merged_logits, _ = self._merge_logits(
          weighting_strategy, model_logits=logits_from_sampled_models,
          for_classifier=False)
      # Sample 10 samples across all labels
      ipet_train_samples, ipet_unlabeled_samples, already_sampled = self._get_next_ipet_data(
          merged_logits, int(np.round(10 / self.data_processor.num_labels)),
          already_sampled, zero_shot=True)
      # Update number of pet samples in current training data
      num_pet_samples_per_label = self.config.ipet_train_scale_factor * num_pet_samples_per_label
      # Increment iPET generations counter, since the first one did not involve training
      init_ipet_gen_for_loop = 1
    else:
      ipet_train_samples = self.pet_train_samples  # start from the PET train set
      ipet_unlabeled_samples = self.unlabeled_train_samples

    # -- MULTI-SHOT iPET --
    # Iterate over the specified number of iPET generations
    for gen_id in range(init_ipet_gen_for_loop, self.config.ipet_num_generations):
      trained_model_paths = list()
      model_pre_training_accuracies = dict()
      dataloaders = dict()  # reset each generation due to changes in data

      if gen_id < self.config.ipet_num_generations - 1:
        # Only train annotator models to improve efficiency
        # Sample annotator models
        pattern_candidates = list(self.verbalizer.pattern_dict.keys()) * self.config.pet_models_per_pattern
        sampled_pattern_ids = random.sample(
            pattern_candidates, num_sampled_models)
        pattern_ids = list(set(sampled_pattern_ids))
        model_ids = [sampled_pattern_ids.count(pi) for pi in pattern_ids]
      else:
        pattern_ids = self.verbalizer.pattern_dict.keys()
        model_ids = range(self.config.pet_models_per_pattern)

      # Train N models for each pattern per iPET generation
      for pattern_id in pattern_ids:
        if pattern_id not in pattern_eval_weights.keys():
          pattern_eval_weights[pattern_id] = None
        if pattern_id not in dataloaders.keys():
          dataloaders[pattern_id] = None

        for model_id in model_ids:
          # Create directories for model checkpoints and logits
          model_key = "ipet_gen_{:d}_pattern_{:d}_model_{:d}/".format(
              gen_id, pattern_id, model_id)
          model_save_path = os.path.join(
              self.root_dir, "num_samples_{:d}_{:s}/".format(
                  self.num_training_samples, model_key))
          if not os.path.isdir(model_save_path):
            os.mkdir(model_save_path)
          trained_model_paths.append((model_key, pattern_id, model_save_path))

          # Initialize model
          model_trainer = SingleModelTrainer(
              self.config, self.tokenizer, self.verbalizer, pattern_id, model_id)

          if dataloaders[pattern_id] is None:
            (train_pet_dataloader, train_unlabeled_dataloader,
             sequential_unlabeled_dataloader) = self.data_processor.get_dataloader(
                 pattern_id, ipet_train_samples, ipet_unlabeled_samples,
                 self.config.train_batch_size, self.config.eval_batch_size,
                 use_mlm_training=self.use_mlm_training, is_eval=False,
                 get_sequential_unlabeled_data=True, no_verbalization=False)
            dataloaders[pattern_id] = [
                train_pet_dataloader, train_unlabeled_dataloader,
                sequential_unlabeled_dataloader]
          else:
             (train_pet_dataloader, train_unlabeled_dataloader,
              sequential_unlabeled_dataloader) = dataloaders[pattern_id]

          eval_dataloader = self.eval_dataloaders[pattern_id]

          # Obtain "pre-training" accuracy to be used for model weighting
          if pattern_eval_weights[pattern_id] is None:
            if weighting_strategy == "weighted":
              logits_and_labels = model_trainer.eval(
                  train_pet_dataloader, is_unlabeled=False,
                  return_logits=False, pet_samples=True)
              pattern_eval_weights[pattern_id] = logits_and_labels["eval_acc"]
            else:
              pattern_eval_weights[pattern_id] = 1.
          model_pre_training_accuracies[model_key] = pattern_eval_weights[pattern_id]

          # Train a single model
          num_raining_steps, train_loss = model_trainer.train(
              train_pet_dataloader, train_unlabeled_dataloader,
              self.use_mlm_training)
          # Save trained model
          self._save_model(model_trainer.model, model_save_path)

          # In final iteration, obtain logits for unlabeled data from all models
          if gen_id == self.config.ipet_num_generations - 1:
            if get_unlabeled_logits:
              print("Labeling unlabeled data!")
              unlabeled_logits_and_labels = model_trainer.eval(
                  sequential_unlabeled_dataloader, is_unlabeled=True,
                  return_logits=True, pet_samples=True)
              unlabeled_logits_and_labels[
                  "model_init_train_acc"] = pattern_eval_weights[pattern_id]
              # Save logits to disc
              print("Saving labeled unlabeled data!")
              out_dir = os.path.join(
                  self.root_dir, 'pet_model_{}_pattern_{}/'.format(
                      model_id, pattern_id))
              if not os.path.isdir(out_dir):
                os.mkdir(out_dir)
              out_path = os.path.join(out_dir, 'unlabeled_logits.pickle')
              self._save_logits(unlabeled_logits_and_labels, out_path)

            # Perform evaluation of all individual models
            logits_and_labels = model_trainer.eval(
                eval_dataloader, is_unlabeled=False,
                return_logits=True, pet_samples=True)
            # Update model records
            self.model_records[
                "pattern_{:d}_model_{:d}".format(pattern_id, model_id)] = dict()
            self.model_records[
                "pattern_{:d}_model_{:d}".format(pattern_id, model_id)][
                    "model_init_train_acc"] = pattern_eval_weights[pattern_id]
            self.model_records[
                "pattern_{:d}_model_{:d}".format(pattern_id, model_id)][
                    "eval_acc"] = logits_and_labels["eval_acc"]
            self.model_records[
                "pattern_{:d}_model_{:d}".format(pattern_id, model_id)][
                    "logits"] = logits_and_labels["logits"]
            self.model_records[
                "pattern_{:d}_model_{:d}".format(pattern_id, model_id)][
                    "eval_labels"] = logits_and_labels["labels"]

            print('--- Finished training the iPET ensemble! --- ')

      if gen_id < self.config.ipet_num_generations - 1:

        # Expand the iPET training set
        logits_from_sampled_models = dict()
        for model_key, pattern_id, model_path in trained_model_paths:
          logits_from_sampled_models[model_key] = dict()
          model_trainer = SingleModelTrainer(
              self.config, self.tokenizer, self.verbalizer, None, None,
              model_path=model_path)

          # Annotate unlabeled data with each selected model
          print("Labeling data with model checkpoint {:s}!".format(model_path))
          unlabeled_logits_and_labels = model_trainer.eval(
              sequential_unlabeled_dataloader, is_unlabeled=True,
              return_logits=True, pet_samples=True)
          logits_from_sampled_models[model_key][
              "logits"] = unlabeled_logits_and_labels["logits"]
          logits_from_sampled_models[model_key][
              "model_init_train_acc"] = model_pre_training_accuracies[model_key]
          logits_from_sampled_models[model_key][
              "eval_labels"] = unlabeled_logits_and_labels["labels"]

        # Merge logits
        merged_logits, _ = self._merge_logits(
            weighting_strategy, model_logits=logits_from_sampled_models,
            for_classifier=False)
        # Generate new iPET training data
        # Compute number of samples to draw
        num_new_pet_samples_per_label = (
            self.config.ipet_train_scale_factor - 1) * num_pet_samples_per_label
        new_samples, ipet_unlabeled_samples, already_sampled = self._get_next_ipet_data(
          merged_logits, num_new_pet_samples_per_label, already_sampled,
          zero_shot=True)

        # Add new data to previous data
        for lk in new_samples.keys():
          if lk not in ipet_train_samples.keys():
            ipet_train_samples[lk] = new_samples[lk]
          else:
            ipet_train_samples[lk] += new_samples[lk]
        # Update number of iPET samples
        num_pet_samples_per_label = self.config.ipet_train_scale_factor * num_pet_samples_per_label


  def eval_pet_ensemble(
      self,
      weighting_strategy
      ):

    """ Evaluates the PET ensemble / classifier on the evaluation data. """

    # Evaluate PET model ensemble
    eval_acc = 0
    # Combine logits
    merged_logits, labels = self._merge_logits(
        weighting_strategy, for_classifier=False)
    # Compute accuracy
    for l_id, lgt in enumerate(merged_logits):
      eval_acc += compute_accuracy(
          lgt.clone().detach(), torch.tensor(labels[l_id]))
    mean_eval_acc = eval_acc / len(merged_logits)
    print("\nPET eval complete; mean eval accuracy = {:.3f}".format(
        mean_eval_acc))

    return mean_eval_acc


  def train_and_eval_pet_classifier(
      self,
      weighting_strategy
      ):

    """ Trains a PET classifier. """

    # Initialize classifier
    model_trainer = SingleModelTrainer(
        self.config, self.tokenizer, self.verbalizer, None, None,
        train_pet_classifier=True, train_sup_classifier=False)

    # Generate classifier data
    # Load and merge PET model logits
    training_labels, _ = self._merge_logits(
        weighting_strategy, for_classifier=True)
    # Replace hard labels with soft labels
    cls_train_samples = dict()
    samples_keys = sorted(list(self.unlabeled_train_samples.keys()))
    label_id = 0

    for sk in samples_keys:
      cls_train_samples[sk] = list()
      for smp in self.unlabeled_train_samples[sk]:
        new_smp = dict()
        for smp_k in smp.keys():
          if smp_k != "label":
            new_smp[smp_k] = smp[smp_k]
          else:
            new_smp[smp_k] = training_labels[label_id].tolist()
        cls_train_samples[sk].append(new_smp)

        # # TODO: DEBUGGING
        # if label_id + 1 == training_labels.shape[0]:
        #   continue
        # else:
        #   label_id += 1

    # Get (non-verbalized) dataloaders
    train_cls_dataloader, _, _ = self.data_processor.get_dataloader(
        None, cls_train_samples, None, self.config.pet_cls_train_batch_size,
        self.config.pet_cls_eval_batch_size, use_mlm_training=False,
        is_eval=False, get_sequential_unlabeled_data=False,
        no_verbalization=True)

    # Train classifier
    num_raining_steps, train_loss = model_trainer.train(
        train_cls_dataloader, None, use_mlm_training=False)
    # Evaluate classifier
    logits_and_labels = model_trainer.eval(
        self.eval_dataloaders["cls"], is_unlabeled=False, return_logits=False,
        pet_samples=False)

    return logits_and_labels["eval_acc"]


  def train_and_eval_sup_classifier(
      self
      ):

    """ Trains a supervised classifier. """

    # Initialize classifier
    model_trainer = SingleModelTrainer(
        self.config, self.tokenizer, self.verbalizer, None, None,
        train_pet_classifier=False, train_sup_classifier=True)

    if self.num_training_samples > 0:
      # Train classifier
      train_sup_dataloader, _, _ = self.data_processor.get_dataloader(
          None, self.pet_train_samples, None, self.config.cls_train_batch_size,
          self.config.cls_eval_batch_size, use_mlm_training=False, is_eval=False,
          get_sequential_unlabeled_data=False, no_verbalization=True)
      _, _ = model_trainer.train(train_sup_dataloader, None, use_mlm_training=False)

    # Evaluate classifier
    logits_and_labels = model_trainer.eval(
        self.eval_dataloaders["cls"], is_unlabeled=False,
        return_logits=False, pet_samples=False)

    # Return classifier accuracy
    return logits_and_labels["eval_acc"]

In [None]:
# Define experiments

# TODO: Add reporting of final results

class Experiment(object):
  """ Runs the planned experiments. """

  def __init__(
      self
      ):

    self.config = Config()
    # Set radom seed for reproducibility
    self._set_seed()
    # Downloads and preprocesses training and evaluation data
    print("Initializing data getter ...")
    self.data_initializer = DataInitializer(self.config.dataset)
    # Initialize tokenizer
    print("Initializing tokenizer ...")
    self.tokenizer = AutoTokenizer.from_pretrained(self.config.checkpoint)
    # Initialize verbalizer
    print("Initializing verbalizer ...")
    self.verbalizer = YelpPVPVerbalizer(self.tokenizer)

    # Prepare evaluation data, as it is shared across experiments
    self.data_processor = DataProcessor(
        self.config, self.data_initializer, self.tokenizer, self.verbalizer)
    self.eval_samples = self.data_processor.eval_data_per_label
    # Instantiate evaluation dataloaders
    self.eval_dataloaders = {"cls": None}
    cls_eval_dataloader, _, _ = self.data_processor.get_dataloader(
          None, self.eval_samples, None, self.config.cls_train_batch_size,
          self.config.cls_eval_batch_size, use_mlm_training=False,
          is_eval=True, get_sequential_unlabeled_data=False,
          no_verbalization=True)
    self.eval_dataloaders["cls"] = cls_eval_dataloader


  def _set_seed(
      self
      ):

    """ Sets random seed for the used (pseudo-)randomizers """

    random_seed = random.randint(0,100)
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(random_seed)


  def evaluate_supervised_models(
      self,
      num_training_samples
      ):

    """ Train and evaluate supervised baselines to reproduce paper results. """

    results = list()
    out_dict = dict()
    print("\n=== Evaluating baseline supervised classifier for {:d} training examples: ===".format(
        num_training_samples))
    for run_id in range(3):  # 3 training runs
      print("! Training run {:d} !".format(run_id))

      # Reset seed
      self._set_seed()

      multi_trainer = MultiModelTrainer(self.config,
                                        self.data_initializer,
                                        self.data_processor,
                                        self.tokenizer,
                                        self.verbalizer,
                                        num_training_samples,
                                        self.eval_dataloaders,
                                        use_mlm_training=False)

      trainer_out = multi_trainer.train_and_eval_sup_classifier()
      results.append(trainer_out)
      out_dict["run_{:d}".format(run_id)] = trainer_out

    print("-" * 20)
    print("Training samples: {:d} | MAX: {:.1f} | MEAN: {:.1f} (std {:1f})".format(
        num_training_samples, np.max(results), np.mean(results),
        np.std(results)))
    print("Done!")

    return out_dict


  def evaluate_pet(
      self,
      num_pet_samples,
      use_mlm,
      use_classifier
      ):

    """ Evaluates PET models to reproduce paper results. """

    # Prepare evaluation samples, if neccessary (one per PET pattern)
    if len(self.eval_dataloaders.keys()) == 1:
      for pattern_id in self.verbalizer.pattern_dict.keys():
        self.eval_dataloaders[pattern_id], _, _ = self.data_processor.get_dataloader(
          pattern_id, self.eval_samples, None, self.config.train_batch_size,
          self.config.eval_batch_size, use_mlm_training=False, is_eval=True,
          get_sequential_unlabeled_data=False, no_verbalization=False)

    results = list()
    out_dict = dict()

    print("\n=== Evaluating PET with {:d} PET samples: ===".format(num_pet_samples))
    for run_id in range(3):

      # Reset seed
      self._set_seed()

      multi_trainer = MultiModelTrainer(self.config,
                                        self.data_initializer,
                                        self.data_processor,
                                        self.tokenizer,
                                        self.verbalizer,
                                        num_pet_samples,
                                        self.eval_dataloaders,
                                        use_mlm_training=use_mlm)

      if use_classifier:
        _ = multi_trainer.train_pet_ensemble(
            weighting_strategy="weighted", get_unlabeled_logits=True)
        trainer_out = multi_trainer.train_and_eval_pet_classifier(
            weighting_strategy="weighted")
      else:
        _ = multi_trainer.train_pet_ensemble(
            weighting_strategy="weighted", get_unlabeled_logits=False)
        trainer_out = multi_trainer.eval_pet_ensemble(
            weighting_strategy="weighted")

      out_dict["run_{:d}".format(run_id)] = trainer_out
      results.append(trainer_out)

    print("-" * 20)
    print("Training samples: {:d} | Use MLM: {} | MEAN: {:.3f} (std {:3f})".format(
        num_pet_samples, use_mlm, np.mean(results), np.std(results)))
    print("Done!")


  def ablate_pet(
      self
      ):

    """ Performs PET ablation studies, required to reproduce Table 4. """

    # Prepare evaluation samples, if neccessary (one per PET pattern)
    if len(self.eval_dataloaders.keys()) == 1:
      for pattern_id in self.verbalizer.pattern_dict.keys():
        self.eval_dataloaders[pattern_id], _, _ = self.data_processor.get_dataloader(
          pattern_id, self.eval_samples, None, self.config.train_batch_size,
          self.config.eval_batch_size, use_mlm_training=False, is_eval=True,
          get_sequential_unlabeled_data=False, no_verbalization=False)

    results_p4 = dict()
    print("=== Reproducing Table 4 of the paper: ===")

    # Reset seed
    self._set_seed()

    multi_trainer = MultiModelTrainer(self.config,
                                      self.data_initializer,
                                      self.data_processor,
                                      self.tokenizer,
                                      self.verbalizer,
                                      10,
                                      self.eval_dataloaders,
                                      use_mlm_training=True)

    trainer_out = multi_trainer.train_pet_ensemble(
        weighting_strategy="weighted", get_unlabeled_logits=True)

    # Identify best- and worst-performing models
    min_model = [100, None, None]
    max_model = [0, None, None]

    for pattern_id in trainer_out.keys():
      for model_id in trainer_out[pattern_id].keys():
        if trainer_out[pattern_id][model_id] > max_model[0]:
          max_model[0] = trainer_out[pattern_id][model_id]
          max_model[1] = pattern_id
          max_model[2] = model_id
        if trainer_out[pattern_id][model_id] < min_model[0]:
          min_model[0] = trainer_out[pattern_id][model_id]
          min_model[1] = pattern_id
          min_model[2] = model_id

    results_p4["max_model"] = "{:.3f}, pattern {}, model {}".format(
        max_model[0], max_model[1], max_model[2])
    results_p4["min_model"] = "{:.3f}, pattern {}, model {}".format(
        min_model[0], min_model[1], min_model[2])
    results_p4["no_distillation"] = multi_trainer.eval_pet_ensemble(
        weighting_strategy="weighted")
    results_p4["pet_uniform"] = multi_trainer.train_and_eval_pet_classifier(
        weighting_strategy="uniform")
    results_p4["pet_weighted"] = multi_trainer.train_and_eval_pet_classifier(
        weighting_strategy="weighted")

    for key in ["max_model", "min_model", "no_distillation",
                "pet_uniform", "pet_weighted"]:
      print("\t{} : {}".format(key, results_p4[key]))

    return results_p4


  def evaluate_ipet(
      self,
      num_pet_samples,
      use_classifier
      ):

    """ Evaluates iPET models to reproduce paper results. """

    # Prepare evaluation samples, if neccessary (one per PET pattern)
    if len(self.eval_dataloaders.keys()) == 1:
      for pattern_id in self.verbalizer.pattern_dict.keys():
        self.eval_dataloaders[pattern_id], _, _ = self.data_processor.get_dataloader(
          pattern_id, self.eval_samples, None, self.config.train_batch_size,
          self.config.eval_batch_size, use_mlm_training=False, is_eval=True,
          get_sequential_unlabeled_data=False, no_verbalization=False)

    results = list()
    out_dict = dict()

    print("\n=== Evaluating iPET with {:d} PET samples: ===".format(
        num_pet_samples))
    for run_id in range(3):

      # Reset seed
      self._set_seed()

      multi_trainer = MultiModelTrainer(self.config,
                                        self.data_initializer,
                                        self.data_processor,
                                        self.tokenizer,
                                        self.verbalizer,
                                        num_pet_samples,
                                        self.eval_dataloaders,
                                        use_mlm_training=True)

      if use_classifier:
        _ = multi_trainer.train_ipet_ensemble(
            weighting_strategy="weighted", get_unlabeled_logits=True)
        trainer_out = multi_trainer.train_and_eval_pet_classifier(
            weighting_strategy="weighted")
      else:
        _ = multi_trainer.train_ipet_ensemble(
            weighting_strategy="weighted", get_unlabeled_logits=False)
        trainer_out = multi_trainer.eval_pet_ensemble(
            weighting_strategy="weighted")

      out_dict["run_{:d}".format(run_id)] = trainer_out
      results.append(trainer_out)

    print("-" * 20)
    print("Training samples: {:d} | MEAN: {:.3f} (std {:3f})".format(
          num_pet_samples, np.mean(results), np.std(results)))
    print("Done!")



In [None]:
# Run experiments
exp = Experiment()

In [None]:
# Supervised models
runs = dict()
training_set_sizes = [0, 10, 50, 100, 1000]
for num_training_samples in training_set_sizes:
  runs[num_training_samples] = exp.evaluate_supervised_models(
      num_training_samples)

print("=" * 100)
for tss in training_set_sizes:
  print("{} : {}".format(tss, runs[tss]))

In [None]:
# PET
runs = dict()
training_set_sizes = [10, 50, 100, 1000]
for num_pet_samples in training_set_sizes:
  runs[num_pet_samples] = dict()
  for use_mlm in [True, False]:
    runs[num_pet_samples][str(use_mlm)] = exp.evaluate_pet(
        num_pet_samples, use_mlm, use_classifier=True)

print("=" * 100)
for tss in training_set_sizes:
  print("{} : {}".format(tss, runs[tss]))

# Plot reproduction of Figure 3
x = np.array(training_set_sizes)
y = np.array(
    [runs[tss]["True"] - runs[tss]["False"] for tss in training_set_sizes])
plt.plot(x, y, color='blue', marker='o', label="Yelp")
plt.xlabel("Training Set Size")
plt.ylabel("Accuracy Improvements")
plt.legend()
plt.show()

In [None]:
exp.ablate_pet()

In [None]:
# iPET
runs = dict()
training_set_sizes = [0, 10, 50, 100]
for num_pet_samples in training_set_sizes:
  runs[num_pet_samples] = exp.evaluate_ipet(
      num_pet_samples, use_classifier=True)

print("=" * 100)
for tss in training_set_sizes:
  print("{} : {}".format(tss, runs[tss]))