# Preprocess Data into graphs for training, testing, and validation

In [1]:
import pandas as pd
import numpy as np
import sys
import os 

## Read Datasets

In [2]:
# with pd.option_context('display.max_columns', None):
#     display(track_features.head())

## Creating negative samples


In [41]:
from numba import njit
import pandas as pd
import random


# read and format dataset
m4a_le = pd.read_parquet('../../datasets/m4a-onion-updated-20241210/M4A_Listening Events_with img_Final.parquet')
m4a_le = m4a_le.sort_values(by = 'timestamp', ascending = True).copy()
m4a_le.timestamp = m4a_le.timestamp.astype('datetime64[ns]')
m4a_le.timestamp =(m4a_le.timestamp - m4a_le.timestamp.min()).dt.total_seconds().astype(int)
# Map track IDs and user IDs to unique indices
track_id_mapping = {id_: idx for idx, id_ in enumerate(m4a_le['track_id'].unique())}
user_id_mapping = {id_: idx + len(track_id_mapping) for idx, id_ in enumerate(m4a_le['user_id'].unique())}
# add graph user id and graph_track_id to m4a le
m4a_le['graph_user_id'] = m4a_le['user_id'].map(user_id_mapping)
m4a_le['graph_track_id'] = m4a_le['track_id'].map(track_id_mapping)
# m4a_le = m4a_le.drop_duplicates(subset = ['user_id', 'track_id'], keep = 'first') # drop duplicate edges, keep first
m4a_features = pd.read_parquet('../../datasets/m4a-onion-updated-20241210/M4A_song features_with img_Final.parquet')

from importlib import reload
import sampling
reload(sampling)
from sampling import sample_negative_edges, get_aux_candidates, get_sorted_tracks_by_timestamp, get_existing_edges, split_user_timeline, get_splits


# CREATING SPLITS
train, val, test = get_splits(m4a_le, train_frac = 0.6, val_frac = 0.2)


# NEGATIVE SAMPLING
# step 1: get existing edges set from all data
existing_edges_set = get_existing_edges(m4a_le)
# step 2: get tracks sorted by first listening event 
ts_tracks = get_sorted_tracks_by_timestamp(m4a_le)
# step 3: get alternate candidates when candidate list is empty or small (older songs)
aux_candidates = get_aux_candidates(m4a_features, track_id_mapping, release_year = 2018)

## implement negative sampling
user_id, pos_sample, neg_sample = sample_negative_edges(train, ts_tracks, existing_edges_set, aux_candidates)


num_users = len(user_id_mapping)
num_tracks = len(track_id_mapping)

## ADDING CANDIDATE FOR TEST SET TO USE IN FINAL MRR CALCULATION FOR MODEL
import random as rd

def sample_mrr_candidates(all_pos_tracks, pos_track, num_candidates):
  candidates = []
  candidates.append(pos_track)
  while len(candidates) < num_candidates:
      candidate = rd.choice(range(num_tracks))
      if candidate not in candidates and candidate not in all_pos_tracks:
        candidates.append(candidate)
  return candidates


def assign_test_candidates(row):
    all_pos_tracks = m4a_le.loc[m4a_le['graph_user_id'] == row['graph_user_id'], 'graph_track_id']
    row['test_candidates'] = sample_mrr_candidates(all_pos_tracks, row['graph_track_id'], num_candidates=20)
    return row
test = test.apply(assign_test_candidates, axis=1)
      
test.head()

Unnamed: 0,user_id,track_id,timestamp,graph_user_id,graph_track_id,test_candidates
0,32911,3Y38upRWxIOVg0Uq,4154564,32685,3121,"[3121, 6117, 21572, 9611, 28712, 4455, 8037, 1..."
1,32911,6DwzPznhYUCUe4Ud,4156633,32685,1118,"[1118, 31238, 24627, 12243, 5682, 1872, 15925,..."
2,32911,uy4i5oHWdJOcP9EQ,4157710,32685,17633,"[17633, 32614, 27016, 5545, 14715, 4549, 2256,..."
3,32911,CmsmtpykCF4w97nF,4159118,32685,1766,"[1766, 28944, 15598, 19309, 18916, 14730, 2763..."
4,32911,qDukKJfNPQJahzMA,4159412,32685,10120,"[10120, 27762, 21662, 4906, 4601, 7780, 31600,..."


