In [1]:
import os
import csv
import time
import pandas as pd
from tqdm import tqdm

In [None]:
os.environ['HF_DATASETS_CACHE'] = '/work/pi_dhruveshpate_umass_edu/amalgonde_umass_edu'
os.environ['CONDA_PKGS_DIRS'] = '/work/pi_dhruveshpate_umass_edu/amalgonde_umass_edu'
os.environ['CONDA_ENVS_PATH'] = '/work/pi_dhruveshpate_umass_edu/amalgonde_umass_edu'
os.environ['TRANSFORMERS_CACHE'] = '/work/pi_dhruveshpate_umass_edu/amalgonde_umass_edu'
os.environ['HF_HOME'] = '/work/pi_dhruveshpate_umass_edu/amalgonde_umass_edu'

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

In [None]:
torch.cuda.get_device_name(0)

In [None]:
model_name = "microsoft/phi-2"
# model_name = "google/gemma-2b"

# config = AutoConfig.from_pretrained(model_name)
# config.max_seq_len = 4096
# config.max_answer_len= 512

tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Model max length - {tokenizer.model_max_length}")
# tokenizer.model_max_length = 4096

model = AutoModelForCausalLM.from_pretrained(model_name,
                                             # max_length=4096,
                                             trust_remote_code=True,
                                             torch_dtype=torch.float16,
                                             device_map="auto",
                                             # load_in_8bit=True
                                            )
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    model.cuda()

In [None]:
# # TEST

# input_text = "Write me a poem about Machine Learning."
# input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# # input_ids = tokenizer(input_text, return_tensors="pt", max_length=1024, padding='max_length').to("cuda")

# outputs = model.generate(**input_ids, max_new_tokens=1024)
# print(type(tokenizer.decode(outputs[0])), tokenizer.decode(outputs[0])) #.replace('<pad>', ''))

## Prep Data

In [4]:
data_path = "/work/pi_dhruveshpate_umass_edu/project_19/ReDocREDPreprocessing/Re-DocRED/processed/"
file = "Re-DocRED_Processed_Dev_EntitiesIncluded.csv"

data = pd.read_csv(os.path.join(data_path, file))
# data = data.rename(columns={'Unnamed: 0': 'original_index', 'index': 'original_index'})
print(data.shape)
data.head()

(500, 6)


