# Explainable Graph Neural Network Recommendation

This notebook will use the tutorial notebook for working with Graph Neural Networks for Link Regression using PyTorch Geometric provided in the documentation: https://pytorch-geometric.readthedocs.io/en/latest/get_started/colabs.html

Applying the library and model instead to the Anime Recommendations Dataset provided: https://www.kaggle.com/datasets/CooperUnion/anime-recommendations-database

In [None]:
import torch
from torch import Tensor
print(torch.__version__)

In [None]:
# Install required packages
import os

os.environ['TORCH'] = torch.__version__
!pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

!pip install sentence_transformers
!pip install captum

In [None]:
# packages
!pip3 install fuzzywuzzy[speedup]


## Loading Data / EDA

This section examines the data provided from Kaggle. The data is loaded and a breakdown is given of different genres and types of anime.

There is also an indication of average rating across these features.

In [5]:
import pandas as pd

directory_path = 'path'
anime = pd.read_csv(directory_path+'anime.csv')
ratings = pd.read_csv(directory_path + 'rating.csv')
ratings = ratings[ratings['rating']>=0]
anime.head()

Unnamed: 0,anime_id,name,genre,type,episodes,rating,members
0,32281,Kimi no Na wa.,"Drama, Romance, School, Supernatural",Movie,1,9.37,200630
1,5114,Fullmetal Alchemist: Brotherhood,"Action, Adventure, Drama, Fantasy, Magic, Mili...",TV,64,9.26,793665
2,28977,Gintama°,"Action, Comedy, Historical, Parody, Samurai, S...",TV,51,9.25,114262
3,9253,Steins;Gate,"Sci-Fi, Thriller",TV,24,9.17,673572
4,9969,Gintama&#039;,"Action, Comedy, Historical, Parody, Samurai, S...",TV,51,9.16,151266


In [6]:
ratings = ratings[ratings['user_id'].isin(ratings['user_id'].unique()[:6000])].reset_index(drop=True)
anime = anime[anime['anime_id'].isin(ratings['anime_id'])].reset_index(drop=True)

In [7]:
from fuzzywuzzy import fuzz

# Specify your userId
our_user_id = 999999999

print('Most rated anime:')
print('==================')
most_rated_anime = ratings['anime_id'].value_counts().head(20)
print(anime[anime['anime_id'].isin(most_rated_anime)]['name'])

# Initialize your rating list
my_ratings = []

Most rated anime:
399                                 Byousoku 5 Centimeter
433                               Saiyuuki Reload: Burial
453                               Flanders no Inu (Movie)
1363                City Hunter: Hyakuman Dollar no Inbou
1623                   Legendz: Yomigaeru Ryuuou Densetsu
2243                                    Black Cat Special
2602                         Kindaichi Shounen no Jikenbo
2769                                   Rosario to Vampire
2966         Urusei Yatsura Movie 6: Itsudatte My Darling
3060                                      Queen Emeraldas
3100         Pokemon Crystal: Raikou Ikazuchi no Densetsu
3236    Mobile Suit Gundam MS IGLOO: The Hidden One Ye...
4428                                Kaze no Na wa Amnesia
4560                  Pokemon: Senritsu no Mirage Pokemon
4688                                   To Heart 2 Special
5112                                               Wild 7
6075                             Hyper-Psychic Geo Gar

In [None]:
# Add your ratings here:
num_ratings = 5
while len(my_ratings) < num_ratings:
    print(f'Select the {len(ratings) + 1}. movie:')
    print('=====================================')
    movie = input('Please enter the anime title: ')
    anime['name_score'] = anime['name'].apply(lambda x: fuzz.ratio(x, movie))
    print(anime.sort_values('name_score', ascending=False)[['name']].head(5))
    movie_id = input('Please enter the movie id: ')
    if not movie_id:
        continue
    movie_id = int(movie_id)
    rating = float(input('Please enter your rating: '))
    if not rating:
        continue
    assert 0 <= rating <= 10
    my_ratings.append({'anime_id': movie_id, 'rating': rating, 'user_id': our_user_id})
