In [1]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

We'll use the stock multilingual MBart50 model, with a decoder. Languages covered are listed in Table 6 here: https://arxiv.org/pdf/2008.00401.pdf 

Note that `ar`, `fa`, and `ur` are present.

In [2]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

In [3]:
from transformers import AutoTokenizer

In [4]:
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50")

We're doing a slightly odd thing here: "translating" from English to English.

In [5]:
tokenizer.src_lang = "en-XX"
tokenizer.tgt_lang = "en-XX"

In [6]:
len(tokenizer.vocab)

250054

In [7]:
tokenizer.tokenize("This record-breaking year for anti-transgender legislation would affect minors the most.")

['▁This',
 '▁record',
 '-',
 'break',
 'ing',
 '▁year',
 '▁for',
 '▁anti',
 '-',
 'trans',
 'gende',
 'r',
 '▁legisla',
 'tion',
 '▁would',
 '▁affect',
 '▁minor',
 's',
 '▁the',
 '▁most',
 '.']

In [8]:
tokenizer.tokenize('steusatheuåø¨ˆ∂˙†˜ßتيممسيااضعخثخهعخهع'    )

['▁ste',
 'usa',
 'the',
 'u',
 'å',
 'ø',
 '▁̈',
 'ˆ',
 '∂',
 '▁',
 '̇',
 '†',
 '▁',
 '̃',
 'ß',
 'ت',
 'يم',
 'مس',
 'يا',
 'ا',
 'ضع',
 'خ',
 'ث',
 'خه',
 'ع',
 'خه',
 'ع']

Now we load the training and development data sets from the RDD corpus of Dong and Smith 2018.  Only 10% of each is used here, i.e. around 700,000 exemplars in the training set.

In [9]:
import datasets

In [10]:
ds = datasets.load_dataset('csv', data_files={'train': '../test_data/rdd/train_sm.txt',
                                     'validation': '../test_data/rdd/dev_sm.txt'}, 
                  delimiter='\t', quoting=3)

Using custom data configuration default-094e2dc7e61ed095
Reusing dataset csv (/home/jds/.cache/huggingface/datasets/csv/default-094e2dc7e61ed095/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0)


In [11]:
ds

DatasetDict({
    train: Dataset({
        features: ['en', 'en.1'],
        num_rows: 700000
    })
    validation: Dataset({
        features: ['en', 'en.1'],
        num_rows: 150000
    })
})

The boring code to tokenize the input and retrieve the vocabulary IDs.

In [12]:
max_input_length = 128
max_target_length = max_input_length
source_lang = "en"
target_lang = "en.1"

def preprocess(examples):
    inputs = examples[source_lang]
    targets = examples[target_lang]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding='max_length')
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True, padding='max_length')

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [13]:
preprocess(ds['train'][:2])

