In [1]:

from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
import torch_geometric.transforms as T

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'USING: {device}')
print(f'CUDA Version: {torch.version.cuda}')

USING: cpu
CUDA Version: None


# Import Data

In [2]:
rest = pd.read_feather('data/yelp_restaurants.feather')
reviews = pd.read_feather('data/yelp_restaurants_reviews.feather')

In [3]:
reviews['datetime'] = pd.to_datetime(reviews['date'])
reviews = reviews[reviews.datetime.dt.year == 2018].reset_index(drop=True)

In [4]:
rest = rest[rest.business_id.isin(reviews.business_id.unique())].reset_index(drop=True)
rest.drop(['address', 'state', 'postal_code', 'hours'], axis='columns', inplace=True)

In [5]:
rest.head()

Unnamed: 0,business_id,name,city,latitude,longitude,stars,review_count,is_open,attributes,categories
0,MTSW4McQd7CbVtyjqoe9mw,St Honore Pastries,Philadelphia,39.955505,-75.155564,4.0,80.0,1.0,"{'AcceptsInsurance': None, 'AgesAllowed': None...","Restaurants, Food, Bubble Tea, Coffee & Tea, B..."
1,CF33F8-E6oudUQ46HnavjQ,Sonic Drive-In,Ashland City,36.269593,-87.058943,2.0,6.0,1.0,"{'AcceptsInsurance': None, 'AgesAllowed': None...","Burgers, Fast Food, Sandwiches, Food, Ice Crea..."
2,bBDDEgkFA1Otx9Lfe7BZUQ,Sonic Drive-In,Nashville,36.208102,-86.76817,1.5,10.0,1.0,"{'AcceptsInsurance': None, 'AgesAllowed': None...","Ice Cream & Frozen Yogurt, Fast Food, Burgers,..."
3,eEOYSgkmpB90uNA7lDOMRA,Vietnamese Food Truck,Tampa Bay,27.955269,-82.45632,4.0,10.0,1.0,"{'AcceptsInsurance': None, 'AgesAllowed': None...","Vietnamese, Food, Restaurants, Food Trucks"
4,il_Ro8jwPlHresjw9EGmBg,Denny's,Indianapolis,39.637133,-86.127217,2.5,28.0,1.0,"{'AcceptsInsurance': None, 'AgesAllowed': None...","American (Traditional), Restaurants, Diners, B..."


In [6]:
# Extract unique restaurant and user IDs and create a dictionary mapping them to indices
restaurant_ids = rest['business_id'].unique().tolist()
num_restaurants = len(restaurant_ids)
restaurant_indices = dict(zip(restaurant_ids, range(num_restaurants)))

user_ids = reviews['user_id'].unique().tolist()
num_users = len(user_ids)
user_indices = dict(zip(user_ids, range(num_users)))

In [7]:
rest.insert(loc=1, column="mapped_business_id", value=rest.business_id.map(restaurant_indices))

In [8]:
reviews_merged = reviews.merge(rest[['business_id']], on='business_id', how='inner')

In [9]:
# Encode IDs
reviews_merged.insert(loc=3, column='mapped_user_id', value=reviews_merged.user_id.map(user_indices))
reviews_merged.insert(loc=4, column='mapped_business_id', value=reviews_merged.business_id.map(restaurant_indices))

In [10]:
reviews_merged.head()

Unnamed: 0,review_id,user_id,business_id,mapped_user_id,mapped_business_id,stars,useful,funny,cool,text,date,datetime
0,KU_O5udG6zpxOg-VcAEodg,mh_-eMZ6K5RLWhZyISBhwA,XQfwVwDr-v0ZS3_CbbE5Xw,0,388,3.0,0.0,0.0,0.0,"If you decide to eat here, just be aware it is...",2018-07-07 22:09:11,2018-07-07 22:09:11
1,uyS0ysaMd4mzw5rNYbgcjA,ql0XsKTjM7VeBAUqbphQDw,XQfwVwDr-v0ZS3_CbbE5Xw,1878,388,3.0,0.0,0.0,0.0,"Food is fantastic, service is quite awful! Ca...",2018-03-24 17:50:37,2018-03-24 17:50:37
2,R10wk4xEHX9r-qs5Z_2vvw,ZeBgfIMxp9K8OFmlXmQ3yA,XQfwVwDr-v0ZS3_CbbE5Xw,5139,388,3.0,0.0,0.0,0.0,Update: I deducted a star because they no long...,2018-07-21 09:26:33,2018-07-21 09:26:33
3,pDN3hRBarmGWXbK64A83MA,IBrReMAeZkVIbjZIe1E_Hw,XQfwVwDr-v0ZS3_CbbE5Xw,7057,388,1.0,0.0,0.0,0.0,never coming back here again. all of the glass...,2018-09-08 17:03:53,2018-09-08 17:03:53
4,HxWtq5q4OQ-4osStqn54bA,k4_8Cw2icH0nFV5MskGK1A,XQfwVwDr-v0ZS3_CbbE5Xw,12935,388,2.0,0.0,0.0,0.0,Unfortunately the weekend chef doesn't know ho...,2018-09-09 14:30:29,2018-09-09 14:30:29


