# Try SpanBERT, then work on multitoken

## trying SpanBERT

In [1]:
from typing import List

from functional import pseq, seq
import torch

In [3]:
from transformers import (
    BertForMaskedLM,
    BertTokenizer,
    DistilBertForMaskedLM,
    DistilBertTokenizer,
)

In [5]:
bert = BertForMaskedLM.from_pretrained('bert-base-uncased') # have to pass the directory!!! ARG

Downloading:   0%|          | 0.00/433 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
tok_path = 'bert-base-cased' # uses same tokenizer in bert/spanbert

tokenizer = BertTokenizer.from_pretrained(tok_path)

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

In [7]:
bert.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [8]:
from fitbert import FitBert

fb = FitBert(model=bert, tokenizer=tokenizer, disable_gpu=True)
# Note:
# I'm SURE I had a reason to disable_gpu... but I wish I'd left a fucking note as to why it was...

device: cpu
using custom model: ['BertForMaskedLM']


In [9]:
# please be true...
fb.bert == bert

True

In [10]:
fb.device

device(type='cpu')

In [11]:
# the way FB currently works, this just looks at the first token

fb.rank("the first Star Wars came ***mask*** 1977", ["out in", "to in", "out of the closet in"])

['to in', 'out in', 'out of the closet in']

### Rewrite rank_multi to use tensors

#### THIS IS THE MAIN FOCUS OF THIS NOTEBOOK... it happens to use spanbert, but I don't think that is important

However this implementation uses loops, and should be all tensor ops

In [12]:
# looking at the code, rank_multi is not gonna work
# so rewrite?

def new_rank_multi(self, masked_sent: str, words: List[str]):

    words_ids = [ self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(lst)) for lst in words ]

    print("word ids: ", words_ids)

    lens = [ len(x) for x in words_ids ]

    print("lengths of each list in word ids: ", lens)

    final_ranked_options = []
    final_ranked_options_prob = []

    pre, post = masked_sent.split(self.mask_token)

    if post[-1] not in [".", ",", "?", "!", ";", ":"]:
        post += "."

    if all([x == 1 for x in lens]):
        # this is just rank_single for inspiration
        tokens = ["[CLS]"] + self.tokenizer.tokenize(pre)
        target_idx = len(tokens)
        tokens += ["[MASK]"]
        tokens += self.tokenizer.tokenize(post)

        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        tens = torch.tensor(input_ids).unsqueeze(0)
        tens = tens.to(self.device)
        with torch.no_grad():
            preds = self.bert(tens)[0]
            probs = self.softmax(preds)

            ranked_pairs = (
                seq(words_ids)
                .map(lambda x: float(probs[0][target_idx][x].item()))
                .zip(words)
                .sorted(key=lambda x: x[0], reverse=True)
            )

            ranked_options = (seq(ranked_pairs).map(lambda x: x[1])).list()
            ranked_options_prob = (seq(ranked_pairs).map(lambda x: x[0])).list()

            del tens, preds, probs, tokens, words_ids, input_ids
            if self.device == "cuda":
                torch.cuda.empty_cache()
            return ranked_options, ranked_options_prob
    else:
        for words_idx, mask_len in enumerate(lens):
            # FUCK
            # this shouldn't be a loop, it should be one big tensor [len(word_ids), num_masked_tokens, vocab_size]
            # might need to pad so when num_masked_tokens is less than the longest mask, they all end up the same shape
            #
            # actually, it should be even bigger, because it should be batched,
            # [batch_size, len(word_ids), num_masked_tokens, vocab_size]
            print(f"mask len = {mask_len}")
            
            tokens = ["[CLS]"] + self.tokenizer.tokenize(pre)
            target_idx_start = len(tokens)
            target_idx_end = target_idx_start + mask_len
            tokens += ["[MASK]"] * mask_len
            tokens += self.tokenizer.tokenize(post)  # no [SEP] b/c SpanBERT doesn't use
            print("there are this many tokens ", len(tokens))
            print("they are ", tokens)
            
            input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
            tens = torch.tensor(input_ids).unsqueeze(0)
            tens = tens.to(self.device)
            with torch.no_grad():
                # @todo don't [0], that assumes batch_size == 1
                preds = self.bert(tens)[0]
                # @TODO don't softmax yet maybe... ok it seems to work. Maybe softmax again at the end?
                probs = self.softmax(preds)

                # @TODO when this is all one batch instead of a loop, this will have to be matrix multiplication
                # start and end will be different depending on the mask length
                # so need to construct a sparse matrix to use to multiply out the values desired (eg the indecise that were masked)
                masked_probs = probs[0][target_idx_start : target_idx_end]

                # masked_probs has size [num_masked_tokens, vocab_size]

                print(f"the masked probs are {masked_probs} \n and its shape is {masked_probs.shape}")

                # want to pick out the probs corresponding to the word ids

                assert masked_probs.shape[0] == mask_len, "there is a row for each word id"

                a = torch.zeros_like(masked_probs)
                
                for i, word_id in enumerate(words_ids[words_idx]):
                    a[i][word_id] = 1

                a = torch.transpose(a, 0, 1)

                print("a's shape is ", a.shape)

                mm = torch.matmul(masked_probs, a)

                print("mm result: ", mm)

                # only care about the diagonal values on mm (this was confusing, but I think is right)
                word_probs = torch.diag(mm)
                # why product? because a long span can have one very likely word, which throws off max and avg too much
                span_prob = torch.prod(word_probs).item()

                print("span probs: ", span_prob, "... words: ", words[words_idx])

                final_ranked_options.append(words[words_idx])
                final_ranked_options_prob.append(span_prob)
        print(sorted(zip(final_ranked_options_prob, final_ranked_options), reverse=True))
        final_ranked_options_prob, final_ranked_options = zip(*sorted(zip(final_ranked_options_prob, final_ranked_options), reverse=True))
        return final_ranked_options, final_ranked_options_prob

