In [33]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
import dgl.function as fn
import dgl

from tqdm.notebook import tqdm

In [8]:
data = pd.read_csv("/Users/rkim0927/Python/Data/VB/dating_suggestions.csv")

test_period_start, test_period_end = '2022-02-01', '2022-02-15'
train_interaction = data[(data['created_at'] > '2021-11-30') & (data['created_at'] < test_period_start)]
test_interaction = data[(data['created_at'] > test_period_start) & (data['created_at'] < test_period_end)]

unique_users = list(set(pd.concat([train_interaction["source_id"], train_interaction["user_id"]])))
test_interaction = test_interaction[(test_interaction["user_id"].isin(unique_users)) & (test_interaction["source_id"].isin(unique_users))]
print(len(unique_users))

user_to_idx, idx_to_user = {}, {}
for i in tqdm(range(len(unique_users))):
  user_to_idx[unique_users[i]] = i
  idx_to_user[i] = unique_users[i]

21392


  0%|          | 0/21392 [00:00<?, ?it/s]

In [22]:
embeddings_user = pd.read_csv('/Users/rkim0927/Python/Data/VB/Embeddings/embeddings_user_mf_rp.csv')
embeddings_item = pd.read_csv('/Users/rkim0927/Python/Data/VB/Embeddings/embeddings_item_mf_rp.csv')
embeddings_user.set_index("Unnamed: 0", drop = True, inplace = True)
embeddings_item.set_index("Unnamed: 0", drop = True, inplace = True)

user_feats, item_feats = [], []
for i in tqdm(range(len(unique_users))):
  u = unique_users[i]
  if u in embeddings_user.index:
    user_feats.append(embeddings_user.loc[u].values)
    item_feats.append(embeddings_item.loc[u].values)
  else:
    user_feats.append(np.ones(50))
    item_feats.append(np.ones(50))

user_feats = np.array(user_feats)
item_feats = np.array(item_feats)

  0%|          | 0/21392 [00:00<?, ?it/s]