## Restaurant features

In [11]:
rest = rest.sort_values(by='mapped_business_id')

In [12]:
cat_dummy = pd.Series(rest['categories']).str.get_dummies(', ')
cat_dummy

Unnamed: 0,Acai Bowls,Accessories,Accountants,Active Life,Acupuncture,Adult,Adult Education,Adult Entertainment,Advertising,Afghan,...,Wine & Spirits,Wine Bars,Wine Tasting Classes,Wine Tasting Room,Wine Tours,Wineries,Women's Clothing,Wraps,Yelp Events,Yoga
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
31212,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
31213,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
31214,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
31215,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [13]:
rest_features = pd.concat([rest[['stars', 'review_count']], cat_dummy], axis=1)
rest_features.drop(columns=['Restaurants', 'Food'],inplace=True,axis=1)

del cat_dummy

In [14]:
rest_features

Unnamed: 0,stars,review_count,Acai Bowls,Accessories,Accountants,Active Life,Acupuncture,Adult,Adult Education,Adult Entertainment,...,Wine & Spirits,Wine Bars,Wine Tasting Classes,Wine Tasting Room,Wine Tours,Wineries,Women's Clothing,Wraps,Yelp Events,Yoga
0,4.0,80.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,2.0,6.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1.5,10.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,4.0,10.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,2.5,28.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
31212,3.0,11.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
31213,4.0,33.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
31214,4.5,35.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
31215,4.5,14.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [15]:
rest_feature_np = rest_features.to_numpy()

In [16]:
rest_feature_np.shape

(31217, 640)

In [17]:
rest_feature_np.shape[0] == num_restaurants

True

## User features

In [18]:
users = pd.read_feather('data/yelp_restaurants_user.feather')

In [19]:
users = users[users.user_id.isin(reviews_merged.user_id.unique())]

In [20]:
users = users.drop(columns='mapped_user_id') # mapped for all dataset 
users.insert(loc=1, column='mapped_user_id', value=users.user_id.map(user_indices)) # 2018 only

In [21]:
users.info()

<class 'pandas.core.frame.DataFrame'>
Index: 290714 entries, 1 to 1416814
Data columns (total 23 columns):
 #   Column              Non-Null Count   Dtype  
---  ------              --------------   -----  
 0   user_id             290714 non-null  object 
 1   mapped_user_id      290714 non-null  int64  
 2   name                290714 non-null  object 
 3   review_count        290714 non-null  int64  
 4   yelping_since       290714 non-null  object 
 5   useful              290714 non-null  int64  
 6   funny               290714 non-null  int64  
 7   cool                290714 non-null  int64  
 8   elite               290714 non-null  object 
 9   friends             290714 non-null  object 
 10  fans                290714 non-null  int64  
 11  average_stars       290714 non-null  float64
 12  compliment_hot      290714 non-null  int64  
 13  compliment_more     290714 non-null  int64  
 14  compliment_profile  290714 non-null  int64  
 15  compliment_cute     290714 non-null  i

In [22]:
users.describe()

Unnamed: 0,mapped_user_id,review_count,useful,funny,cool,fans,average_stars,compliment_hot,compliment_more,compliment_profile,compliment_cute,compliment_list,compliment_note,compliment_plain,compliment_cool,compliment_funny,compliment_writer,compliment_photos
count,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0,290714.0
mean,145356.5,37.403194,76.668485,29.750019,47.812568,2.566632,3.732474,2.872934,0.420878,0.259403,0.131721,0.071785,2.761071,5.726195,4.861912,4.861912,1.792249,2.542031
std,83922.047416,126.241182,1084.978975,681.073432,952.948519,26.783882,1.013966,116.071651,10.829016,14.695337,7.641787,6.10344,138.631717,229.148695,147.173501,147.173501,44.14529,134.122374
min,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,72678.25,4.0,1.0,0.0,0.0,0.0,3.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,145356.5,10.0,5.0,1.0,1.0,0.0,3.93,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,218034.75,27.0,20.0,4.0,5.0,1.0,4.47,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,290713.0,17473.0,206296.0,185823.0,195814.0,3642.0,5.0,25784.0,4138.0,6411.0,2974.0,2413.0,59031.0,101097.0,49967.0,49967.0,15934.0,56104.0


