# Creating a simple GNN Music Recommender System

## Loading Data

Data is from #nowplaying dataset. 

In [1]:
import pandas as pd
import os
import numpy as np
from sklearn.model_selection import GroupShuffleSplit # used for splitting data 
from torch_geometric.data import Data


In [2]:
# Loading data
data_dir = os.path.normpath(r'D:\Projects\cs224-multimodal-recommender-system\processed_data\nowplaying')

train = pd.read_parquet(
    os.path.join(data_dir, 'session_candidates_train.parquet')
)

test = pd.read_parquet(
    os.path.join(data_dir, 'session_candidates_test.parquet')
)

features = pd.read_parquet(
    os.path.join(data_dir, 'Lyrics_HSP-L_Nowplay_Data.parquet')
)

track_map = pd.read_parquet(
    os.path.join(data_dir, 'track_id_map.parquet')
).set_index(['Artist', 'Title'])


## view data

In [3]:
track_features = (features
                  .drop_duplicates(subset= ['Artist', 'Title'], keep = 'first')
                  .drop(columns = ['user id', 'source of the tweet', 'track title', 'artist name', ])
                  )
print('Initial number of tracks:', track_features.shape[0])
track_features = track_features.merge(track_map, on = ['Artist', 'Title'], how = 'inner')
track_features = (track_features
                  .set_index('track_id')
                  .lyrics_embedding
                  .apply(lambda x: pd.Series(x))
                  )

print('Number of tracks post-inner join:', track_features.shape[0])
print('Track features')
print(79*'-')
display(track_features.head())

train_user_listening_history = train.loc[:, ['user', 'past_interactions']]
print('Train: User listening history')
print(79*'-')
display(train_user_listening_history.head())


train_user_item_interactions = train.loc[:, ['user', 'positive_song_id']]
print('Train: User-item interactions')
print(79*'-')
display(train_user_item_interactions.head())

train_user_item_all_interactions = train.loc[: , ['user', 'candidates']].explode('candidates').rename(columns = {'candidates': 'song_id'})
print('Train: User-item negative + positive interactions')
print(79*'-')
display(train_user_item_all_interactions.head())

train_track_features = track_features.loc[train_user_item_all_interactions.song_id.unique(), :]



test_user_listening_history = test.loc[:, ['user', 'past_interactions']]
print('Test: User listening history')
print(79*'-')
display(test_user_listening_history.head())


test_user_item_interactions = test.loc[:, ['user', 'positive_song_id']]
print('Test: User-item interactions')
print(79*'-')
display(test_user_item_interactions.head())

test_user_item_all_interactions = test.loc[: , ['user', 'candidates']].explode('candidates').rename(columns = {'candidates': 'song_id'})
print('Test: User-item negative + positive interactions')
print(79*'-')
display(test_user_item_all_interactions.head())

test_track_features = track_features.loc[test_user_item_all_interactions.song_id.unique(), :]

# positive interaction tuples
train_positive_interactions_set = set(
    train_user_item_interactions
    .drop_duplicates()
    .apply(lambda x: (x['user'], x['positive_song_id']), axis = 1)
    .to_list()
)
test_positive_interactions_set = set(
    test_user_item_interactions
    .drop_duplicates()
    .apply(lambda x: (x['user'], x['positive_song_id']), axis = 1)
    .to_list()
)



Initial number of tracks: 2471
Number of tracks post-inner join: 2155
Track features
-------------------------------------------------------------------------------


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
track_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,-0.016555,0.016258,0.032462,-0.065711,0.035892,-0.030516,0.04403,0.046941,0.031009,-0.051097,...,0.137975,-0.032969,-0.178082,0.032461,0.081872,0.018873,0.134802,-0.184255,-0.02097,-0.009385
2,-0.043866,0.079759,0.031227,-0.009023,0.198229,-0.117961,0.023687,0.126027,0.002055,-0.055694,...,0.082887,-0.094865,-0.103203,0.029579,0.168657,0.054824,0.156257,0.036643,-0.045506,-0.051694
3,0.040213,0.071726,-0.022398,-0.148623,0.250008,-0.033887,0.026806,0.242937,0.030847,0.016596,...,0.093019,-0.006108,-0.10845,0.160098,0.071841,0.118777,0.233991,0.042564,-0.028533,-0.093544
4,-0.009909,-0.046708,0.058077,0.009227,0.064153,-0.003633,0.027486,0.102278,0.117365,-0.024824,...,0.10089,-0.193991,-0.133274,0.1081,0.134432,0.114251,0.241544,-0.116317,-0.074772,-0.075406
5,0.003207,0.011662,0.044229,-0.107282,0.164013,-0.027828,0.0179,0.089839,0.104875,-0.010808,...,0.083276,-0.045836,-0.046923,0.081197,0.055439,0.086655,0.196141,-0.123199,-0.025011,0.013837


