In [103]:
from config import Param

param = Param()
args = param.args
args

args.task_name = args.dataname

# rel_per_task
args.rel_per_task = 8 if args.dataname == "FewRel" else 4

In [104]:
print(args)

Namespace(gpu=0, dataname='TACRED', task_name='TACRED', device='cuda', batch_size=64, num_tasks=10, rel_per_task=4, pattern='entity_marker', max_length=192, encoder_output_size=768, vocab_size=30522, marker_size=4, num_workers=0, save_checkpoint='./checkpoint/', classifier_lr=0.01, encoder_lr=0.001, prompt_pool_lr=0.001, sgd_momentum=0.1, gmm_num_components=1, pull_constraint_coeff=0.1, classifier_epochs=10, encoder_epochs=10, prompt_pool_epochs=10, replay_s_e_e=256, replay_epochs=100, seed=2021, max_grad_norm=10, data_path='./datasets', bert_path='bert-base-uncased', cov_mat=True, max_num_models=10, sample_freq=5, prompt_length=1, prompt_embed_dim=768, prompt_pool_size=80, prompt_top_k=8, prompt_init='uniform', prompt_key_init='uniform', drop_p=0.1, gradient_accumulation_steps=4, total_round=6, drop_out=0.5, use_gpu=True, hidden_size=768, rank_lora=8, bge_model='BAAI/bge-m3', description_path='./description/all-2.json', type_similar='colbert')


In [105]:
import pickle
import random
import json, os
from transformers import BertTokenizer
import numpy as np  


def get_tokenizer(args):
    tokenizer = BertTokenizer.from_pretrained(args.bert_path, additional_special_tokens=["[E11]", "[E12]", "[E21]", "[E22]"])
    return tokenizer