{'input_ids': [[3, 6, 5256, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 2223, 136, 4163, 57616, 214, 23, 70, 352, 95972, 4, 7440, 1836, 1902, 47, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'attention_mask': [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [14]:
token_ds = ds.map(preprocess, batched=True)

HBox(children=(FloatProgress(value=0.0, max=700.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=150.0), HTML(value='')))




Huggingface's new `Trainer` is very welcome, as it allowed me to cut down all the detailed steps in the training loop.

In [15]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

A batch size of 12 is all I could fit on a 24GB graphics card.  We will run just 2 epochs (typical of finetuning).

In [16]:
training_args = Seq2SeqTrainingArguments(
    output_dir = './results',          # output directory
    num_train_epochs = 2,              # total # of training epochs
    per_device_train_batch_size = 12,  # batch size per device during training
    per_device_eval_batch_size = 12,   # batch size for evaluation
    warmup_steps = 500,                # number of warmup steps for learning rate scheduler
    weight_decay = 0.01,               # strength of weight decay
    logging_dir = './logs',            # directory for storing logs
)

Code off the web to add more metrics for training, in addition to just the loss.

In [17]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [18]:
trainer = Seq2SeqTrainer(
    model = model,                         
    args = training_args,                  
    train_dataset = token_ds['train'],         
    eval_dataset = token_ds['validation'],
    compute_metrics = compute_metrics
)

In [None]:
trainer.train()

Note that in this run I stopped it before two whole epochs had passed.

I ran out of GPU memory when trying to evaluate, so insteal let's take a look directly.

In [28]:
model.eval()

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): Embedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): Embedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0): MBartEncoderLayer(
          (self_attn): MBartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=T

In [29]:
fastwer.score(corrected_strings, test[:1000])
def correct(sent):
    tokens = tokenizer(sent, max_length=max_input_length, 
                  truncation=True, padding='max_length', return_tensors='pt').to(device='cuda')
    with torch.no_grad():
        gen_tokens = model.generate(**tokens)
    
    return ' '.join(tokenizer.tokenize(sent)), tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)

Here's a line from the RDD _test_ set.

In [30]:
correct('make her orifons, for rhe furthering of her purpofes ; and yet the aire, che water ,')

('▁make ▁her ▁ori fon s , ▁for ▁rhe ▁further ing ▁of ▁her ▁pur po fes ▁; ▁and ▁yet ▁the ▁aire , ▁che ▁water ▁ ,',
 ['make her orisons, for the furthering of her purposes; and yet the aire, the water,'])

It fixes spelling errors, space punctuation.  What about spelling errors in general?

In [35]:
correct('It has often been taken for gran ted that, especially in che xeat of the civil woar, the :Xpanish anarchists were republicans.')

('▁It ▁has ▁often ▁been ▁taken ▁for ▁gran ▁ ted ▁that , ▁especially ▁in ▁che ▁xe at ▁of ▁the ▁civil ▁wo ar , ▁the ▁: X pan ish ▁anar ch ists ▁were ▁republica ns .',
 ['It has often been taken for granted that, especially in the great of the civil war, the Spanish anarchists were republicans.'])

Not terrible, but we see `xeat` -> `great` instead of `heat`, which is somewhat less grammatical.  Let's mangle something a lot more.

In [42]:
correct('an archism dizPays an ont__logiical opp;zition to all formz of gggovern, a thE reupblic is evthe.,.ntly one form of t##he same,')

('▁an ▁archi s m ▁diz Pay s ▁an ▁ont __ logi ical ▁opp ; zi tion ▁to ▁all ▁form z ▁of ▁ gg govern , ▁a ▁th E ▁re up bli c ▁is ▁ev the . , . nt ly ▁one ▁form ▁of ▁t # # he ▁same ,',
 ['anarchical dizpays an outward political opposition to all forms of government, athe republic is e\xad ther, notly one form of the same,'])

Eh well, it's not God-AI quite yet ...  So let's run over the first 1000 rows of the RDD test set.

In [44]:
with open('../test_data/rdd/test.x.txt') as f:
    test = f.read().splitlines()

In [68]:
with open('../test_data/rdd/test.y.txt') as f:
    gold = f.read().splitlines()

In [45]:
%%time
corrected = [correct(line) for line in test[:1000]]

CPU times: user 7min 48s, sys: 50.6 s, total: 8min 39s
Wall time: 8min 39s


In [53]:
f"{(8 * 60 + 39)/1000} secs/correction"

'0.519 secs/correction'

In [54]:
import fastwer

In [58]:
corrected_strings = [corr[0] for _, corr in corrected]

In [69]:
f"WER: {fastwer.score(corrected_strings, gold[:1000])}"

'WER: 14.7459'

In [70]:
f"CER: {fastwer.score(corrected_strings, gold[:1000], char_level=True)}"

'CER: 6.3847'

For comparison, the CER before correction:
    

In [72]:
f"CER: {fastwer.score(test[:1000], gold[:1000], char_level=True)}"

'CER: 11.1861'

In [75]:
f"CER Improvement: {round((11.1861-6.3847)/11.1861 * 100, 2)}%"

'CER Improvement: 42.92%'