In [None]:
# | default_exp project.models.bert.infer
# %matplotlib inline

In [None]:
import collections
import json

# | export
import os
from random import (
    shuffle,
)

import torch
import torch.nn.functional as F
from project.models.bert.model import (
    BERT,
)
from project.utils.plot import (
    LivePlot,
)
from torch import (
    nn,
)
from torch.utils.data import (
    DataLoader,
    TensorDataset,
)
from transformers import (
    AdamW,
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
)


class BERTInference:
    def __init__(
        self,
        model_file,
        tokenizer,
        vocab_size,
        num_heads=6,
        n_layers=6,
        d_model=1000,
        max_length=1024,
    ):
        config = BertConfig(
            vocab_size=vocab_size,
            hidden_size=d_model,
            num_hidden_layers=n_layers,
            max_length=max_length,
            num_attention_heads=num_heads,
            intermediate_size=2048,
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=max_length,
            initializer_range=0.0001,
        )

        # Initialize the model
        model = BertForMaskedLM(config).cuda()

        # Load the state dictionary
        state_dict = torch.load(model_file)

        # Load this state dictionary into your BERT model
        model.load_state_dict(state_dict)

        self.model = model
        self.model.eval()
        self.tokenizer = tokenizer

    def inference(
        self,
        sentence,
        mask_token="[MASK]",
    ):
        # Tokenize input sentence and obtain token IDs
        inputs = self.tokenizer(
            sentence,
            return_tensors="pt",
        )

        # Move tokens to CUDA
        inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
        original_input_ids = inputs["input_ids"].clone()

        # Use model to get predictions
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = outputs.logits

        # Get the predicted token IDs for all masked positions
        mask_token_indices = torch.where(inputs["input_ids"] == self.tokenizer.mask_token_id)[1]
        predicted_token_ids = torch.argmax(
            predictions[
                0,
                mask_token_indices,
            ],
            dim=1,
        )

        # Decode and return the predicted tokens
        predicted_tokens = [self.tokenizer.decode(token_id.item()) for token_id in predicted_token_ids]

        # Replace [MASK] tokens with predicted tokens
        original_input_ids[
            0,
            mask_token_indices,
        ] = predicted_token_ids

        # Decode the modified tensor to get the complete sentence
        decoded_sentence = self.tokenizer.decode(original_input_ids[0])

        return decoded_sentence

In [None]:
max_length = 1024
d_model = 1000
num_heads = 10
gradient_accumulation_steps = 2
tokenizer = BertTokenizer(
    vocab_file=os.path.join(
        "../.tmp/data",
        "training",
        "vocab.txt",
    )
)
tokenizer.eos_token = "[SEP]"
tokenizer.bos_token = "[CLS]"
tokenizer.mask_token = "[MASK]"
tokenizer.unknown_token = "[UNK]"
tokenizer.pad_token = "[PAD]"
inference = BERTInference(
    "../.tmp/data/training/bert.pth",
    tokenizer,
    vocab_size=tokenizer.vocab_size,
    num_heads=num_heads,
    d_model=d_model,
    max_length=max_length,
)

inference.inference("1. p10 [MASK] p11 [MASK]")