In [None]:
# © Copyright IBM Corporation 2022.
#
# LICENSE: Apache License 2.0 (Apache-2.0)
# http://www.apache.org/licenses/LICENSE-2.0

# Loading Libraries

In [None]:
import logging
import os
import random
import shutil
import uuid
import io
import html
import re
import urllib.request
import tarfile
import zipfile

from filelock import FileLock

from argparse import ArgumentParser
from collections import defaultdict, Counter
from enum import Enum
from typing import Dict, List

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')


import numpy as np
import pandas as pd
import torch


from sklearn.metrics import classification_report
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split


from datasets import Dataset
from tqdm.auto import tqdm
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase,
                          InputFeatures, Trainer, TrainingArguments, RobertaConfig, BartConfig, DebertaConfig, pipeline)
from transformers.pipelines.pt_utils import KeyDataset
from dataclasses import dataclass, field
from typing import Dict, List

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)


def get_root_dir():
    return os.path.abspath(os.path.join(__file__, os.pardir))

# Data Structure and Class Names

In [None]:
# Data Structure & Class Names
@dataclass
class Predictions:
    predicted_labels: List[str]
    ranked_classes: List[List[str]]
    class_name_to_score: List[Dict[str, float]]


@dataclass
class SelfTrainingSet:
    texts: List[str] = field(default_factory=list)
    class_names: List[str] = field(default_factory=list)
    entailment_labels: List[str] = field(default_factory=list)
DATASET_TO_CLASS_NAME_MAPPING = \
    {'20_newsgroup':
         {'0': 'atheism', '1': 'computer graphics', '2': 'microsoft windows', '3': 'pc hardware',
          '4': 'mac hardware', '5': 'windows x', '6': 'for sale', '7': 'cars', '8': 'motorcycles', '9': 'baseball',
          '10': 'hockey', '11': 'cryptography', '12': 'electronics', '13': 'medicine', '14': 'space',
          '15': 'christianity', '16': 'guns', '17': 'middle east', '18': 'politics', '19': 'religion'},

     'ag_news':
         {'Business': 'business', 'Sci/Tech': 'science and technology', 'Sports': 'sports', 'World': 'world'},

     'dbpedia':
         {'Album': 'album', 'Animal': 'animal', 'Artist': 'artist', 'Athlete': 'athlete', 'Building': 'building',
          'Company': 'company', 'EducationalInstitution': 'educational institution', 'Film': 'film',
          'MeanOfTransportation': 'mean of transportation', 'NaturalPlace': 'natural place',
          'OfficeHolder': 'office holder', 'Plant': 'plant', 'Village': 'village', 'WrittenWork': 'written work'},

     'imdb':
         {'pos': 'good', 'neg': 'bad'}
     }


# Rank Candidate Predictions

In [None]:
def rank_candidate_indices_per_class(all_class_names, predictions: Predictions) -> Dict[str, List[int]]:
    diff_scores_to_second_best = \
        [class_name_to_score[ranked_classes[0]] - class_name_to_score[ranked_classes[1]]
         for class_name_to_score, ranked_classes in zip(predictions.class_name_to_score, predictions.ranked_classes)]

    class_name_to_sorted_candidate_idxs = {}
    for class_name in all_class_names:
        sorted_candidate_idxs = \
            [idx for idx, (predicted_class, diff_to_second_best)
             in sorted(enumerate(zip(predictions.predicted_labels, diff_scores_to_second_best)),
                       key=lambda x: x[1][1], reverse=True)
             if predicted_class == class_name]
        class_name_to_sorted_candidate_idxs[class_name] = sorted_candidate_idxs

    return class_name_to_sorted_candidate_idxs

# Negative Examples

In [None]:
class NegativeSamplingStrategy(Enum):
    TAKE_ALL = 0
    TAKE_RANDOM = 1
    TAKE_SECOND = 2
    TAKE_LAST = 3

