In [24]:
import pandas as pd
import spacy

In [1]:

nlp = spacy.load('en_core_web_sm')

source = 'The city was brought to a standstill on 15 December last year when a gunman held 18 hostages for 17 hours. Family members of victims Tori Johnson and Katrina Dawson were in attendance. Images of the floral tributes that filled the city centre in the wake of the siege were projected on to the cafe and surrounding buildings in an emotional twilight ceremony. Prime Minister Malcolm Turnbull gave an address saying a "whole nation resolved to answer hatred with love". "Testament to the spirit of Australians is that with such unnecessary, thoughtless tragedy, an amazing birth of mateship, unity and love occurs. Proud to be Australian," he said. How the Sydney siege unfolded. New South Wales Premier Mike Baird has also announced plans for a permanent memorial to be built into the pavement in Martin Place. Clear cubes containing flowers will be embedded into the concrete and will shine with specialised lighting. It is a project inspired by the massive floral tributes that were left in the days after the siege. "Something remarkable happened here. As a city we were drawn to Martin Place. We came in shock and in sorrow but every step we took was with purpose," he said on Tuesday.'
prediction = 'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city.'

entities = nlp(prediction).to_json()['ents']
ent_text = [prediction[e['start']: e['end']] for e in entities]

print('entitites:')
print(type(entities))

print('ent text:')
print(type(ent_text))
print(ent_text)

  from .autonotebook import tqdm as notebook_tqdm


entitites:
<class 'list'>
ent text:
<class 'list'>
['Sydney', 'first', 'Waverley', 'two', 'Australian']


In [4]:
def prepare_clm_inputs(source, target, ent_parts=None):
    """For Masked Language Model. For BART only."""
    if ent_parts is None:
        ent_parts = nlp(target).to_json()['ents']
    
    inputs, targets = [], []
    positions, entities = [], []

    for e in ent_parts:
        inputs.append(target[0: e['start']] + '<mask>')
        targets.append(target[:e['end']])
        entities.append(target[e['start']: e['end']])
        positions.append((e['start'], e['end']))
    
    return inputs, targets, positions, entities

inputs = prepare_clm_inputs(source, prediction, ent_parts=entities)
inputs

(['<mask>',
  'Sydney has marked the <mask>',
  'Sydney has marked the first anniversary of the siege at the <mask>',
  'Sydney has marked the first anniversary of the siege at the Waverley cafe in which <mask>',
  'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the <mask>'],
 ['Sydney',
  'Sydney has marked the first',
  'Sydney has marked the first anniversary of the siege at the Waverley',
  'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two',
  'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian'],
 [(0, 6), (22, 27), (60, 68), (83, 86), (124, 134)],
 ['Sydney', 'first', 'Waverley', 'two', 'Australian'])

In [70]:
masked_input = inputs[0][2]
print(masked_input)
target = inputs[1][2]
print(target)

Sydney has marked the first anniversary of the siege at the <mask>
Sydney has marked the first anniversary of the siege at the Waverley


In [11]:
from transformers import BartTokenizer, BartForConditionalGeneration

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

In [71]:

input_ids = tokenizer(masked_input, return_tensors="pt")["input_ids"]
logits = model(input_ids).logits
print(logits.shape)


torch.Size([1, 16, 50265])


In [72]:
print('tokenized masked input shape:', input_ids.shape)

print('token ID -> token:')

for token_id in input_ids[0]:
    print(f"{token_id} | '{tokenizer.decode(token_id)}'")

tokenized masked input shape: torch.Size([1, 16])
token ID -> token:
0 | '<s>'
104 | 'S'
9611 | 'yd'
2596 | 'ney'
34 | ' has'
4760 | ' marked'
5 | ' the'
78 | ' first'
4038 | ' anniversary'
9 | ' of'
5 | ' the'
19951 | ' siege'
23 | ' at'
5 | ' the'
50264 | '<mask>'
2 | '</s>'


In [73]:
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
print('mask position in input sequence:', masked_index)

probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)

predicted_tokens = [token.replace(' ', '_') for token in tokenizer.decode(predictions).split()]
pd.DataFrame({'predicted_token_id': predictions, 'predicted_token': predicted_tokens, 'values': values.detach().numpy(),})

mask position in input sequence: 14


Unnamed: 0,predicted_token_id,predicted_token,values
0,4290,Sydney,0.049893
1,2059,Australian,0.039405
2,1082,site,0.027545
3,276,same,0.016303
4,3062,airport,0.014825


In [104]:
target_tokens = tokenizer.encode(target, return_tensors='pt')
target_tokens

tensor([[    0,   104,  9611,  2596,    34,  4760,     5,    78,  4038,     9,
             5, 19951,    23,     5,   305,  9903,   607,     2]])

In [106]:
target_tokens = tokenizer.encode(target, return_tensors='pt')
target_encoded_decoded = tokenizer.decode(target_tokens[0])
print('starting target:', target)
print('target tokens:', target_tokens)
print('[check] encoded-decoded target:', target_encoded_decoded)

starting target: Sydney has marked the first anniversary of the siege at the Waverley
target tokens: tensor([[    0,   104,  9611,  2596,    34,  4760,     5,    78,  4038,     9,
             5, 19951,    23,     5,   305,  9903,   607,     2]])
[check] encoded-decoded target: <s>Sydney has marked the first anniversary of the siege at the Waverley</s>


In [107]:
input_ids

tensor([[    0,   104,  9611,  2596,    34,  4760,     5,    78,  4038,     9,
             5, 19951,    23,     5, 50264,     2]])

In [108]:
target_tokens

tensor([[    0,   104,  9611,  2596,    34,  4760,     5,    78,  4038,     9,
             5, 19951,    23,     5,   305,  9903,   607,     2]])