In [13]:
mask_opts, mask_probs = new_rank_multi(fb, "the first Star Wars came ***mask*** 1977", ["out in", "to in", "from mars to earth in"])

word ids:  [[1149, 1107], [1106, 1107], [1121, 12477, 1733, 1106, 4033, 1107]]
lengths of each list in word ids:  [2, 2, 6]
mask len = 2
there are this many tokens  10
they are  ['[CLS]', 'the', 'first', 'Star', 'Wars', 'came', '[MASK]', '[MASK]', '1977', '.']
the masked probs are tensor([[3.0007e-07, 3.0925e-07, 2.9548e-07,  ..., 6.1910e-07, 7.8386e-07,
         4.3798e-06],
        [2.8629e-07, 2.8874e-07, 2.7992e-07,  ..., 5.0310e-07, 8.5147e-07,
         3.7964e-06]]) 
 and its shape is torch.Size([2, 30522])
a's shape is  torch.Size([30522, 2])
mm result:  tensor([[1.2060e-05, 3.8723e-06],
        [1.2404e-05, 2.1813e-06]])
span probs:  2.630639976686222e-11 ... words:  out in
mask len = 2
there are this many tokens  10
they are  ['[CLS]', 'the', 'first', 'Star', 'Wars', 'came', '[MASK]', '[MASK]', '1977', '.']
the masked probs are tensor([[3.0007e-07, 3.0925e-07, 2.9548e-07,  ..., 6.1910e-07, 7.8386e-07,
         4.3798e-06],
        [2.8629e-07, 2.8874e-07, 2.7992e-07,  ..., 5.0

In [14]:
# Not sure which is the best format for these... I think scores?

print(mask_probs)
scores = [x / max(mask_probs) for x in mask_probs ]
print(scores)

print(fb.softmax(torch.tensor(scores).unsqueeze(0)))

(2.630639976686222e-11, 2.1699878005598805e-11, 4.361758114092486e-33)
[1.0, 0.8248896921628104, 1.6580596937430137e-22]
tensor([[0.4531, 0.3803, 0.1667]])


In [15]:
mask_opts

('out in', 'to in', 'from mars to earth in')

In [16]:
rm = lambda x, y: new_rank_multi(fb, x, y)
fb.rank_multi = rm

In [17]:
fb.fitb("The first Star Wars came ***mask*** 1977", ["to from", "out in", "out of the closet in"])

word ids:  [[1149, 1107], [1106, 1121], [1149, 1104, 1103, 9369, 1107]]
lengths of each list in word ids:  [2, 2, 5]
mask len = 2
there are this many tokens  10
they are  ['[CLS]', 'The', 'first', 'Star', 'Wars', 'came', '[MASK]', '[MASK]', '1977', '.']
the masked probs are tensor([[3.8999e-07, 3.9552e-07, 3.7480e-07,  ..., 5.6233e-07, 8.2060e-07,
         5.2370e-06],
        [3.9211e-07, 3.8405e-07, 3.7379e-07,  ..., 5.3387e-07, 8.4046e-07,
         4.8882e-06]]) 
 and its shape is torch.Size([2, 30522])
a's shape is  torch.Size([30522, 2])
mm result:  tensor([[6.0911e-06, 3.5360e-06],
        [6.2990e-06, 2.8287e-06]])
span probs:  1.7230000412538082e-11 ... words:  out in
mask len = 2
there are this many tokens  10
they are  ['[CLS]', 'The', 'first', 'Star', 'Wars', 'came', '[MASK]', '[MASK]', '1977', '.']
the masked probs are tensor([[3.8999e-07, 3.9552e-07, 3.7480e-07,  ..., 5.6233e-07, 8.2060e-07,
         5.2370e-06],
        [3.9211e-07, 3.8405e-07, 3.7379e-07,  ..., 5.3387e-0

'The first Star Wars came out in 1977'

In [18]:
[fb.tokenizer.ids_to_tokens[wid] for wid in [1121, 12477, 1733, 1106, 4033, 1107]]

['from', 'ma', '##rs', 'to', 'earth', 'in']