In [None]:
def get_negative_examples(predictions: Predictions, class_name_to_chosen_pos_idxs: Dict[str, List[int]],
                          negative_sampling_strategy: NegativeSamplingStrategy) -> Dict[str, List[int]]:

    all_positive_idxs = [idx for class_idxs in class_name_to_chosen_pos_idxs.values() for idx in class_idxs]

    class_name_to_chosen_negative_idxs = defaultdict(list)
    for idx in all_positive_idxs:
        example_ranked_classes = predictions.ranked_classes[idx]

        if negative_sampling_strategy == NegativeSamplingStrategy.TAKE_SECOND:
            class_name_to_chosen_negative_idxs[example_ranked_classes[1]].append(idx)
        elif negative_sampling_strategy == NegativeSamplingStrategy.TAKE_ALL:
            for class_name in example_ranked_classes[1:]:
                class_name_to_chosen_negative_idxs[class_name].append(idx)
        elif negative_sampling_strategy == NegativeSamplingStrategy.TAKE_LAST:
            class_name_to_chosen_negative_idxs[example_ranked_classes[-1]].append(idx)
        elif negative_sampling_strategy == NegativeSamplingStrategy.TAKE_RANDOM:
            random_negative_class = random.choice(example_ranked_classes[1:])
            class_name_to_chosen_negative_idxs[random_negative_class].append(idx)
        else:
            raise ValueError(f"Unknown negative sampling strategy {negative_sampling_strategy}")

    return class_name_to_chosen_negative_idxs

# Evaluate

In [None]:
def safe_to_csv(df, path, save_index_col=False):
    buffer = io.StringIO()
    df.to_csv(buffer, index=save_index_col)
    buffer.seek(0)
    output = buffer.getvalue()
    buffer.close()
    with open(path, "w") as text_file:
        text_file.write(output)
        text_file.flush()

In [None]:
def save_results(res_file_name, agg_file_name, result_dict, setting_name_col='setting_name'):
    def mean_str(col: pd.Series):
        if pd.api.types.is_numeric_dtype(col):
            return col.mean()
        else:
            return col.dropna().unique()[0] if col.nunique() == 1 else np.NaN

    new_res_df = pd.DataFrame([result_dict])

    lock_path = os.path.abspath(os.path.join(res_file_name, os.pardir, 'result_csv_files.lock'))
    with FileLock(lock_path):
        if os.path.isfile(res_file_name):
            orig_df = pd.read_csv(res_file_name)
            df = pd.concat([orig_df, new_res_df])
            df_agg = df.groupby(by=setting_name_col).agg(mean_str).sort_values(by=['dataset_name', setting_name_col])
            safe_to_csv(df_agg, agg_file_name, save_index_col=True)
        else:
            df = new_res_df
        safe_to_csv(df, res_file_name)

In [None]:
def evaluate_classification_performance(predicted_labels, gold_labels, out_dir, info_dict):

    accuracy = np.mean([gold_label == prediction for gold_label, prediction in zip(gold_labels, predicted_labels)])
    evaluation_dict = {'accuracy': accuracy, 'evaluation_size': len(gold_labels)}

    report = classification_report(gold_labels, predicted_labels, output_dict=True)
    report.pop('accuracy')
    for category, metrics in report.items():
        if not category.endswith("avg"):
            category = f"'{category}'"

        evaluation_dict[f"{category} precision"] = metrics['precision']
        evaluation_dict[f"{category} recall"] = metrics['recall']
        evaluation_dict[f"{category} f1"] = metrics['f1-score']

    evaluation_dict = {**info_dict, **evaluation_dict}
    logging.info(evaluation_dict)

    all_copies_file = os.path.join(out_dir, "all_copies.csv")
    agg_file = os.path.join(out_dir, "aggregated.csv")
    save_results(all_copies_file, agg_file, evaluation_dict)

# Entailment models

In [None]:
def get_zero_shot_predictions(model_name, texts_to_infer, label_names, batch_size, max_length):
    device = 0 if torch.cuda.is_available() else -1

    # We initialize the tokenizer here in order to set the maximum sequence length
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, model_max_length=max_length)

    classifier = pipeline("zero-shot-classification", model=model_name, tokenizer=tokenizer, device=device)

    ds = Dataset.from_dict({'text': texts_to_infer})

    preds_list = []
    for text, output in tqdm(zip(texts_to_infer, classifier(KeyDataset(ds, 'text'),
                                                            batch_size=batch_size,
                                                            candidate_labels=label_names, multi_label=True)),
                             total=len(ds), desc="zero-shot inference"):
        preds_list.append(output)

    predictions = Predictions(predicted_labels=[x['labels'][0] for x in preds_list],
                              ranked_classes=[x['labels'] for x in preds_list],
                              class_name_to_score=[dict(zip(x['labels'], x['scores'])) for x in preds_list])
    return predictions

