**------------------------------------------------------------------------------------------------------------------------------------------------------**

k-fold cross-validation

**Input: Knowledge Graph**

**GNN-based Link Prediction Model: GCN, GraphSAGE, and GAT**

**Output: "Food" Embeddings**

**------------------------------------------------------------------------------------------------------------------------------------------------------**

# Libraries

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import os
import re
import random
import itertools
import warnings
warnings.simplefilter("ignore")

import torch_geometric
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, Linear, to_hetero

import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F

In [2]:
#vModel = 'GCN'
vModel = 'GraphSAGE'
#vModel = 'GAT'

In [3]:
k_fold = 10
n_input_feat = 10
n_epochs = 201

# 1) Create HeteroGraph

In [4]:
def load_node_csv(path, index_col, encoders=None, **kwargs):
    df = pd.read_csv(path, index_col=index_col, **kwargs)
    mapping = {index: i for i, index in enumerate(df.index.unique())}

    x = None
    if encoders is not None:
        xs = [encoder(df[col]) for col, encoder in encoders.items()]
        x = torch.cat(xs, dim=-1)

    return x, mapping

In [5]:
def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping, encoders=None, **kwargs):
    df = pd.read_csv(path, **kwargs)

    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    edge_index = torch.tensor([src, dst])

    edge_attr = None
    if encoders is not None:
        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
        edge_attr = torch.cat(edge_attrs, dim=-1)

    return edge_index, edge_attr

In [6]:
_, food_mapping = load_node_csv('../Input Data/data/all_foods.csv', index_col='subject')
_, nutrient_mapping = load_node_csv('../Input Data/data/df_food_nutrient.csv', index_col='object')
_, tag_mapping = load_node_csv('../Input Data/data/df_food_tag.csv', index_col='object')
_, category_mapping = load_node_csv('../Input Data/data/df_food_cat.csv', index_col='object')
_, flavor_mapping = load_node_csv('../Input Data/data/df_food_flavor.csv', index_col='object')
_, product_mapping = load_node_csv('../Input Data/data/df_product_ingredient.csv', index_col='subject')
_, ingredient_mapping = load_node_csv('../Input Data/data/df_product_ingredient.csv', index_col='object')

In [7]:
food_subs_index, _ = load_edge_csv('../Input Data/data/df_food_subs.csv', src_index_col='source_id', src_mapping=food_mapping, dst_index_col='destination_id', dst_mapping=food_mapping)
subs_food_index, _ = load_edge_csv('../Input Data/data/df_food_subs.csv', src_index_col='destination_id', src_mapping=food_mapping, dst_index_col='source_id', dst_mapping=food_mapping)
food_nutrient_index, _ = load_edge_csv('../Input Data/data/df_food_nutrient.csv', src_index_col='subject', src_mapping=food_mapping, dst_index_col='object', dst_mapping=nutrient_mapping)
food_tag_index, _ = load_edge_csv('../Input Data/data/df_food_tag.csv', src_index_col='subject', src_mapping=food_mapping, dst_index_col='object', dst_mapping=tag_mapping)
food_cat_index, _ = load_edge_csv('../Input Data/data/df_food_cat.csv', src_index_col='subject', src_mapping=food_mapping, dst_index_col='object', dst_mapping=category_mapping)
food_flavor_index, _ = load_edge_csv('../Input Data/data/df_food_flavor.csv', src_index_col='subject', src_mapping=food_mapping, dst_index_col='object', dst_mapping=flavor_mapping)
product_ingredient_index, _ = load_edge_csv('../Input Data/data/df_product_ingredient.csv', src_index_col='subject', src_mapping=product_mapping, dst_index_col='object', dst_mapping=ingredient_mapping)
food_ingredient_index, _ = load_edge_csv('../Input Data/data/df_food_ingredient.csv', src_index_col='subject', src_mapping=food_mapping, dst_index_col='object', dst_mapping=ingredient_mapping)
ingredient_food_index, _ = load_edge_csv('../Input Data/data/df_food_ingredient.csv', src_index_col='object', src_mapping=ingredient_mapping, dst_index_col='subject', dst_mapping=food_mapping)

In [8]:
data = HeteroData()

