In [1]:
import time
import re

import numpy as np
import transformers
import torch

DEVICE = 'mps'
BASE_MODEL_NAME = 'gpt2-medium'
MASK_FILLING_MODEL_NAME = 't5-large'
PROMPT_TOKENS = 30

In [2]:
real_text = \
''''Maj Richard Scott, 40, is accused of driving at speeds of up to 95mph (153km/h) in bad weather before the smash 
on a B-road in Wiltshire. Gareth Hicks, 24, suffered fatal injuries when the van he was asleep in was hit by Mr Scott\'s 
Audi A6. Maj Scott denies a charge of causing death by careless driving. Prosecutor Charles Gabb alleged the defendant, 
from Green Lane in Shepperton, Surrey, had crossed the carriageway of the 60mph-limit B390 in Shrewton near Amesbury. 
The weather was "awful" and there was strong wind and rain, he told jurors. He said Mr Scott\'s car was described as 
"twitching" and "may have been aquaplaning" before striking the first vehicle; a BMW driven by Craig Reed. Mr Scott\'s 
Audi then returned to his side of the road but crossed the carriageway again before colliding head-on with a Ford Transit 
van in which Mr Hicks was a passenger, the court was told. "There is no doubt that when the Audi smashed into the panel 
van he was on completely the wrong side of the road," Mr Gabb said. Mr Hicks, from Bath in Somerset, was asleep in the 
van being driven to a construction site in Salisbury by fellow DR Groundworks colleague, Patrick Gilleece. The jury was 
told the Maj Scott suffered "substantial injuries" and could not recall the crash, which happened shortly after 07:00 GMT 
on 6 October, 2014. He does not accept the charge and suggests it was in fact Mr Reed who had crossed the carriageway, 
causing the collision, Mr Gabb told the court. The trial continues.'''

In [9]:
def load_base_model():
    print('MOVING BASE MODEL TO GPU...', end='', flush=True)
    start = time.time()
    mask_model.cpu()
    base_model.to(DEVICE)
    print(f'DONE ({time.time() - start:.2f}s)')

def load_mask_model():
    print('MOVING MASK MODEL TO GPU...', end='', flush=True)
    start = time.time()
    mask_model.to(DEVICE)
    base_model.cpu()
    print(f'DONE ({time.time() - start:.2f}s)')

def trim_to_shorter_length(texta, textb):
    # truncate to shorter of o and s
    shorter_length = min(len(texta.split(' ')), len(textb.split(' ')))
    texta = ' '.join(texta.split(' ')[:shorter_length])
    textb = ' '.join(textb.split(' ')[:shorter_length])
    return texta, textb

def get_ll(text):
    with torch.no_grad():
        tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE)
        labels = tokenized.input_ids
        return -base_model(**tokenized, labels=labels).loss.item()


def tokenize_and_mask(text, span_length, pct, ceil_pct=False):
    tokens = text.split(' ')
    mask_string = '<<<mask>>>'

    n_spans = pct * len(tokens) / (span_length + 1 * 2)
    if ceil_pct:
        n_spans = np.ceil(n_spans)
    n_spans = int(n_spans)

    n_masks = 0
    while n_masks < n_spans:
        start = np.random.randint(0, len(tokens) - span_length)
        end = start + span_length
        search_start = max(0, start - 1)
        search_end = min(len(tokens), end + 1)
        if mask_string not in tokens[search_start:search_end]:
            tokens[start:end] = [mask_string]
            n_masks += 1
    
    # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
    num_filled = 0
    for idx, token in enumerate(tokens):
        if token == mask_string:
            tokens[idx] = f'<extra_id_{num_filled}>'
            num_filled += 1
    assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
    text = ' '.join(tokens)
    return text

def count_masks(texts):
    return [len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts]

def replace_masks(texts):
    n_expected = count_masks(texts)
    stop_id = mask_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
    tokens = mask_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
    outputs = mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=1, num_return_sequences=1, eos_token_id=stop_id)
    return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)

def extract_fills(texts, pattern=re.compile(r"<extra_id_\d+>")):
    # remove <pad> from beginning of each text
    texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts]

    # return the text in between each matched mask token
    extracted_fills = [pattern.split(x)[1:-1] for x in texts]

    # remove whitespace around each fill
    extracted_fills = [[y.strip() for y in x] for x in extracted_fills]

    return extracted_fills

def apply_extracted_fills(masked_texts, extracted_fills):
    # split masked text into tokens, only splitting on spaces (not newlines)
    tokens = [x.split(' ') for x in masked_texts]

    n_expected = count_masks(masked_texts)

    # replace each mask token with the corresponding fill
    for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):
        if len(fills) < n:
            tokens[idx] = []
        else:
            for fill_idx in range(n):
                text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx]

    # join tokens back into text
    texts = [" ".join(x) for x in tokens]
    return texts

def generate_perturbed_texts(text, n):
    perturbed_texts = []
    for _ in range(n):
        masked_sample = tokenize_and_mask(text, span_length=2, pct=0.3)
        raw_fills = replace_masks([masked_sample])
        extracted_fills = extract_fills(raw_fills)
        perturbed_text = apply_extracted_fills([masked_sample], extracted_fills)
        perturbed_texts.append(perturbed_text)
        pass
    return perturbed_texts

In [4]:
base_model = transformers.AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME)
base_tokenizer = transformers.AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
base_tokenizer.pad_token = base_tokenizer.eos_token # throws error if not set

mask_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(MASK_FILLING_MODEL_NAME)
mask_tokenizer = transformers.AutoTokenizer.from_pretrained(MASK_FILLING_MODEL_NAME)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [5]:
load_base_model()

# tokenize text and slice the first PROMPT_TOKENS
all_encoded = base_tokenizer(real_text, return_tensors="pt", padding=True).to(DEVICE)
all_encoded = {key: value[:, :PROMPT_TOKENS] for key, value in all_encoded.items()}

MOVING BASE MODEL TO GPU...DONE (0.87s)


In [6]:
sampling_kwargs = dict()
sampling_kwargs['top_p'] = 0.96
sampling_kwargs['top_k'] = 40
min_length = 50
max_length = 200

outputs = base_model.generate(**all_encoded, 
                              min_length=min_length, 
                              max_length=max_length, 
                              do_sample=True, 
                              **sampling_kwargs, 
                              pad_token_id=base_tokenizer.eos_token_id, 
                              eos_token_id=base_tokenizer.eos_token_id)
sampled_text = base_tokenizer.batch_decode(outputs, skip_special_tokens=True)
sampled_text = sampled_text[0]

In [7]:
data = {}
data['real'], data['sampled'] = trim_to_shorter_length(real_text, sampled_text)

In [8]:
# get likelihood of each text under base model
ll_real = get_ll(data['real'])
ll_sampled = get_ll(data['sampled'])

(ll_real, ll_sampled)

(-3.108694314956665, -1.8114635944366455)

In [13]:
load_mask_model()

MOVING MASK MODEL TO GPU...DONE (2.11s)


In [18]:
perturbed_texts = generate_perturbed_texts(data['sampled'], 10)

In [19]:
load_base_model()

MOVING BASE MODEL TO GPU...DONE (4.49s)


In [22]:
lls_perturbed = torch.tensor([get_ll(perturbed_text) for perturbed_text in perturbed_texts])

MU = (1/10) * torch.sum(lls_perturbed)

In [25]:
variance_norm = torch.sqrt((1/(10 - 1)) * torch.square(torch.sum(lls_perturbed - MU)))

In [26]:
variance_norm

tensor(8.3447e-07)

In [27]:
(ll_real - MU)/variance_norm

tensor(-1380807.2500)