class data_sampler(object):
    def __init__(self, args, seed=None):
        self.set_path(args)
        self.args = args

        # data path
        file_name = "{}.pkl".format("-".join([str(x) for x in [args.dataname, args.seed]]))
        mid_dir = ""
        for temp_p in ["datasets", "_process_path"]:
            mid_dir = os.path.join(mid_dir, temp_p)
            if not os.path.exists(mid_dir):
                os.mkdir(mid_dir)
        self.save_data_path = os.path.join(mid_dir, file_name)

        # import tokenizer
        self.tokenizer = get_tokenizer(args)

        # read relation data
        self.id2rel, self.rel2id = self._read_relations(args.relation_file)

        # random sampling
        self.seed = seed
        if self.seed is not None:
            random.seed(self.seed)
        self.shuffle_index = list(range(len(self.id2rel)))
        random.shuffle(self.shuffle_index)
        self.shuffle_index = np.argsort(self.shuffle_index)

        # regenerate data
        self.training_dataset, self.valid_dataset, self.test_dataset = self._read_data(self.args.data_file)

        # generate the task number
        self.batch = 0
        self.task_length = len(self.id2rel) // self.args.rel_per_task

        # record relations
        self.seen_relations = []
        self.history_test_data = {}
        
        if args.dataname in ["FewRel"]:
            self.id2rel = json.load(open(os.path.join(args.data_path, "id2rel.json"), 'r'))
        else:
            self.id2rel = json.load(open(os.path.join(args.data_path, "id2rel_tacred.json"), 'r'))
        
        self.rel2id = {label: idx for idx, label in enumerate(self.id2rel)}
        

    def set_path(self, args):
        use_marker = ""
        if args.dataname in ["FewRel"]:
            args.data_file = os.path.join(args.data_path, "data_with{}_marker.json".format(use_marker))
            args.relation_file = os.path.join(args.data_path, "id2rel.json")
            args.num_of_relation = 80
            args.num_of_train = 420
            args.num_of_val = 140
            args.num_of_test = 140
            
        elif args.dataname in ["TACRED"]:
            args.data_file = os.path.join(args.data_path, "data_with{}_marker_tacred.json".format(use_marker))
            args.relation_file = os.path.join(args.data_path, "id2rel_tacred.json")
            args.num_of_relation = 40
            args.num_of_train = 420
            args.num_of_val = 140
            args.num_of_test = 140

    def set_seed(self, seed):
        self.seed = seed
        if self.seed != None:
            random.seed(self.seed)
        self.shuffle_index = list(range(len(self.id2rel)))
        random.shuffle(self.shuffle_index)
        self.shuffle_index = np.argsort(self.shuffle_index)

    def __iter__(self):
        return self

    def __next__(self):
        if self.batch == self.task_length:
            raise StopIteration()

        indexs = self.shuffle_index[self.args.rel_per_task * self.batch : self.args.rel_per_task * (self.batch + 1)]
        self.batch += 1

        current_relations = []
        cur_training_data = {}
        cur_valid_data = {}
        cur_test_data = {}

        for index in indexs:
            current_relations.append(self.id2rel[index])
            self.seen_relations.append(self.id2rel[index])
            cur_training_data[self.id2rel[index]] = self.training_dataset[index]
            cur_valid_data[self.id2rel[index]] = self.valid_dataset[index]
            cur_test_data[self.id2rel[index]] = self.test_dataset[index]
            self.history_test_data[self.id2rel[index]] = self.test_dataset[index]

        return cur_training_data, cur_valid_data, cur_test_data, current_relations, self.history_test_data, self.seen_relations

    def _read_data(self, file):
        if os.path.isfile(self.save_data_path):
            with open(self.save_data_path, "rb") as f:
                datas = pickle.load(f)
            train_dataset, val_dataset, test_dataset = datas
            return train_dataset, val_dataset, test_dataset
        else:
            data = json.load(open(file, "r", encoding="utf-8"))
            train_dataset = [[] for i in range(self.args.num_of_relation)]
            val_dataset = [[] for i in range(self.args.num_of_relation)]
            test_dataset = [[] for i in range(self.args.num_of_relation)]
            for relation in data.keys():
                rel_samples = data[relation]
                if self.seed != None:
                    random.seed(self.seed)
                random.shuffle(rel_samples)
                count = 0
                count1 = 0
                for i, sample in enumerate(rel_samples):
                    tokenized_sample = {}
                    tokenized_sample["relation"] = self.rel2id[sample["relation"]]
                    tokenized_sample["text"] = " ".join(sample["tokens"])
                    tokenized_sample["tokens"] = self.tokenizer.encode(" ".join(sample["tokens"]), padding="max_length", truncation=True, max_length=self.args.max_length)


                    if self.args.task_name == "FewRel":
                        if i < self.args.num_of_train:
                            train_dataset[self.rel2id[relation]].append(tokenized_sample)
                        elif i < self.args.num_of_train + self.args.num_of_val:
                            val_dataset[self.rel2id[relation]].append(tokenized_sample)
                        else:
                            test_dataset[self.rel2id[relation]].append(tokenized_sample)
                    else:
                        if i < len(rel_samples) // 5 and count <= 40:
                            count += 1
                            test_dataset[self.rel2id[relation]].append(tokenized_sample)
                        else:
                            count1 += 1
                            train_dataset[self.rel2id[relation]].append(tokenized_sample)
                            if count1 >= 320:
                                break

                    
            with open(self.save_data_path, "wb") as f:
                pickle.dump((train_dataset, val_dataset, test_dataset), f)
            return train_dataset, val_dataset, test_dataset

    def _read_relations(self, file):
        id2rel = json.load(open(file, "r", encoding="utf-8"))
        rel2id = {}
        for i, x in enumerate(id2rel):
            rel2id[x] = i
        return id2rel, rel2id


In [106]:
import json

file = "/home/luungoc/Thesis - 2023.2/Thesis_NgocLT/datasets/data_with_marker_tacred.json"
data = json.load(open(file, "r", encoding="utf-8"))

In [107]:
import random

random.seed(args.seed)
print(args.seed)

2021


In [108]:
data = data_sampler(args, seed=2021)
list_data = []

for steps, (training_data, valid_data, test_data, current_relations, historic_test_data, seen_relations) in enumerate(data):
    
    task_x = []
    # print(current_relations)
    # for relation in current_relations:
    #     for sample in training_data[relation]:
    #         task_x.append({
    #             'relation': relation,
    #             'text': sample['text']
    #         })
    # list_data.append(task_x)
    
        # for item in training_data[relation]:
        #     if item['relation'] == 0:
        #         print(item)
    # if training_data[current_relations[0]][21]['relation'] == 0:
    print(training_data[current_relations[0]][21])
    
    

