# Creating a simple GNN Music Recommender System

## Loading Data

Data is from #nowplaying dataset. 

In [38]:
import pandas as pd
import os
import numpy as np
from sklearn.model_selection import GroupShuffleSplit # used for splitting data 
from torch_geometric.data import Data


In [39]:
# Loading data
data_dir = os.path.normpath(r'D:\Projects\cs224-multimodal-recommender-system\processed_data\nowplaying')

train = pd.read_parquet(
    os.path.join(data_dir, 'session_candidates_train.parquet')
)

test = pd.read_parquet(
    os.path.join(data_dir, 'session_candidates_test.parquet')
)

features = pd.read_parquet(
    os.path.join(data_dir, 'Lyrics_HSP-L_Nowplay_Data.parquet')
)

track_map = pd.read_parquet(
    os.path.join(data_dir, 'track_id_map.parquet')
).set_index(['Artist', 'Title'])


## view data

In [40]:
track_features = (features
                  .drop_duplicates(subset= ['Artist', 'Title'], keep = 'first')
                  .drop(columns = ['user id', 'source of the tweet', 'track title', 'artist name', ])
                  )
print('Initial number of tracks:', track_features.shape[0])
track_features = track_features.merge(track_map, on = ['Artist', 'Title'], how = 'inner')
track_features = (track_features
                  .set_index('track_id')
                  .lyrics_embedding
                  .apply(lambda x: pd.Series(x))
                  )

print('Number of tracks post-inner join:', track_features.shape[0])
print('Track features')
print(79*'-')
display(track_features.head())

train_user_listening_history = train.loc[:, ['user', 'past_interactions']]
print('Train: User listening history')
print(79*'-')
display(train_user_listening_history.head())


train_user_item_interactions = train.loc[:, ['user', 'positive_song_id']]
print('Train: User-item interactions')
print(79*'-')
display(train_user_item_interactions.head())

train_user_item_all_interactions = train.loc[: , ['user', 'candidates']].explode('candidates').rename(columns = {'candidates': 'song_id'})
print('Train: User-item negative + positive interactions')
print(79*'-')
display(train_user_item_all_interactions.head())

train_track_features = track_features.loc[train_user_item_all_interactions.song_id.unique(), :]



test_user_listening_history = test.loc[:, ['user', 'past_interactions']]
print('Test: User listening history')
print(79*'-')
display(test_user_listening_history.head())


test_user_item_interactions = test.loc[:, ['user', 'positive_song_id']]
print('Test: User-item interactions')
print(79*'-')
display(test_user_item_interactions.head())

test_user_item_all_interactions = test.loc[: , ['user', 'candidates']].explode('candidates').rename(columns = {'candidates': 'song_id'})
print('Test: User-item negative + positive interactions')
print(79*'-')
display(test_user_item_all_interactions.head())

test_track_features = track_features.loc[test_user_item_all_interactions.song_id.unique(), :]

# positive interaction tuples
train_positive_interactions_set = set(
    train_user_item_interactions
    .drop_duplicates()
    .apply(lambda x: (x['user'], x['positive_song_id']), axis = 1)
    .to_list()
)
test_positive_interactions_set = set(
    test_user_item_interactions
    .drop_duplicates()
    .apply(lambda x: (x['user'], x['positive_song_id']), axis = 1)
    .to_list()
)



Initial number of tracks: 2471
Number of tracks post-inner join: 2155
Track features
-------------------------------------------------------------------------------


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
track_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,-0.016555,0.016258,0.032462,-0.065711,0.035892,-0.030516,0.04403,0.046941,0.031009,-0.051097,...,0.137975,-0.032969,-0.178082,0.032461,0.081872,0.018873,0.134802,-0.184255,-0.02097,-0.009385
2,-0.043866,0.079759,0.031227,-0.009023,0.198229,-0.117961,0.023687,0.126027,0.002055,-0.055694,...,0.082887,-0.094865,-0.103203,0.029579,0.168657,0.054824,0.156257,0.036643,-0.045506,-0.051694
3,0.040213,0.071726,-0.022398,-0.148623,0.250008,-0.033887,0.026806,0.242937,0.030847,0.016596,...,0.093019,-0.006108,-0.10845,0.160098,0.071841,0.118777,0.233991,0.042564,-0.028533,-0.093544
4,-0.009909,-0.046708,0.058077,0.009227,0.064153,-0.003633,0.027486,0.102278,0.117365,-0.024824,...,0.10089,-0.193991,-0.133274,0.1081,0.134432,0.114251,0.241544,-0.116317,-0.074772,-0.075406
5,0.003207,0.011662,0.044229,-0.107282,0.164013,-0.027828,0.0179,0.089839,0.104875,-0.010808,...,0.083276,-0.045836,-0.046923,0.081197,0.055439,0.086655,0.196141,-0.123199,-0.025011,0.013837


