In [1]:
import pandas as pd
from collections import Counter
from typing import Dict, Tuple

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [2]:
df = pd.read_csv('./data/music/mo_triples.csv',index_col=0)

In [3]:
df

Unnamed: 0,S,P,O,SPO
0,j.0:Unit,rdf:type,owl:Class,"Unit , type , Class"
1,skmo:20080410_Mnet_엠카운트다운,skmo:rankBy,skmo:엠카운트다운,"20080410_Mnet_엠카운트다운 , rankBy , 엠카운트다운"
2,skmo:이주형,skmo:hasMusicActivity,skmo:composer,"이주형 , hasMusicActivity , composer"
3,skmo:T.O.P,foaf:name,T.O.P,"T.O.P , name , T.O.P"
4,skmo:뱅뱅뱅_앨범,skmo:hasTrack,skmo:Loser,"뱅뱅뱅_앨범 , hasTrack , Loser"
...,...,...,...,...
3383,skmo:한국,rdf:type,schema:Country,"한국 , type , Country"
3384,skmo:Carlos_Adaamick_Mendoza,rdfs:label,카를로스 아담믹 멘도자,"Carlos_Adaamick_Mendoza , label , 카를로스 아담믹 멘도자"
3385,skmo:소녀시대-앨범,skmo:hasTrack,skmo:7989_(강타&태연),"소녀시대-앨범 , hasTrack , 7989_(강타&태연)"
3386,skmo:trackReleaseType,rdfs:range,mo:Track,"trackReleaseType , range , Track"


In [4]:
heads, relations, tails = df.S.values, df.P.values, df.O.values

In [5]:
len(heads), len(relations), len(tails)

(3388, 3388, 3388)

In [6]:
triples = []
for i in range(len(heads)):
    triples.append([heads[i],relations[i],tails[i]] )

In [7]:
len(triples)

3388

In [8]:
triples[2]

['skmo:이주형', 'skmo:hasMusicActivity', 'skmo:composer']

In [9]:
Mapping = Dict[str,int]
def create_mapping(triples: list) -> Tuple[Mapping,Mapping]:
    entity_counter = Counter()
    relation_counter = Counter()
    for line in triples:
        # -1 to remove newline sign
        head, relation, tail = line
        entity_counter.update([head, tail])
        relation_counter.update([relation])
    entity2id = {}
    relation2id = {}
    for idx, (mid, _) in enumerate(entity_counter.most_common()):
        entity2id[mid] = idx
    for idx, (relation, _) in enumerate(relation_counter.most_common()):
        relation2id[relation] = idx
    return entity2id, relation2id

In [10]:
entity2id, relation2id = create_mapping(triples)

In [11]:
entity2id['skmo:AOA']

24

In [12]:
from sklearn.model_selection import train_test_split
T_train, T_test=train_test_split(triples, test_size=0.3,  shuffle= True)
len(T_train),len(T_test)

(2371, 1017)

In [13]:
T_test , T_valid = T_test[: int(len(triples)*0.2)] , T_test[int(len(triples) * 0.2):]

In [14]:
len(T_test),len(T_valid)

(677, 340)

In [15]:
class GraphDataset(Dataset):
    def __init__(self,triple_data, entity2id : Mapping,relation2id : Mapping):
        super().__init__()
        self.entity2id = entity2id
        self.relation2id = relation2id        

        #Head, relation, Tail triple
        self.data = triple_data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        '''h,r,t 를 각 라인에서 가져온 후 Mapping에서 조회하여 index값으로 반환'''
        head, relation, tail = self.data[index]
        head_id = self.to_idx(key = head, mapping= self.entity2id)
        relation_id = self.to_idx(relation, self.relation2id)
        tail_id = self.to_idx(tail, self.entity2id)
        return head_id, relation_id, tail_id

    #정적 메소드 인스턴스만들지 않아도 class의 메소도를 바로 실행가능
    @staticmethod
    def to_idx(key: str, mapping : Mapping)-> int:
        try:
            return mapping[key]
        except KeyError:
            return len(mapping) #만약 Mapping이라는 사전에 해당 값이 없으면 OOV 이슈가 생김 이를 방지하고자 그냥 최대값을 반환
        



In [16]:
#entity2id, relation2id = create_mapping(T_train)

In [17]:
print(f'entity 총 길이:{len(entity2id)}, relation 총 길이{len(relation2id)}')

entity 총 길이:1754, relation 총 길이55


In [18]:
batch_size = 128
num_worker = 6
train_set = GraphDataset(T_train, entity2id=entity2id, relation2id= relation2id)
trainloader = DataLoader(train_set,batch_size=batch_size,num_workers=6,shuffle=True)
valid_set = GraphDataset(T_valid, entity2id=entity2id, relation2id= relation2id)
validloader = DataLoader(valid_set,batch_size=batch_size,num_workers=num_worker,shuffle=True)
test_set = GraphDataset(T_test, entity2id=entity2id, relation2id= relation2id)
testloader = DataLoader(test_set,batch_size=batch_size,num_workers=num_worker,shuffle=True)

In [19]:
len(train_set)

2371

In [20]:
with open('./data/music/train2id.txt', 'w', encoding = 'UTF-8') as file:
    file.write(f'{len(train_set)}\n')
    for data in train_set:
        h, r , t = data
        file.write(f'{h} {t} {r}\n') # openKE는 entity1 entity2 relation 형식

In [21]:
with open('./data/music/valid2id.txt', 'w', encoding = 'UTF-8') as file: 
    file.write(f'{len(valid_set)}\n')
    for data in valid_set:
        h, r , t = data
        file.write(f'{h} {t} {r}\n')

In [22]:
with open('./data/music/test2id.txt', 'w', encoding = 'UTF-8') as file:    
    file.write(f'{len(test_set)}\n')
    for data in test_set:
        h, r , t = data
        file.write(f'{h} {t} {r}\n')

In [22]:
with open('./data/music/entity2id.txt', 'w', encoding = 'UTF-8') as file:    
    file.write(f'{len(entity2id)}\n')
    for i, v in entity2id.items():
        file.write(f'{i} {v}\n')
with open('./data/music/relation2id.txt', 'w', encoding = 'UTF-8') as file:    
    file.write(f'{len(relation2id)}\n')
    for i, v in relation2id.items():
        file.write(f'{i} {v}\n')

In [23]:
len(entity2id)

1754