In [16]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.SAGEConv(in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
        # self.pred = MLPPredictor(out_features, 30, 1)
    def forward(self, g, pos_g, neg_g, x1, x2, etype1, etype2):
        h_src = self.sage(g, x1)     
        h_dst = self.sage(g, x2)       
        return self.pred(pos_g, h_src, h_dst, etype1), self.pred(neg_g, h_src,h_dst, etype2)

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h_src, h_dst, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h_src'] = h_src['user']
            graph.ndata['h_dst'] = h_dst['user']
            graph.apply_edges(fn.u_dot_v('h_src', 'h_dst', 'score'), etype=etype)
            return graph.edges[etype].data['score']

def construct_negative_graph(graph, etype1, etype2, device, num_edges = 100000):
    src, dst = graph.edges(etype=etype1)
    indices = torch.LongTensor(np.random.choice(src.size(0), num_edges)).to(device)
    pos_graph = dgl.heterograph(
        {etype1: (src[indices], dst[indices])},
        num_nodes_dict={ntype: graph.number_of_nodes(ntype) for ntype in graph.ntypes})
    neg_src, neg_dst = graph.edges(etype=etype2)
    indices = torch.LongTensor(np.random.choice(neg_src.size(0), num_edges)).to(device)
    neg_graph = dgl.heterograph(
        {etype2: (neg_src[indices], neg_dst[indices])},
        num_nodes_dict={ntype: graph.number_of_nodes(ntype) for ntype in graph.ntypes})
    return pos_graph, neg_graph

def compute_loss(pos_score, neg_score, delta):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (delta - neg_score + pos_score).clamp(min=0).mean()

In [41]:
NODE_FEATURE_DIM = 50
HIDDEN_FEATURE_DIM = 100
EMBEDDING_DIM = 50
MLP_HIDDEN_DIM = 30
NUM_EPOCHS = 1000
k1, k2 = 100000, 30000
alpha = .8

# device = "mps" if torch.backends.mps.is_built() else "gpu" if torch.cuda.is_available() else "cpu"
device = "cpu"
print("Device: ", device)

Device:  cpu


In [42]:
torch.manual_seed(777)

train_accepted = train_interaction[train_interaction["accepted"] == 1]
train_rejected = train_interaction[train_interaction["accepted"] == 0]
acc_src, acc_dst = torch.LongTensor(train_accepted['user_id'].map(user_to_idx).values).to(device), torch.LongTensor(train_accepted['source_id'].map(user_to_idx).values).to(device)
rej_src, rej_dst = torch.LongTensor(train_rejected['user_id'].map(user_to_idx).values).to(device), torch.LongTensor(train_rejected['source_id'].map(user_to_idx).values).to(device)

graph = dgl.heterograph({
    ('user', 'accepts', 'user'): (acc_src, acc_dst),
    ('user', 'rejects', 'user'): (rej_src, rej_dst),
},  num_nodes_dict = {"user": len(idx_to_user)}).to(device)

# graph.nodes['user'].data['src_feature'] = torch.randn(graph.nodes("user").size(0), NODE_FEATURE_DIM).to(device)
# graph.nodes['user'].data['dst_feature'] = torch.randn(graph.nodes("user").size(0), NODE_FEATURE_DIM).to(device)

graph.nodes['user'].data['src_feature'] = torch.FloatTensor(user_feats).to(device)
graph.nodes['user'].data['dst_feature'] = torch.FloatTensor(item_feats).to(device)
graph

Graph(num_nodes={'user': 21392},
      num_edges={('user', 'accepts', 'user'): 834278, ('user', 'rejects', 'user'): 4153220},
      metagraph=[('user', 'user', 'accepts'), ('user', 'user', 'rejects')])

In [46]:
model = Model(NODE_FEATURE_DIM, HIDDEN_FEATURE_DIM, EMBEDDING_DIM, graph.etypes).to(device)
opt = torch.optim.Adam(model.parameters())
src_feats, dst_feats = {"user": graph.nodes['user'].data['src_feature']} , {"user": graph.nodes['user'].data['dst_feature']}
etype_1, etype_2 = ('user', 'accepts', 'user'),  ('user', 'rejects', 'user')

In [47]:
%%time
for epoch in range(NUM_EPOCHS):
    pos_graph, negative_graph = construct_negative_graph(graph, etype_1, etype_2, device, k1)
    pos_score, neg_score = model(graph, pos_graph,  negative_graph, src_feats, dst_feats, etype_1, etype_2)
    loss = compute_loss(pos_score, neg_score, 1)

    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch % 100 == 0:
      print(epoch+1, loss.item())

1 4.540107727050781
101 0.5375457406044006
201 0.4386550486087799
301 0.40113264322280884
401 0.37233319878578186
501 0.34544888138771057
601 0.3359788656234741
701 0.32296472787857056
801 0.3087318241596222
901 0.29816579818725586
CPU times: user 12min 19s, sys: 1min 16s, total: 13min 35s
Wall time: 12min 5s


In [50]:
embeddings_src, embeddings_dst = model.sage(graph, src_feats)['user'].detach().cpu().numpy(), model.sage(graph, dst_feats)['user'].detach().cpu().numpy()
embeddings_src, embeddings_dst = pd.DataFrame(embeddings_src, index = [x[1] for x in idx_to_user.items()]), pd.DataFrame(embeddings_dst, index = [x[1] for x in idx_to_user.items()])

src_to_embedding, dst_to_embedding = dict(), dict()

for i in tqdm(range(len(embeddings_src))):
  src_to_embedding[embeddings_src.index[i]] = embeddings_src.iloc[i].values
  dst_to_embedding[embeddings_dst.index[i]] = embeddings_dst.iloc[i].values

embeddings_src.to_csv('/Users/rkim0927/Python/Data/VB/Embeddings/embeddings_user_matchsage.csv')
embeddings_dst.to_csv('/Users/rkim0927/Python/Data/VB/Embeddings/embeddings_item_matchsage.csv')

  0%|          | 0/21392 [00:00<?, ?it/s]

In [51]:
embeddings_src

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
262146,-0.247022,-0.585324,0.278355,0.977576,-0.163174,-0.149333,0.403666,-0.390639,-0.680525,-0.156883,...,0.407454,0.752831,-0.312989,0.467062,-0.176985,0.059873,-0.637069,-0.173985,0.361946,-0.644119
262148,-0.384095,0.053177,-0.289479,0.169507,0.738665,-0.146694,0.907702,0.080897,-0.097276,-0.384564,...,0.123898,0.409229,0.353526,-0.080921,-0.013004,0.048866,-0.733887,-0.161099,0.156275,0.329346
262149,-0.289618,-0.337384,0.102390,0.013831,0.232124,0.133435,0.583250,-0.023239,-0.526302,-0.112118,...,0.341626,0.104950,0.211426,0.144730,0.130463,0.100896,-0.016758,-0.274512,0.252849,-0.126965
262150,0.122197,-0.993227,0.185879,0.237607,-0.149140,-0.687928,0.437555,0.309256,-0.616605,0.277166,...,0.400098,-0.206687,0.053693,0.090392,-0.293778,-0.139575,-0.892943,0.542541,0.386484,0.122656
262151,0.076501,-0.542083,-0.270100,-0.023784,0.291219,-0.047542,0.270747,-0.035526,-0.445480,0.152017,...,0.403062,-0.034659,-0.034196,0.244297,-0.081065,-0.018950,-0.262636,0.292451,-0.133327,0.154459
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
262139,-0.356279,0.101586,-0.224972,0.202832,0.815259,-0.079751,0.770300,0.278017,0.185470,-0.250616,...,0.219999,0.262693,0.267423,-0.045921,-0.353057,0.129104,-0.791793,-0.090569,-0.094683,0.403881
262140,-0.172082,-0.328460,-0.406324,0.698802,0.648844,0.094804,0.352157,0.038808,-0.443557,-0.256785,...,0.172910,0.076585,-0.721370,-0.521600,-0.057247,0.024184,-0.397897,-0.091874,0.108164,-0.021169
262141,-0.494139,-0.371082,-0.330752,0.311240,-0.230299,-0.309441,0.285026,0.422687,-0.307018,0.023818,...,0.383108,0.477025,0.387020,0.066760,0.594881,0.369909,-0.462010,-0.087332,-0.583248,-0.068032
262142,-0.332772,-0.103500,-0.557458,0.532374,0.518756,0.051194,0.777940,-0.095660,-0.081268,0.202454,...,0.202976,0.007067,0.032891,-0.074116,-0.029475,0.059299,0.061885,0.046481,-0.154132,0.204300