In [None]:
def finetune_entailment_model(model_name, self_training_set: SelfTrainingSet, seed, learning_rate=2e-5, batch_size=32,
                              max_length=512, num_epochs=1, hypothesis_template="This example is {}."):
    model_id = f"{str(uuid.uuid4())}_fine_tuned_{model_name.replace(os.sep, '_')}"
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, model_max_length=max_length)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)

    out_dir = os.path.join(get_root_dir(), "output", "models", str(model_id))
    os.makedirs(out_dir, exist_ok=True)

    inputs = preprocess_and_tokenize(model.config, tokenizer, self_training_set, seed=seed,
                                     hypothesis_template=hypothesis_template)

    training_args = TrainingArguments(output_dir=out_dir,
                                      overwrite_output_dir=True,
                                      num_train_epochs=num_epochs,
                                      per_device_train_batch_size=batch_size,
                                      learning_rate=learning_rate)

    trainer = Trainer(model=model, args=training_args, train_dataset=inputs)
    trainer.train()

    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)
    return out_dir

In [None]:
def preprocess_and_tokenize(model_config: PretrainedConfig, tokenizer: PreTrainedTokenizerBase,
                            self_training_set: SelfTrainingSet, seed: int, hypothesis_template: str):

    if type(model_config) not in [RobertaConfig, DebertaConfig, BartConfig]:
        raise NotImplementedError(f"{model_config.architectures} model is not supported")

    def get_numeric_label(label: str, model_config: PretrainedConfig):
        """
        Different entailment classification models on Huggingface use different names and IDs for the textual entailment
        labels of entailment/neutral/contradiction. Here we convert names to the appropriate model label IDs.
        """
        if label in model_config.label2id:
            return model_config.label2id[label]
        elif label.lower() in model_config.label2id:
            return model_config.label2id[label.lower()]
        else:
            raise Exception(f'The label "{label}" is not recognized by the model, '
                            f'model labels are: {model_config.label2id.keys()}')

    numeric_entailment_labels = [get_numeric_label(label, model_config)
                                 for label in self_training_set.entailment_labels]

    tokenized = []
    for text, class_name, label \
            in zip(self_training_set.texts, self_training_set.class_names, numeric_entailment_labels):

        hypothesis = hypothesis_template.format(class_name)
        inputs = (tokenizer.encode_plus([text, hypothesis], add_special_tokens=True, padding='max_length',
                                        truncation='only_first'))

        if type(model_config) == DebertaConfig:
            tokenized.append(InputFeatures(input_ids=inputs['input_ids'],
                                           attention_mask=inputs['attention_mask'],
                                           token_type_ids=inputs['token_type_ids'],
                                           label=label))
        elif type(model_config) in [RobertaConfig, BartConfig]:
            tokenized.append(InputFeatures(input_ids=inputs['input_ids'],
                                           attention_mask=inputs['attention_mask'],
                                           label=label))

    random.Random(seed).shuffle(tokenized)
    return tokenized

# Download Datasets

In [None]:
OUT_DIR = './datasets'
RAW_DIR = os.path.join(OUT_DIR, 'raw')

In [None]:
def get_label_name(dataset_name, csv_label_name):
    if dataset_name in DATASET_TO_CLASS_NAME_MAPPING:
        return DATASET_TO_CLASS_NAME_MAPPING[dataset_name][str(csv_label_name)]
    else:
        return csv_label_name.lower()

In [None]:
def load_20_newsgroup():
    def clean_text(x):
        x = re.sub('#\S+;', '&\g<0>', x)
        x = re.sub('(\w+)\\\(\w+)', '\g<1> \g<2>', x)
        x = x.replace('quot;', '&quot;')
        x = x.replace('amp;', '&amp;')
        x = x.replace('\$', '$')
        x = x.replace("\r\n", " ").replace("\n", " ")
        x = x.strip()
        while x.endswith("\\"):
            x = x[:-1]
        return html.unescape(x)

    dataset_name = "20_newsgroup"
    dataset_out_dir = os.path.join(OUT_DIR, dataset_name)
    os.makedirs(dataset_out_dir, exist_ok=True)

    newsgroups_train = fetch_20newsgroups(subset='train')
    train_df = pd.DataFrame({"text": newsgroups_train["data"], "label": newsgroups_train["target"]})
    train_df["text"] = train_df["text"].apply(lambda x: clean_text(x))
    train_df["label"] = train_df["label"].apply(lambda x: get_label_name(dataset_name, x))
    train_df.to_csv(os.path.join(dataset_out_dir, "unlabeled.csv"), index=False)
    logging.info(f"20_newsgroup unlabeled file created with {len(train_df)} samples")
    newsgroups_test = fetch_20newsgroups(subset='test')
    test_df = pd.DataFrame({"text": newsgroups_test["data"], "label": newsgroups_test["target"]})
    test_df["text"] = test_df["text"].apply(lambda x: clean_text(x))
    test_df["label"] = test_df["label"].apply(lambda x: get_label_name(dataset_name, x))
    test_df.to_csv(os.path.join(dataset_out_dir, "test.csv"), index=False)

    with open(os.path.join(dataset_out_dir, 'class_names.txt'), 'w') as f:
        f.writelines([class_name+'\n' for class_name in sorted(test_df["label"].unique())])

    logging.info(f"20_newsgroup test file created with {len(test_df)} samples")