print('Complete')

In [9]:
ratings = pd.concat([ratings, pd.DataFrame.from_records(my_ratings)])

In [10]:
import plotly.express as px

rating_aggregate_type = anime.groupby('type').rating.median().reset_index()
px.bar(rating_aggregate_type, 'type', 'rating', template='simple_white', title = 'Average Rating by Type')

In [11]:
# find the top ten genres
anime.dropna(subset='genre', inplace=True)
genre_dict = {}
for i, r in anime.iterrows():
  genres = r['genre'].split(', ')
  for g in genres:
    if genre_dict.get(g) is None:
      genre_dict.update({g: 1})
    else:
      genre_dict.update({g: genre_dict[g]+1})

top_ten_genres = sorted(genre_dict, key=genre_dict.get, reverse=True)[:10]

for i, r in anime.iterrows():
  genres = r['genre'].split(', ')
  for genre in genres:
    if genre in top_ten_genres:
      if genre not in anime.columns:
        anime[genre] = 0
      anime.loc[i, genre] = 1

genre_vis = pd.melt(anime, id_vars = ['anime_id', 'name', 'genre', 'type', 'episodes', 'rating', 'members'])
genre_vis = genre_vis[genre_vis['value']>0]
genre_vis = genre_vis.groupby('variable').rating.median().reset_index()
px.bar(genre_vis, 'variable', 'rating', template='simple_white', title='Average rating by genre')

In [12]:
import numpy as np
anime = anime[anime['episodes']!='Unknown']
anime['episodes'] = anime['episodes'].astype(np.int64)
px.scatter(anime, 'episodes', 'rating', template='simple_white', title='Rating by number of episodes')

# Feature Engineering

The features selected above were:

These features are now converted into numeric representations using either the get dummies function or a custom function with the same encoding behaviour.

An additional feature is derived which is a sentence transformed embedding representation of the title of the Anime.

In [13]:
types = anime['type'].str.get_dummies('|') # simple dummies
episodes = anime['episodes'] # simple mapping

In [14]:
genres = anime[top_ten_genres] # extract only the genres based on the visualisation cell above

In [15]:
# convert to numpy from pandas
genres = genres.to_numpy()
types = types.to_numpy()
episodes = episodes.to_numpy()
episodes = episodes.reshape((len(episodes), 1))

# convert to tensors
genres = torch.from_numpy(genres).to(torch.float)
types = torch.from_numpy(types).to(torch.float)
episodes = torch.from_numpy(episodes).to(torch.float)

In [16]:
import numpy as np
import torch
from sentence_transformers import SentenceTransformer

# encode anime title using pre-trained transformer
model = SentenceTransformer('all-MiniLM-L6-v2')
with torch.no_grad():
    titles = model.encode(anime['name'].tolist(), convert_to_tensor=True, show_progress_bar=True)
    titles = titles.cpu()

Batches:   0%|          | 0/224 [00:00<?, ?it/s]

In [17]:
# concatenate all of the features
anime_features = torch.cat([genres, types, titles, episodes], dim=-1)

In [18]:
# there are no user features in this model
user_features = torch.eye(len(ratings['user_id'].unique()))

# Creating Graph Data Structure

For the network to learn the connections between users and anime it requires a mapping of user features to anime features as well as a representation of the edges between anime and users.

The graph data structure is then what is passed to the neural network. In this case a Heterogenous graph as there are both users and anime as nodes.

In [19]:
# find unique users and make sequential mapping
unique_user_id = ratings['user_id'].unique()
unique_user_id = pd.DataFrame(data={
    'user_id': unique_user_id,
    'mappedUserId': pd.RangeIndex(len(unique_user_id))
    })

# find unique anime and make sequential mapping
unique_anime_id = anime['anime_id'].unique()
unique_anime_id = pd.DataFrame(data={
    'anime_id': unique_anime_id,
    'mappedAnimeID': pd.RangeIndex(len(unique_anime_id))
    })


