In [None]:
from transformers import BertForTokenClassification, BertTokenizer

def load_model_and_tokenizer_for_inference():

  model = BertForTokenClassification.from_pretrained("/content/drive/My Drive/SDE_bert_model")
  tokenizer = BertTokenizer.from_pretrained("/content/drive/My Drive/SDE_bert_tokenizer")

  device = "cuda" if torch.cuda.is_available() else "cpu"
  model.to(device)
  model.eval()

  print("Model and tokenizer loaded successfully for inference!")
  return model, tokenizer

In [None]:
def create_mappings(corrected_mapping):

  reverse_entity_mapping = {}
  for key, val in corrected_mapping.items():
    reverse_entity_mapping[val] = key

  base_entity_to_id = {"BUYER NAME":0, "SELLER NAME":1,"BUYER ADDRESS":2, "BUYER ORG":3, "SELLER ORG":4, "SELLER ADDRESS":5, "APN":6}
  n = len(base_entity_to_id)
  base_entity_misclassifications = np.zeros((n,n))
  correct_classification = np.zeros(n)


  reverse_base_entity_to_id = {value:key for key, value in base_entity_to_id.items()}

  return reverse_entity_mapping, base_entity_to_id, reverse_base_entity_to_id, base_entity_misclassifications, correct_classification


In [None]:
import numpy as np


def find_entities_in_text(labels):
    masks = []
    entity_names = []

    for i, label in enumerate(labels):
        if i > 0 and labels[i] != 'O' and labels[i][2:] != labels[i-1][2:]:
            entity_end = i
            entity_names.append([labels[i][2:]])

            while entity_end < len(labels) and labels[entity_end] != 'O':
                entity_end += 1

            mask = ([1] * (entity_end - i)) + ([0] * (len(labels) - entity_end + i))
            mask = np.array(np.roll(mask, shift=i))
            masks.append(mask)

    return np.array(masks), entity_names



def show_sample(input_ids, masks, entity_names):

    all_entities = []
    reconstructed_text = []
    for i in range(len(masks)):

        all_entities.append(entity_names[i])
        boolean_mask = np.array(masks[i,:], dtype=bool)
        tokens = tokenizer.convert_ids_to_tokens(np.array(input_ids)[boolean_mask])
        reconstructed_text.append(" ".join(tokens).replace(" ##", ""))

    return all_entities, reconstructed_text



In [None]:
import pandas as pd
from sklearn.metrics import classification_report
from difflib import SequenceMatcher


def add_to_list(data, entity_map, entity_name, value):

  if entity_name not in entity_map:

    data.append((entity_name, value))
    entity_map[entity_name] = len(data) - 1

  else:

    entity_index = entity_map[entity_name]
    original_entry =  data[entity_index]
    data[entity_index] = (entity_name, original_entry[1]+value)

  return data, entity_map




def perform_inference(dataset, corrected_mapping):
    reverse_entity_mapping, base_entity_to_id, reverse_base_entity_to_id, base_entity_misclassifications, correct_classification = create_mappings(corrected_mapping)

    csv_data = []
    all_ground_truths = []
    all_predictions = []
    model.eval()

    result = []
    actual_document_level_entities = []
    predicted_document_level_entities = []
    actual_entity_map = {}
    predicted_entity_map = {}


    with torch.no_grad():
        for batch in dataset:
            for index, (sequence_ids, attention_mask, labels) in enumerate(zip(
                batch["input_ids"], batch["attention_mask"], batch["labels"]
            )):
                sequence_ids = sequence_ids.to(device).unsqueeze(0)
                attention_mask = attention_mask.to(device).unsqueeze(0)
                labels = labels.to(device).unsqueeze(0)
                is_last_chunk = batch["is_last_chunk"][index].numpy()

                outputs = model(input_ids=sequence_ids, attention_mask=attention_mask)
                logits = outputs.logits

                predictions = torch.argmax(logits, dim=-1).squeeze(0)

                valid_indices = labels.squeeze(0) != -100
                filtered_predictions = predictions[valid_indices].cpu().numpy()
                filtered_labels = labels.squeeze(0)[valid_indices].cpu().numpy()

                all_predictions.extend(filtered_predictions)
                all_ground_truths.extend(filtered_labels)

                predictions_mapped = list(map(lambda x: reverse_entity_mapping[x], filtered_predictions))
                labels = labels.tolist()[0]
                ground_truth_mapped = list(map(lambda x: reverse_entity_mapping[x], labels))



                predicted_masks, predicted_entity_names = find_entities_in_text(predictions_mapped)
                predicted_entities, predictions = show_sample(sequence_ids.squeeze(0).cpu().numpy(), predicted_masks, predicted_entity_names)

                actual_masks, actual_entity_names = find_entities_in_text(ground_truth_mapped)
                actual_entities, ground_truth = show_sample(sequence_ids.squeeze(0).cpu().numpy(), actual_masks, actual_entity_names)


                for actual_entity, gt in zip(actual_entities, ground_truth):
                  actual_document_level_entities, actual_entity_map = add_to_list(actual_document_level_entities, actual_entity_map, actual_entity[0], gt )

                for predicted_entity, prediction in zip(predicted_entities, predictions):
                  predicted_document_level_entities, predicted_entity_map = add_to_list(predicted_document_level_entities, predicted_entity_map, predicted_entity[0], prediction)

                if is_last_chunk == 1:

                  result.append([actual_document_level_entities, predicted_document_level_entities])
                  actual_document_level_entities = []
                  predicted_document_level_entities = []
                  actual_entity_map = {}
                  predicted_entity_map = {}


                if len(predicted_masks) > 0:
                    for predicted_mask, predicted_entity_name, actual_mask, actual_entity_name in zip(predicted_masks, predicted_entity_names, actual_masks, actual_entity_names):
                        predicted_entity_id = base_entity_to_id[predicted_entity_name[0]]
                        actual_entity_id = base_entity_to_id[actual_entity_name[0]]

                        overlapping_indices = [predicted_mask[i] & actual_mask[i] for i, _ in enumerate(predicted_mask)]

                        if len(overlapping_indices) >= 0.66 * np.sum(predicted_mask) :

                            if predicted_entity_name != actual_entity_name:
                                base_entity_misclassifications[predicted_entity_id][actual_entity_id] += 1
                            else:
                              correct_classification[predicted_entity_id] += 1

                        else:
                            base_entity_misclassifications[predicted_entity_id][actual_entity_id] += 1


    target_names = [reverse_entity_mapping[i] for i in range(len(reverse_entity_mapping))]

    print("\nClassification Report:")
    print(classification_report(all_ground_truths, all_predictions, target_names=target_names))

    return result, base_entity_misclassifications, correct_classification, base_entity_to_id, reverse_base_entity_to_id