In [44]:
train.to_parquet('../../processed_data/m4a-onion/m4a_train.parquet')
val.to_parquet('../../processed_data/m4a-onion/m4a_val.parquet')
test.to_parquet('../../processed_data/m4a-onion/m4a_test.parquet')
m4a_le.to_parquet('../../processed_data/m4a-onion/m4a_all.parquet')
pd.DataFrame({
    'user_id' : user_id_mapping.keys(),
    'graph_user_id': user_id_mapping.values()
}).to_parquet('../../processed_data/m4a-onion/m4a_user_mapping.parquet')
pd.DataFrame({
    'track_id' : track_id_mapping.keys(),
    'graph_track_id': track_id_mapping.values()
}).to_parquet('../../processed_data/m4a-onion/m4a_track_mapping.parquet')

## Create Graphs

In [111]:
#### TRACK FEATURES

# Extract node features for tracks
track_features = m4a_features.drop(['id', 'genres', 'tags', 'track_id', '(tag, weight)', 'artist', 'song', 'lang', 'spotify_id', 'album_name', 'lyrics_embedding'], axis=1, errors='ignore').fillna(0)
# Check which columns contain strings
str_columns = [col for col in track_features.columns if track_features[col].apply(lambda x: isinstance(x, str)).any()]
# Convert string columns to float
track_features[str_columns] = track_features[str_columns].astype(float)
# convert lyrics embeddings to columns
lyrics_embedding = pd.DataFrame(m4a_features.lyrics_embedding.tolist(), index = m4a_features.index, columns = [f'lyrics_{i}' for i in range(768)])
track_features = pd.concat([track_features, lyrics_embedding], axis = 1)
track_features_tensor = torch.tensor(track_features.values.tolist(), dtype=torch.float)

In [118]:
m4a_features 

Unnamed: 0,id,lowlevel.average_loudness,lowlevel.barkbands_crest.stdev,lowlevel.barkbands_flatness_db.stdev,lowlevel.barkbands_kurtosis.stdev,lowlevel.barkbands_skewness.stdev,lowlevel.barkbands_spread.stdev,lowlevel.dissonance.stdev,lowlevel.erbbands_crest.stdev,lowlevel.erbbands_flatness_db.stdev,...,release,danceability,energy,key,mode,valence,tempo,genres,tags,duration_ms
0,0010xmHR6UICBOYT,0.985564,3.347479,0.090954,33.817204,2.315825,5.184556,0.058524,6.695590,0.071246,...,2013,0.591,0.513,7.0,0.0,0.263,172.208,underground hip hop,"instrumental hip-hop,underground hip hop,instr...",325096
1,007LIJOPQ4Sb98qV,0.974263,2.883465,0.020475,2.416713,0.617364,13.249235,0.011965,2.576587,0.020143,...,2009,0.357,0.708,9.0,1.0,0.470,123.904,"post-punk,new wave","post-punk,new wave,1985",326067
2,00CH4HJdxQQQbJfu,0.985387,2.644755,0.032223,3.445632,0.734539,7.933530,0.016695,3.363418,0.037764,...,2009,0.541,0.530,7.0,1.0,0.583,152.810,"indie rock,experimental,shoegaze,experimental","indie rock,experimental,indie,chillout,shoegaz...",175347
3,00GCd9HYEge6Ntwi,0.970022,5.783864,0.088446,41.049015,2.812269,20.235806,0.035987,4.849710,0.047168,...,2016,0.903,0.521,1.0,1.0,0.761,150.018,hip hop,"hip hop,banned,british hip hop,english hip hop...",223411
4,00P2bHdWFkghmDqz,0.951014,3.333300,0.048984,5.675779,0.924659,8.323689,0.036883,4.542486,0.040003,...,2012,0.232,0.577,4.0,0.0,0.669,65.195,soul,"fip,soul",252213
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
42756,zzbrm6nHm8grTZeF,0.980749,3.041098,0.025521,3.856150,0.803437,10.725990,0.025318,4.997821,0.032084,...,2009,0.416,0.354,0.0,1.0,0.181,80.767,country,"country,carrie underwood,female vocalists,ballads",268560
42757,zziWCQWVlwBURp9X,0.912043,3.801924,0.048703,9.231892,1.210800,13.700731,0.040273,5.450499,0.040460,...,2015,0.644,0.647,0.0,0.0,0.476,131.014,"electropop,house,electronic","electropop,house,love at first listen,beautifu...",226541
42758,zzjzIS5AWk6uHtWX,0.969503,3.946826,0.033180,6.699633,0.862496,7.988182,0.026416,3.771169,0.030554,...,2011,0.527,0.326,1.0,0.0,0.216,142.000,"progressive rock,rock","progressive rock,rock,melancholic",230000
42759,zzm0WMJ14dzbttpm,0.980208,3.325993,0.024284,3.587760,0.809922,11.766291,0.017426,3.778920,0.027107,...,2003,0.332,0.921,10.0,0.0,0.221,122.760,"alternative rock,rock","cover,covers,alternative rock,alternative,rock",386267