In [None]:
def load_ag_news_dbpedia_yahoo():
    def clean_text(x):
        x = re.sub('#\S+;', '&\g<0>', x)
        x = re.sub('(\w+)\\\(\w+)', '\g<1> \g<2>', x)
        x = x.replace('quot;', '&quot;')
        x = x.replace('amp;', '&amp;')
        x = x.replace('\$', '$')
        x = ' '.join(x.split())
        while x.endswith("\\"):
            x = x[:-1]
        return html.unescape(x)

    dataset_to_columns = {'ag_news': ["label", "title", "text"],
                          'dbpedia': ["label", "title", "text"],
                          'yahoo_answers': ['label', 'question_title', 'question_content', 'answer']}

    for dataset, column_names in dataset_to_columns.items():
        logging.info(f'processing {dataset} csv files')
        raw_path = os.path.join(RAW_DIR, dataset, f'{dataset}_csv')
        with open(os.path.join(raw_path, 'classes.txt'), 'r') as f:
            idx_to_class_name = dict(enumerate([get_label_name(dataset, row.strip())
                                                for row in f.readlines()]))

        dataset_out_dir = os.path.join(OUT_DIR, dataset)
        os.makedirs(dataset_out_dir, exist_ok=True)

        for dataset_part in ["train", "test"]:
            part_file = os.path.join(raw_path, f'{dataset_part}.csv')
            part_df = pd.read_csv(part_file, header=None)
            part_df.columns = column_names

            if dataset == 'yahoo_answers':
                part_df = part_df[~part_df['answer'].isna()]
                part_df['text'] = part_df.apply(lambda x:
                                                f"{x['question_title']} {x['question_content']} {x['answer']}", axis=1)
            elif dataset == 'ag_news':
                part_df['text'] = part_df.apply(lambda x: f"{x['title']}. {x['text']}", axis=1)

            part_df = part_df[~part_df['text'].isna()]
            part_df['text'] = part_df['text'].apply(lambda x: clean_text(x))
            part_df['label'] = part_df['label'].apply(lambda x: idx_to_class_name[x - 1])
            if dataset_part == 'test':
                part_df.to_csv(os.path.join(dataset_out_dir, f'test.csv'), index=False)

                with open(os.path.join(dataset_out_dir, 'class_names.txt'), 'w') as f:
                    f.writelines([class_name + '\n' for class_name in sorted(part_df["label"].unique())])
            else:
                part_df.to_csv(os.path.join(dataset_out_dir, 'unlabeled.csv'), index=False)



In [None]:
def load_isear():
    dataset_name = 'isear'
    dataset_out_dir = os.path.join(OUT_DIR, dataset_name)
    os.makedirs(dataset_out_dir, exist_ok=True)

    logging.info(f'processing {dataset_name} csv files')
    df = pd.read_csv(os.path.join(RAW_DIR, 'isear', 'isear.csv'), sep='|', quotechar='"', on_bad_lines='warn')
    df = df[['SIT', 'Field1']]
    df.columns = ['text', 'label']
    df['text'] = df['text'].apply(lambda x: x.replace('á ', ''))
    df["label"] = df["label"].apply(lambda x: get_label_name(dataset_name, x))

    unlabeled_df, test_df = train_test_split(df, test_size=0.2)
    unlabeled_df.to_csv(os.path.join(dataset_out_dir, 'unlabeled.csv'), index=False)
    test_df.to_csv(os.path.join(dataset_out_dir, 'test.csv'), index=False)

    with open(os.path.join(dataset_out_dir, 'class_names.txt'), 'w') as f:
        f.writelines([class_name+'\n' for class_name in sorted(test_df["label"].unique())])

