In [1]:
from utils.config import get_config
from utils.dataloader import DataLoader
from tqdm import tqdm
import time

In [2]:
cfg = get_config('kkbox')
dataloader = DataLoader(cfg)
dataloader.load_dicts()


In [3]:
dataloader.load_interactions_directed()

 52%|█████▏    | 3837890/7377420 [00:07<00:06, 526687.27it/s]


In [4]:
dataloader.load_kg_triples_directed()

  0%|          | 11388/4593666 [00:00<00:14, 313610.32it/s]


In [5]:
kg_triples = dataloader.kg_triples
interactions = dataloader.interactions
interactions = [(i[0], 'u', i[1]) for i in interactions if i[2]==1 ]

In [6]:
cfg

{'dataset': 'kkbox',
 'dataset_folder': 'data/kkbox/',
 'entity2id': 'entity2id.dict',
 'id2entity': 'id2entity.dict',
 'id2type': 'id2type.dict',
 'id2user': 'id2user.dict',
 'triples': 'triples.txt',
 'type2id': [{'u': 0}, {'s': 1}, {'a': 2}, {'c': 4}],
 'type_id': ['user', 'song', 'genre_ids', 'artist_name', 'composer'],
 'type_id_simple': ['u', 's', 'a', 'c'],
 'user2id': 'user2id.dict',
 'user_num': 34403}

In [7]:
type2id={'u': 0, 'a': 1, 'g': 2, 'u_': 3, 'a_': 4, 'g_': 5}
id2type = ['u', 'a', 'g', 'u_', 'a_', 'g_']
type_descript = ['user_interaction_with', 'artist', 'genre_id']

In [8]:
full_h = []
full_r = []
full_t = []
for t in interactions:
    full_h.append(t[0])
    full_r.append(type2id[t[1]])
    full_t.append(t[2])
    # reverse
    full_h.append(t[2])
    full_r.append(type2id[t[1]+'_'])
    full_t.append(t[0])

for t in kg_triples:
    full_h.append(t[0])
    full_r.append(type2id[t[1]])
    full_t.append(t[2])
    # reverse
    full_h.append(t[2])
    full_r.append(type2id[t[1]+'_'])
    full_t.append(t[0])

In [9]:
total_size = len(full_h)
head2tail = {}
for i in tqdm(range(total_size), desc='head2tail', total=total_size):
    if full_h[i] not in head2tail.keys():
        head2tail[full_h[i]] = [[],[]]
    head2tail[full_h[i]][0].append(full_r[i])
    head2tail[full_h[i]][1].append(full_t[i])

head2tail: 100%|██████████| 4382224/4382224 [00:08<00:00, 510458.64it/s]


In [10]:
head_songs = [k for k in full_h if dataloader.id2type[k] == 's']
tail_songs = [k for k in full_t if dataloader.id2type[k] == 's']

In [11]:
songs_starts_from = min(head_songs)
songs_ends_to =  max(head_songs)

In [12]:
head2tail[0]

[[0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [16190,
  16304,
  15632,
  13567,
  14171,
  16299,
  13284,
  16072,
  14315,
  16257,
  16689,
  14217,
  13725,
  17266,
  16958,
  16505,
  12535,
  13955,
  14729,
  12425,
  16426,
  17142,
  13589,
  16345,
  15398,
  13745,
  13225,
  13368,
  13832,
  1

In [33]:
song2song = {}
for head_song in tqdm(set(head_songs), desc='song2song', total=songs_ends_to-songs_starts_from+1):
    rels_1, tails_1 = head2tail[head_song]
    for i in range(len(rels_1)):
        rel_1 = rels_1[i]
        tail_1 = tails_1[i]
        rels_2, tails_2 = head2tail[tail_1]
        for j in range(len(rels_2)):
            rel_2 = rels_2[j]
            tail_2 = tails_2[j]
            if dataloader.id2type[tail_2] == 's':
                t = (rel_1, tail_1, rel_2, tail_2)
                if head_song not in song2song.keys():
                    song2song[head_song] = [t]
                else:
                    song2song[head_song].append(t)

song2song:   0%|          | 147/2296833 [00:36<164:08:52,  3.89it/s]

KeyboardInterrupt: 

In [39]:
def link_items_forward_2(item1:int, item2, head2tail=head2tail):
    inter_links = []
    rels_forward, mid_forward = head2tail[item1]
    rels_backward, mid_backward = head2tail[item2]
    tmp = {mid_backward[i]:rels_backward[i] for i in range(len(rels_backward))}
#     tmp = set(mid_backward)
    for i,f in enumerate(mid_forward):
        if f not in tmp.keys():
            continue
        else:
            inter_links.append((rels_forward[i], f, tmp[f]+3))
    return inter_links

path = link_items_forward_2(16304, 16190)

In [None]:
def 

In [None]:
def link_items_backward_2(item1:end, item2, head2tail=head2tail):
    inter_links = []
    rels_backward, mid_backward = head2tail[item1]
    rels_forward, mid_forward = head2tail[item2]
    tmp = set(mid_backward)
    for i,f in enumerate(mid_forward):
        if f not in tmp:
            continue
        for j,b in enumerate(mid_backward):
            if f==b:
                inter_links.append((rels_forward[i], f, rels_backward[j]+3))
                continue
    return inter_links

In [None]:
def link_items_forward_4(item1, item2, head2tail=head2tail):
    inter_links = []
    rels_backward, mid_backward = head2tail[item1]
    rels_forward, mid_forward = head2tail[item2]
    for i,f in enumerate(mid_forward):
        for j,b in enumerate(mid_backward):
            if f==b:
                inter_links.append((rels_forward[i], f, rels_backward[j]+3))
                continue
    return inter_links

In [33]:
len(head2tail[12163][1])

94

In [38]:
start = time.time()
path = []
for i in range(songs_starts_from, songs_ends_to+1):
    path.extend(link_items_forward_2(12123, i))
print(time.time() - start)

0.761254072189331


In [27]:
songs_ends_to-songs_starts_from

5660

In [43]:
len(song2song[34403])

14357

In [21]:
def find_path(user_id, item_id, max_len=1, head2tail=head2tail, init=False):
    if max_len == 0:
        return []
    path = []
#     cache = []
    if user_id in head2tail.keys():
        rels, tails = head2tail[user_id]
        if init:
            for i in tqdm(range(len(rels)), desc='find_path', total=len(rels)):
                if tails[i] == item_id:
                    path.append([(user_id, rels[i], tails[i])])
                else:
                    path.extend([[(user_id, rels[i], tails[i])]+p for p in find_path(tails[i], item_id, max_len-1, head2tail)])
        else:
            for i in range(len(rels)):
#                 if rels[i] != type2id['u'] and rels[i] != type2id['u_']:
                    if tails[i] == item_id:
                        path.append([(user_id, rels[i], tails[i])])
                    else:
                        path.extend([[(user_id, rels[i], tails[i])]+p for p in find_path(tails[i], item_id, max_len-1, head2tail)])
    else:
        return []
    return path

pair = interactions[10001]
path = find_path(pair[0], pair[2], 5, head2tail, True)




find_path:   0%|          | 0/660 [00:00<?, ?it/s][A[A

KeyboardInterrupt: 

In [20]:
pair

(11634, 'u', 16810)

In [140]:
len(interactions)

3714654

In [112]:
[(1,2)]+[(3,4)]

[(1, 2), (3, 4)]

In [110]:
[[1]+ i for i in []]

[]

In [28]:
for i in set([1,2]):
    print(i)

1
2
