In [1]:
import transformers
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
# see ``examples/summarization/bart/run_eval.py`` for a longer example
bart_cnn = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
bart_xsum = BartForConditionalGeneration.from_pretrained('facebook/bart-large-xsum')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
lm_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", force_bos_token_to_be_generated=True)


In [2]:


def run_lm(model, tokenizer, device, sum_prefix="", topk=10):
    sum_prefix = sum_prefix.strip()
    # Mask filling only works for bart-large
    TXT = f"{sum_prefix}<mask> "    # we basically remove all of the last step cases.
    input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'].to(device)
    logits = model(input_ids, return_dict=True)['logits']
    masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
    probs = logits[0, masked_index].softmax(dim=0)
    values, predictions = probs.topk(topk)
    top_output = tokenizer.decode(predictions).split()
    return top_output[0], probs


def run_full_model(model, tokenizer, input_text, sum_prefix, encoder_outputs=None, device='cuda:0', output_attentions=False, output_dec_hid=False):

    if not encoder_outputs:
        inputs = tokenizer(input_text, max_length=300,
                           return_tensors='pt', truncation=True, padding=True)
        encoder_outputs = model.model.encoder(
            inputs['input_ids'].to(device), return_dict=True)

    sum_prefix = sum_prefix.strip()
    batch_size = len(input_text)
    decoder_input_ids = torch.LongTensor(tokenizer.encode(
        sum_prefix, return_tensors='pt')).to(device)
    decoder_input_ids = decoder_input_ids.expand(
        (batch_size, decoder_input_ids.size()[-1]))
    # ATTN: remove the EOS token from the prefix!
    decoder_input_ids = decoder_input_ids[:, :-1]

    model.output_attentions = output_attentions
    model_inputs = {"input_ids": None,
                    "past_key_values": None,
                    "encoder_outputs": encoder_outputs,
                    "decoder_input_ids": decoder_input_ids,
                    }
    outputs = model(**model_inputs, output_attentions=output_attentions, output_hidden_states=output_dec_hid,
                    use_cache=False, return_dict=True)

    if output_attentions:
        # use cross attention as the distribution
        # last layer.   batch=1, head, dec len, enc len
        # by default we use the last layer of attention
        cross_attns = outputs['cross_attentions'][-1]
        attn = cross_attns[0, :, -1, :]    # head, enc len

        mean_attn = torch.mean(attn, dim=0)
        assert len(mean_attn.size()) == 1

        topk = min(30, mean_attn.size()[0])

        values, indices = torch.topk(mean_attn, k=topk)
        values = values.detach().cpu().tolist()
        indices = indices.detach().cpu().tolist()
        input_ids_list = inputs['input_ids'][0].tolist()  # batch=1, enc len
        output = tokenizer.decode(
            int(input_ids_list[indices[0]]))
        logging.info(f"Most attention: {output}")
        p_list = [0.0 for _ in range(tokenizer.vocab_size)]
        for v, i in zip(values, indices):
            p_list[input_ids_list[i]] += v
        p = torch.as_tensor(p_list, device=device)
        p = p / p.sum()
        return output, p
    else:
        next_token_logits = outputs.logits[:, -1, :]
        prob = next_token_logits.softmax(dim=-1)
        next_token = torch.argmax(next_token_logits, dim=-1)
        # next_token = next_token.unsqueeze(-1)
        next_token = next_token.tolist()

        output = [tokenizer.decode(tk) for tk in next_token]
        logging.info(f"Next token: {output}")
        outputs['output'] = output
        return outputs, prob


In [3]:
import torch
device = torch.device('cpu')
import logging
logger = logging.getLogger('sum')

ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
logger.addHandler(ch)

def pnum(num):
    return "{:.2f}".format(num)

def show_top_k(prob, prefix, name, tokenizer, k=5):
    prob = prob.squeeze()
    topk_v, top_idx = torch.topk(prob, k=k)
    index = top_idx.tolist()
    toks = [tokenizer.decode(i) for i in index]
    print(f"Type: {name}")
    for i, t in enumerate(toks):
        print(f"{i}: {pnum(topk_v[i].item())} {prefix}{t}")


In [4]:
prefix = "Bri"
# input="Brisk is a tea and juice brand managed by the Pepsi Lipton Partnership."
input="Whether you’re only just thinking about getting engaged (or hinting at it and looking for the perfect ring), in the throes of wedding planning, or already navigating newlywed life, Brides is here to inspire, guide, and entertain you during this exciting, and trying, time. Looking for the most gorgeous wedding dress? "
# input=""


In [5]:
_, lm_prob = run_lm(lm_model,tokenizer,device, prefix)
show_top_k(lm_prob,prefix,'lm',tokenizer)

_, xsum_prob = run_full_model(bart_xsum, tokenizer,[input],prefix,device=device)
show_top_k(xsum_prob,prefix,'xsum',tokenizer)


_, cnn_prob = run_full_model(bart_cnn, tokenizer,[input],prefix,device=device)
show_top_k(cnn_prob,prefix,'cnn',tokenizer)

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729096996/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
Type: lm
0: 0.02 Bri B
1: 0.01 Bri R
2: 0.01 Bri H
3: 0.01 Bri L
4: 0.01 Bri D
Type: xsum
0: 0.35 Brianna
1: 0.13 Briana
2: 0.02 Bribery
3: 0.02 Briar
4: 0.02 Briley
Type: cnn
0: 0.23 Bripe
1: 0.15 Brips
2: 0.09 Bri.
3: 0.04 Bripping
4: 0.03 Brivers


In [15]:

x = np.zeros(5)
x[1] = 10
x[2] = 11

eye = np.eye(5)
eye[1][1] = 0
eye[1][2] = 1


print(eye)
print(x)
np.matmul(x,eye)


[[1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]]
[ 0. 10. 11.  0.  0.]


array([ 0.,  0., 21.,  0.,  0.])

In [26]:
# tokenizer.convert_tokens_to_ids(['x',' x','x'])
# define a new matrix to get the mapping of all "_x" to "x"

def init_vocab_distb_fix(tokenizer)->torch.Tensor:
    trans_mat = np.eye(tokenizer.vocab_size)
    cnt=0
    for vocab_idx in range(tokenizer.vocab_size):
        tok = tokenizer.convert_ids_to_tokens(vocab_idx)
        if tok.startswith('Ġ'):
            no_space_tok = tok[1:]
            no_space_id = tokenizer.convert_tokens_to_ids(no_space_tok)
            if no_space_id == 3:
                continue
            logging.debug(f"{vocab_idx}:{tok} -> {no_space_id}:{tokenizer.convert_ids_to_tokens(no_space_id)}")
            trans_mat[vocab_idx][vocab_idx] = 0
            trans_mat[vocab_idx][no_space_id] = 1
            cnt+=1
    logging.info(f"Lines of change: {cnt}")
    return torch.from_numpy(trans_mat)