In [120]:
input_ids[0][len(input_ids[0]) - 2 ]

tensor(50264)

In [121]:
target_tokens[0][len(input_ids[0]) - 2 ]

tensor(305)

In [93]:
len(input_ids[0])

16

In [91]:
tokenizer.decode(target_tokens[len(input_ids[0]) - 1])

'aver'

In [136]:
input_ids[0].shape

torch.Size([16])

In [132]:
target_tokens.shape

torch.Size([1, 18])

In [147]:
tokenizer.encode(' Waverly', add_special_tokens=False)

[305, 9903, 352]

In [148]:
tokenizer.decode(305)

' W'

In [238]:
len(inputs[0])

5

In [246]:
target_tokens[0].tolist()

[0,
 104,
 9611,
 2596,
 34,
 4760,
 5,
 78,
 4038,
 9,
 5,
 19951,
 23,
 5,
 305,
 9903,
 607,
 16381,
 11,
 61,
 80,
 2]

In [259]:
import torch

for named_entity_index in range(len(inputs[0])):
    masked_input = inputs[0][named_entity_index]
    print('masked input:', masked_input)
    target = inputs[1][named_entity_index]
    print('print target:', target)
    target_tokens = tokenizer.encode(target, return_tensors='pt')

    input_ids = tokenizer(masked_input, return_tensors="pt")["input_ids"]
    print('len input ids:', input_ids.shape[1])

    def prefix_allowed_tokens_fn(_batch_id, input_ids):
        current_step = len(input_ids) - 1
        return target_tokens[0, current_step].tolist()
        
    model_output = model.generate(
        input_ids,
        num_beams=1,
        early_stopping=True,
        return_dict_in_generate=True,
        output_scores=True,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    )

    print('len of target tokens:', len(target_tokens[0].tolist()))
    print('len of the model output scores:', len(model_output.scores))
    print('forced output:', tokenizer.batch_decode(model_output.sequences))

    for target_token, token_score in zip(target_tokens[0].tolist(), model_output.scores):
        target_parts = tokenizer.decode(target_token, add_special_tokens=False)
        token_score = token_score.softmax(dim=1)[0, target_token]
        print(f"{target_token} | '{target_parts}' | {token_score}")
    
    print('---------')


masked input: <mask>
print target: Sydney
len input ids: 3
len of target tokens: 5
len of the model output scores: 5
forced output: ['</s><s>Sydney</s>']
0 | '<s>' | 0.9999822378158569
104 | 'S' | 4.914681994705461e-06
9611 | 'yd' | 0.0021797779481858015
2596 | 'ney' | 0.9180679321289062
2 | '</s>' | 0.5756633877754211
---------
masked input: Sydney has marked the <mask>
print target: Sydney has marked the first
len input ids: 9
len of target tokens: 9
len of the model output scores: 9
forced output: ['</s><s>Sydney has marked the first</s>']
0 | '<s>' | 0.9999856948852539
104 | 'S' | 0.9999973773956299
9611 | 'yd' | 0.9999998807907104
2596 | 'ney' | 0.9999958276748657
34 | ' has' | 0.9999608993530273
4760 | ' marked' | 0.9995445609092712
5 | ' the' | 0.9999163150787354
78 | ' first' | 0.013117797672748566
2 | '</s>' | 0.003610642161220312
---------
masked input: Sydney has marked the first anniversary of the siege at the <mask>
print target: Sydney has marked the first anniversary of 

In [255]:
tokenizer.batch_decode(model_output.sequences)

['</s><s>Sydney</s>']

In [None]:
print('prob of Waverley:', 0.001133607467636466 * 0.08252919465303421 * 0.9853810667991638)

prob of Waverley: 9.218802666182273e-05


In [173]:
len(model_output.sequences[0])

19

In [169]:
len(model_output.scores)

18

In [48]:
tokenizer.batch_decode(input_ids)

['<s>Sydney has marked the first anniversary of the siege at the<mask></s>']

In [179]:
max_scores = [score.argmax().item() for score in model_output.scores]
print(max_scores)

[0, 104, 9611, 2596, 34, 4760, 5, 78, 4038, 9, 5, 19951, 23, 5, 4290, 5416, 607, 4679]


In [184]:
tokenizer.decode(max_scores)

'<s>Sydney has marked the first anniversary of the siege at the Sydneyadiley Bridge'

In [53]:
tokenizer.batch_decode(model_output.sequences)

['</s><s>Sydney has marked the first anniversary of the siege at the Sydney Opera House.</s>']

In [194]:
tokenizer.mask_token_id

50264

## Tokenization with `<mask>`

Prefixing `<mask>` with whitespace doesn't impact tokenization, it seems..

In [202]:
foo = "Sydney has marked the first anniversary of the siege at the <mask>"
encoded = tokenizer.encode(foo, add_special_tokens=False)
print(foo)
print(encoded)
print(f"--{tokenizer.decode(encoded, clean_up_tokenization_spaces=True)}--")

Sydney has marked the first anniversary of the siege at the <mask>
[104, 9611, 2596, 34, 4760, 5, 78, 4038, 9, 5, 19951, 23, 5, 50264]
--Sydney has marked the first anniversary of the siege at the<mask>--


In [203]:
foo = "Sydney has marked the first anniversary of the siege at the<mask>"
encoded = tokenizer.encode(foo, add_special_tokens=False)
print(foo)
print(encoded)
print(f"--{tokenizer.decode(encoded, clean_up_tokenization_spaces=True)}--")

Sydney has marked the first anniversary of the siege at the<mask>
[104, 9611, 2596, 34, 4760, 5, 78, 4038, 9, 5, 19951, 23, 5, 50264]
--Sydney has marked the first anniversary of the siege at the<mask>--