In [None]:
def load_imdb():
    dataset_name = 'imdb'
    dataset_out_dir = os.path.join(OUT_DIR, dataset_name)
    os.makedirs(dataset_out_dir, exist_ok=True)

    logging.info(f'processing {dataset_name} csv files')
    raw_dir = os.path.join(RAW_DIR, 'imdb', 'aclImdb')
    train = []
    for label in ['pos', 'neg', 'unsup']:
        for file in os.listdir(os.path.join(raw_dir, 'train', label)):
            train.append({'text': open(os.path.join(raw_dir, 'train', label, file)).read().replace('<br />', ' '),
                          'label': get_label_name(dataset_name, label) if label != 'unsup' else ''})
    test = []
    for label in ['pos', 'neg']:
        for file in os.listdir(os.path.join(raw_dir, 'test', label)):
            test.append({'text': open(os.path.join(raw_dir, 'test', label, file)).read().replace('<br />', ' '),
                         'label': get_label_name(dataset_name, label)})

    unlabeled_df = pd.DataFrame(train)
    test_df = pd.DataFrame(test)

    unlabeled_df.to_csv(os.path.join(dataset_out_dir, 'unlabeled.csv'), index=False)
    test_df.to_csv(os.path.join(dataset_out_dir, 'test.csv'), index=False)

    with open(os.path.join(dataset_out_dir, 'class_names.txt'), 'w') as f:
        f.writelines([class_name+'\n' for class_name in sorted(test_df["label"].unique())])

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')

dataset_to_download_url = \
    {
        'isear': 'https://raw.githubusercontent.com/sinmaniphel/py_isear_dataset/master/isear.csv',
        'ag_news': 'https://docs.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
        'dbpedia': 'https://docs.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k&confirm=t',
        'yahoo_answers': 'https://docs.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU&confirm=t',
        'imdb': 'https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
    }

for dataset, url in dataset_to_download_url.items():
    out_dir = os.path.join(RAW_DIR, dataset)
    os.makedirs(out_dir, exist_ok=True)
    logging.info(f'downloading {dataset} raw files')
    extension = '.'.join(url.split(os.sep)[-1].split('.')[1:])
    if len(extension) == 0:
        extension = 'tar.gz'
    target_file = os.path.join(out_dir, f'{dataset}.{extension}')
    urllib.request.urlretrieve(url, target_file)
    if extension == 'tar.gz':
        file = tarfile.open(target_file)
        file.extractall(out_dir)
        file.close()
    elif extension == 'zip':
        with zipfile.ZipFile(target_file, 'r') as zip_ref:
            zip_ref.extractall(out_dir)

load_20_newsgroup()
load_ag_news_dbpedia_yahoo()
load_isear()
load_imdb()

# Running Experiments

In [None]:
parser = ArgumentParser()
parser.add_argument('--experiment_name', required=True)
parser.add_argument('--dataset_name', required=True)
parser.add_argument("--base_model", required=True)

parser.add_argument("--num_iterations", type=int, default=2)
parser.add_argument("--dataset_subset_size", type=int, default=10000)
parser.add_argument("--sample_ratio", type=float, default=0.01)
parser.add_argument("--negative_sampling_strategy", default='take_random', type=str)

parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--infer_batch_size", type=int, default=16)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument('--seed', type=int, default=0)

parser.add_argument("--delete_models", action='store_true')



In [None]:
params = "--experiment_name newsgroup_Lst --dataset_name 20_newsgroup "\
" --base_model roberta-large-mnli --seed 0 --dataset_subset_size 10000"\
" --train_batch_size 4 --infer_batch_size 4 --negative_sampling_strategy TAKE_LAST"

In [None]:
args = parser.parse_args(params.split())

config_dict = vars(args)
logging.info(config_dict)

# This string describes the full self-training configuration, to ease aggregation across seeds
setting_name = '_'.join([str(value) for key, value in config_dict.items()
                            if key not in ['seed', 'infer_batch_size', 'delete_models']])


In [None]:
set_seed(args.seed)

data_path = os.path.join(get_root_dir(), 'datasets', args.dataset_name)
out_dir = os.path.join(get_root_dir(), 'output', 'experiments', args.experiment_name)
os.makedirs(out_dir, exist_ok=True)

unlabeled_df = pd.read_csv(os.path.join(data_path, 'unlabeled.csv'))
unlabeled_texts = unlabeled_df['text']

with open(os.path.join(data_path, 'class_names.txt')) as f:
    class_names = f.read().splitlines()

