In [None]:
from collections import Counter
from pathlib import Path
import pickle

import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import dgl
from tqdm.auto import tqdm
from scipy.special import softmax
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score

from train import MultiSAGEModel

In [None]:
datapath = Path('..', 'data')

In [None]:
with open(datapath / 'dataset.pkl', 'rb') as f:
    dataset = pickle.load(f)

In [None]:
with open(datapath / 'cuisine.pkl', 'rb') as f:
    cuisine = pickle.load(f)

with open(datapath / 'ingredient.pkl', 'rb') as f:
    ingredient = pickle.load(f)

In [None]:
graph = dgl.load_graphs(str(datapath / 'graph.bin'))[0][0]
graph

In [None]:
model = MultiSAGEModel(
    graph,
    'ingredient',
    'recipe',
    'cuisine',
    256,
    2,
    3,
)
model.load_state_dict(torch.load(
    Path('..', 'saved', '39_3model_state_dict.pt'),
    map_location=torch.device('cpu')
))

In [None]:
h_item = ingredient.detach().numpy().astype('float')
biases = model.node_scorer.bias.detach().numpy().astype('float')

In [None]:
def ingredient_scores(ingredients):
    scores = (
        np.dot(h_item[ingredients], h_item.T)
        + biases[ingredients].reshape(-1, 1)
        + biases
    ).sum(axis=0)
    return scores

In [None]:
all_scores = []
for i, recipe in enumerate(tqdm(dataset['X_cmp_val'])):
    all_scores.append(ingredient_scores(recipe))

In [None]:
all_scores = np.array(all_scores)

In [None]:
probs = softmax(all_scores, axis=1)

In [None]:
y_ans = np.array(dataset['y_cmp_val'])

In [None]:
ranks = all_scores.argsort()

In [None]:
len(ranks)

In [None]:
ranks.shape

In [None]:
all_ranks = []
for i in range(len(ranks)):
    rank = 6714 - np.where(ranks[i] == y_ans[i])[0][0]
    all_ranks.append(rank)

In [None]:
plt.hist(all_ranks, bins=30)

In [None]:
all_test_scores = []
for i, recipe in enumerate(tqdm(dataset['X_cmp_test'])):
    all_test_scores.append(ingredient_scores(recipe))

In [None]:
all_test_scores = np.array(all_test_scores)

In [None]:
all_ranks = all_test_scores.argsort(axis=-1)

In [None]:
pred = []
for i in range(all_ranks.shape[0]):
    for item in all_ranks[i]:
        if item not in dataset['X_cmp_test'][i]:
            pred.append(item)
            break

In [None]:
with open('../results/test_completion_answer.csv', 'wt') as f:
    for item in pred:
        f.write(f'{item}\n')