In [39]:
from collections import Counter
from pathlib import Path
import pickle

import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import dgl
from tqdm.auto import tqdm
from scipy.special import softmax

from train import MultiSAGEModel

In [2]:
datapath = Path('..', 'data')

In [3]:
with open(datapath / 'dataset.pkl', 'rb') as f:
    dataset = pickle.load(f)

In [4]:
with open(datapath / 'cuisine.pkl', 'rb') as f:
    cuisine = pickle.load(f)

with open(datapath / 'ingredient.pkl', 'rb') as f:
    ingredient = pickle.load(f)

In [5]:
graph = dgl.load_graphs(str(datapath / 'graph.bin'))[0][0]
graph

Graph(num_nodes={'cuisine': 20, 'ingredient': 6714, 'recipe': 23547},
      num_edges={('cuisine', 'c2r', 'recipe'): 23547, ('ingredient', 'i2r', 'recipe'): 253459, ('recipe', 'r2c', 'cuisine'): 23547, ('recipe', 'r2i', 'ingredient'): 253459},
      metagraph=[('cuisine', 'recipe', 'c2r'), ('recipe', 'cuisine', 'r2c'), ('recipe', 'ingredient', 'r2i'), ('ingredient', 'recipe', 'i2r')])

In [6]:
model = MultiSAGEModel(
    graph,
    'ingredient',
    'recipe',
    'cuisine',
    256,
    2,
    3,
)
model.load_state_dict(torch.load(
    Path('..', 'saved', '39_3model_state_dict.pt'),
    map_location=torch.device('cpu')
))

<All keys matched successfully>

In [7]:
ingredient = ingredient.detach().numpy().astype('float')

In [8]:
h_item = ingredient

In [13]:
h_item[[1,2]].shape

(2, 256)

In [15]:
ingredients = dataset['X_cmp_val'][0]

In [21]:
biases = model.node_scorer.bias.detach().numpy().astype('float')

In [36]:
def ingredient_scores(ingredients):
    scores = np.dot(h_item[ingredients], h_item.T) + biases[ingredients].reshape(-1, 1) + biases
    scores = scores.sum(axis=0)
    return scores

In [37]:
all_scores = []
for i, recipe in enumerate(tqdm(dataset['X_cmp_val'])):
    all_scores.append(ingredient_scores(recipe))

  0%|          | 0/7848 [00:00<?, ?it/s]

In [40]:
all_scores = np.array(all_scores)

In [41]:
all_scores.shape

(7848, 6714)

In [47]:
probs = softmax(all_scores, axis=1)


In [51]:
probs[dataset['y_cmp_val']]

array([[2.22572195e-203, 1.05824214e-150, 1.17421788e-122, ...,
        4.93332991e-170, 1.06091453e-191, 7.32989675e-204],
       [1.17168366e-156, 1.00643402e-152, 3.72235551e-158, ...,
        4.67394216e-141, 8.00419600e-174, 6.26580488e-171],
       [2.78212096e-155, 1.08686225e-180, 5.69258488e-192, ...,
        2.42419642e-159, 3.86004957e-188, 7.64327165e-112],
       ...,
       [2.67451646e-105, 2.71072200e-093, 1.91153142e-142, ...,
        2.42757458e-163, 3.35363696e-121, 1.44116514e-114],
       [1.68179142e-160, 6.42721325e-158, 4.59639033e-151, ...,
        7.88974860e-155, 2.89946795e-145, 2.16222971e-149],
       [2.17525226e-115, 2.47559895e-110, 7.83849252e-138, ...,
        1.54931709e-120, 5.70752661e-152, 2.08768229e-130]])

In [None]:
dataset['y_cmp_val']