In [1]:
import pandas as pd
import torch
import dgl
import pickle
from utils import HGT
import random

In [2]:
graph = dgl.load_graphs('training_data/graph.dgl')
graph = graph[0][0]

with open('training_data/train.obj', 'rb') as fp:
	train = pickle.load(fp)

with open('training_data/val.obj', 'rb') as fp:
	val = pickle.load(fp)

with open('training_data/test.obj', 'rb') as fp:
	test = pickle.load(fp)

In [3]:
graph.nodes('business').shape[0]

150243

In [4]:
edges = {}
for canonical_etype in graph.canonical_etypes:
    edges[canonical_etype] = graph.edges(etype=canonical_etype)

edges[('category', 'category_to_business', 'business')] = (graph.edges(etype='business_has_category')[1], graph.edges(etype='business_has_category')[0])
edges[('business', 'business_to_review', 'review')] = (graph.edges(etype='review_to_business')[1], graph.edges(etype='review_to_business')[0])
edges[('business', 'business_to_tip', 'tip')] = (graph.edges(etype='tip_to_business')[1], graph.edges(etype='tip_to_business')[0])
edges[('review', 'review_to_user', 'user')] = (graph.edges(etype='user_to_review')[1], graph.edges(etype='user_to_review')[0])
edges[('tip', 'tip_to_user', 'user')] = (graph.edges(etype='user_to_tip')[1], graph.edges(etype='user_to_tip')[0])

num_nodes_dict = {} 
for ntype in graph.ntypes:
    num_nodes_dict[ntype] = graph.nodes(ntype).shape[0]

g = dgl.heterograph(edges, num_nodes_dict = num_nodes_dict)
g.ndata['feat'] = {k: torch.tensor(v, dtype=torch.float32) for k, v in graph.ndata['feat'].items() }
del graph

  g.ndata['feat'] = {k: torch.tensor(v, dtype=torch.float32) for k, v in graph.ndata['feat'].items() }


In [5]:
g.num_nodes