In [23]:
users = users.sort_values(by='mapped_user_id')

In [24]:
user_features = users[['review_count', 'useful', 'funny', 'cool', 'fans',]]

In [25]:
user_features

Unnamed: 0,review_count,useful,funny,cool,fans
114135,33,32,3,8,0
75067,240,264,17,176,5
51580,84,75,7,38,5
80878,12,0,0,0,0
94400,156,63,20,31,3
...,...,...,...,...,...
522493,43,34,1,13,1
1389220,3,0,0,0,0
483420,4,10,0,1,0
1395021,1,0,0,0,0


In [26]:
user_features_np = user_features.to_numpy()
user_features_np.shape

(290714, 5)

In [27]:
user_features_np.shape[0] == num_users

True

# Graph creation

In [28]:
edge_index = torch.stack(
    [torch.from_numpy(reviews_merged.mapped_user_id.values), torch.from_numpy(reviews_merged.mapped_business_id.values)],
    dim=0
)

data = HeteroData()

data['user'].node_id = torch.arange(len(reviews_merged.mapped_user_id.unique()))
data['user'].x = torch.from_numpy(user_features_np)

data['restaurant'].node_id = torch.arange(len(reviews_merged.mapped_business_id.unique()))
data['restaurant'].x = torch.from_numpy(rest_feature_np)

data['user', 'rating', 'restaurant'].edge_index = edge_index
data['user', 'rating', 'restaurant'].edge_label = torch.from_numpy(reviews_merged.stars.values)

data = T.ToUndirected()(data)
del data['restaurant', 'rev_rating', 'user'].edge_label 

In [29]:
data

