### Practice: text style transfer

Credits: this notebook is deeply based on [YSDA NLP course notebook](https://github.com/yandexdataschool/nlp_course/tree/2022/week10_style)

Hello, sitzen class A.412C!

Based on your browser search history, we conclude that you have an above average skill in natural language processing. In our benevolence, we give you a chance to contribute your skills to upholding the happiest society in the universe. Are you up to the task?

As you know, our most recent breakthrough was replacing 97% restaurant workers with BFGHQBERT+++ autonomous food dispensers.

Yet a some radical elements failed to recognize the greater good that we brought them. They mistakenly voice their ignorant opinions about our new INGSOC-approved restaurants, brining dangerous doubt to the minds of our loyal citzens.

Surely you cannot tolerate such infidelity! Our loyal citzens demand that you rectify their mistake. _You must build a model that will automatically improve their ignorant thoughts and replace them with the thoughts they should actually have._

Attached below are the INGSOC-approved datasets for ignorant and correct thoughts. The scientific terminology is for wrong opinions and correct opinions is "negative" and "positive", respectively.

Respond within 7 days or you will lose 3.7629 citzenship points.

![img](https://ih1.redbubble.net/image.1254830934.9884/poster,504x498,f8f8f8-pad,400x240,f8f8f8.jpg)

In [None]:
# !pip install -q transformers
!wget -q https://github.com/shentianxiao/language-style-transfer/raw/master/data/yelp/sentiment.train.0 -O train_negative
!wget -q https://github.com/shentianxiao/language-style-transfer/raw/master/data/yelp/sentiment.train.1 -O train_positive
!wget -q https://github.com/shentianxiao/language-style-transfer/raw/master/data/yelp/sentiment.dev.0 -O dev_negative
!wget -q https://github.com/shentianxiao/language-style-transfer/raw/master/data/yelp/sentiment.dev.1 -O dev_positive

In [None]:
!head -n 5 ./dev_positive
!echo
!head -n 5 ./dev_negative

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cpu':
    print("Fine-tuning BERT without an accelerator is not party-approved.")

### Part 1: Masked language model

Attached below you can find the INGSOC-compliant training code that fine-tunes a BERT model for Masked Language Modeling.

You shall use this model to generate positive replacements for negative tokens.

In [None]:
from transformers import BertTokenizer, BertForMaskedLM

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(True)

In [None]:
from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

print("Preparing the training data...")
dataset = LineByLineTextDataset(
    file_path="./train_positive", tokenizer=tokenizer, block_size=128)

print("Dataset ready!")

trainer = Trainer(
    model=bert_mlm_positive,
    train_dataset=dataset, 
    data_collator=DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=True, mlm_probability=0.15),
    args=TrainingArguments(
        output_dir="./bert_mlm_positive", overwrite_output_dir=True,
        num_train_epochs=1, per_device_train_batch_size=32,
        save_steps=10_000, save_total_limit=2, report_to=None),
)

trainer.train()

In [None]:
# <Build and train a MLM for incorrect opinions>

bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(True) #<...>

# <A whole lot of your code>

In [None]:
print("Preparing the training data...")
negative_dataset = LineByLineTextDataset(
    file_path="./train_negative", tokenizer=tokenizer, block_size=128)

print("Negative dataset ready!")

In [None]:
trainer = Trainer(
    model=bert_mlm_negative,
    train_dataset=negative_dataset, 
    data_collator=DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=True, mlm_probability=0.15),
    args=TrainingArguments(
        output_dir="./bert_mlm_negative", overwrite_output_dir=True,
        num_train_epochs=1, per_device_train_batch_size=32,
        save_steps=10_000, save_total_limit=2, report_to="none"),
)

trainer.train()

### Part 2: Replace tokens

You can now use the two masked language models to align user opinions. You can do so with the following steps:

1. Find tokens where the ratio $(P_{positive}(x) + \epsilon) / (P_{negative}(x) + \epsilon)$ is the smallest
2. Replace those tokens with one of $k$ most likely tokens according to $P_{positive}(x)$.
3. Rinse, repeat

You can find the full procedure at https://arxiv.org/abs/2010.01054

In [None]:
idx_to_token = {idx: token for token, idx in tokenizer.vocab.items()}

In [None]:
sentence = f'great wings and decent drinks but the wait staff is {tokenizer.mask_token} !'
batch = tokenizer([sentence], padding=True, truncation=True, return_tensors='pt')

In [None]:
tokenizer.decode(batch['input_ids'][0].data.cpu().numpy())

In [None]:
batch = {key: value.to(device) for key, value in batch.items()}

In [None]:
logits_positive = bert_mlm_positive(**batch)['logits']
logits_negative = bert_mlm_negative(**batch)['logits']

In [None]:
mask_logits_positive = logits_positive[0, -3].cpu().data.numpy()

In [None]:
for index in np.argsort(-mask_logits_positive)[:5]:
    print(idx_to_token[index])

In [None]:
mask_logits_negative = logits_negative[0, -3].cpu().data.numpy()

In [None]:
for index in np.argsort(-mask_logits_negative)[:5]:
    print(idx_to_token[index])

In [None]:
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
  """
  - split the sentence into tokens using the INGSOC-approved BERT tokenizer
  - find :num_tokens: tokens with the highest ratio (see above)
  - replace them with :k_best: words according to bert_mlm_positive
  :return: a list of all possible strings (up to k_best * num_tokens)
  """
#   <YOUR CODE HERE>
    batch = tokenizer([sentence], padding=True, truncation=True, return_tensors='pt')
    

    return <...>

In [None]:
###

In [None]:
dev_data = list(open('./dev_negative'))

In [None]:
dev_data[500:505]

In [None]:
get_replacements("great wings and decent drinks but the wait staff is horrible !",
                 num_tokens=1, k_best=2)
# >>> ["great wings and decent drinks but the wait staff is great !", "great wings and decent drinks but the wait staff is awesome !"])

__Final task__ - build a procedure that iteratively applies replacements, demonstrate the effectiveness of your approach with at least 10 examples to satisfy INGSOC.