# Limit the size of the unlabeled set to reduce runtime
subset_idxs = random.sample(range(len(unlabeled_texts)), min(args.dataset_subset_size, len(unlabeled_texts)))
unlabeled_texts = [unlabeled_texts[idx] for idx in subset_idxs]

test_df = pd.read_csv(os.path.join(data_path, 'test.csv'))
test_texts = test_df['text']             #.iloc[0:5000]
test_gold_labels = test_df['label']      #.iloc[0:5000]

# Set the desired number of pseudo-labeled positive examples per class
sample_size = int(len(unlabeled_texts) * args.sample_ratio)
logging.info(f'sample size per class is {sample_size}, set by a sample ratio of {args.sample_ratio}')

model_name = args.base_model

logging.info(f"Evaluating base zero-shot model '{model_name}' performance on test set")
test_preds = get_zero_shot_predictions(model_name, test_texts, class_names,
                                        batch_size=args.infer_batch_size, max_length=args.max_length)
evaluate_classification_performance(test_preds.predicted_labels, test_gold_labels, out_dir,
                                    info_dict={
                                        'iteration': 0,
                                        'setting_name': f'{setting_name}_base',
                                        **config_dict
                                    })

for iter_number in range(1, args.num_iterations+1):
    logging.info(f"Inferring with zero-shot model '{model_name}' on {len(unlabeled_texts)} unlabeled elements)")
    predictions = get_zero_shot_predictions(model_name, unlabeled_texts, class_names,
                                            batch_size=args.infer_batch_size, max_length=args.max_length)
    logging.info(f"Done inferring zero-shot model on {len(unlabeled_texts)} unlabeled elements")

    if args.delete_models and model_name != args.base_model:
        logging.info(f"deleting fine-tuned model {model_name}")
        shutil.rmtree(model_name)

    self_training_set = SelfTrainingSet()
    # For each class, we rank the elements as candidates for self-training according to the model confidence
    class_name_to_sorted_idxs = rank_candidate_indices_per_class(class_names, predictions)

    # We choose the <sample_size> best examples from each class as positive (entailment) examples
    class_name_to_positive_chosen_idxs = {class_name: sorted_idxs[:sample_size]
                                            for class_name, sorted_idxs in class_name_to_sorted_idxs.items()}

    for class_name, idxs in class_name_to_positive_chosen_idxs.items():
        self_training_set.texts.extend([unlabeled_texts[idx] for idx in idxs])
        self_training_set.class_names.extend([class_name]*len(idxs))
        self_training_set.entailment_labels.extend(['ENTAILMENT']*len(idxs))

    # Add negative (contradiction) examples
    negative_sampling_strategy = NegativeSamplingStrategy[args.negative_sampling_strategy.upper()]
    class_name_to_negative_chosen_idxs = \
        get_negative_examples(predictions, class_name_to_positive_chosen_idxs, negative_sampling_strategy)

    for class_name, idxs in class_name_to_negative_chosen_idxs.items():
        self_training_set.texts.extend([unlabeled_texts[idx] for idx in idxs])
        self_training_set.class_names.extend([class_name]*len(idxs))
        self_training_set.entailment_labels.extend(['CONTRADICTION']*len(idxs))

    logging.info(f"Done collecting pseudo-labeled elements for self-training iteration {iter_number}: "
                    f"{Counter(self_training_set.entailment_labels)}")

    # We use the updated pseudo-labeled set from this iteration to fine-tune the *base* entailment model
    logging.info(f"Fine-tuning model '{args.base_model}' on {len(self_training_set.entailment_labels)} "
                    f"pseudo-labeled texts")
    finetuned_model_path = finetune_entailment_model(
        model_name=args.base_model, self_training_set=self_training_set, seed=args.seed,
        learning_rate=args.learning_rate, batch_size=args.train_batch_size, max_length=args.max_length,
        num_epochs=1)
    logging.info(f"Done fine-tuning. Model for self-training iteration {iter_number} "
                    f"saved to {finetuned_model_path}.")

    model_name = finetuned_model_path

    logging.info(f'iteration {iter_number}: evaluating model {model_name} performance on test set')
    test_preds = get_zero_shot_predictions(model_name, test_texts, class_names,
                                            batch_size=args.infer_batch_size, max_length=args.max_length)
    evaluate_classification_performance(test_preds.predicted_labels, test_gold_labels, out_dir,
                                        info_dict={
                                            'iteration': iter_number,
                                            'setting_name': f'{setting_name}_iter_{iter_number}',
                                            'self_training_set_size':  len(self_training_set.texts),
                                            **config_dict
                                        })