data['Food'].num_nodes = len(food_mapping)
data['Nutrient'].num_nodes = len(nutrient_mapping)
data['Tag'].num_nodes = len(tag_mapping)
data['Category'].num_nodes = len(category_mapping)
data['Flavor'].num_nodes = len(flavor_mapping)
#data['Product'].num_nodes = len(product_mapping)
#data['Ingredient'].num_nodes = len(ingredient_mapping)

data['Food', 'isSubstitutedBy', 'Food'].edge_index = food_subs_index
data['Food', 'substitutes', 'Food'].edge_index = subs_food_index
data['Food', 'containsNutrient', 'Nutrient'].edge_index = food_nutrient_index
data['Food', 'hasTag', 'Tag'].edge_index = food_tag_index
data['Food', 'isInCategory', 'Category'].edge_index = food_cat_index
data['Food', 'hasFlavor', 'Flavor'].edge_index = food_flavor_index
#data['Product', 'containsIngredient', 'Ingredient'] = product_ingredient_index
#data['Food', 'sameAs', 'Ingredient'].edge_index = food_ingredient_index
#data['Ingredient', 'sameAs', 'Food'].edge_index = ingredient_food_index

In [9]:
data

HeteroData(
  [1mFood[0m={ num_nodes=9372 },
  [1mNutrient[0m={ num_nodes=63883 },
  [1mTag[0m={ num_nodes=25 },
  [1mCategory[0m={ num_nodes=13 },
  [1mFlavor[0m={ num_nodes=272 },
  [1m(Food, isSubstitutedBy, Food)[0m={ edge_index=[2, 1841] },
  [1m(Food, substitutes, Food)[0m={ edge_index=[2, 1841] },
  [1m(Food, containsNutrient, Nutrient)[0m={ edge_index=[2, 300523] },
  [1m(Food, hasTag, Tag)[0m={ edge_index=[2, 17746] },
  [1m(Food, isInCategory, Category)[0m={ edge_index=[2, 1667] },
  [1m(Food, hasFlavor, Flavor)[0m={ edge_index=[2, 11167] }
)

# 2) Run Model

**GCN**

In [10]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(-1, hidden_channels)
        self.conv2 = GCNConv(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

**GraphSAGE**

In [11]:
class GraphSAGEEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

**GAT**

In [12]:
class GATEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
        self.lin1 = Linear(n_input_feat, hidden_channels)
        self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index) + self.lin1(x)
        x = x.relu()
        x = self.conv2(x, edge_index) + self.lin2(x)
        return x

**------------------------------------------------------------------------------------------------------------------------------------------------------**

In [13]:
class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        if vModel == 'GCN':
            self.encoder = GCNEncoder(hidden_channels, hidden_channels)
        elif vModel == 'GraphSAGE':
            self.encoder = GraphSAGEEncoder(hidden_channels, hidden_channels)
        elif vModel == 'GAT': 
            self.encoder = GATEncoder(hidden_channels, hidden_channels)
        else:
            'No Model is chosen !'
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return z_dict, self.decoder(z_dict, edge_label_index)

In [14]:
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['Food'][row], z_dict['Food'][col]], dim=-1)
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)

In [15]:
def mse_loss(pred, target):
    return (pred - target.to(pred.dtype)).pow(2).mean()

**------------------------------------------------------------------------------------------------------------------------------------------------------**

**Add random Node Features**

In [16]:
node_types, edge_types = data.metadata()
for node in node_types: 
    data[node].x = torch.rand(data[node].num_nodes, n_input_feat)

**------------------------------------------------------------------------------------------------------------------------------------------------------**

In [17]:
def train(data):    
    print('------------------------------------------------------------------')
    for epoch in range(n_epochs):

        #forward
        h, pred = model(data.x_dict, data.edge_index_dict, data['Food', 'isSubstitutedBy', 'Food'].edge_index)
        target = data['Food', 'isSubstitutedBy', 'Food'].edge_label

        #loss
        loss = mse_loss(pred, target)

        #backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, _, train_rmse = eval_model(data)
        
        if epoch % 10 == 0:
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}')

In [18]:
@torch.no_grad()
def eval_model(data):
    model.eval()
    h, pred = model(data.x_dict, data.edge_index_dict, data['Food', 'isSubstitutedBy', 'Food'].edge_index)
    pred = pred.clamp(min=0, max=1)
    target = data['Food', 'isSubstitutedBy', 'Food'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return h, pred, float(rmse)

In [19]:
def get_test_foods(data, i):
    test_u = data['Food', 'isSubstitutedBy', 'Food']['edge_label_index'][0]

    reverse_food_mapping = dict(zip(food_mapping.values(),food_mapping.keys()))
    reverse_test_u = []

    for u in test_u:
        reverse_test_u.append(reverse_food_mapping[u.item()])

    foods_2_test = pd.DataFrame()
    foods_2_test['id'] = reverse_test_u
    foods_2_test.to_csv(f'../Output/k_fold/{vModel}_{i}_foods_2_test.csv')

In [20]:
for i in range(k_fold):    
    
    #Split into Train, Test, Val Sets
    train_data, val_data, test_data = T.RandomLinkSplit(
        num_val=0.1,
        num_test=0.1,
        neg_sampling_ratio=0,
        edge_types=[('Food', 'isSubstitutedBy', 'Food')],
        rev_edge_types=[('Food', 'substitutes', 'Food')],
    )(data)

    #Initialize Model
    model = Model(hidden_channels=32)

    with torch.no_grad():
        model.encoder(train_data.x_dict, train_data.edge_index_dict)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    train_data['Food', 'isSubstitutedBy', 'Food'].edge_label = torch.ones(train_data['Food', 'isSubstitutedBy', 'Food'].num_edges)
    test_data['Food', 'isSubstitutedBy', 'Food'].edge_label = torch.ones(test_data['Food', 'isSubstitutedBy', 'Food'].num_edges)
    val_data['Food', 'isSubstitutedBy', 'Food'].edge_label = torch.ones(val_data['Food', 'isSubstitutedBy', 'Food'].num_edges)
    
    #Train Model
    train(train_data)

    #Remember the Foods that are contained in the Test Set
    get_test_foods(test_data, i)
    
    #Get Embeddings from the Test Foods
    with torch.no_grad():
        h, pred, rmse = eval_model(test_data)

    foods = food_mapping.keys()
    food_embeddings = dict(zip(foods, h['Food']))

    for j in range(2):
        fw = open(f'../Output/k_fold/{vModel}_{i}_food_embeddings.txt','w')
        fw.write(str(len(foods))+' '+str(len(h['Food'][0]))+'\n')
        for food in foods:
            fw.write(food+' ')
            for i in range(len(h['Food'][0])):
                value = str(food_embeddings[food][i].item()).strip()
                fw.write(value+' ')
            fw.write('\n')

------------------------------------------------------------------
Epoch: 000, Loss: 0.8723, Train: 0.6578
Epoch: 010, Loss: 0.0173, Train: 0.0244
Epoch: 020, Loss: 0.0112, Train: 0.0547
Epoch: 030, Loss: 0.0155, Train: 0.0930
Epoch: 040, Loss: 0.0068, Train: 0.0713
Epoch: 050, Loss: 0.0044, Train: 0.0630
Epoch: 060, Loss: 0.0029, Train: 0.0491
Epoch: 070, Loss: 0.0025, Train: 0.0421
Epoch: 080, Loss: 0.0022, Train: 0.0379
Epoch: 090, Loss: 0.0019, Train: 0.0351
Epoch: 100, Loss: 0.0017, Train: 0.0324
Epoch: 110, Loss: 0.0015, Train: 0.0295
Epoch: 120, Loss: 0.0013, Train: 0.0268
Epoch: 130, Loss: 0.0012, Train: 0.0258
Epoch: 140, Loss: 0.0010, Train: 0.0244
Epoch: 150, Loss: 0.0009, Train: 0.0230
Epoch: 160, Loss: 0.0008, Train: 0.0220
Epoch: 170, Loss: 0.0008, Train: 0.0209
Epoch: 180, Loss: 0.0007, Train: 0.0200
Epoch: 190, Loss: 0.0006, Train: 0.0191
Epoch: 200, Loss: 0.0006, Train: 0.0183
------------------------------------------------------------------
Epoch: 000, Loss: 1.6669, 