In [1]:
import torch
print(torch.version.cuda)

11.7


  from .autonotebook import tqdm as notebook_tqdm


In [1]:
import numpy
import torch
import torch.nn as nn

import os
import random
import copy

from tqdm import tqdm
import prettytable

class GraphDataset:
    def __init__(self):
        self.data_path = os.path.abspath('')
        self.dataset_size = ''
        self.num_entities = 0
        self.num_relations = 0
        self.unique_entities = set()
        self.unique_relations = set()
        self.unique_triples = set()

    def load_data(self):
        # Load the data from the specified path
        pass

    def preprocess_data(self):
        # Preprocess the data for training
        pass

    def get_train_data(self):
        # Return training data
        pass

    def get_valid_data(self):
        # Return validation data
        pass

    def get_test_data(self):
        # Return test data
        pass

    # 일단 codex-S부터 분석 
    def load_codex_data(self, size, which):
        '''
        size : s, m, l
        which : train, valid, test
        '''

        self.dataset_size = size

        degree_per_entity = {}
        frequency_per_relation = {}

        file_path = os.path.join(os.path.abspath(''), 'triples', f'codex-{self.dataset_size}', f'{which}.txt')
        with open(file_path, 'r') as f:
            for line in f:
                h, r, t = line.strip().split()
                self.unique_entities.add(h)
                self.unique_entities.add(t)
                self.unique_relations.add(r)
                if h not in degree_per_entity:
                    degree_per_entity[h] = 0
                else:
                    degree_per_entity[h] += 1
                if t not in degree_per_entity:
                    degree_per_entity[t] = 0
                else:
                    degree_per_entity[t] += 1
                if r not in frequency_per_relation:
                    frequency_per_relation[r] = 0
                else:
                    frequency_per_relation[r] += 1
                if (h, r, t) not in self.unique_triples:
                    self.unique_triples.add((h, r, t))
        degree_per_entity = {k: v for k, v in sorted(degree_per_entity.items(), key=lambda item: item[1], reverse=True)}
        frequency_per_relation = {k: v for k, v in sorted(frequency_per_relation.items(), key=lambda item: item[1], reverse=True)}

        print(f'Unique entities: {len(self.unique_entities)}')
        print(f'Unique relations: {len(self.unique_relations)}')
        print(f'Degree per entity: {degree_per_entity}')
        print(f'Frequency per relation: {frequency_per_relation}')
        print(f'Average degree per entity: {sum(degree_per_entity.values()) / len(degree_per_entity)}')

    def construct_element_centric(self, element, num_snapshots):
        '''
        element : 'entity', 'relation', 'fact', 'hybrid'
        num_snapshots : Number of snapshots to be constructed
        '''
        sample_triples = copy.deepcopy(self.unique_triples)
        snapshot_entities = [set()]
        snapshot_relations = [set()]
        snapshot_triples = [set()]
        for i in tqdm(range(num_snapshots)):
            if i == 0:
                seed_triples = []
                for _ in range(10):
                    seed_triple = random.sample(sample_triples, 1)[0]
                    seed_triples.append(seed_triple)
                    sample_triples.remove(seed_triples[-1])
                    h, r, t = seed_triple
                    snapshot_entities[0].add(h)
                    snapshot_entities[0].add(t)
                    snapshot_relations[0].add(r)
                    snapshot_triples[0].add(seed_triple)
            while True:
                if element == 'entity' and len(snapshot_entities[i]) >= len(self.unique_entities) * (i + 1) / num_snapshots:
                    break
                elif element == 'relation' and len(snapshot_relations[i]) >= len(self.unique_relations) * (i + 1) / num_snapshots:
                    break
                elif element == 'fact' and len(snapshot_triples[i]) >= len(self.unique_triples) * (i + 1) / num_snapshots:
                    break
                sample_triple = random.sample(sample_triples, 1)[0]
                h, r, t = sample_triple
                if h not in snapshot_entities[i] and t not in snapshot_entities[i]:
                    continue
                snapshot_entities[i].add(h)
                snapshot_entities[i].add(t)
                snapshot_relations[i].add(r)
                snapshot_triples[i].add(sample_triple)
                sample_triples.remove(sample_triple)
            print(f'Snapshot {i}: {len(snapshot_entities[i])} entities, {len(snapshot_relations[i])} relations, {len(snapshot_triples[i])} triples')

            file_path = os.path.join(os.path.abspath(''), 'triples', f'codex-{self.dataset_size}', f'ENTITY_snapshot_{i}.txt')

            with open(file_path, 'w') as f:
                for h, r, t in snapshot_triples[i]:
                    f.write(f'{h}\t{r}\t{t}\n')
            i += 1
            snapshot_entities.append(copy.deepcopy(snapshot_entities[i - 1]))
            snapshot_relations.append(copy.deepcopy(snapshot_relations[i - 1]))
            snapshot_triples.append(copy.deepcopy(snapshot_triples[i - 1]))

        # print results
        t = prettytable.PrettyTable(['Snapshot', 'Entities', 'Relations', 'Triples'])
        for i in range(num_snapshots):
            t.add_row([i, len(snapshot_entities[i]), len(snapshot_relations[i]), len(snapshot_triples[i])])
        print(t)


In [2]:
g = GraphDataset()

In [3]:
g.load_codex_data('s', 'train')

