In [8]:
import os
import csv
import json
import time
import pandas as pd
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

In [5]:
os.environ['TRANSFORMERS_CACHE'] = '/work/pi_dhruveshpate_umass_edu/abaranwal_umass_edu'
os.environ['HF_DATASETS_CACHE'] = '/work/pi_dhruveshpate_umass_edu/abaranwal_umass_edu'
os.environ['CONDA_ENVS_PATH'] = '/work/pi_dhruveshpate_umass_edu/abaranwal_umass_edu'
os.environ['CONDA_PKGS_DIRS'] = '/work/pi_dhruveshpate_umass_edu/abaranwal_umass_edu'
os.environ['HF_HOME'] = '/work/pi_dhruveshpate_umass_edu/abaranwal_umass_edu'

In [None]:
torch.cuda.get_device_name(0)

In [None]:
model_name = "Universal-NER/UniNER-7B-all"
# model_name = "google/gemma-2b"

# config = AutoConfig.from_pretrained(model_name)
# config.max_seq_len = 4096
# config.max_answer_len= 512

tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Model max length - {tokenizer.model_max_length}")
# tokenizer.model_max_length = 4096

model = AutoModelForCausalLM.from_pretrained(model_name,
                                             # max_length=4096,
                                             trust_remote_code=True,
                                             torch_dtype=torch.float16,
                                             device_map="auto",
                                             # load_in_8bit=True
                                            )
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    model.cuda()

## Prep Data

In [None]:
data_path = "/work/pi_dhruveshpate_umass_edu/project_19/astha/696DS-named-entity-extraction-and-linking-for-KG-construction/code/mc1/mc1_preprocess/"
file = "mc1_chunked_data.json"

with open(os.path.join(data_path, file), 'r') as f:
    data = json.load(f)

## Prompt

In [None]:
data_path = "/work/pi_dhruveshpate_umass_edu/project_19/ReDocREDPreprocessing/Re-DocRED/processed/"
file_v2 = "Re-DocRED_Processed_Train.csv"
    
def get_entity_example(idx, datatmp):
    allEntities = set()
    for item in datatmp["Triplets"][idx].split("\n"):
        item = item.split(" | ")
        if len(item) == 3:
            allEntities.add(item[0])
            allEntities.add(item[2])

    return datatmp["Text"][idx], '; '.join(list(allEntities))
    
def get_entities_prompt(text):
    
    data_v2 = pd.read_csv(os.path.join(data_path, file_v2), skiprows = range(1, 2000), nrows = 500)
    
    ex1, exout1 = get_entity_example(1, data_v2)
    ex2, exout2 = get_entity_example(2, data_v2)

    prompt=f'''Task: Please detect all the entities from the given input Text.
Entities could be people, organization, places, concepts, dates or any other proper nouns present in the text. \
Use the following examples as reference to understand the task. \
Give the output in the same format as given in the Example Entities Output, i.e., separated by a semicolon, ';'.

Example Text 1: {ex1}
Example Entities Output 1: {exout1}

Example Text 2: {ex2}
Example Entities Output 2: {exout2}

Text: {text}
Entities Output:'''
    
    return prompt

def few_shot_prompt_universalNER(entity_type, input_text):
    prompt = f"""
Given a Text, your task is to extract all entities based on the given Entity type.

Text: {input_text}

Entity Type: {entity_type}

Output: """

    return prompt

def few_shot_prompt_universalNER_v2(entity_type, input_text):
    prompt = f"""Text: {input_text}

What describes {entity_type} in the text?"""

    return prompt

def few_shot_prompt_universalNER_v3(entity_type, input_text):
    prompt = {"user": f"Text: {input_text}",
             "assistant": "I've read this text.",
             "user": f"What describes {entity_type} in the text?"}

    return prompt

In [None]:
sample_prompt = get_entities_prompt("text")
len(sample_prompt.split(" ")), print(sample_prompt)

## Predictions

In [None]:
max_tokens = 256

# lower the value, deterministic result
temperature = 0.1

# a higher value increases the chance of finding a better output
top_p = 0.9

In [None]:
output_file_name = "universalner_mc1_entities.csv"

In [None]:
from transformers import StoppingCriteria, StoppingCriteriaList