Train: User listening history
-------------------------------------------------------------------------------


Unnamed: 0,user,past_interactions
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[96, 46, 46, 46, 46, 232, 325, 232, 125, 325]"
1,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[46, 46, 46, 46, 232, 325, 232, 125, 325, 125]"
2,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[46, 46, 46, 232, 325, 232, 125, 325, 125, 232]"
3,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[46, 46, 232, 325, 232, 125, 325, 125, 232, 46]"
4,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[46, 232, 325, 232, 125, 325, 125, 232, 46, 96]"


Train: User-item interactions
-------------------------------------------------------------------------------


Unnamed: 0,user,positive_song_id
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,125
1,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,232
2,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,46
3,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,96
4,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,325


Train: User-item negative + positive interactions
-------------------------------------------------------------------------------


Unnamed: 0,user,song_id
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,278
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,1690
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,366
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,1984
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,795


Test: User listening history
-------------------------------------------------------------------------------


Unnamed: 0,user,past_interactions
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[595, 84, 152, 614, 688, 48, 34, 191, 1362, 171]"
1,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[84, 152, 614, 688, 48, 34, 191, 1362, 171, 111]"
2,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[152, 614, 688, 48, 34, 191, 1362, 171, 111, 196]"
3,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[614, 688, 48, 34, 191, 1362, 171, 111, 196, 152]"
4,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[688, 48, 34, 191, 1362, 171, 111, 196, 152, 643]"


Test: User-item interactions
-------------------------------------------------------------------------------


Unnamed: 0,user,positive_song_id
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,111
1,103aaf06662cc01fe8b1ddaa39e41f59a8332832,196
2,103aaf06662cc01fe8b1ddaa39e41f59a8332832,152
3,103aaf06662cc01fe8b1ddaa39e41f59a8332832,643
4,103aaf06662cc01fe8b1ddaa39e41f59a8332832,144


Test: User-item negative + positive interactions
-------------------------------------------------------------------------------


Unnamed: 0,user,song_id
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,427
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,46
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,995
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,1350
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,1548


In [41]:
import torch