Unique entities: 2034
Unique relations: 42
Degree per entity: {'Q30': 1007, 'Q1860': 731, 'Q36180': 573, 'Q33999': 559, 'Q17172850': 550, 'Q177220': 541, 'Q183': 479, 'Q639669': 439, 'Q36834': 413, 'Q488205': 378, 'Q10800557': 377, 'Q865': 353, 'Q159': 344, 'Q145': 337, 'Q142': 324, 'Q28389': 312, 'Q10798782': 302, 'Q148': 281, 'Q6607': 275, 'Q408': 253, 'Q49757': 246, 'Q5994': 241, 'Q150': 234, 'Q753110': 233, 'Q463303': 231, 'Q16': 230, 'Q183945': 224, 'Q37073': 217, 'Q38': 214, 'Q1930187': 213, 'Q855091': 213, 'Q188': 209, 'Q35': 209, 'Q3282637': 206, 'Q96': 206, 'Q668': 204, 'Q902': 202, 'Q28': 199, 'Q6625963': 196, 'Q1622272': 194, 'Q230': 193, 'Q486748': 188, 'Q1065': 182, 'Q403': 182, 'Q842490': 178, 'Q8475': 178, 'Q843': 174, 'Q41': 173, 'Q833': 169, 'Q801': 169, 'Q82955': 167, 'Q376150': 166, 'Q191384': 165, 'Q7809': 165, 'Q656801': 164, 'Q155': 164, 'Q4964182': 163, 'Q1043527': 162, 'Q17': 160, 'Q17495': 159, 'Q827525': 155, 'Q2405480': 153, 'Q2526255': 151, 'Q252': 150, 'Q21

In [4]:
g.construct_element_centric('entity', 5)

since Python 3.9 and will be removed in a subsequent version.
  seed_triple = random.sample(sample_triples, 1)[0]
since Python 3.9 and will be removed in a subsequent version.
  sample_triple = random.sample(sample_triples, 1)[0]
 20%|██        | 1/5 [00:01<00:04,  1.24s/it]

Snapshot 0: 407 entities, 18 relations, 490 triples


 40%|████      | 2/5 [00:01<00:02,  1.06it/s]

Snapshot 1: 814 entities, 28 relations, 1168 triples


 60%|██████    | 3/5 [00:02<00:01,  1.20it/s]

Snapshot 2: 1221 entities, 30 relations, 2023 triples


 80%|████████  | 4/5 [00:03<00:00,  1.06it/s]

Snapshot 3: 1628 entities, 36 relations, 3481 triples


100%|██████████| 5/5 [00:10<00:00,  2.09s/it]

Snapshot 4: 2034 entities, 40 relations, 12895 triples
+----------+----------+-----------+---------+
| Snapshot | Entities | Relations | Triples |
+----------+----------+-----------+---------+
|    0     |   407    |     18    |   490   |
|    1     |   814    |     28    |   1168  |
|    2     |   1221   |     30    |   2023  |
|    3     |   1628   |     36    |   3481  |
|    4     |   2034   |     40    |  12895  |
+----------+----------+-----------+---------+





In [5]:
g.construct_element_centric('relation', 5)

since Python 3.9 and will be removed in a subsequent version.
  seed_triple = random.sample(sample_triples, 1)[0]
since Python 3.9 and will be removed in a subsequent version.
  sample_triple = random.sample(sample_triples, 1)[0]
 20%|██        | 1/5 [00:00<00:00,  5.74it/s]

Snapshot 0: 31 entities, 9 relations, 21 triples


 40%|████      | 2/5 [00:00<00:01,  2.64it/s]

Snapshot 1: 140 entities, 17 relations, 140 triples


 60%|██████    | 3/5 [00:01<00:00,  2.64it/s]

Snapshot 2: 311 entities, 26 relations, 356 triples


 80%|████████  | 4/5 [00:02<00:00,  1.39it/s]

Snapshot 3: 1012 entities, 34 relations, 1614 triples


100%|██████████| 5/5 [00:20<00:00,  4.09s/it]

Snapshot 4: 2034 entities, 42 relations, 29279 triples
+----------+----------+-----------+---------+
| Snapshot | Entities | Relations | Triples |
+----------+----------+-----------+---------+
|    0     |    31    |     9     |    21   |
|    1     |   140    |     17    |   140   |
|    2     |   311    |     26    |   356   |
|    3     |   1012   |     34    |   1614  |
|    4     |   2034   |     42    |  29279  |
+----------+----------+-----------+---------+





In [6]:
g.construct_element_centric('fact', 5)

since Python 3.9 and will be removed in a subsequent version.
  seed_triple = random.sample(sample_triples, 1)[0]
since Python 3.9 and will be removed in a subsequent version.
  sample_triple = random.sample(sample_triples, 1)[0]
 20%|██        | 1/5 [00:05<00:23,  5.89s/it]

Snapshot 0: 1934 entities, 38 relations, 6578 triples


 40%|████      | 2/5 [00:10<00:15,  5.12s/it]

Snapshot 1: 2032 entities, 39 relations, 13156 triples


 60%|██████    | 3/5 [00:14<00:09,  4.76s/it]

Snapshot 2: 2034 entities, 42 relations, 19733 triples


 80%|████████  | 4/5 [00:18<00:04,  4.46s/it]

Snapshot 3: 2034 entities, 42 relations, 26311 triples


100%|██████████| 5/5 [00:22<00:00,  4.51s/it]

Snapshot 4: 2034 entities, 42 relations, 32888 triples
+----------+----------+-----------+---------+
| Snapshot | Entities | Relations | Triples |
+----------+----------+-----------+---------+
|    0     |   1934   |     38    |   6578  |
|    1     |   2032   |     39    |  13156  |
|    2     |   2034   |     42    |  19733  |
|    3     |   2034   |     42    |  26311  |
|    4     |   2034   |     42    |  32888  |
+----------+----------+-----------+---------+



