# Style Transfer Intensity (STI)

Style Transfer Strength is often evaluated by training a classifier on the labeled dataset and measuring the number of outputes classified as having the target style. 

[This paper](https://arxiv.org/pdf/1904.02295.pdf) proposes an alternative method.

Rather than count how many output texts achieve a target style, we can capture more nuanced differences between the style distributions of x and x', using Earth Mover’s Distance.EMD is the minimum “cost” to turn one distribution into the other, or how “intense” the transfer is. Distributions can have any number of values (styles), so EMD handles binary and non-binary datasets



## Prepare WNC for style classification

In [1]:
import os
import numpy as np
from collections import defaultdict
from datasets import (
    load_dataset,
    load_from_disk,
    load_metric,
    Dataset,
    Features,
    Value,
    ClassLabel,
    DatasetDict,
)

%load_ext lab_black

In [2]:
def build_classification_dataset(path: str) -> DatasetDict:
    """
    Formats the translation-task version of WNC as a classification dataset.

    Dataset splits remain the same, but the number of records in each split are doubled
    as we create an individual record for both the "source_text" and "target_text" fields.
    In this way, "source_text" is assigned a label of "subjective" and "target_text" is assigned
    a label of "neutral". Records are randomly shuffled within each split.

    Args:
        path (str): path to HuggingFace dataset

    Returns:
        DatasetDict

    """
    datasets = load_from_disk(path)
    dataset_dict = defaultdict(dict)

    SPLITS = ["train", "test", "validation"]
    LABEL_MAPPING = {"source_text": "subjective", "target_text": "neutral"}
    FEATURES = Features(
        {
            "text": Value("string"),
            "label": ClassLabel(num_classes=2, names=["subjective", "neutral"]),
        }
    )

    for split in SPLITS:
        df = datasets[split].to_pandas()
        split_dict = defaultdict(list)

        for column, label in LABEL_MAPPING.items():
            split_dict["text"].extend(df[column].tolist())
            split_dict["label"].extend([label] * len(df))

        dataset_dict[split] = Dataset.from_dict(
            split_dict, features=FEATURES
        )  # .shuffle(seed=42)

    return DatasetDict(dataset_dict)

In [3]:
DATASETS_PATH = "/home/cdsw/data/processed/WNC_full"
wnc_classification = build_classification_dataset(DATASETS_PATH)

In [25]:
wnc_classification

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 308394
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 17154
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 17214
    })
})

In [34]:
wnc_classification["train"][:10]

{'text': ["while for long nearly only women where shown as sex objects, increasing tolerance, more tempered censorship, emancipatory developments and increasing buying power of previously neglected appreciative target groups in rich markets (mainly in the west) have lead to a marked increase in the share of attractive male flesh 'on display'.",
  "following the end of kenneth kaunda's repressive dictatorship , chiluba won the country's multi-party presidential elections.",
  'a brilliant quarterback with the university of illinois, haller was signed by the giants as an amateur free agent in 1958. he made his debut on april 11, 1961 as a platoon catcher.',
  'traitor to his people adam yahiye gadahn (born september 1, 1978) is an american-born man who is suspected of being a member of the al qaeda organization.',
  'a funny thing happened on the way to the moon is a 2001 documentary written, produced, and directed by nashville, tennessee-based filmmaker and investigative journalist bart

In [33]:
second_half = int(len(wnc_classification["train"]) / 2)
wnc_classification["train"][second_half : second_half + 10]

{'text': ["increased tolerance, more tempered censorship, emancipatory developments and increasing buying power of previously neglected appreciative target groups in rich markets (mainly in the west) have lead to a marked increase in the share of attractive male flesh 'on display'.",
  "following the end of kenneth kaunda's presidency , chiluba won the country's multi-party presidential elections.",
  'a quarterback with the university of illinois, haller was signed by the giants as an amateur free agent in 1958. he made his debut on april 11, 1961 as a platoon catcher.',
  'adam yahiye gadahn (born september 1, 1978) is an american-born man who is suspected of being a member of the al qaeda organization.',
  'a funny thing happened on the way to the moon is a 2001 documentary written, produced, and directed by nashville, tennessee-based filmmaker and investigative journalist bart winfield sibrel, a critic of the united states space program and proponent of the theory that the six apol

In [29]:
len(wnc_classification["train"]) / 2

154197.0

In [31]:
154197 * 2

308394

In [5]:
# save dataset
CLS_DATASET_PATH = "/home/cdsw/data/processed/WNC_full_cls"
# os.makedirs(CLS_DATASET_PATH)
# wnc_classification.save_to_disk(CLS_DATASET_PATH)

### Testing dataset

In [15]:
test_wnc_classification = DatasetDict(
    {
        "train": wnc_classification["train"].select(range(1000)),
        "test": wnc_classification["test"].select(range(1000)),
        "validation": wnc_classification["validation"].select(range(1000)),
    }
)

TEST_CLS_DATASET_PATH = "/home/cdsw/data/processed/WNC_full_cls_TEST"
os.makedirs(TEST_CLS_DATASET_PATH)
test_wnc_classification.save_to_disk(TEST_CLS_DATASET_PATH)

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

In [16]:
test_wnc_classification

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
})

## Train a classifier

In [4]:
CLS_DATASET_PATH = "/home/cdsw/data/processed/WNC_full_cls_TEST"
wnc_full_cls = load_from_disk(CLS_DATASET_PATH)

In [5]:
wnc_full_cls

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
})