Unnamed: 0.1,Unnamed: 0,index,Title,Text,Triplets,Entities
0,0,0,Willi Schneider (skeleton racer),"Wilfried "" Willi "" Schneider (born 13 March 19...",2002 Winter Olympics | start time | 2002\n2002...,German\nJeff Pain\nFIBT World Championships\n2...
1,1,1,Ross Alger,"Ross Patterson Alger (August 20, 1920 January...",Ross Patterson Alger | place of birth | Prelat...,Rod Sykes\nOlympic\nRoss Patterson Alger\nRoya...
2,2,2,Mess of Blues (Jeff Healey album),Mess of Blues is an album by Jeff Healey. It w...,Mess of Blues | publication date | 2008\nMess ...,Toronto\nDoc Pomus\nCanada\nStudio 92\nIslingt...
3,3,3,Ramey Idriss,Ramey Idriss (11 September 1911 5 February 19...,Wet Blanket Policy | publication date | 1948\n...,The Old Chaperone\nI 'll Wait\n11 September 19...
4,4,4,ELAM (Latin American School of Medicine) Cuba,"Escuela Latinoamericana de Medicina (ELAM), fo...",Latin American School of Medicine | country | ...,Guri\nCuba\nLatin America\nEscuela Latinoameri...


In [None]:
entityPredictions = "/work/pi_dhruveshpate_umass_edu/project_19/aishwarya/696DS-named-entity-extraction-and-linking-for-KG-construction/code/llama2/lamma2_entity_redocred_dev.csv"

dataEntity = pd.read_csv(entityPredictions, header=None)
dataEntity = dataEntity.rename(columns={0: 'original_index', 1: "Entities", 2: "InferenceTime"})
print(dataEntity.shape)
dataEntity.head()

In [None]:
completeData = pd.merge(data, dataEntity, on='original_index')
print(completeData.shape)
completeData.head()

## Prompt

In [None]:
data_path = "/work/pi_dhruveshpate_umass_edu/project_19/ReDocREDPreprocessing/Re-DocRED/processed/"
file_v2 = "Re-DocRED_Processed_Train.csv"
    
def get_example(idx, datatmp):
    allEntities = set()
    for item in datatmp["Triplets"][idx].split("\n"):
        item = item.split(" | ")
        # print(item, len(item))
        if len(item) == 3:
            allEntities.add(item[0])
            allEntities.add(item[2])

    allEntities = list(allEntities)
    # print(allEntities)
    return datatmp["Text"][idx], allEntities, datatmp["Triplets"][idx]
    # return datatmp["Text"][idx], "; ".join(allEntities), datatmp["Triplets"][idx]
    
def get_prompt(text, entities):
    
    data_v2 = pd.read_csv(os.path.join(data_path, file_v2), skiprows = range(1, 2000), nrows = 500)
    
    ex1, exent1, exout1 = get_example(100, data_v2)
    ex2, exent2, exout2 = get_example(200, data_v2)
    # ex3, exent3, exout3 = get_example(300, data_v2)
    # ex4, exent4, exout4 = get_example(400, data_v2)

    prompt=f'''Task Description:
The task is to extract Relations between the Entity List for given text, in the form of triplets. \
Extract triplets from the given Text based solely on the relationships present in the text. \
Ensure that entities are chosen directly from the provided Entity List to maintain accuracy. \
Avoid duplicating triplets in the output. Use the provided Example Text and Relations Output as references \
to understand how to identify meaningful relationships between entities from Entity List. \
Pay attention to all potential relations between all the entities and include them in the output.

Example Text 1: {ex1}
Entity List of Text 1: {exent1}
Relations Output of Text 1: {exout1}

Example Text 2: {ex2}
Entity List of Text 2: {exent2}
Relations Output of Text 2: {exout2}

Text: {text}
Entity List: {entities}
Relations Output:'''
    
#     Example Text 3: {ex3}
#     Entity List of Text 3: {exent3}
#     Relations Output of Text 3: {exout3}

#     Example Text 4: {ex4}
#     Entity List of Text 4: {exent4}
#     Relations Output of Text 4: {exout4}
    
    return prompt

In [None]:
sample_prompt = get_prompt("text", "abd; abc")
len(sample_prompt.split(" ")) #, sample_prompt

## Predictions

In [None]:
max_tokens = 512

# lower the value, deterministic result
temperature = 0.1

# a higher value increases the chance of finding a better output
top_p = 0.9

In [None]:
from transformers import StoppingCriteria, StoppingCriteriaList

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False


# stop_words = ['Test', 'Test Sentence', 'Test Output']
stop_words = ["Example", "Entity", "Relations"]
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = KeywordsStoppingCriteria(stop_ids)
stop_ids

In [None]:
def process_entities_v1(ent):
    ent = ent.strip()
    ent = ent.split("; ")
    ent = list(set(ent))
    ent = "; ".join(ent)
    return ent

def process_gtEntities(triplets):
    
    if not isinstance(triplets,str):
        print(f"Not string - {triplets}")
            
    triplets = triplets.strip()
    triplets = triplets.split("\n")
    
    output = set()
    for t in triplets:
        if not isinstance(t,str):
            print(f"Not string - {t}")
            continue
            
        t = t.split(" | ")
        if len(t) != 3:
            continue
            
        if len(t) != 3:
            print(f"Not len 3 - {t}")
            continue
            
        output.add(t[0])
        output.add(t[2])
        
    output = list(output)
    return output

# enty = completeData["Entities"][2]
# enty, process_entities(enty)

In [None]:
output_file_name = "phi2_relations_gt_redocred_dev_v4.csv"

In [None]:
output_data = []
debug = False

for i in tqdm(range(data.shape[0])):
    
    if i <= 40 and debug:
        continue
        
    if i < 491 or i == 492:
        continue
        
    gt = data['Triplets'][i]
    ent = process_gtEntities(data["Triplets"][i])
    
    # Two step prediction
    # prompt = get_prompt(completeData["Text"][i], process_entities_v1(completeData["Entities"][i]))
    # gt = completeData['Triplets'][i]
    # ent = completeData['Entities'][i]
    
    # Relation prediction from GT entities
    prompt = get_prompt(data["Text"][i], ent)
    

    input_ids = tokenizer(prompt, return_tensors="pt").to("cuda") #,  max_length=4096, truncation=True
    
    start = time.time()
    outputs = model.generate(**input_ids,
                             max_new_tokens=max_tokens,
                             top_p=top_p,
                             do_sample=True,
                             temperature=temperature,
                             pad_token_id=tokenizer.eos_token_id,
                             stopping_criteria=StoppingCriteriaList([stop_criteria]),
                             )
    time_diff = time.time() - start
    
    # output2 = tokenizer.decode(outputs[0], skip_special_tokens=True)
    output2 = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    
    if debug:
        print(output2)
        print(f"\nGT : {gt}")
        print(f"\nEntities : {ent}")
        
    output_data.append([data["original_index"][i], output2, time_diff,
                       len(input_ids[0]), len(outputs[0][len(input_ids[0]):])])
    
    if i % 10 == 0 and i > 1:
        with open(output_file_name, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerows(output_data)
            
    if debug: break
    
with open(output_file_name, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerows(output_data)

## Combine results

In [None]:
homeFolder = "/work/pi_dhruveshpate_umass_edu/project_19/aishwarya/696DS-named-entity-extraction-and-linking-for-KG-construction/code/phi2"

DataPath0 = os.path.join(homeFolder, "phi2_relations_gt_redocred_dev.csv")
DataPath1 = os.path.join(homeFolder, "phi2_relations_gt_redocred_dev_v2.csv")
DataPath2 = os.path.join(homeFolder, "phi2_relations_gt_redocred_dev_v3.csv")
DataPath3 = os.path.join(homeFolder, "phi2_relations_gt_redocred_dev_v4.csv")

Data0 = pd.read_csv(DataPath0, header=None)
Data1 = pd.read_csv(DataPath1, header=None)
Data2 = pd.read_csv(DataPath2, header=None)
Data3 = pd.read_csv(DataPath3, header=None)
Data = pd.concat([Data0, Data1, Data2, Data3], axis=0)

Data = Data.rename(columns={0: 'original_index', 1: 'entities', 2: 'latency',
                            3: 'input_tokens',  4: 'output_tokens'})

Data.to_csv('phi2_relations_gt_redocred_dev_combined.csv', index=False)

print(Data.shape, len(list(set(list(Data["original_index"])))))
Data.head()