{'relation': 3, 'text': '[E21] lancaster county [E22] coroner dr. [E11] g. gary kirchner [E12] said the summary citations , filed by a humane society officer last week , were unwarranted .', 'tokens': [101, 30524, 10237, 2221, 30525, 22896, 2852, 1012, 30522, 1043, 1012, 5639, 11382, 11140, 3678, 30523, 2056, 1996, 12654, 22921, 1010, 6406, 2011, 1037, 23369, 2554, 2961, 2197, 2733, 1010, 2020, 4895, 9028, 17884, 2098, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
{'relation': 32, 'text': '[E11] vladimir ladyzhenskiy [E12] of russia died after she suffered [E21] a shock [E22] in the final of the spa world championship in heinola , a southern city of finland , on saturday .', 'tokens': [101, 30522, 8748, 3203, 27922, 6132, 3211, 2100, 30523, 1997, 3607

In [79]:
# list_description = []


x = {
    'relation': 'per:religion',
    'text': """The relation "per:religion" identifies and associates an individual with their religious affiliation or beliefs. In the given examples, this relation is used to link people to their respective religions, showcasing how this information is relevant or contextual within the narrative.
the "per:religion" relation is crucial for understanding how an individual's religious identity can influence, explain, or contextualize their actions, interactions, and roles within various narratives, whether these are legal, social, or political.
"""
}


In [80]:
list_description.append(x)
print(len(list_description))

import json
json.dump(list_description, open('./description.json', 'w'), ensure_ascii=False)

40


In [98]:
for item in list_description:
    if item['relation'] == 'per:charges':
        print(item['text'])

The relation name "per:charges" refers to a specific type of relationship between a person (E11, E12) and the charges or accusations (E21, E22) made against that person. This relationship is extracted from textual data where an individual's legal or criminal charges are explicitly mentioned. The examples provided illustrate various instances of how these charges can be represented in text, highlighting the diversity in the nature of the charges and the contexts in which they are presented.
The "per:charges" relation captures the legal or criminal accusations made against individuals. This relation is vital for extracting information related to legal proceedings, criminal activities, and judicial sentences from textual sources. It can include a wide array of charges, from financial crimes and violent offenses to politically motivated actions. Additionally, the relation can provide insights into the legal characterizations of the actions (e.g., hate crimes), offering a deeper understandi

In [95]:
import json 

des1 = json.load(open('/home/luungoc/Thesis - 2023.2/Thesis_NgocLT/description/all.json', 'r'))
des2 = json.load(open('/home/luungoc/Thesis - 2023.2/Thesis_NgocLT/description/new_description_2.json', 'r'))
des3 = json.load(open('/home/luungoc/Thesis - 2023.2/Thesis_NgocLT/description/new_description_3.json', 'r'))
des4 = json.load(open('/home/luungoc/Thesis - 2023.2/Thesis_NgocLT/description/new_description_4.json', 'r'))

In [101]:
# des_all = des1 + des2 + des3 + des4
# len(des_all)

# json.dump(des_all, open('/home/luungoc/Thesis - 2023.2/Thesis_NgocLT/description/all.json', 'w'), ensure_ascii=False)


for item in des1:
    if item['relation'] == 'per:city_of_birth':
        print(item['text'])

The relation name "per:city_of_birth" refers to the specific type of relationship between a person (E11, E12) and the city or location where that person was born (E21, E22). This relationship is identified from textual contexts that explicitly mention the birthplace of an individual. The examples provided showcase instances where the city of birth is integral to the identity or background of the person mentioned, demonstrating the variety of ways this information can be presented in text.

The "per:city_of_birth" relation is fundamental for understanding the geographical and cultural origins of individuals. It connects people to specific places, offering insights into their early life influences, cultural heritage, and the socio-economic conditions of their birthplaces. This relationship is crucial for biographical studies, historical research, and sociological analyses, as it provides a foundational aspect of a person's identity. Identifying the city of birth can also be important in 

In [None]:
['per:cities_of_residence', 'per:other_family', 'org:founded', 'per:origin']
['per:cause_of_death', 'org:dissolved', 'per:employee_of', 'org:member_of']
['per:parents', 'per:alternate_names', 'org:top_members/employees', 'per:siblings']
['per:stateorprovinces_of_residence', 'org:alternate_names', 'org:country_of_headquarters', 'per:country_of_birth']
['per:children', 'per:date_of_birth', 'org:founded_by', 'per:countries_of_residence']
['per:schools_attended', 'org:subsidiaries', 'org:members', 'org:political/religious_affiliation']
['org:stateorprovince_of_headquarters', 'per:charges', 'per:stateorprovince_of_birth', 'per:title']
['per:stateorprovince_of_death', 'org:number_of_employees/members', 'per:city_of_death', 'per:spouse']
['org:website', 'per:age', 'per:city_of_birth', 'per:date_of_death']
['org:shareholders', 'org:parents', 'org:city_of_headquarters', 'per:religion']