## import statements

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    StoppingCriteria,
    StoppingCriteriaList,
)

import torch

import json, re, ast

from pprint import pprint
from typing import Dict, Any

## model and tokenizer loading

In [None]:
model_name = "nis12ram/HindiNER-4B-v0.0"  ## or "nis12ram/HindiNER-4B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
)

## prompt

In [None]:
ner_user_msg = '''You are a Hindi language expert who specializes in extracting entities from text. Given a piece of text, extract all crucial entities along with their respective context-aware entity types. Ensure that entity type is in Hindi. The output should be in JSON format.

## Output format:
```json
{{
  "entities": [
    {{
      "type": "_",
      "value": ["_", "_"]
    }},
    {{
      "type": "_",
      "value": ["_"]
    }}
  ]
}}
```

## Text:
""" {text} """'''


prompt_format = """<extra_id_0>System

<extra_id_1>User
{user_msg}
<extra_id_1>Assistant
"""

## utils

In [None]:
## "nis12ram/HindiNER-4B-v0.0" and "nis12ram/HindiNER-4B-v0.1" needs stopping criteria to work efficiently.
class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops=[], encounters=1):
        super().__init__()

        self.stops = [stop.to("cuda") for stop in stops]


    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):

        last_token = input_ids[0][-1]

        for stop in self.stops:

            if tokenizer.decode(stop) == tokenizer.decode(last_token):

                return True

        return False



stop_words = ["<extra_id_1>"]

stop_words_ids = [
    tokenizer(stop_word, return_tensors="pt", add_special_tokens=False)[
        "input_ids"
    ].squeeze()
    for stop_word in stop_words
]

stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

In [None]:
def extract_json(text):
    # Regex to extract content between ```json and ```
    match = re.search(r"```json\s*([\s\S]*?)\s*```", text)

    if match:
        json_str: str = match.group(1).strip()
        try:
            ## for proper json structure
            return json.loads(json_str)
        except:
            try:
                ## for malformed json
                python_literal: Any = ast.literal_eval(
                    json_str
                )  ## any python literal(str or dict or list or tuple ,..)
                return json.loads(
                    json.dumps(
                        python_literal, ensure_ascii=False
                    )  ## json encoded string
                )  ## json data loaded as python literal
            except:
                print(f"NOT ABLE TO EXTRACT JSON DATA FROM TEXT: {text}")
                return None

    print(f"NOT ABLE TO EXTRACT JSON DATA FROM TEXT: {text}")
    return None

In [None]:
def inference(
    text: str, max_new_tokens: int = 1000, sampling_params: Dict[str, Any] = {}
):
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens,
        **sampling_params,
        stopping_criteria=stopping_criteria
    )
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()

    content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
    content = extract_json(content)
    return content

## main

In [7]:
text = """एअर इंडिया ने X पर बताया है कि गुजरात के अहमदाबाद एयरपोर्ट के पास गुरुवार को हुए विमान हादसे में फ्लाइट में सवार 242 लोगों में से 241 लोगों की मौत हो गई और सिर्फ एक शख्स जीवित बचा है। एअर इंडिया ने बताया, 'यात्रियों में 169 भारतीय नागरिक, 53 ब्रिटिश और 7 पुर्तगाली और 1 कनाडाई नागरिक था।'"""

In [None]:
input_text = prompt_format.format(user_msg=ner_user_msg.format(text=text))
json_output = inference(
    input_text, max_new_tokens=1000, sampling_params={"do_sample": False}
)
pprint(json_output)