def prepare_graph(user_item_interactions, user_item_all_interactions , item_features, positive_interactions_set):
    # item feature manipulation 
    # get user id and item id
    user_ids = list(set([ui for ui, ii in user_item_all_interactions]))
    item_ids = item_features.index
    
    # create map for users and items
    user_id_map = {uid: idx for idx, uid in enumerate(user_ids)}
    num_users = len(user_ids)
    item_id_map = {iid: idx + num_users for idx, iid in enumerate(item_ids)}
    print('Number of users: ', num_users)
    num_items = len(item_ids)
    print('Number of items: ', num_items)
    
    
    num_features = item_features.shape[1]
    num_nodes = num_users + num_items
    x = torch.zeros(num_nodes, num_features)
    for iid, idx in item_id_map.items():
        x[idx] = torch.tensor(item_features.loc[iid, :].values, dtype = torch.float)
    # for any indexed value in x, if all the values in the row are nan, then convert it to a zero row
    nan_rows = torch.isnan(x).all(dim=1)  # Find rows where all elements are NaN
    x[nan_rows] = 0
    
    edge_index = []
    for uid, iid in user_item_interactions:
        u_idx = user_id_map[uid]
        i_idx = item_id_map[iid]
        # add undirected edges
        # TODO: Undirected edges is a major caveat of this architecture 
        #  However it is necessary for us to consider this style of architecture because
        #  otherwise the maximum depth of the GNN is 1-hop (user -> item x ) [item has no outgoing nodes]
        # Add edges in both directions 
        edge_index.append([u_idx, i_idx])
        edge_index.append([i_idx, u_idx])
    edge_index = torch.tensor(edge_index).t().contiguous()
    # Masks for users and items
    user_mask = torch.zeros(num_users + num_items, dtype=torch.bool)
    item_mask = torch.zeros(num_users + num_items, dtype=torch.bool)
    user_mask[:num_users] = True
    item_mask[num_users:] = True    # Create data object
    data = Data(x=x, edge_index=edge_index)
    data.user_mask = user_mask
    data.item_mask = item_mask
    # positive edge index
    all_sample_interactions = []
    for uid, iid in user_item_all_interactions:
        if iid in item_id_map:
            u_idx = user_id_map[uid]
            i_idx = item_id_map[iid]
            all_sample_interactions.append([u_idx, i_idx])
        else:
            print(f'{iid} not in item_id_map')
    data.all_sample_index = torch.tensor(all_sample_interactions, dtype = torch.long).t().contiguous()



    # for a user in user_item_all_interactions, 
    
    labels = []
    for uid, iid in user_item_all_interactions:
        if (uid, iid) in positive_interactions_set:
            labels.append(1)
        else:
            labels.append(0)
    data.labels = torch.tensor(labels, dtype=torch.float)
    data.edge_users = [user_id_map[uid] for uid, iid in user_item_all_interactions]
    data.edge_items = [item_id_map[iid] for uid, iid in user_item_all_interactions]
    
    return  data, (item_id_map, user_id_map)
print('Creating train graph')
train_data, (train_item_id_map, train_user_id_map) = prepare_graph(
    user_item_interactions = train_user_item_interactions.to_numpy(),
    user_item_all_interactions = train_user_item_all_interactions.to_numpy(),
    item_features = train_track_features,
    positive_interactions_set=train_positive_interactions_set
)
print('Creating test graph')
test_data, (test_item_id_map, test_user_id_map) = prepare_graph(
    user_item_interactions = test_user_item_interactions.to_numpy(),
    user_item_all_interactions = test_user_item_all_interactions.to_numpy(),
    item_features = test_track_features,
    positive_interactions_set=test_positive_interactions_set
    
)


Creating train graph
Number of users:  453
Number of items:  1981
Creating test graph
Number of users:  114
Number of items:  2132


In [42]:
train_data

Data(x=[2434, 768], edge_index=[2, 124402], user_mask=[2434], item_mask=[2434], all_sample_index=[2, 622010], labels=[622010], edge_users=[622010], edge_items=[622010])

In [43]:
torch.save(train_data,os.path.join(data_dir, 'train_graph.pt'))
torch.save(test_data, os.path.join(data_dir, 'test_graph.pt'))

## Descriptive statistics

In [44]:
train_graph = torch.load(os.path.join(data_dir, 'train_graph.pt'))
test_graph = torch.load(os.path.join(data_dir, 'test_graph.pt'))

  train_graph = torch.load(os.path.join(data_dir, 'train_graph.pt'))
  test_graph = torch.load(os.path.join(data_dir, 'test_graph.pt'))


In [45]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class SimpleGNN(torch.nn.Module):
    def __init__(self, ):
        super(SimpleGNN, self).__init__()
        self.conv1 = GCNConv(train_data.num_features, 256, aggr = 'sum')
        self.conv2 = GCNConv(256, 256, aggr = 'sum')
        self.conv3 = GCNConv(256, 1, aggr= 'sum')
        # self.bn1 = torch.nn.BatchNorm1d(256)
        # self.bn2 = torch.nn.BatchNorm1d(256)
        # self.sigmoid = torch.nn.LogSigmoid()
        
        
        
    def forward(self, data):
        # for each node in the graph, apply GNN to pool node embeddings
        x = self.conv1(data.x, data.all_sample_index)
        # x = self.bn1(x)
        x = F.relu(x)
        x = F.dropout(x, p = 0.5,  training=self.training)
        x = self.conv2(x, data.all_sample_index)
        # x = self.bn2(x)
        x = F.relu(x)
        x = F.dropout(x, p = 0.5, training = self.training)
        x = self.conv3(x, data.all_sample_index)
        # the final output of the node embeddings 
        out = F.sigmoid(x)

        # create edge embeddings as the product of the node embeddings
        edge_out = out[data.all_sample_index[0]] * out[data.all_sample_index[1]]

        return edge_out

