In [45]:
from transformers import AutoTokenizer, BartModel, BartForConditionalGeneration

In [46]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = BartModel.from_pretrained("facebook/bart-base")

In [47]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state
# Sequence of hidden-states at the output of the last layer of the decoder of the model.
print(last_hidden_states.shape)  # batch_size, sequence_length, hidden_size

torch.Size([1, 8, 768])


In [48]:
model_cg = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

ARTICLE_TO_SUMMARIZE = (
    "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt", truncation=True)

# Generate Summary
summary_ids = model_cg.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

'PG&E stated it scheduled the blackouts in response to forecasts for high winds amid'

In [49]:
TXT = "My friends are <mask> but they eat too many carbs."
input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
# Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax
model_result = model_cg(input_ids)  # labels为None
print(model_result.loss)  # None(labels为None)
logits = model_result.logits
print(logits.shape)  # batch_size, sequence_length, config.vocab_size

None
torch.Size([1, 13, 50265])


In [50]:
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
print(input_ids[0])
print(tokenizer.mask_token_id)
print(masked_index)  # 4
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)
print(predictions)  # 最大的的5个元素的索引

tokenizer.decode(predictions).split()

tensor([    0,  2387,   964,    32, 50264,    53,    51,  3529,   350,   171,
        33237,     4,     2])
50264
4
tensor([  45,  205, 2245,  372,  182])


['not', 'good', 'healthy', 'great', 'very']