<a href="https://colab.research.google.com/github/matthewleechen/woodcroft_patents/blob/main/ner/notebooks/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is designed to be run on Google Colab. You can run it locally but you will need to check dependencies carefully. 

This notebook is designed for running inference using the fine-tuned weights from any model in your directory containing the model weights and configuration files. 

In [1]:
%%capture
!pip install transformers

In [2]:
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
import numpy as np
import csv
import json
import os
from tqdm import tqdm

You will need to upload the output files generated by using the save_pretrained method after fine tuning to a directory if they are not in Google Drive.

In [4]:
# Load the saved model weights and configuration
model_path = "/path/to/model/weights"
model = AutoModelForTokenClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

This code below runs inference on the merged text boxes (each is referred to as a "sentence") using the HuggingFace Pipelines API - documentation is linked [here](https://huggingface.co/docs/transformers/main_classes/pipelines).

This assumes that you have an input directory consisting of .txt files, and have an output directory that you want .csv files to exported to. The output is a csv file containing the labelled classes as columns (with each entity outputted to a separate column for the classes "PER", "LOC" and "OCC"), and each patent being recorded as a row (observation). You may need to modify this code depending on your desired output format.

Running this code on a cheap GPU is strongly recommended. A Nvidia Tesla T4 GPU (provided on the Colab free plan) is orders of magnitude faster than using the CPU. On the T4, inference on 1000 patents takes approximately 15-20 seconds, but several hours on the Colab CPU.

In [None]:
# Set input and output directories
input_dir = "/path/to/input/dir"
output_dir = "/path/to/output/dir"

In [None]:
# Move model to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
# Deploy Pipeline API: device = 0 for GPU, device = -1 is default (for CPU)
pipe = pipeline(task="token-classification", model=model, device = 0, tokenizer=tokenizer, aggregation_strategy="simple")

In [None]:
# Loop over all .txt files in input directory
for filename in tqdm(os.listdir(input_dir)):
    if filename.endswith(".txt"):
        # Specify file paths
        input_path = os.path.join(input_dir, filename)
        output_path = os.path.join(output_dir, filename[:-4] + ".csv")  # Remove .txt extension and add .csv extension

        # Read sentences from text file
        with open(input_path, "r") as f:
            sentences = f.read().split("\n\n")

        # Create list of dictionaries to store entities for each sentence
        all_entities = []
        fieldnames = ["NUM", "PER", "DATE", "LOC", "COMM", "OCC", "MISC", "INFO"]  # Specify the fieldnames

        for sentence in sentences:
            # Extract entities
            combined_entities = {}
            for entity in pipe(sentence):
                entity_group = entity['entity_group']
                word = entity['word']
                if entity_group not in combined_entities:
                    combined_entities[entity_group] = []
                combined_entities[entity_group].append(word)

            # Create a new dictionary to store the updated entities
            updated_entities = {}
            for entity_group, words in combined_entities.items():
                if entity_group in ["PER"]:
                    for i, word in enumerate(words):
                        column_name = f"{entity_group}_{i + 1}"
                        if column_name not in fieldnames:
                            updated_entities[column_name] = word
                else:
                    updated_entities[entity_group] = '& '.join(words)

            # Add updated entities for this sentence to list
            all_entities.append(updated_entities)

        # Update fieldnames to include separate columns for PER, LOC, and OCC
        max_columns = {key: 0 for key in ["PER"]}
        for entity_dict in all_entities:
            for key in max_columns.keys():
                max_columns[key] = max(max_columns[key], len([k for k in entity_dict.keys() if k.startswith(key)]))

        for key, count in max_columns.items():
            for i in range(count):
                column_name = f"{key}_{i + 1}"
                if column_name not in fieldnames:
                    fieldnames.append(column_name)

        # Write entities to CSV file
        with open(output_path, 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(all_entities)