# use cuda if available
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cuda'

# Create a model and optimizer
model = SimpleGNN().to(device)
train_data = train_data.to(device)
test_data = test_data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


# Ensure your edge_labels are on the same device and correct shape
edge_labels = train_data.labels.to(device)

model.train()  # Ensure the model is in training mode
for epoch in range(300): 
    optimizer.zero_grad()  # Clear gradients from the last step
    edge_out = model(train_data)  # Forward pass through network
    loss = F.binary_cross_entropy(edge_out.squeeze(), train_data.labels.float().squeeze())  # Calculate the loss
    loss.backward()  # Calculate gradients based on the loss
    optimizer.step()  # Adjust model parameters based on the calculated gradients

    print(f"Epoch: {epoch}, Loss: {loss.item()}")



Epoch: 0, Loss: 0.3975432813167572
Epoch: 1, Loss: 0.46186187863349915
Epoch: 2, Loss: 0.3831929862499237
Epoch: 3, Loss: 0.3761373460292816
Epoch: 4, Loss: 0.3661309480667114
Epoch: 5, Loss: 0.3621349632740021
Epoch: 6, Loss: 0.3628937900066376
Epoch: 7, Loss: 0.3655337691307068
Epoch: 8, Loss: 0.36000293493270874
Epoch: 9, Loss: 0.35858821868896484
Epoch: 10, Loss: 0.3625227212905884
Epoch: 11, Loss: 0.3597876727581024
Epoch: 12, Loss: 0.35881590843200684
Epoch: 13, Loss: 0.3581443727016449
Epoch: 14, Loss: 0.35657405853271484
Epoch: 15, Loss: 0.3558272123336792
Epoch: 16, Loss: 0.35847166180610657
Epoch: 17, Loss: 0.35666078329086304
Epoch: 18, Loss: 0.3562353551387787
Epoch: 19, Loss: 0.3546087443828583
Epoch: 20, Loss: 0.3521776497364044
Epoch: 21, Loss: 0.35553261637687683
Epoch: 22, Loss: 0.35286498069763184
Epoch: 23, Loss: 0.3541761636734009
Epoch: 24, Loss: 0.35333052277565
Epoch: 25, Loss: 0.35105791687965393
Epoch: 26, Loss: 0.35110336542129517
Epoch: 27, Loss: 0.3519822359

In [46]:
# Calculate Bayesian Personalized Ranking metric based on this model for the pyg dataset `test_data`
from torch_geometric.nn import global_sort_pool
from torch_geometric.loader import DataLoader

def evaluate_model(model, data):
    model.eval()

    out = model(data)
    return out.detach().cpu().squeeze(), data.labels.detach().cpu()
test_predictions, test_targets = evaluate_model(model, test_data)
print('Test predictions:', test_predictions)
print('Test targets:', test_targets)
# Save the trained model
torch.save(model.state_dict(), os.path.join(data_dir, 'trained_model.pth'))


Test predictions: tensor([0.0673, 0.0704, 0.0663,  ..., 0.0676, 0.0643, 0.0665])
Test targets: tensor([0., 0., 0.,  ..., 0., 0., 0.])


In [47]:

candidate_index = test_data.all_sample_index.detach().cpu()
test_labels = test_data.labels.detach().cpu()
true_index = test_data.edge_index.detach().cpu()

In [48]:
test_retrieval_df = test.copy()