Train: User listening history
-------------------------------------------------------------------------------


Unnamed: 0,user,past_interactions
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[96, 46, 46, 46, 46, 232, 325, 232, 125, 325]"
1,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[46, 46, 46, 46, 232, 325, 232, 125, 325, 125]"
2,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[46, 46, 46, 232, 325, 232, 125, 325, 125, 232]"
3,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[46, 46, 232, 325, 232, 125, 325, 125, 232, 46]"
4,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,"[46, 232, 325, 232, 125, 325, 125, 232, 46, 96]"


Train: User-item interactions
-------------------------------------------------------------------------------


Unnamed: 0,user,positive_song_id
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,125
1,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,232
2,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,46
3,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,96
4,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,325


Train: User-item negative + positive interactions
-------------------------------------------------------------------------------


Unnamed: 0,user,song_id
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,278
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,1690
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,366
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,1984
0,1001b22c8e4adeb77ef10481ad06ff9c35006cb3,795


Test: User listening history
-------------------------------------------------------------------------------


Unnamed: 0,user,past_interactions
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[595, 84, 152, 614, 688, 48, 34, 191, 1362, 171]"
1,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[84, 152, 614, 688, 48, 34, 191, 1362, 171, 111]"
2,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[152, 614, 688, 48, 34, 191, 1362, 171, 111, 196]"
3,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[614, 688, 48, 34, 191, 1362, 171, 111, 196, 152]"
4,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[688, 48, 34, 191, 1362, 171, 111, 196, 152, 643]"


Test: User-item interactions
-------------------------------------------------------------------------------


Unnamed: 0,user,positive_song_id
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,111
1,103aaf06662cc01fe8b1ddaa39e41f59a8332832,196
2,103aaf06662cc01fe8b1ddaa39e41f59a8332832,152
3,103aaf06662cc01fe8b1ddaa39e41f59a8332832,643
4,103aaf06662cc01fe8b1ddaa39e41f59a8332832,144


Test: User-item negative + positive interactions
-------------------------------------------------------------------------------


Unnamed: 0,user,song_id
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,427
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,46
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,995
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,1350
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,1548


In [4]:
import torch

def prepare_graph(user_item_interactions, user_item_all_interactions , item_features, positive_interactions_set):
    # item feature manipulation 
    # get user id and item id
    user_ids = list(set([ui for ui, ii in user_item_all_interactions]))
    item_ids = item_features.index
    
    # create map for users and items
    user_id_map = {uid: idx for idx, uid in enumerate(user_ids)}
    num_users = len(user_ids)
    item_id_map = {iid: idx + num_users for idx, iid in enumerate(item_ids)}
    print('Number of users: ', num_users)
    num_items = len(item_ids)
    print('Number of items: ', num_items)
    
    
    num_features = item_features.shape[1]
    num_nodes = num_users + num_items
    x = torch.zeros(num_nodes, num_features)
    for iid, idx in item_id_map.items():
        x[idx] = torch.tensor(item_features.loc[iid, :].values, dtype = torch.float)
    # for any indexed value in x, if all the values in the row are nan, then convert it to a zero row
    nan_rows = torch.isnan(x).all(dim=1)  # Find rows where all elements are NaN
    x[nan_rows] = 0
    
    edge_index = []
    for uid, iid in user_item_interactions:
        u_idx = user_id_map[uid]
        i_idx = item_id_map[iid]
        # add undirected edges
        # TODO: Undirected edges is a major caveat of this architecture 
        #  However it is necessary for us to consider this style of architecture because
        #  otherwise the maximum depth of the GNN is 1-hop (user -> item x ) [item has no outgoing nodes]
        # Add edges in both directions 
        edge_index.append([u_idx, i_idx])
        edge_index.append([i_idx, u_idx])
    edge_index = torch.tensor(edge_index).t().contiguous()
    # Masks for users and items
    user_mask = torch.zeros(num_users + num_items, dtype=torch.bool)
    item_mask = torch.zeros(num_users + num_items, dtype=torch.bool)
    user_mask[:num_users] = True
    item_mask[num_users:] = True    # Create data object
    data = Data(x=x, edge_index=edge_index)
    data.user_mask = user_mask
    data.item_mask = item_mask
    # positive edge index
    all_sample_interactions = []
    for uid, iid in user_item_all_interactions:
        if iid in item_id_map:
            u_idx = user_id_map[uid]
            i_idx = item_id_map[iid]
            all_sample_interactions.append([u_idx, i_idx])
        else:
            print(f'{iid} not in item_id_map')
    data.all_sample_index = torch.tensor(all_sample_interactions, dtype = torch.long).t().contiguous()
    
    labels = []
    for uid, iid in user_item_all_interactions:
        if (uid, iid) in positive_interactions_set:
            labels.append(1)
        else:
            labels.append(0)
    data.labels = torch.tensor(labels, dtype=torch.float)
    data.edge_users = [user_id_map[uid] for uid, iid in user_item_all_interactions]
    data.edge_items = [item_id_map[iid] for uid, iid in user_item_all_interactions]
    
    return  data, (item_id_map, user_id_map)
