In [1]:
# Get transfer accuracy and fluency

In [2]:
!pip install fairseq

Collecting fairseq
  Downloading fairseq-0.12.2.tar.gz (9.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.6/9.6 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting hydra-core<1.1,>=1.0.7 (from fairseq)
  Downloading hydra_core-1.0.7-py3-none-any.whl (123 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.8/123.8 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting omegaconf<2.1 (from fairseq)
  Downloading omegaconf-2.0.6-py3-none-any.whl (36 kB)
Collecting sacrebleu>=1.4.12 (from fairseq)
  Downloading sacrebleu-2.3.2-py3-none-any.whl (119 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.7/119.7 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
Collecting bitarray (from fairse

In [3]:
from fairseq.models.roberta import RobertaModel
from fairseq.data.data_utils import collate_tokens
import torch
import tqdm

In [30]:
# required dirs
style_transfer_dir = "/content/drive/MyDrive/Models/STYLE_TRANSFERRED"
acc_classifier_dir = "/content/drive/MyDrive/Models/CLASSIFIER/SHAKESPEARE"
fluency_classifier_dir = "/content/drive/MyDrive/Models/CLASSIFIER/FLUENCY/"

accepted_labels = {
    "shakespeare" : "original"
}
accepted_fluency_label = "acceptable"

In [4]:
def label_fn(label, roberta):
    return roberta.task.label_dictionary.string(
        [label + roberta.task.target_dictionary.nspecial]
    )

In [5]:
def detokenize(x):
    x = x.replace(" .", ".").replace(" ,", ",").replace(" !", "!").replace(" ?", "?").replace(" )", ")").replace("( ", "(")
    return x

In [38]:
def get_acc(classifier_dir, bin_datadir, sents_list):
  roberta = RobertaModel.from_pretrained(
    classifier_dir,
    checkpoint_file='checkpoint_best.pt',
    data_name_or_path = bin_datadir
  )

  roberta.eval()
  ncorrect, nsamples = 0, 0
  roberta.cuda()

  unk_bpe = roberta.bpe.encode(" <unk>").strip()
  batch_size = 10

  for i in tqdm.tqdm(range(0, len(sents_list), batch_size), total=len(sents_list) // batch_size):
    sds = sents_list[i:i + batch_size]
    sds = [roberta.bpe.encode(detokenize(sd.lower())) for sd in sds]
    # lds = label_data[i:i + batch_size]

    batch = collate_tokens(
        [roberta.task.source_dictionary.encode_line("<s> " + sd + " </s>", append_eos=False) for sd in sds], pad_idx=1
    )

    batch = batch[:, :512]

    with torch.no_grad():
      predictions = roberta.predict('classification_head', batch.long())

    prediction_labels = [label_fn(x.argmax(axis=0).item(), roberta) for x in predictions]

  return prediction_labels[0]

In [23]:
# For accuracy
acc_bin_datadir = '/content/drive/MyDrive/IRE_Project/style_transfer_paraphrase/datasets/shakespeare-bin/'
fluency_bin_datadir = '/content/drive/MyDrive/IRE_Project/style_transfer_paraphrase/datasets/cola-bin/'

In [24]:
# read file
from_style = "bible"
to_style = "shakespeare"
fname = style_transfer_dir + "/sent_"+from_style+"_to_"+to_style+".txt"

with open(fname, "r") as f:
  data = f.read().strip().split("\n")

orig_seng = data[0]
transferred_sent = data[1]

print(f"Orig Sent[{from_style}]- ", orig_seng)
print(f"Transferred Sent[{to_style}]- ", transferred_sent)

Orig Sent[bible]-  And they caught him, and beat him, and sent him away empty.
Transferred Sent[shakespeare]-  He was caught, throw him out.


In [22]:
pred_label = get_acc(acc_classifier_dir, acc_bin_datadir, [transferred_sent]).strip().lower()
print("\nPredicted label- ", pred_label)

if accepted_labels[to_style] != pred_label:
  print("Not in target style")
  print("Accuracy : 0")
else:
  print("In target style")
  print("Accuracy : 1")

1it [00:00, 57.07it/s]


Predicted label-  modern
Not in target style





In [36]:
# For Fluency
pred_label = get_acc(fluency_classifier_dir, fluency_bin_datadir, [transferred_sent]).strip().lower()
print("\nPredicted label- ", pred_label)

if accepted_fluency_label != pred_label:
  print("Not Fluent")
  print("Fluency : 0")
else:
  print("Fleunt")
  print("Fluency : 1")

1it [00:00, 28.84it/s]


Predicted label-  acceptable
Fleunt
Fluency : 1





In [40]:
pred_label = get_acc(fluency_classifier_dir, fluency_bin_datadir, [transferred_sent]).strip().lower()

1it [00:00, 66.80it/s]