# add the mappings back to the original data
ratings = ratings.merge(unique_user_id, on='user_id')
ratings = ratings.merge(unique_anime_id, on='anime_id')

# With this, we are ready to create the edge_index representation in COO format
# following the PyTorch Geometric semantics:
edge_index = torch.stack([
    torch.tensor(ratings['mappedUserId'].values),
    torch.tensor(ratings['mappedAnimeID'].values)]
    , dim=0)

In [20]:
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData

# initialise obj
data = HeteroData()

# add users
data['user'].x = user_features

# add anime
data['anime'].x = anime_features

# add ratings
data['user', 'rates', 'anime'].edge_index = edge_index

# add labels to ratings (rating itself)
rating = torch.from_numpy(ratings['rating'].values).to(torch.float)
data['user', 'rates', 'anime'].edge_label = rating

# add connections from anime to users (undirected connections)
# so that messages can be passed between them
data = T.ToUndirected()(data)

# remove edge labels from reverse ratings
del data['anime', 'rev_rates', 'user'].edge_label

In [21]:
data

HeteroData(
  user={ x=[6001, 6001] },
  anime={ x=[7156, 401] },
  (user, rates, anime)={
    edge_index=[2, 539133],
    edge_label=[539133],
  },
  (anime, rev_rates, user)={ edge_index=[2, 539133] }
)

In [22]:
# split into training/validation/test data
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[('user', 'rates', 'anime')],
    rev_edge_types=[('anime', 'rev_rates', 'user')],
)(data)
train_data, val_data

(HeteroData(
   user={ x=[6001, 6001] },
   anime={ x=[7156, 401] },
   (user, rates, anime)={
     edge_index=[2, 431307],
     edge_label=[431307],
     edge_label_index=[2, 431307],
   },
   (anime, rev_rates, user)={ edge_index=[2, 431307] }
 ),
 HeteroData(
   user={ x=[6001, 6001] },
   anime={ x=[7156, 401] },
   (user, rates, anime)={
     edge_index=[2, 431307],
     edge_label=[53913],
     edge_label_index=[2, 53913],
   },
   (anime, rev_rates, user)={ edge_index=[2, 431307] }
 ))

# Build and Train Graph Neural Network

The graph neural network is built using PyTorch Geometric and is split into three elements:
* GNN Encoder - Use SageConv layers to encode graph features with message passing.
* The Edge decoder - Use Linear layers to decode the edges and predict the ratings.
* Model class - combines the above two objects and defines overall training process of encoding/decoding.

In [23]:
from torch_geometric.nn import SAGEConv, to_hetero

class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        # layers to encode the edges and nodes using message passing
        self.conv1 = SAGEConv((-1, -1), hidden_channels) # first layer
        self.conv2 = SAGEConv((-1, -1), out_channels) # second layer

    def forward(self, x, edge_index):
        # action layers on data (forward pass)
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # layers to take the encoded layers and decode to predict rating
        # (edge weight)
        self.lin1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        # concatenate the user and anime features for linked nodes
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['anime'][col]], dim=-1)

        # action the layers to predict the rating
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # set up the encoder and decoder
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        # perform the entire forward pass
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Model(hidden_channels=32).to(device)

print(model)