In [None]:
Trainer(

In [8]:
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
)
from transformers.integrations import MLflowCallback
from datasets import load_metric

In [None]:
from transformers.trainer_utils import IntervalStrategy

In [7]:
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


def tokenize_function(example):
    return tokenizer(example["text"], truncation=True)


tokenized_datasets = wnc_full_cls.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Loading cached processed dataset at /home/cdsw/data/processed/WNC_full_cls_TEST/train/cache-11a8b3faaf6e3c0a.arrow
Loading cached processed dataset at /home/cdsw/data/processed/WNC_full_cls_TEST/test/cache-5c62f87956a8bec2.arrow
Loading cached processed dataset at /home/cdsw/data/processed/WNC_full_cls_TEST/validation/cache-8db993456d6202ea.arrow


In [8]:
MODEL_NAME = "bert-cls-full"
MODEL_DIR = "/home/cdsw/models"

training_args = TrainingArguments(
    output_dir=os.path.join(MODEL_DIR, MODEL_NAME),
    learning_rate=5e-05,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    logging_dir=os.path.join(MODEL_DIR, "logs", MODEL_NAME),
    logging_steps=50,
    overwrite_output_dir=True,
    evaluation_strategy="steps",
    eval_steps=100,
    save_total_limit=5,
    save_steps=100,
    # metric_for_best_model="f1",
    # greater_is_better=True,
)