<bound method DGLGraph.num_nodes of Graph(num_nodes={'business': 150243, 'category': 1311, 'review': 6339837, 'tip': 908878, 'user': 1987897},
      num_edges={('business', 'business_has_category', 'category'): 668592, ('business', 'business_to_review', 'review'): 6339837, ('business', 'business_to_tip', 'tip'): 908878, ('category', 'category_to_business', 'business'): 668592, ('review', 'review_to_business', 'business'): 6339837, ('review', 'review_to_user', 'user'): 6339837, ('tip', 'tip_to_business', 'business'): 908878, ('tip', 'tip_to_user', 'user'): 908878, ('user', 'user_to_review', 'review'): 6339837, ('user', 'user_to_tip', 'tip'): 908878, ('user', 'user_to_user', 'user'): 437928},
      metagraph=[('business', 'category', 'business_has_category'), ('business', 'review', 'business_to_review'), ('business', 'tip', 'business_to_tip'), ('category', 'business', 'category_to_business'), ('review', 'business', 'review_to_business'), ('review', 'user', 'review_to_user'), ('tip', 'bus

In [6]:
node_dict = { ntype: g.ntypes.index(ntype) for ntype in g.ntypes }
edge_dict = { canonical_etype: g.canonical_etypes.index(canonical_etype) for canonical_etype in g.canonical_etypes }
feature_dim_dict = { ntype: g.ndata['feat'][ntype].shape[1] for ntype in g.ntypes }

In [7]:
node_dict

{'business': 0, 'category': 1, 'review': 2, 'tip': 3, 'user': 4}

In [8]:
model = HGT(node_dict, edge_dict, feature_dim_dict, n_hid=256, n_out=128, n_layers=4, n_heads=8, use_norm=False)
opt = torch.optim.AdamW(model.parameters(), 1e-4)
sampler = dgl.dataloading.NeighborSampler([24, 24, 24, 24])
criterion = torch.nn.MarginRankingLoss(margin=0.1)
# dgl.dataloading.NeighborSampler([
#     {('user', 'follows', 'user'): 5,
#      ('user', 'plays', 'game'): 4,
#      ('game', 'played-by', 'user'): 3}] * 3)

In [9]:
train_pos_ids = list(range(train['pos'][0].shape[0]))
train_neg_ids = list(range(train['neg'][0].shape[0]))

In [10]:
len(train_neg_ids)

135108

In [11]:
g.ndata['feat']['category'].dtype

torch.float32

In [28]:
def predict(g, model, pos_ids, neg_ids, relation_tuple, sampler, batch_size):
    pos_users = torch.index_select(relation_tuple['pos'][0], 0, torch.tensor(pos_ids))
    pos_users_unique, pos_users_inverse = torch.unique(pos_users, return_inverse=True)
    pos_block_user = [blocks for _, _, blocks in dgl.dataloading.DataLoader(
        g, {'user': pos_users_unique}, sampler,
        batch_size=batch_size, shuffle=False, drop_last=False, num_workers=1)][0]

    pos_business = torch.index_select(relation_tuple['pos'][1], 0, torch.tensor(pos_ids))
    pos_business_unique, pos_business_inverse = torch.unique(pos_business, return_inverse=True)
    pos_block_business = [blocks for _, _, blocks in dgl.dataloading.DataLoader(
        g, {'business': pos_business_unique }, sampler,
        batch_size=batch_size, shuffle=False, drop_last=False, num_workers=1)][0]
    
    neg_users = torch.index_select(relation_tuple['neg'][0], 0, torch.tensor(neg_ids))
    neg_users_unique, neg_users_inverse = torch.unique(neg_users, return_inverse=True)
    neg_block_user = [blocks for _, _, blocks in dgl.dataloading.DataLoader(
        g, {'user': neg_users_unique }, sampler,
        batch_size=batch_size, shuffle=False, drop_last=False, num_workers=1)][0]
    
    neg_business = torch.index_select(relation_tuple['neg'][1], 0, torch.tensor(neg_ids))
    neg_business_unique, neg_business_inverse = torch.unique(neg_business, return_inverse=True)
    neg_block_business = [blocks for _, _, blocks in dgl.dataloading.DataLoader(
        g, {'business': neg_business_unique }, sampler,
        batch_size=batch_size, shuffle=False, drop_last=False, num_workers=1)][0]

    pos_user_logits = torch.index_select(model(pos_block_user, 'user'), 0, pos_users_inverse)
    pos_business_logits = torch.index_select(model(pos_block_business, 'business'), 0, pos_business_inverse)
    neg_user_logits = torch.index_select(model(neg_block_user, 'user'), 0, neg_users_inverse)
    neg_business_logits = torch.index_select(model(neg_block_business, 'business'), 0, neg_business_inverse)
    return pos_user_logits, pos_business_logits, neg_user_logits, neg_business_logits

In [30]:
def split(list_a, chunk_size):
    for i in range(0, len(list_a), chunk_size):
        yield list_a[i:i + chunk_size]

batch_size = 64
for epoch in range(1):
    model.train()
    random.shuffle(train_pos_ids)
    random.shuffle(train_neg_ids)
    for batch in split(list(zip(train_pos_ids, train_neg_ids)), batch_size):
        opt.zero_grad()
        pos_ids, neg_ids = list(zip(*batch))
        pos_user_logits, pos_business_logits, neg_user_logits, neg_business_logits = predict(g, model, pos_ids, neg_ids, train, sampler, batch_size)
        pos_score = torch.bmm(pos_user_logits.view(batch_size, 1, model.n_out), pos_business_logits.view(batch_size, model.n_out, 1)).squeeze()
        neg_score = torch.bmm(neg_user_logits.view(batch_size, 1, model.n_out), neg_business_logits.view(batch_size, model.n_out, 1)).squeeze()
        loss = criterion(pos_score, neg_score, torch.ones(batch_size))
        loss.backward()
        opt.step()
        print(loss.item())
    break

0.0011257501319050789
0.0009124502539634705
0.0009070197120308876
0.0006720079109072685
0.0006000311113893986
0.0002428349107503891
0.00031552277505397797
0.00036688148975372314
0.0003705797716975212
0.0003163143992424011
0.0003445069305598736
0.000242678914219141
0.0003475630655884743
0.00017621507868170738
0.0003565829247236252
0.00036633387207984924
0.0003993348218500614
0.00033133476972579956


KeyboardInterrupt: 