print('Creating train graph')
train_data, (train_item_id_map, train_user_id_map) = prepare_graph(
    user_item_interactions = train_user_item_interactions.to_numpy(),
    user_item_all_interactions = train_user_item_all_interactions.to_numpy(),
    item_features = train_track_features,
    positive_interactions_set=train_positive_interactions_set
)
print('Creating test graph')
test_data, (test_item_id_map, test_user_id_map) = prepare_graph(
    user_item_interactions = test_user_item_interactions.to_numpy(),
    user_item_all_interactions = test_user_item_all_interactions.to_numpy(),
    item_features = test_track_features,
    positive_interactions_set=test_positive_interactions_set
    
)


Creating train graph
Number of users:  453
Number of items:  1981
Creating test graph
Number of users:  114
Number of items:  2132


In [5]:
train_data

Data(x=[2434, 768], edge_index=[2, 124402], user_mask=[2434], item_mask=[2434], all_sample_index=[2, 622010], labels=[622010], edge_users=[622010], edge_items=[622010])

In [6]:
torch.save(train_data,os.path.join(data_dir, 'train_graph.pt'))
torch.save(test_data, os.path.join(data_dir, 'test_graph.pt'))

## Descriptive statistics

In [7]:
train_graph = torch.load(os.path.join(data_dir, 'train_graph.pt'))
test_graph = torch.load(os.path.join(data_dir, 'test_graph.pt'))

  train_graph = torch.load(os.path.join(data_dir, 'train_graph.pt'))
  test_graph = torch.load(os.path.join(data_dir, 'test_graph.pt'))


In [8]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class SimpleGNN(torch.nn.Module):
    def __init__(self, ):
        super(SimpleGNN, self).__init__()
        self.conv1 = GCNConv(train_data.num_features, 256, aggr = 'sum')
        self.conv2 = GCNConv(256, 256, aggr = 'sum')
        self.conv3 = GCNConv(256, 1, aggr= 'sum')
        # self.bn1 = torch.nn.BatchNorm1d(256)
        # self.bn2 = torch.nn.BatchNorm1d(256)
        # self.sigmoid = torch.nn.LogSigmoid()
        
        
        
    def forward(self, data):
        # for each node in the graph, apply GNN to pool node embeddings
        x = self.conv1(data.x, data.all_sample_index)
        # x = self.bn1(x)
        x = F.relu(x)
        x = F.dropout(x, p = 0.5,  training=self.training)
        x = self.conv2(x, data.all_sample_index)
        # x = self.bn2(x)
        x = F.relu(x)
        x = F.dropout(x, p = 0.5, training = self.training)
        x = self.conv3(x, data.all_sample_index)
        # the final output of the node embeddings 
        out = F.sigmoid(x)

        # create edge embeddings as the product of the node embeddings
        edge_out = out[data.all_sample_index[0]] * out[data.all_sample_index[1]]

        return edge_out

