In [9]:
import json
from tqdm import tqdm

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import ndcg_score

import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, HeteroConv, to_hetero
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.utils.convert import to_networkx
import torch_geometric.transforms as T


In [1]:
from torch_geometric.datasets import MovieLens
dataset = MovieLens('data/movielens/', model_name='all-MiniLM-L6-v2')


In [12]:
data = torch.load('data/yelp-hetero.pt')

In [13]:
data

HeteroData(
  [1muser[0m={
    node_id=[290714],
    x=[290714, 5]
  },
  [1mrestaurant[0m={
    node_id=[31217],
    x=[31217, 640]
  },
  [1m(user, rating, restaurant)[0m={
    edge_index=[2, 596895],
    edge_label=[596895]
  },
  [1m(restaurant, rev_rating, user)[0m={ edge_index=[2, 596895] }
)

In [35]:
weight = torch.bincount(data['user', 'rating', 'restaurant'].edge_label.int())
weight = weight.max() / weight

In [36]:
weight

tensor([   inf, 3.8545, 6.1450, 4.8445, 2.3577, 1.0000])

In [10]:
with open('data/yelp-hetero-meta.json', 'r') as f:
    meta_data = json.load(f)
num_users = meta_data['num_users']
num_restaurants = meta_data['num_restaurants']
user_mapping = meta_data['user_mapping']
restaurant_mapping = meta_data['restaurant_mapping']

In [14]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = nn.Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['restaurant'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)

class RecoModel(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

In [15]:
model = RecoModel(64)
model.load_state_dict(torch.load('output/recommender-yelp-final.pth', map_location=torch.device('cpu'))) 
model.eval()

RecoModel(
  (encoder): GraphModule(
    (conv1): ModuleDict(
      (user__rating__restaurant): SAGEConv((-1, -1), 64, aggr=mean)
      (restaurant__rev_rating__user): SAGEConv((-1, -1), 64, aggr=mean)
    )
    (conv2): ModuleDict(
      (user__rating__restaurant): SAGEConv((-1, -1), 64, aggr=mean)
      (restaurant__rev_rating__user): SAGEConv((-1, -1), 64, aggr=mean)
    )
  )
  (decoder): EdgeDecoder(
    (lin1): Linear(in_features=128, out_features=64, bias=True)
    (lin2): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [37]:
reverse_movie_mapping = dict(zip(restaurant_mapping.values(),restaurant_mapping.keys()))
reverse_user_mapping = dict(zip(user_mapping.values(),user_mapping.keys()))

results = []

for user_id in range(0, 1): 

    row = torch.tensor([user_id] * num_restaurants)
    col = torch.arange(num_restaurants)
    edge_label_index = torch.stack([row, col], dim=0)
    edge_label_index
    
    print(edge_label_index)
    data.to('cpu')
    pred = model(data.x_dict, data.edge_index_dict,
                 edge_label_index)
    pred = pred.clamp(min=0, max=5)
    
    user_id_str = reverse_user_mapping[user_id]

    mask = (pred == 5).nonzero(as_tuple=True)

    ten_predictions = [reverse_movie_mapping[el] for el in  mask[0].tolist()[:10]]
    results.append({'user': user_id_str, 'restaurant': ten_predictions})

tensor([[    0,     0,     0,  ...,     0,     0,     0],
        [    0,     1,     2,  ..., 31214, 31215, 31216]])


In [45]:
data['user', 'rating', 'restaurant'].edge_index[0]

tensor([     0,   1878,   5139,  ...,   3514,  48695, 290691])

In [42]:
[x for x in data['user', 'rating', 'restaurant'] if x[0] == 0]

[]

In [17]:
results

[{'user': 'mh_-eMZ6K5RLWhZyISBhwA',
  'restaurant': ['LGqiubTmpJ-A1L5n7dmc6g',
   '_T0cPZE2ZJOTTlYYKMP64Q',
   '1E9o1SNo7UTf1XHTFPv1_Q',
   'FF45pKN_lzqG8Bqk-_HQvw',
   'adATTqggIQX5xxLDISkFTw',
   'IazLGcO9aggJnMMa_5UO1Q',
   'Xjal8g4PsYinAfeQ8RWf4Q',
   'U30ggGzFpXvc2NZYwOW3qg',
   'Pb5agnsD9EdCl6yuZp2jJA',
   'ruFtZKwlJASx5BTk1dh5AQ']},
 {'user': 'RreNy--tOmXMl1en0wiBOg',
  'restaurant': ['lk9IwjZXqUMqqOhM774DtQ',
   'knQ4vIgx-r85kjlWVVjcpQ',
   'LGqiubTmpJ-A1L5n7dmc6g',
   '1E9o1SNo7UTf1XHTFPv1_Q',
   'adATTqggIQX5xxLDISkFTw',
   'sophKEDc2rBDe-cuOaJDkA',
   '2DsplH_vy4GCcEnVpn0AbA',
   'biGIDbCGsAZJ-Y4zyV_b_A',
   'lj-E32x9_FA7GmUrBGBEWg',
   'IazLGcO9aggJnMMa_5UO1Q']}]