In [5]:
# %pip install transformers[torch] datasets torch accelerate>=0.26.0

In [6]:
import torch
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification

In [7]:
# Load the fine-tuned tokenizer and model
def load_model_and_tokenizer(model_path):
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_path)
    model = DistilBertForTokenClassification.from_pretrained(model_path)
    return tokenizer, model

# Define the label list (same as the training script)
label_list = ["O","B-MOVIE", "I-MOVIE", ]

# Load tokenizer and model
model_path = "./custom_ner_model"
tokenizer, model = load_model_and_tokenizer(model_path)

# Define the prediction function
def predict(input_text, tokenizer, model, label_list):
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, is_split_into_words=False)

    # Run the model and get predictions
    with torch.no_grad():  # Turn off gradient calculation for inference
        outputs = model(**inputs)
        logits = outputs.logits

    # Get the predicted label for each token
    predictions = torch.argmax(logits, dim=2)

    # Convert token IDs to tokens
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    # Map the predictions to labels
    predicted_labels = [label_list[pred.item()] for pred in predictions[0]]

    # Display the results
    results = []
    for token, label in zip(tokens, predicted_labels):
        # Ignore special tokens such as [CLS] and [SEP]
        if token.startswith("▁") or token not in ["[CLS]", "[SEP]", "[PAD]"]:
            results.append((token, label))

    # Example formatted output
    output_text = ""
    for token, label in zip(tokens, predicted_labels):
        # Skip special tokens
        if token in ["[CLS]", "[SEP]", "[PAD]"]:
            continue

        if label != "O":
            output_text += f"[{token} ({label})] "
        else:
            output_text += f"{token} "

    return results, output_text

In [8]:
sentences = [
    "Who is the director of STAR WARS",
    "Who directed star wars",
    "Who is the screenwriter of the movie Inception",
    "Who wrote the screenplay for Inception",
    # given
    "Who is the director of Star Wars: Episode VI - Return of the Jedi ",
    "Who is the screenwriter of The Masked Gang: Cyprus ",
    "When was 'The Godfather' released ",
    "Recommend movies like Nightmare on Elm Street, Friday the 13th, and Halloween"
]

for sentence in sentences:
    results, formatted_output = predict(sentence, tokenizer, model, label_list)
    # print("\nResults:")
    # for token, label in results:
        # print(f"{token}: {label}")
    print(formatted_output)


who is the director of [star (B-MOVIE)] [wars (I-MOVIE)] 
who directed [star (B-MOVIE)] [wars (I-MOVIE)] 
who is the screenwriter of [the (B-MOVIE)] [movie (I-MOVIE)] [inception (I-MOVIE)] 
who wrote the screenplay for [inception (B-MOVIE)] 
who is [the (I-MOVIE)] director [of (I-MOVIE)] [star (B-MOVIE)] [wars (I-MOVIE)] [: (I-MOVIE)] [episode (I-MOVIE)] [vi (I-MOVIE)] [- (I-MOVIE)] [return (I-MOVIE)] [of (I-MOVIE)] [the (I-MOVIE)] [jedi (I-MOVIE)] 
who is the screenwriter of [the (B-MOVIE)] [masked (I-MOVIE)] [gang (I-MOVIE)] [: (I-MOVIE)] [cyprus (I-MOVIE)] 
when was [' (B-MOVIE)] [the (B-MOVIE)] [godfather (I-MOVIE)] [' (I-MOVIE)] released 
recommend movies like [nightmare (B-MOVIE)] [on (I-MOVIE)] [elm (I-MOVIE)] [street (I-MOVIE)] [, (I-MOVIE)] [friday (B-MOVIE)] [the (I-MOVIE)] [13th (I-MOVIE)] [, (I-MOVIE)] and [halloween (B-MOVIE)] 