test_retrieval_df['user_node_id']= test_retrieval_df['user'].map(test_user_id_map)
test_retrieval_df['positive_item_id'] = test_retrieval_df['positive_song_id'].map(test_item_id_map)
test_retrieval_df['candidate_item_ids'] = test_retrieval_df['candidates'].apply(lambda x: [test_item_id_map[c] for c in x])
test_retrieval_df['labels'] = test_retrieval_df.apply(lambda x: [int(c == x['positive_item_id']) for c in x['candidate_item_ids']], axis = 1)
test_retrieval_df.head()

Unnamed: 0,user,past_interactions,candidates,positive_song_id,user_node_id,positive_item_id,candidate_item_ids,labels
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[595, 84, 152, 614, 688, 48, 34, 191, 1362, 171]","[427, 46, 995, 1350, 1548, 88, 895, 871, 1676,...",111,80,123,"[114, 115, 116, 117, 118, 119, 120, 121, 122, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]"
1,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[84, 152, 614, 688, 48, 34, 191, 1362, 171, 111]","[2376, 585, 1842, 196, 1296, 528, 1258, 1321, ...",196,80,127,"[124, 125, 126, 127, 128, 129, 130, 131, 132, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]"
2,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[152, 614, 688, 48, 34, 191, 1362, 171, 111, 196]","[2381, 152, 1330, 832, 2467, 116, 877, 1051, 2...",152,80,135,"[134, 135, 136, 137, 138, 139, 140, 141, 142, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]"
3,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[614, 688, 48, 34, 191, 1362, 171, 111, 196, 152]","[1883, 2426, 2211, 1801, 1446, 1634, 643, 1626...",643,80,150,"[144, 145, 146, 147, 148, 149, 150, 151, 152, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0]"
4,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[688, 48, 34, 191, 1362, 171, 111, 196, 152, 643]","[2072, 2277, 1258, 2222, 691, 207, 144, 2194, ...",144,80,159,"[154, 155, 130, 156, 157, 158, 159, 160, 161, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0]"


In [49]:
from torchmetrics.functional.retrieval import retrieval_reciprocal_rank
from tqdm import tqdm
def find_tensor_indices(
    candidate_index,
    row_value,
    col_values):

    # Convert col_values to tensor if it's a list
    if isinstance(col_values, list):
        col_values = torch.tensor(col_values)

    row_mask = (candidate_index[0, :] == row_value)
    col_mask = torch.isin(candidate_index[1,:], col_values)
    final_mask = row_mask & col_mask
    indices = torch.nonzero(final_mask).squeeze()
    return indices

def dedupe_edges_and_predictions(
    query_edges: torch.Tensor,
    query_predictions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    # Convert edges to tuples for hashing
    edge_tuples = list(map(tuple, query_edges.t().tolist()))
    
    # Create dictionary to store unique edges and their first occurrence index
    unique_dict = {}
    for idx, edge in enumerate(edge_tuples):
        if edge not in unique_dict:
            unique_dict[edge] = idx
            
    # Get indices of unique edges
    unique_indices = torch.tensor(list(unique_dict.values()))
    
    # Get unique edges and their corresponding predictions
    unique_edges = query_edges[:, unique_indices]
    unique_predictions = query_predictions[unique_indices]
    
    return unique_edges, unique_predictions


print(f'Reciprocal Rank: ', end = '')
query_results_list = []
reciprocal_rank_list = []
for i, record in tqdm(test_retrieval_df.iterrows(), desc = 'Calculating reciprocal rank', total = test_retrieval_df.shape[0]):
    candidate_items = record['candidate_item_ids']
    user_id = record['user_node_id']
    query_labels = record['labels']
    # get preditions from model based on candidate_items and user_id
    # Find indices where row equals 23
    row_mask = (candidate_index[0, :] == user_id)
    # Find indices where column is in our desired values
    col_mask = torch.isin(candidate_index[1, :], torch.tensor(candidate_items))
    # Combine both conditions
    final_mask = row_mask & col_mask
    indices = torch.nonzero(final_mask).squeeze()
    # get query preds and labels
    query_target_nodes = pd.Series(candidate_index[1, indices], name = 'candidate_nodes')
    query_predictions = pd.Series(test_predictions[indices], index = query_target_nodes, name = 'predictions')
    query_labels = pd.Series(query_labels, index = pd.Series(candidate_items, name ='candidate_nodes'), name = 'labels')
    query_results = pd.merge(query_predictions, query_labels, on = 'candidate_nodes').reset_index().drop_duplicates()
    query_results['source_node'] = user_id
    query_results['query_id'] = i
    # get reciprocal rank
    reciprocal_rank = retrieval_reciprocal_rank(
        torch.tensor(query_results['predictions'].values) ,
        torch.tensor(query_results['labels'].values))
    query_results = query_results.groupby(['query_id', 'source_node']).agg({'candidate_nodes':  list, 'predictions': list, 'labels': list}).reset_index()
    query_results['reciprocal_rank'] = reciprocal_rank
    reciprocal_rank_list.append(reciprocal_rank)
    query_results_list.append(query_results)
    # print(f'{reciprocal_rank}', end = ' ')
# for index_chunk in torch.split(candidate_index[1], 10):
#     candidate_list = index_chunk
#     query_labels = test_labels[candidate_list]
#     query_predictions = test_predictions[candidate_list]
#     reciprocal_rank = retrieval_reciprocal_rank(query_predictions, query_labels)
#     print(f'Query Labels: {query_labels}')
#     print(f'Query Predictions: {query_predictions}')
#     print(f'Reciprocal Rank: {reciprocal_rank}')

Reciprocal Rank: 

Calculating reciprocal rank: 100%|██████████| 18923/18923 [01:10<00:00, 268.35it/s]


In [50]:
all_query_results = pd.concat(query_results_list,ignore_index = True)

In [51]:
test_df = all_query_results.sample(15)
test_df

Unnamed: 0,query_id,source_node,candidate_nodes,predictions,labels,reciprocal_rank
578,578,91,"[263, 493, 358, 687, 877, 1571, 970, 681, 488,...","[0.16500170528888702, 0.06611007452011108, 0.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]",tensor(1.)
18023,18023,53,"[821, 1971, 912, 1683, 1007, 2026, 1361, 580, ...","[0.06591524183750153, 0.06452807784080505, 0.0...","[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]",tensor(0.1429)
1010,1010,91,"[1069, 317, 985, 358, 924, 761, 1845, 260, 195...","[0.17056889832019806, 0.06452805548906326, 0.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]",tensor(1.)
18799,18799,90,"[1313, 575, 333, 1802, 545, 353, 197, 1569, 40...","[0.06669006496667862, 0.06492862105369568, 0.0...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]",tensor(1.)
12435,12435,110,"[2197, 676, 264, 1308, 283, 1532, 373, 1242, 8...","[0.04532736539840698, 0.06392043083906174, 0.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]",tensor(0.1000)
18394,18394,26,"[495, 485, 642, 1673, 509, 1420, 1175, 1892, 1...","[0.07198337465524673, 0.06611008197069168, 0.0...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]",tensor(1.)
12223,12223,89,"[222, 1991, 912, 791, 1730, 297, 1697, 235, 50...","[0.06225760281085968, 0.0628887489438057, 0.06...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0]",tensor(0.1667)
384,384,91,"[1111, 214, 1019, 121, 678, 371, 1549, 1595, 9...","[0.06669007241725922, 0.06591524183750153, 0.0...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]",tensor(1.)
7096,7096,50,"[1666, 1751, 1017, 973, 932, 1966, 1757, 611, ...","[0.06726326048374176, 0.06857578456401825, 0.0...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]",tensor(0.5000)
16685,16685,108,"[618, 2054, 237, 1116, 1999, 1300, 1358, 798, ...","[0.06707293540239334, 0.07003655284643173, 0.0...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]",tensor(0.1250)


In [52]:
test_df.loc[637]

KeyError: 637

In [53]:
all_query_results.to_pickle(os.path.join(data_dir, 'test_results_mrr_baseline_v1.pickle'))

# Calculating Mean Reciprocal Rank

In [54]:
all_query_results.reciprocal_rank.mean()

0.819841636170269

In [55]:
all_query_results.shape[0]

18923

torch.Size([2, 37846])