In [1]:
import torch

from allennlp.training.metrics import Perplexity
from allennlp.data.iterators import BasicIterator
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import SimpleLanguageModelingDatasetReader
from allennlp.predictors import SimpleSeq2SeqPredictor

from adat.dataset import WhitespaceTokenizer
from adat.utils import load_weights
from adat.lm import get_basic_lm

## Load adversarial data

In [2]:
max_examples = 20000

original = []
adversarial = []

with open('val.12') as file:
    for i, line in enumerate(file):
        if i % 2 == 0:
            original.append(line.strip())
        else:
            adversarial.append(line.strip())
            
        if i > 2 * max_examples:
            break

In [3]:
len(original), len(adversarial)

(20001, 20001)

In [4]:
print(f'Original = {original[299]}\nAdversarial = {adversarial[299]}')

Original = a_876 a_877
Adversarial = a_776 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207 a_207


## Perplexity

In [5]:
reader = SimpleLanguageModelingDatasetReader(
    tokenizer=WhitespaceTokenizer(),
    max_sequence_length=None
)

In [6]:
vocab = Vocabulary.from_files('experiments/vocab')

model = get_basic_lm(vocab)

load_weights(model, 'experiments/model.th')

iterator = BasicIterator(batch_size=128)
iterator.index_with(vocab)

In [7]:
# predictor = SimpleSeq2SeqPredictor(model, reader)

In [8]:
original_instances = [reader.text_to_instance(t) for t in original]
adversarial_instances = [reader.text_to_instance(t) for t in adversarial]

model.training = False

In [9]:
perplexity_orig = Perplexity()

for i, x in enumerate(iterator(original_instances, num_epochs=1)):
    with torch.no_grad():
        average_loss = model(**x)['loss']
        perplexity_orig(average_loss)
    
print(f'Perplexity (Original) = {perplexity_orig.get_metric(False)}')

Perplexity (Original) = 47.74658966064453


In [10]:
perplexity = Perplexity()

for i, x in enumerate(iterator(adversarial_instances, num_epochs=1)):
    with torch.no_grad():
        average_loss = model(**x)['loss']
        perplexity(average_loss)
    
print(f'Perplexity (Adversarial) = {perplexity.get_metric(False)}')

Perplexity (Adversarial) = 2229.762939453125


## Another way


In [7]:
from adat.utils import calculate_perplexity

In [8]:
calculate_perplexity(original, model, reader, vocab)

47.63511657714844

In [9]:
calculate_perplexity(adversarial, model, reader, vocab)

2230.009521484375