# use cuda if available
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cuda'

# Create a model and optimizer
model = SimpleGNN().to(device)
train_data = train_data.to(device)
test_data = test_data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


# Ensure your edge_labels are on the same device and correct shape
edge_labels = train_data.labels.to(device)

model.train()  # Ensure the model is in training mode
for epoch in range(300):  # You can change 100 to however many epochs you want to train for 
    optimizer.zero_grad()  # Clear gradients from the last step
    edge_out = model(train_data)  # Forward pass through network
    loss = F.binary_cross_entropy(edge_out.squeeze(), train_data.labels.float().squeeze())  # Calculate the loss
    loss.backward()  # Calculate gradients based on the loss
    optimizer.step()  # Adjust model parameters based on the calculated gradients

    print(f"Epoch: {epoch}, Loss: {loss.item()}")



Epoch: 0, Loss: 0.3975432813167572
Epoch: 1, Loss: 0.4187798500061035
Epoch: 2, Loss: 0.3828023076057434
Epoch: 3, Loss: 0.3741346299648285
Epoch: 4, Loss: 0.364216685295105
Epoch: 5, Loss: 0.3582327365875244
Epoch: 6, Loss: 0.36553215980529785
Epoch: 7, Loss: 0.36643531918525696
Epoch: 8, Loss: 0.36201217770576477
Epoch: 9, Loss: 0.35952284932136536
Epoch: 10, Loss: 0.36237189173698425
Epoch: 11, Loss: 0.36136239767074585
Epoch: 12, Loss: 0.3595818281173706
Epoch: 13, Loss: 0.35637959837913513
Epoch: 14, Loss: 0.3575963079929352
Epoch: 15, Loss: 0.35773277282714844
Epoch: 16, Loss: 0.35678839683532715
Epoch: 17, Loss: 0.35589370131492615
Epoch: 18, Loss: 0.3551081717014313
Epoch: 19, Loss: 0.35481584072113037
Epoch: 20, Loss: 0.35466933250427246
Epoch: 21, Loss: 0.35215792059898376
Epoch: 22, Loss: 0.35435807704925537
Epoch: 23, Loss: 0.35248076915740967
Epoch: 24, Loss: 0.3546645939350128
Epoch: 25, Loss: 0.35064220428466797
Epoch: 26, Loss: 0.35045692324638367
Epoch: 27, Loss: 0.351

In [30]:
# Calculate Bayesian Personalized Ranking metric based on this model for the pyg dataset `test_data`
from torch_geometric.nn import global_sort_pool
from torch_geometric.loader import DataLoader

def evaluate_model(model, data):
    model.eval()

    out = model(data)
    return out.detach().cpu().squeeze(), data.labels.detach().cpu()
test_predictions, test_targets = evaluate_model(model, test_data)
print('Test predictions:', test_predictions)
print('Test targets:', test_targets)
# Save the trained model
torch.save(model.state_dict(), os.path.join(data_dir, 'trained_model.pth'))


Test predictions: tensor([0.0627, 0.0655, 0.0618,  ..., 0.0630, 0.0600, 0.0620])
Test targets: tensor([0., 0., 0.,  ..., 0., 0., 0.])


In [31]:
candidate_index = test_data.all_sample_index.detach().cpu()
test_labels = test_data.labels.detach().cpu()
true_index = test_data.edge_index.detach().cpu()

In [38]:
from torchmetrics.functional.retrieval import retrieval_reciprocal_rank
# preds = torch.tensor([0.2, 0.3, 0.5])
# target = torch.tensor([False, True, False])
# retrieval_reciprocal_rank(preds, target)




reciprocal_rank_list = []
for index_chunk in torch.split(candidate_index[1], 10):
    candidate_list = index_chunk
    query_labels = test_labels[candidate_list]
    query_predictions = test_predictions[candidate_list]
    reciprocal_rank = retrieval_reciprocal_rank(query_predictions, query_labels)
    print(f'Query Labels: {query_labels}')
    print(f'Query Predictions: {query_predictions}')
    print(f'Reciprocal Rank: {reciprocal_rank}')

