## Dataset

In [72]:
import re
import pandas as pd
import numpy as np

movies = []
with open('./dataset/movielens/movies.dat', encoding='latin1') as f:
    for l in f:
        id_, title, genres = l.strip().split('::')

        # extract year
        assert re.match(r'.*\([0-9]{4}\)$', title)
        year = title[-5:-1]
        title = title[:-6].strip()

        data = {'movie_id': int(id_), 'title': title, 'year': year, 'genre': genres}
        movies.append(data)
movies = pd.DataFrame(movies).astype({'year': 'category'})

ratings = []
with open('./dataset/movielens/ratings.dat', encoding='latin1') as f:
    for l in f:
        user_id, movie_id, rating, timestamp = [int(_) for _ in l.split('::')]
        ratings.append({
            'user_id': user_id,
            'movie_id': movie_id,
            'rating': rating,
            'timestamp': timestamp,
            })
ratings = pd.DataFrame(ratings)

distinct_movies_in_ratings = ratings['movie_id'].unique()
movies = movies[movies['movie_id'].isin(distinct_movies_in_ratings)]
movies['genre'] = movies['genre'].apply(lambda x: x.split("|")[0])
entities = movies['movie_id'].astype('category')
m_entities = entities.cat.reorder_categories(movies['movie_id'].values)

In [73]:
index_id_to_movie_id = {}
movie_id_to_index_id = {}
for idx, movie_id in enumerate(m_entities):
    index_id_to_movie_id[idx] = movie_id
    movie_id_to_index_id[movie_id] = idx

----
## Normal Query

In [74]:
from scipy import spatial

saved_npz = np.load('./multisage/' + 'h_items.npz')
h_item = saved_npz['movie_vectors']
tree = spatial.KDTree(h_item.tolist())

In [75]:
h_query = h_item[movie_id_to_index_id[3]]
index_ids = tree.query(h_query, 10)[1]
movie_ids = [index_id_to_movie_id[idx] for idx in index_ids]

In [76]:
for mid in movie_ids:
    print(movies[movies['movie_id']==mid]['title'].values)

['Grumpier Old Men']
['Incredibly True Adventure of Two Girls in Love, The']
['French Twist (Gazon maudit)']
['French Kiss']
["Pyromaniac's Love Story, A"]
['Vampire in Brooklyn']
['While You Were Sleeping']
['Rendezvous in Paris (Rendez-vous de Paris, Les)']
['Forget Paris']
['Clueless']


----
## Context Query

In [77]:
!export PYTHONPATH=/Users/A202009066/Documents/private-github/recommender-system/embedding/gnn:$PYTHONPATH

In [78]:
import os
import dgl

import numpy as np
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader

from multisage import layers
from multisage.sampler import ItemToItemBatchSampler, NeighborSampler, PinSAGECollator


class MultiSAGEModel(nn.Module):
    def __init__(self, full_graph, ntype, ctype, hidden_dims, n_layers, gat_num_heads):
        super().__init__()
        self.nodeproj = layers.LinearProjector(full_graph, ntype, hidden_dims)
        self.contextproj = layers.LinearProjector(full_graph, ctype, hidden_dims)
        self.multisage = layers.MultiSAGENet(hidden_dims, n_layers, gat_num_heads)
        self.scorer = layers.ItemToItemScorer(full_graph, ntype)

    def forward(self, pos_graph, neg_graph, blocks, context_blocks):
        h_item = self.get_representation(blocks, context_blocks)
        pos_score = self.scorer(pos_graph, h_item)
        neg_score = self.scorer(neg_graph, h_item)
        return (neg_score - pos_score + 1).clamp(min=0)

    def get_representation(self, blocks, context_blocks, context_id=None):
        if context_id:
            return self.get_context_query(blocks, context_blocks, context_id)
        else:
            h_item = self.nodeproj(blocks[0].srcdata)
            h_item_dst = self.nodeproj(blocks[-1].dstdata)
            z_c = self.contextproj(context_blocks[0])
            z_c_dst = self.contextproj(context_blocks[-1])
            h = h_item_dst + self.multisage(blocks, h_item, (z_c, z_c_dst))
            return h

    def get_context_query(self, blocks, context_blocks, context_id):
        # check sub-graph contains context id
        context_id = context_blocks[-1]['_ID'][0].item()
        print(context_id)
        print(context_blocks[-1]['_ID'])
        context_index = (context_id == context_blocks[-1]['_ID']).nonzero(as_tuple=True)[0]
        if context_index.size()[0] == 0:  # if context id not in sub-graph, only random sample context using for repr
            print("context not in sub graph")
            return self.get_representation(blocks, context_blocks)
        else:  # if context id in sub-graph, get MultiSAGE's context query
            print("execute context query")
            attn_index = torch.ones(context_blocks[-1]['_ID'].shape[0], dtype=bool)
            attn_index[context_index] = False
            h_item = self.nodeproj(blocks[0].srcdata)
            h_item_dst = self.nodeproj(blocks[-1].dstdata)
            z_c = self.contextproj(context_blocks[0])
            z_c_dst = self.contextproj(context_blocks[-1])
            h = h_item_dst + self.multisage(blocks, h_item, (z_c, z_c_dst), attn_index)
            return h