Model(
  (encoder): GraphModule(
    (conv1): ModuleDict(
      (user__rates__anime): SAGEConv((-1, -1), 32, aggr=mean)
      (anime__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
    (conv2): ModuleDict(
      (user__rates__anime): SAGEConv((-1, -1), 32, aggr=mean)
      (anime__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
  )
  (decoder): EdgeDecoder(
    (lin1): Linear(in_features=64, out_features=32, bias=True)
    (lin2): Linear(in_features=32, out_features=1, bias=True)
  )
)


In [24]:
import torch.nn.functional as F

# define training and testing functions
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train() # train model step
    optimizer.zero_grad()
    pred = model(train_data.x_dict, train_data.edge_index_dict,
                 train_data['user', 'anime'].edge_label_index) # make pred
    target = train_data['user', 'anime'].edge_label # find target
    loss = F.mse_loss(pred, target) # find loss
    loss.backward() # compute gradient loss
    optimizer.step() # update parameters
    return float(loss)

@torch.no_grad()
def test(data):
    data = data.to(device)
    model.eval() # put model into evaluation mode
    pred = model(data.x_dict, data.edge_index_dict,
                 data['user', 'anime'].edge_label_index) # make prediction
    pred = pred.clamp(min=0, max=10) # clamp prediction between min and max of ratings
    target = data['user', 'anime'].edge_label.float() # get target
    rmse = F.mse_loss(pred, target).sqrt() # find loss (RMSE)
    return float(rmse)

In [25]:
# train the model for 300 epochs
for epoch in range(1, 300):
    train_data = train_data.to(device)
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
          f'Val: {val_rmse:.4f}')

Epoch: 001, Loss: 60.5161, Train: 7.0867, Val: 7.0965
Epoch: 002, Loss: 50.2228, Train: 5.9069, Val: 5.9164
Epoch: 003, Loss: 34.9201, Train: 3.8235, Val: 3.8338
Epoch: 004, Loss: 15.3794, Train: 1.9096, Val: 1.9068
Epoch: 005, Loss: 12.8251, Train: 2.3073, Val: 2.2986
Epoch: 006, Loss: 19.9117, Train: 1.9299, Val: 1.9264
Epoch: 007, Loss: 8.9826, Train: 2.0599, Val: 2.0665
Epoch: 008, Loss: 5.1493, Train: 2.8594, Val: 2.8763
Epoch: 009, Loss: 8.3864, Train: 3.1057, Val: 3.1272
Epoch: 010, Loss: 9.7779, Train: 2.8202, Val: 2.8428
Epoch: 011, Loss: 8.1102, Train: 2.2086, Val: 2.2292
Epoch: 012, Loss: 5.1495, Train: 1.7307, Val: 1.7411
Epoch: 013, Loss: 3.5889, Train: 1.8725, Val: 1.8687
Epoch: 014, Loss: 4.8727, Train: 2.0639, Val: 2.0563
Epoch: 015, Loss: 6.1705, Train: 1.9279, Val: 1.9221
Epoch: 016, Loss: 5.0066, Train: 1.6819, Val: 1.6841
Epoch: 017, Loss: 3.4088, Train: 1.7133, Val: 1.7247
Epoch: 018, Loss: 3.1877, Train: 1.9314, Val: 1.9478
Epoch: 019, Loss: 3.8635, Train: 2.0608,

# Evaluate and Explain the model

Now let's evaluate the model and get some predictions for ourselves.

We will also make the model explain it's decision to see how it arrived at it's conclusion.

In [26]:
with torch.no_grad():
    # get test data and make clamped predictions
    test_data = test_data.to(device)
    pred = model(test_data.x_dict, test_data.edge_index_dict,
                 test_data['user', 'anime'].edge_label_index)
    pred = pred.clamp(min=0, max=10)

    # find the loss
    target = test_data['user', 'anime'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    print(f'Test RMSE: {rmse:.4f}')

# print the prediction for the first user
userId = test_data['user', 'anime'].edge_label_index[0].cpu().numpy()
movieId = test_data['user', 'anime'].edge_label_index[1].cpu().numpy()
pred = pred.cpu().numpy()
target = target.cpu().numpy()

test_res = pd.DataFrame({'userId': userId, 'animeID': movieId, 'rating': pred, 'target': target})

Test RMSE: 1.2333


In [27]:
test_res['rating'] = test_res.apply(lambda row: round(row['rating'], 0), axis=1)
test_res_mat = test_res.groupby(['rating', 'target']).animeID.count().reset_index().sort_values('target')
px.bar(test_res_mat, 'rating', 'animeID', facet_col = 'target', template='seaborn')