In [117]:
m4a_features['id'].map(track_id_mapping).values

array([12056, 12057, 12058, ..., 55032, 55033, 55034], dtype=int64)

In [120]:
def check(n, l):
    subs = [l[i:i+n] for i in range(len(l)) if len(l[i:i+n]) == n]
    return any([(sorted(sub) in range(min(l), max(l)+1)) for sub in subs])


check(42761, m4a_features['id'].map(track_id_mapping).values.tolist())

False

In [123]:
m4a_features.drop_duplicates(['id']) # no duplicates

Unnamed: 0,id,lowlevel.average_loudness,lowlevel.barkbands_crest.stdev,lowlevel.barkbands_flatness_db.stdev,lowlevel.barkbands_kurtosis.stdev,lowlevel.barkbands_skewness.stdev,lowlevel.barkbands_spread.stdev,lowlevel.dissonance.stdev,lowlevel.erbbands_crest.stdev,lowlevel.erbbands_flatness_db.stdev,...,release,danceability,energy,key,mode,valence,tempo,genres,tags,duration_ms
0,0010xmHR6UICBOYT,0.985564,3.347479,0.090954,33.817204,2.315825,5.184556,0.058524,6.695590,0.071246,...,2013,0.591,0.513,7.0,0.0,0.263,172.208,underground hip hop,"instrumental hip-hop,underground hip hop,instr...",325096
1,007LIJOPQ4Sb98qV,0.974263,2.883465,0.020475,2.416713,0.617364,13.249235,0.011965,2.576587,0.020143,...,2009,0.357,0.708,9.0,1.0,0.470,123.904,"post-punk,new wave","post-punk,new wave,1985",326067
2,00CH4HJdxQQQbJfu,0.985387,2.644755,0.032223,3.445632,0.734539,7.933530,0.016695,3.363418,0.037764,...,2009,0.541,0.530,7.0,1.0,0.583,152.810,"indie rock,experimental,shoegaze,experimental","indie rock,experimental,indie,chillout,shoegaz...",175347
3,00GCd9HYEge6Ntwi,0.970022,5.783864,0.088446,41.049015,2.812269,20.235806,0.035987,4.849710,0.047168,...,2016,0.903,0.521,1.0,1.0,0.761,150.018,hip hop,"hip hop,banned,british hip hop,english hip hop...",223411
4,00P2bHdWFkghmDqz,0.951014,3.333300,0.048984,5.675779,0.924659,8.323689,0.036883,4.542486,0.040003,...,2012,0.232,0.577,4.0,0.0,0.669,65.195,soul,"fip,soul",252213
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
42756,zzbrm6nHm8grTZeF,0.980749,3.041098,0.025521,3.856150,0.803437,10.725990,0.025318,4.997821,0.032084,...,2009,0.416,0.354,0.0,1.0,0.181,80.767,country,"country,carrie underwood,female vocalists,ballads",268560
42757,zziWCQWVlwBURp9X,0.912043,3.801924,0.048703,9.231892,1.210800,13.700731,0.040273,5.450499,0.040460,...,2015,0.644,0.647,0.0,0.0,0.476,131.014,"electropop,house,electronic","electropop,house,love at first listen,beautifu...",226541
42758,zzjzIS5AWk6uHtWX,0.969503,3.946826,0.033180,6.699633,0.862496,7.988182,0.026416,3.771169,0.030554,...,2011,0.527,0.326,1.0,0.0,0.216,142.000,"progressive rock,rock","progressive rock,rock,melancholic",230000
42759,zzm0WMJ14dzbttpm,0.980208,3.325993,0.024284,3.587760,0.809922,11.766291,0.017426,3.778920,0.027107,...,2003,0.332,0.921,10.0,0.0,0.221,122.760,"alternative rock,rock","cover,covers,alternative rock,alternative,rock",386267


In [127]:
print(len(set(m4a_le.track_id) - set(m4a_features.id)))
print(len(set(m4a_features.id) - set(m4a_le.track_id)))

218
0


array([12056, 12057, 12058, ..., 55032, 55033, 55034], dtype=int64)