stop_list = ["Example", "Entity", "Text:", "##Your task", "Entities", "Output", "entity"]
stop_token_ids = [tokenizer(x,  return_tensors='pt', add_special_tokens=False)['input_ids'].to("cuda") for x in stop_list]

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_ids in stop_token_ids:
            if torch.eq(input_ids[0][-len(stop_ids[0]):], stop_ids[0]).all():
                return True
        return False


stopping_criteria = StoppingCriteriaList([StopOnTokens()])

In [None]:
output_data = []
debug = True

# entity_types = ["person", "organization", "location", "dates", "number", "product", "event", "language", "group", "misc"]
entity_types = ["person", "organization", "location", "datetime", "concept", "event", "group"]

all_keys = list(data['Documents'].keys())

for i in tqdm(range(len(all_keys))):
    
    doc_key = all_keys[i]
    
    if i <= 1:
        continue
    
    for item in data['Documents'][doc_key]:
    
        text = item['chunk_text']
        chunk_idx = item['chunk_index']
        
        # if chunk_idx == 0:
        #     continue
        
        tmp_output = {}
        
        if debug:
            print(f"\ntext : {text}")
        
        for ent_typ in entity_types:

            # Relation prediction from GT entities
            prompt = few_shot_prompt_universalNER_v2(ent_typ, text)
            
            # if debug:
            #     print(f"\nprompt : {prompt}")

            input_ids = tokenizer(prompt, return_tensors="pt").to("cuda") #,  max_length=4096, truncation=True

            start = time.time()
            outputs = model.generate(**input_ids,
                                     max_new_tokens=max_tokens,
                                     top_p=top_p,
                                     do_sample=True,
                                     temperature=temperature,
                                     pad_token_id=tokenizer.eos_token_id,
                                     # stopping_criteria=StoppingCriteriaList([stop_criteria]),
                                     # stopping_criteria=[Phi2StoppingCriteria()],
                                     stopping_criteria=stopping_criteria
                                     )
            time_diff = time.time() - start
            
            if debug:
                print(f"{ent_typ} : {output2}")

            # output2 = tokenizer.decode(outputs[0], skip_special_tokens=True)
            output2 = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
            
            tmp_output[ent_typ] = {"output": output2, "time_diff": time_diff,
                                   "input_tokens": len(input_ids[0]),
                                   "output_tokens":  len(outputs[0][len(input_ids[0]):])}

        # if debug:
        #     print(f"Output : {tmp_output}")

        # output_data.append([doc_key, chunk_idx, output2, time_diff,
        #                    len(input_ids[0]), len(outputs[0][len(input_ids[0]):])])
        output_data.append([doc_key, chunk_idx, tmp_output])

        if i % 10 == 0 and i > 1:
            with open(output_file_name, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerows(output_data)

        if debug: break
        
    if debug: break
    
    with open(output_file_name, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerows(output_data)

In [None]:
tmp_output

## Combine results

In [None]:
homeFolder = "/work/pi_dhruveshpate_umass_edu/project_19/aishwarya/696DS-named-entity-extraction-and-linking-for-KG-construction/code/phi2"

DataPath0 = os.path.join(homeFolder, "phi2_relations_gt_redocred_dev.csv")
DataPath1 = os.path.join(homeFolder, "phi2_relations_gt_redocred_dev_v2.csv")
DataPath2 = os.path.join(homeFolder, "phi2_relations_gt_redocred_dev_v3.csv")
DataPath3 = os.path.join(homeFolder, "phi2_relations_gt_redocred_dev_v4.csv")

Data0 = pd.read_csv(DataPath0, header=None)
Data1 = pd.read_csv(DataPath1, header=None)
Data2 = pd.read_csv(DataPath2, header=None)
Data3 = pd.read_csv(DataPath3, header=None)
Data = pd.concat([Data0, Data1, Data2, Data3], axis=0)

Data = Data.rename(columns={0: 'original_index', 1: 'entities', 2: 'latency',
                            3: 'input_tokens',  4: 'output_tokens'})

Data.to_csv('phi2_relations_gt_redocred_dev_combined.csv', index=False)

print(Data.shape, len(list(set(list(Data["original_index"])))))
Data.head()