Query Labels: tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])
Query Predictions: tensor([0.0634, 0.0639, 0.0611, 0.0629, 0.0608, 0.0652, 0.0625, 0.0597, 0.0660,
        0.0602])
Reciprocal Rank: 0.5
Query Labels: tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.])
Query Predictions: tensor([0.0602, 0.0620, 0.0604, 0.0580, 0.0600, 0.0616, 0.0778, 0.0615, 0.0595,
        0.0623])
Reciprocal Rank: 0.5
Query Labels: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Query Predictions: tensor([0.0623, 0.0627, 0.0606, 0.0615, 0.0611, 0.0608, 0.0635, 0.0602, 0.0615,
        0.0729])
Reciprocal Rank: 0.0
Query Labels: tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 1.])
Query Predictions: tensor([0.0595, 0.0634, 0.0600, 0.0587, 0.0609, 0.0602, 0.0618, 0.0604, 0.0637,
        0.1558])
Reciprocal Rank: 1.0
Query Labels: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Query Predictions: tensor([0.0618, 0.0623, 0.0778, 0.0609, 0.0622, 0.0593, 0.0600, 0.0625, 0.0623,
        0.0618])
Reciprocal Rank: 0.0


(tensor([114, 115, 116, 117, 118, 119, 120, 121, 122, 123]),
 tensor([124, 125, 126, 127, 128, 129, 130, 131, 132, 133]),
 tensor([134, 135, 136, 137, 138, 139, 140, 141, 142, 143]),
 tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153]),
 tensor([154, 155, 130, 156, 157, 158, 159, 160, 161, 162]),
 tensor([163, 164, 165, 166, 167, 168, 169, 170, 171, 172]),
 tensor([173, 174, 175, 176, 177, 178, 179, 180, 181, 182]),
 tensor([183, 184, 185, 186, 187, 161, 188, 189, 190, 191]),
 tensor([192, 193, 194, 195, 196, 153, 197, 198, 199, 200]),
 tensor([201, 202, 203, 204, 205, 206, 207, 208, 209, 210]),
 tensor([211, 212, 213, 214, 215, 216, 217, 218, 219, 220]),
 tensor([221, 222, 223, 224, 225, 226, 227, 228, 229, 230]),
 tensor([231, 232, 233, 234, 235, 236, 237, 238, 239, 240]),
 tensor([241, 242, 243, 244, 245, 246, 247, 248, 249, 250]),
 tensor([251, 140, 252, 253, 254, 255, 256, 257, 258, 259]),
 tensor([260, 261, 262, 263, 172, 264, 265, 266, 267, 268]),
 tensor([269, 270, 172, 

In [17]:
candidate_index[1, candidate_list]

array([225, 226, 227, 228, 229, 230, 231, 232, 233, 234], dtype=int64)

In [80]:
test_data.all_sample_index[test_data.all_sample_index[1, candidate_list].detach().cpu()]

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [69]:
record

user                          103aaf06662cc01fe8b1ddaa39e41f59a8332832
past_interactions     [595, 84, 152, 614, 688, 48, 34, 191, 1362, 171]
candidates           [427, 46, 995, 1350, 1548, 88, 895, 871, 1676,...
positive_song_id                                                   111
Name: 0, dtype: object

Unnamed: 0,user,past_interactions,candidates,positive_song_id
0,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[595, 84, 152, 614, 688, 48, 34, 191, 1362, 171]","[427, 46, 995, 1350, 1548, 88, 895, 871, 1676,...",111
1,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[84, 152, 614, 688, 48, 34, 191, 1362, 171, 111]","[2376, 585, 1842, 196, 1296, 528, 1258, 1321, ...",196
2,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[152, 614, 688, 48, 34, 191, 1362, 171, 111, 196]","[2381, 152, 1330, 832, 2467, 116, 877, 1051, 2...",152
3,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[614, 688, 48, 34, 191, 1362, 171, 111, 196, 152]","[1883, 2426, 2211, 1801, 1446, 1634, 643, 1626...",643
4,103aaf06662cc01fe8b1ddaa39e41f59a8332832,"[688, 48, 34, 191, 1362, 171, 111, 196, 152, 643]","[2072, 2277, 1258, 2222, 691, 207, 144, 2194, ...",144


In [44]:
test_targets.shape

torch.Size([189230])