In [9]:
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [10]:
def compute_metrics(eval_preds):
    metric = load_metric("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [11]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.remove_callback(MLflowCallback)

In [12]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1000
  Num Epochs = 5
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 315


Step,Training Loss,Validation Loss,Accuracy,F1
100,0.6853,0.685976,0.581,0.527621
200,0.3937,0.983388,0.593,0.657696
300,0.0667,1.293017,0.622,0.623506


The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16
Saving model checkpoint to /home/cdsw/models/bert-cls-full/checkpoint-100
Configuration saved in /home/cdsw/models/bert-cls-full/checkpoint-100/config.json
Model weights saved in /home/cdsw/models/bert-cls-full/checkpoint-100/pytorch_model.bin
tokenizer config file saved in /home/cdsw/models/bert-cls-full/checkpoint-100/tokenizer_config.json
Special tokens file saved in /home/cdsw/models/bert-cls-full/checkpoint-100/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you

TrainOutput(global_step=315, training_loss=0.4158816258112589, metrics={'train_runtime': 65.9758, 'train_samples_per_second': 75.785, 'train_steps_per_second': 4.774, 'total_flos': 194422625470080.0, 'train_loss': 0.4158816258112589, 'epoch': 5.0})

### Step Through Eval

In [13]:
predictions = trainer.predict(tokenized_datasets["validation"])

The following columns in the test set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 1000
  Batch size = 16


In [14]:
predictions.predictions

array([[-0.3496061, -0.6291089],
       [ 1.2544787, -2.2091932],
       [ 1.4202304, -2.272884 ],
       ...,
       [ 1.9655778, -2.7637262],
       [-2.3133612,  2.5391092],
       [ 1.3082958, -2.4287748]], dtype=float32)

In [15]:
predictions.predictions.shape

(1000, 2)

In [16]:
preds = np.argmax(predictions.predictions, axis=-1)

In [18]:
metric = load_metric("glue", "mrpc")

In [19]:
metric.compute(predictions=preds, references=predictions.label_ids)

{'accuracy': 0.626, 'f1': 0.625250501002004}

In [24]:
compute_metrics(predictions)

ValueError: too many values to unpack (expected 2)

#### manual

In [20]:
accuracy_metric = load_metric("accuracy")
f1_metric = load_metric("f1")

Downloading:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

In [22]:
accuracy.compute(predictions=preds, references=predictions.label_ids)

{'accuracy': 0.626}

In [23]:
f1.compute(predictions=preds, references=predictions.label_ids)

{'f1': 0.625250501002004}

## Load a trained model

In [7]:
from transformers import pipeline, set_seed

In [None]:
set_seed(

In [9]:
MODEL_PATH = "/home/cdsw/models/bert-cls-full3/checkpoint-96000/"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

In [10]:
classifier = pipeline(task="text-classification", model=model, tokenizer=tokenizer)

In [35]:
classifier(
    "following the end of kenneth kaunda's repressive dictatorship , chiluba won the country's multi-party presidential elections."
)

[{'label': 'LABEL_0', 'score': 0.9783034920692444}]

In [36]:
classifier(
    "following the end of kenneth kaunda's repressive presidency , chiluba won the country's multi-party presidential elections."
)

[{'label': 'LABEL_0', 'score': 0.981116771697998}]

In [12]:
classifier.

'text-classification'

In [None]:
Trainer(

In [37]:
import datasets

In [38]:
datasets.list_metrics()

['accuracy',
 'bertscore',
 'bleu',
 'bleurt',
 'cer',
 'chrf',
 'code_eval',
 'comet',
 'competition_math',
 'coval',
 'cuad',
 'exact_match',
 'f1',
 'frugalscore',
 'glue',
 'google_bleu',
 'indic_glue',
 'mae',
 'mahalanobis',
 'matthews_correlation',
 'mauve',
 'mean_iou',
 'meteor',
 'mse',
 'pearsonr',
 'perplexity',
 'precision',
 'recall',
 'roc_auc',
 'rouge',
 'sacrebleu',
 'sari',
 'seqeval',
 'spearmanr',
 'squad',
 'squad_v2',
 'super_glue',
 'ter',
 'wer',
 'wiki_split',
 'xnli',
 'xtreme_s']

In [27]:
type(training_args)

transformers.training_args.TrainingArguments

In [23]:
import os

In [None]:
os.path.exists('

In [12]:
from dataclasses import dataclass, field
from typing import Optional
from transformers.trainer_utils import IntervalStrategy

@dataclass
class StiArguments:
    """
    TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop itself**.
    Using [`HfArgumentParser`] we can turn this class into [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the command line.
    """
    
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    output_dir: str = field(
        metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
    )
    overwrite_output_dir: bool = field(
        default=False,
        metadata={
            "help": (
                "Overwrite the content of the output directory. "
                "Use this to continue training if output_dir points to a checkpoint directory."
            )
        },
    )
    learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
    per_device_train_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
    )
    per_device_eval_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
    )
    num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
    logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
    logging_strategy: IntervalStrategy = field(
        default="steps",
        metadata={"help": "The logging strategy to use."},
    )
    logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
    eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
    evaluation_strategy: IntervalStrategy = field(
        default="no",
        metadata={"help": "The evaluation strategy to use."},
    )
    save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
    save_total_limit: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "Limit the total amount of checkpoints. "
                "Deletes the older checkpoints in the output_dir. Default is unlimited checkpoints"
            )
        },
    )
    load_best_model_at_end: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether or not to load the best model found during training at the end of training."},
    )
    metric_for_best_model: Optional[str] = field(
        default=None, metadata={"help": "The metric to use to compare two different models."}
    )
    greater_is_better: Optional[bool] = field(
        default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
    )
    

In [16]:
parser = HfArgumentParser(StiArguments)

In [19]:
parser

HfArgumentParser(prog='ipykernel_launcher.py', usage=None, description=None, formatter_class=<class 'argparse.ArgumentDefaultsHelpFormatter'>, conflict_handler='error', add_help=True)

In [22]:
dir(parser)

['__annotations__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_action_groups',
 '_actions',
 '_add_action',
 '_add_container_actions',
 '_add_dataclass_arguments',
 '_check_conflict',
 '_check_value',
 '_defaults',
 '_get_args',
 '_get_formatter',
 '_get_handler',
 '_get_kwargs',
 '_get_nargs_pattern',
 '_get_option_tuples',
 '_get_optional_actions',
 '_get_optional_kwargs',
 '_get_positional_actions',
 '_get_positional_kwargs',
 '_get_value',
 '_get_values',
 '_handle_conflict_error',
 '_handle_conflict_resolve',
 '_has_negative_number_optionals',
 '_match_argument',
 '_match_arguments_partial',
 '_mutually_exclusive_groups',
 '_negative_number_matcher',
 '_opt

In [25]:
import argparse

In [26]:
parser = argparse.ArgumentParser(description="Script to run train job for seq2seq (TST) or classifier (STI) models.")

In [None]:
parser.add_argument(

In [27]:
parser.add_argument('task', type=str, help='Select which task to run: seq2seq or classifier.')

_StoreAction(option_strings=[], dest='task', nargs=None, const=None, default=None, type=<class 'str'>, choices=None, help='Select which task to run: seq2seq or classifier.', metavar=None)

In [28]:
parser

ArgumentParser(prog='ipykernel_launcher.py', usage=None, description='Script to run train job for seq2seq (TST) or classifier (STI) models.', formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)