HeteroData(
  [1muser[0m={
    node_id=[290714],
    x=[290714, 5]
  },
  [1mrestaurant[0m={
    node_id=[31217],
    x=[31217, 640]
  },
  [1m(user, rating, restaurant)[0m={
    edge_index=[2, 596895],
    edge_label=[596895]
  },
  [1m(restaurant, rev_rating, user)[0m={ edge_index=[2, 596895] }
)

In [31]:
del rest, rest_features, rest_feature_np, users, user_features, user_features_np

In [32]:
del reviews, reviews_merged

In [33]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=0.0,
    add_negative_train_samples=False,
    edge_types=("user", "rating", "restaurant"),
    rev_edge_types=("restaurant", "rev_rating", "user"), 
)

train_data, val_data, test_data = transform(data)

In [37]:
print('Train data:', train_data)
print('Val data:', val_data)
print('Test data', test_data)

Train data: HeteroData(
  [1muser[0m={
    node_id=[290714],
    x=[290714, 5]
  },
  [1mrestaurant[0m={
    node_id=[31217],
    x=[31217, 640]
  },
  [1m(user, rating, restaurant)[0m={
    edge_index=[2, 334262],
    edge_label=[143255],
    edge_label_index=[2, 143255]
  },
  [1m(restaurant, rev_rating, user)[0m={ edge_index=[2, 334262] }
)
Val data: HeteroData(
  [1muser[0m={
    node_id=[290714],
    x=[290714, 5]
  },
  [1mrestaurant[0m={
    node_id=[31217],
    x=[31217, 640]
  },
  [1m(user, rating, restaurant)[0m={
    edge_index=[2, 477517],
    edge_label=[59689],
    edge_label_index=[2, 59689]
  },
  [1m(restaurant, rev_rating, user)[0m={ edge_index=[2, 477517] }
)
Test data HeteroData(
  [1muser[0m={
    node_id=[290714],
    x=[290714, 5]
  },
  [1mrestaurant[0m={
    node_id=[31217],
    x=[31217, 640]
  },
  [1m(user, rating, restaurant)[0m={
    edge_index=[2, 537206],
    edge_label=[59689],
    edge_label_index=[2, 59689]
  },
  [1m(restauran

In [34]:
# Define the training seed changes
edge_label_index = train_data["user", "rating", "restaurant"].edge_label_index
edge_label = train_data["user", "rating", "restaurant"].edge_label
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 10],
    edge_label_index=(("user", "rating", "restaurant"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

# Define the validation seed edges:
edge_label_index = val_data["user", "rating", "restaurant"].edge_label_index
edge_label = val_data["user", "rating", "restaurant"].edge_label
val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[20, 10],
    edge_label_index=((["user", "rating", "restaurant"]), edge_label_index),
    edge_label=edge_label,
    batch_size=3 * 128,
    shuffle=False,
)
sampled_data = next(iter(val_loader))

In [35]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(hidden_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

# Our final regressor applies the dot-product between source and destination
# node embeddings to derive edge-level predictions:
class Regressor(torch.nn.Module):
    def forward(self, x_user, x_restaurant, edge_label_index):
        # Convert node embeddings to edge-level representations:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_restaurant = x_restaurant[edge_label_index[1]]
        # Apply dot-product to get a prediction per supervision edge:
        embeds = (edge_feat_user * edge_feat_restaurant).sum(dim=-1)
        return embeds

class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and items:
        self.restaurant_lin = torch.nn.Linear(20, hidden_channels)
        self.user_emb = torch.nn.Embedding(data["user"].num_nodes, hidden_channels)
        self.restaurant_emb = torch.nn.Embedding(data["restaurant"].num_nodes, hidden_channels)
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        self.classifier = Regressor()

    def forward(self, data: HeteroData):
        x_dict = {
          "user": self.user_emb(data["user"].node_id),
          "restaurant": self.restaurant_emb(data["restaurant"].node_id),
        } 

        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["user"],
            x_dict["restaurant"],
            data["user", "rating", "restaurant"].edge_label_index,
        )
        return pred


In [36]:
model = Model(hidden_channels=32)
model

Model(
  (restaurant_lin): Linear(in_features=20, out_features=32, bias=True)
  (user_emb): Embedding(290714, 32)
  (restaurant_emb): Embedding(31217, 32)
  (gnn): GraphModule(
    (conv1): ModuleDict(
      (user__rating__restaurant): SAGEConv(32, 32, aggr=mean)
      (restaurant__rev_rating__user): SAGEConv(32, 32, aggr=mean)
    )
    (dropout): ModuleDict(
      (user): Dropout(p=0.5, inplace=False)
      (restaurant): Dropout(p=0.5, inplace=False)
    )
    (conv2): ModuleDict(
      (user__rating__restaurant): SAGEConv(32, 32, aggr=mean)
      (restaurant__rev_rating__user): SAGEConv(32, 32, aggr=mean)
    )
  )
  (classifier): Regressor()
)

In [49]:
# Training code adapted to include in-training validation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

min_val_loss = np.inf
min_train_loss = np.inf
prev_val_counts = 0
flag = False
best_base_model = None
n_epoch = 10

for epoch in range(1, n_epoch + 1):
    if flag:
        break
    total_loss = total_examples = 0
    for sampled_data in tqdm(train_loader):
        optimizer.zero_grad()
        sampled_data.to(device)
        pred = model(sampled_data)
        ground_truth = sampled_data["user", "rating", "restaurant"].edge_label
        loss = F.mse_loss(pred,ground_truth.float())
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    avg_train_loss = total_loss / total_examples
    print(f"Epoch: {epoch:03d}, Training Loss: {avg_train_loss:.4f}")

    # Calculate validation loss and stop when our loss starts increasing
    total_val_loss = total_val_examples = 0
    model.eval()
    for sampled_data in tqdm(val_loader):
        with torch.no_grad():
            sampled_data.to(device)
            pred = torch.clamp(model(sampled_data), min = 1, max = 5) # Ratings should be minimum 1 and maximum 5
            ground_truth = sampled_data["user", "rating", "restaurant"].edge_label
            loss = F.mse_loss(pred,ground_truth.float())
            total_val_loss += float(loss) * pred.numel()
            total_val_examples += pred.numel()
    avg_val_loss = total_val_loss / total_val_examples
    print(f"Epoch: {epoch:03d}, Validation Loss: {avg_val_loss:.4f}")

    if avg_val_loss > min_val_loss:
        prev_val_counts+=1
        if prev_val_counts==10:
            flag = True
    else:
        min_val_loss = avg_val_loss
        min_train_loss = avg_train_loss
        prev_val_counts =0
        best_base_model = model

Device: 'cpu'


100%|██████████| 1120/1120 [05:32<00:00,  3.37it/s]


Epoch: 001, Training Loss: 2.5824


100%|██████████| 156/156 [00:34<00:00,  4.50it/s]


Epoch: 001, Validation Loss: 2.0591


100%|██████████| 1120/1120 [05:32<00:00,  3.37it/s]


Epoch: 002, Training Loss: 1.9456


100%|██████████| 156/156 [00:35<00:00,  4.36it/s]


Epoch: 002, Validation Loss: 1.9753


100%|██████████| 1120/1120 [05:33<00:00,  3.36it/s]


Epoch: 003, Training Loss: 1.7464


100%|██████████| 156/156 [00:37<00:00,  4.15it/s]


Epoch: 003, Validation Loss: 1.9674


100%|██████████| 1120/1120 [05:45<00:00,  3.24it/s]


Epoch: 004, Training Loss: 1.4850


100%|██████████| 156/156 [00:35<00:00,  4.42it/s]


Epoch: 004, Validation Loss: 2.0044


100%|██████████| 1120/1120 [05:17<00:00,  3.53it/s]


Epoch: 005, Training Loss: 1.2079


100%|██████████| 156/156 [00:33<00:00,  4.59it/s]


Epoch: 005, Validation Loss: 2.0773


100%|██████████| 1120/1120 [05:38<00:00,  3.31it/s]


Epoch: 006, Training Loss: 0.9362


100%|██████████| 156/156 [00:38<00:00,  4.03it/s]


Epoch: 006, Validation Loss: 2.1655


100%|██████████| 1120/1120 [05:30<00:00,  3.39it/s]


Epoch: 007, Training Loss: 0.7088


100%|██████████| 156/156 [00:32<00:00,  4.87it/s]


Epoch: 007, Validation Loss: 2.2599


100%|██████████| 1120/1120 [04:53<00:00,  3.81it/s]


Epoch: 008, Training Loss: 0.5409


100%|██████████| 156/156 [00:31<00:00,  4.92it/s]


Epoch: 008, Validation Loss: 2.2824


100%|██████████| 1120/1120 [04:54<00:00,  3.80it/s]


Epoch: 009, Training Loss: 0.4186


100%|██████████| 156/156 [00:32<00:00,  4.87it/s]


Epoch: 009, Validation Loss: 2.4002


100%|██████████| 1120/1120 [05:24<00:00,  3.45it/s]


Epoch: 010, Training Loss: 0.3321


100%|██████████| 156/156 [00:35<00:00,  4.38it/s]

Epoch: 010, Validation Loss: 2.3475





In [50]:
print(f"Validation Loss: {min_val_loss:.4f}, Training Loss: {min_train_loss:.4f}")

Validation Loss: 1.9674, Training Loss: 1.7464


In [51]:
# Calculate test loss
test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=[20, 10],
    edge_label_index=((["user", "rating", "restaurant"]), edge_label_index),
    edge_label=edge_label,
    batch_size=3 * 128,
    shuffle=False,
)

total_test_loss = total_test_examples = 0
best_base_model.eval()
for sampled_data in tqdm(test_loader):
    with torch.no_grad():
        sampled_data.to(device)
        pred = torch.clamp(best_base_model(sampled_data), min = 1, max = 5) # Ratings should be minimum 1 and maximum 5
        
        ground_truth = sampled_data["user", "rating", "restaurant"].edge_label
        loss = F.mse_loss(pred,ground_truth.float())
        total_test_loss += float(loss) * pred.numel()
        total_test_examples += pred.numel()
avg_test_loss = total_test_loss / total_test_examples
print(f"Test Loss: {avg_test_loss:.4f}")

100%|██████████| 156/156 [00:37<00:00,  4.14it/s]

Test Loss: 2.4504





In [52]:
best_base_model

Model(
  (restaurant_lin): Linear(in_features=20, out_features=64, bias=True)
  (user_emb): Embedding(290714, 64)
  (restaurant_emb): Embedding(31217, 64)
  (gnn): GraphModule(
    (conv1): ModuleDict(
      (user__rating__restaurant): SAGEConv(64, 64, aggr=mean)
      (restaurant__rev_rating__user): SAGEConv(64, 64, aggr=mean)
    )
    (dropout): ModuleDict(
      (user): Dropout(p=0.5, inplace=False)
      (restaurant): Dropout(p=0.5, inplace=False)
    )
    (conv2): ModuleDict(
      (user__rating__restaurant): SAGEConv(64, 64, aggr=mean)
      (restaurant__rev_rating__user): SAGEConv(64, 64, aggr=mean)
    )
  )
  (classifier): Regressor()
)

In [54]:
torch.save(model.state_dict(), 'output/model-hetero-node-features.pth')
torch.save(best_base_model.state_dict(), 'output/best-model-hetero-node-features.pth')