In [79]:
import torch
import pickle

with open('./multisage/graph_data.pickle', 'rb') as f:
    dataset = pickle.load(f)
g = dataset['train-graph']
load_dict = torch.load('./multisage/MultiSAGE_weights.pth')

In [80]:
model = MultiSAGEModel(g, 'movie', 'genre', 512, 2, 3)
model.load_state_dict(load_dict)

<All keys matched successfully>

In [81]:
batch_sampler = ItemToItemBatchSampler(g, 'genre', 'movie', 512)
neighbor_sampler = NeighborSampler(
    g, 'genre', 'movie', 2, 0.5, 10, 5, 2)
collator = PinSAGECollator(neighbor_sampler, g, 'movie', 'genre')

index_id = movie_id_to_index_id[3]
with torch.no_grad():
    blocks, context_blocks = collator.collate_point(index_id=index_id)
    context_batch = model.get_representation(blocks, context_blocks, context_id=4)

5
tensor([5, 5, 5, 5, 5])
execute context query


In [83]:
# Comedy Query
index_ids = tree.query(context_batch.numpy()[0], 10)[1]
movie_ids = [index_id_to_movie_id[idx] for idx in index_ids]
for mid in movie_ids:
    print(movies[movies['movie_id']==mid]['title'].values)

['Perez Family, The']
['Sabrina']
['Pie in the Sky']
['Englishman Who Went Up a Hill, But Came Down a Mountain, The']
['Forget Paris']
['Clueless']
['Rendezvous in Paris (Rendez-vous de Paris, Les)']
['While You Were Sleeping']
["Pyromaniac's Love Story, A"]
['Vampire in Brooklyn']


In [85]:
index_id = movie_id_to_index_id[3]
with torch.no_grad():
    blocks, context_blocks = collator.collate_point(index_id=index_id)
    context_batch = model.get_representation(blocks, context_blocks, context_id=5)

4
tensor([4, 5, 5, 5, 4])
execute context query


In [86]:
# Romance Query
index_ids = tree.query(context_batch.numpy()[0], 10)[1]
movie_ids = [index_id_to_movie_id[idx] for idx in index_ids]
for mid in movie_ids:
    print(movies[movies['movie_id']==mid]['title'].values)

['French Kiss']
['French Twist (Gazon maudit)']
['Grumpier Old Men']
['Incredibly True Adventure of Two Girls in Love, The']
["Pyromaniac's Love Story, A"]
['Vampire in Brooklyn']
['While You Were Sleeping']
['Rendezvous in Paris (Rendez-vous de Paris, Les)']
['Forget Paris']
['Englishman Who Went Up a Hill, But Came Down a Mountain, The']


----
## Visualize

In [87]:
%matplotlib inline

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

from bokeh.models import *
from bokeh.plotting import *
from bokeh.io import *
from bokeh.tile_providers import *
from bokeh.palettes import *
from bokeh.transform import *
from bokeh.layouts import *

from bokeh.plotting import figure, show
from bokeh.sampledata.iris import flowers
from bokeh.models import HoverTool

model = TSNE(learning_rate=300)
transformed = model.fit_transform(h_item)

In [88]:
movies['x'] = transformed[:, 0]
movies['y'] = transformed[:, 1]
movies = movies[['title', 'x', 'y', 'genre']]

In [89]:
output_notebook()

p = figure(title = "Movie t-SNE by GNN")
p.xaxis.axis_label = 'x'
p.yaxis.axis_label = 'y'

color_column = []
for genre in movies['genre'].unique().tolist():
    color_column.append(genre)

c = p.circle(x='x', y='y', color='red', fill_alpha=0.2, size=3, source=movies)

c = p.circle(x='x', 
             y='y', 
             legend_field="genre",
             color=factor_cmap('genre', d3['Category20'][18], color_column),
             fill_alpha=1, 
             size=3, 
             source=movies)


circle_hover = HoverTool(tooltips=[('title:', '@title')], 
                         mode='mouse', 
                         point_policy='follow_mouse', 
                         renderers=[c])
circle_hover.renderers.append(c)

# mouse hover와 legend 정보 부착
p.tools.append(circle_hover)
p.legend.label_text_font_size = '5pt'
p.legend.location = 'bottom_left'

show(p)