In [20]:
import os 

import pandas as pd
import numpy as np 
import json
from tqdm.notebook import tqdm
import pickle

from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data, Dataset
import torch

import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU, GCLSTM, DCRNN, A3TGCN

from sklearn.metrics import accuracy_score, f1_score, classification_report

from importlib import reload
"""
import songs_mobility
reload(songs_mobility)
from songs_mobility import SongMobilityDatasetLoader
"""

import song_mobility_static_graph
reload(song_mobility_static_graph)
from song_mobility_static_graph import StaticSongDatasetLoader

from sklearn.metrics import accuracy_score

import warnings
warnings.filterwarnings('ignore')

source= 'spotify' # 'lfm' or 'spotify'

In [9]:
import glob
target_songs=glob.glob(os.path.join('data', 'gnn_datasets',f'{source}', '*_static_v2.json')) 
target_songs= [os.path.basename(p) for p in target_songs]
target_songs = [s.split('_')[1] for s in target_songs]
target_songs

['3cfOd4CMv2snFaKAnMdnvK',
 '7Hz6LLOVxrojLPIHJJ1S0E',
 '1Wrzhfa5bNlqvsnCztz190',
 '3H3r2nKWa3Yk5gt8xgmsEt',
 '62aP9fBQKYKxi7PDXwcUAS',
 '2t16D9V5FmmRAJjsSpwvZf',
 '0d28khcov6AiegSCpG5TuT',
 '7hPtrGqzMOO0KMevkaQnYR',
 '1ot6jEe4w4hYnsOPjd3xKQ',
 '37vqCDZ6g5pZ5gqDS90MDP',
 '0ljEFYEFuqspdPCWn75FPP',
 '0ezxVlILGg3XiTJONE7PEn',
 '7yO48FWUkqsrdxrwkGcnwl',
 '2tnVG71enUj33Ic2nFN6kZ',
 '77V0TWycxV4wuhcip5LF3X',
 '6ScCr1C4VIFzulq8bP65mj',
 '77GcYr9JP6uvJM0kPa6Nzk',
 '61YzdCCBPM5Pc7lIiD5i8C',
 '2FPfeYlrbSBR8PwCU0zaqq',
 '4kKdvXD0ez7jp1296JmAts',
 '4zii9cq9ZgxYJFnaxM3Tul',
 '5sICkBXVmaCQk5aISGR3x1',
 '4QYQgJhBryglC2hEpVGrZU',
 '1RwoXmQ3RvIdfG8bLeX65g',
 '3C2xQqHVkhA1Ht17SzPRke',
 '0NKoPGc72yUKTIneVhFfg1',
 '6Qyc6fS4DsZjB2mRW9DsQs',
 '6ap9lSRJ0iLriGLqoJ44cq',
 '6ZfXA2xakAvphXOSOJ3u1W',
 '4QX5pZQpQTgVlkqfUTDim0',
 '7B3z0ySL9Rr0XvZEAjWZzM',
 '484sAt8h6YYMXNuBV2oT83',
 '4GGSmoTmc5LClrowB4SmEw',
 '2AkmdLbVKS1steeZdy8H1l',
 '59g1IKRYrCBsYLdd0WhKus',
 '3PKk9pGTfQBINlJrEyrH6c',
 '2WD9ggmpZE7Wodh3qVVCgg',
 

In [10]:
from torch_geometric.nn import MessagePassing
class LightGCN(MessagePassing):
    """
    A single LightGCN layer. Extends the MessagePassing class from PyTorch Geometric
    """
    def __init__(self):
        super(LightGCN, self).__init__(aggr='add') # aggregation function is 'add

    def message(self, x_j, norm):
        """
        Specifies how to perform message passing during GNN propagation. For LightGCN, we simply pass along each
        source node's embedding to the target node, normalized by the normalization term for that node.
        args:
          x_j: node embeddings of the neighbor nodes, which will be passed to the central node (shape: [E, emb_dim])
          norm: the normalization terms we calculated in forward() and passed into propagate()
        returns:
          messages from neighboring nodes j to central node i
        """
        # Here we are just multiplying the x_j's by the normalization terms (using some broadcasting)
        return norm.view(-1, 1) * x_j

    def forward(self, x, edge_index, edge_weight):
        """
        Performs the LightGCN message passing/aggregation/update to get updated node embeddings

        args:
          x: current node embeddings (shape: [N, emb_dim])
          edge_index: message passing edges (shape: [2, E])
        returns:
          updated embeddings after this layer
        """
        # Computing node degrees for normalization term in LightGCN (see LightGCN paper for details on this normalization term)
        # These will be used during message passing, to normalize each neighbor's embedding before passing it as a message
        row, col = edge_index
        deg = degree(col)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        norm[norm == float('inf')] = 0
        norm = norm * edge_weight

        # Begin propagation. Will perform message passing and aggregation and return updated node embeddings.
        return self.propagate(edge_index, x=x, norm=norm)

In [11]:
from torch_geometric.nn import GATConv
from torch_geometric.utils import degree


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = LightGCN()

        self.lin1 = torch.nn.Linear(1, 1024)
        self.lin2 = torch.nn.Linear(1024, 1)

    def forward(self, x, edge_index, edge_weight):
        
        x_at_each_layer = []             # stores embeddings from each layer. Start with layer-0 embeddings
        
        x=self.conv1(x, edge_index, edge_weight)
        x_at_each_layer.append(x)
        
        x = torch.stack(x_at_each_layer, dim=0).mean(dim=0) # take average to calculate multi-scale embeddings
        
        return torch.sigmoid(x)


def train_and_eval(data):
    #total_loss = 0
    data = data.to(device)
    y= data.y.cpu().data.numpy()

    y_hat=model(data.x, data.edge_index, data.edge_attr)
    y_hat_flatten= y_hat.cpu().data.numpy().reshape(-1)
        
    y_hat_crisp= np.array([1.0 if v >0.5 else 0 for v in y_hat_flatten])
    acc = accuracy_score(y,y_hat_crisp)
    f1= f1_score(y,y_hat_crisp)
        
    return acc, f1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


results=[]
for s in tqdm(target_songs, desc='Songs...'):
    model = Net().to(device)
    loader= StaticSongDatasetLoader(source= source, target_song=s)
    dataset=loader.get_dataset(lags=1)

    f1_scores= []
    accs = []
    for data in dataset:
        acc, f1 = train_and_eval(data)
        f1_scores.append(f1)
        accs.append(acc)

    mean_f1= np.mean(f1_scores)
    mean_acc= np.mean(accs)

    results.append((s,mean_f1, mean_acc))

results_df = pd.DataFrame.from_records(results, columns='song f1_score acc'.split())
results_df

Songs...:   0%|          | 0/774 [00:00<?, ?it/s]

Unnamed: 0,song,f1_score,acc
0,3cfOd4CMv2snFaKAnMdnvK,0.441790,0.844643
1,7Hz6LLOVxrojLPIHJJ1S0E,0.716028,0.831579
2,1Wrzhfa5bNlqvsnCztz190,0.749966,0.937143
3,3H3r2nKWa3Yk5gt8xgmsEt,0.675294,0.839560
4,62aP9fBQKYKxi7PDXwcUAS,0.739347,0.820463
...,...,...,...
769,1WOmEg42Mgaz1nRimIMR9C,0.327683,0.744444
770,51cdazVlYWqmnsfrVaeTeG,0.439202,0.880952
771,3YQUpKTmtgPJf8u6ooi9KA,0.554138,0.798980
772,7wBrglFVTNCcW6IhdgBkm1,0.344773,0.803361


In [12]:
results_df.mean(axis=0), results_df.std(axis=0)

(f1_score    0.539838
 acc         0.845460
 dtype: float64,
 f1_score    0.175845
 acc         0.057161
 dtype: float64)

In [13]:
results_df.to_csv(os.path.join('results', f'results_gnn_static_v2_{source}.csv'))

### Analysis of the results based on gender

In [21]:
results_df= pd.read_csv(os.path.join('results', f'results_gnn_static_v2_spotify.csv'), index_col=0)
results_df.head()

Unnamed: 0,song,f1_score,acc
0,3cfOd4CMv2snFaKAnMdnvK,0.44179,0.844643
1,7Hz6LLOVxrojLPIHJJ1S0E,0.716028,0.831579
2,1Wrzhfa5bNlqvsnCztz190,0.749966,0.937143
3,3H3r2nKWa3Yk5gt8xgmsEt,0.675294,0.83956
4,62aP9fBQKYKxi7PDXwcUAS,0.739347,0.820463


In [22]:
song_genres_df = pd.read_csv(os.path.join('data', 'songs_genres_2017_2022.csv'), index_col=0)
song_genres_df.head()

Unnamed: 0,0,1,2,3,4,5,6
6m2LNopVJKsvBB9l7Z1rwn,funk,pop,,,,,
7Ckhk1XW5NV2k4jpqtQNlz,reggaeton,latin,,,,,
7MHN1aCFtLXjownGhvEQlF,reggaeton,latin,,,,,
1xndOD8CreR5ctkOv5G1LN,hardcore,,,,,,
3xWEI23MLJrD0dpDcDUTW6,reggaeton,latin,pop,,,,


In [29]:
song_genres_df.loc['6m2LNopVJKsvBB9l7Z1rwn'].values.tolist()

['funk', 'pop', nan, nan, nan, nan, nan]

In [23]:
### Read songs genre groups

import json

# Opening JSON file
f = open(os.path.join('data','music_styles_unique.json'))
 
# returns JSON object as a dictionary
music_styles = json.load(f)

sub_gender_to_gender = {}
for gender, sub_genders in music_styles.items():
    for sb in sub_genders:
        sub_gender_to_gender[sb]=gender
sub_gender_to_gender

{'pop rock': 'pop',
 'pop': 'pop',
 'british': 'pop',
 'j-pop': 'pop',
 'k-pop': 'pop',
 'mandopop': 'pop',
 'cantopop': 'pop',
 'hip-hop': 'hip-hop',
 'funk': 'hip-hop',
 'trap': 'hip-hop',
 'rap': 'hip-hop',
 'rock': 'rock',
 'rockabilly': 'rock',
 'metal': 'rock',
 'punk': 'rock',
 'metalcore': 'rock',
 'hardcore': 'rock',
 'j-rock': 'rock',
 'grunge': 'rock',
 'trance': 'electronic',
 'chill': 'electronic',
 'dance': 'electronic',
 'edm': 'electronic',
 'house': 'electronic',
 'dubstep': 'electronic',
 'dancehall': 'electronic',
 'electro': 'electronic',
 'techno': 'electronic',
 'electronic': 'electronic',
 'club': 'electronic',
 'latin': 'latin',
 'reggaeton': 'latin',
 'spanish': 'latin',
 'pagode': 'latin',
 'cumbia': 'latin',
 'salsa': 'latin',
 'latino': 'latin',
 'sertanejo': 'latin',
 'mpb': 'latin',
 'anime': 'indie',
 'indie': 'indie',
 'alternative': 'indie',
 'emo': 'indie',
 'soundtracks': 'classical and ost',
 'classical': 'classical and ost',
 'piano': 'classical and

In [24]:
results_df.mean(axis=0), results_df.std(axis=0)

(f1_score    0.539838
 acc         0.845460
 dtype: float64,
 f1_score    0.175845
 acc         0.057161
 dtype: float64)

In [30]:
np.mean([1,3,4])

2.6666666666666665

In [40]:
import math
results_by_gender={}
for i, row in results_df.iterrows():
    song_id = row['song']
    acc= row['acc']
    if not math.isnan(acc):
        genders = song_genres_df.loc[song_id].values.tolist()
        for sg in genders:
            if sg in sub_gender_to_gender:
                g = sub_gender_to_gender[sg]
                results_g =results_by_gender.get(g,[])
                results_g.append(acc)
                results_by_gender[g]=results_g
                break

mean_acc_by_gender={}
for g in results_by_gender.keys():
    mean_acc_by_gender[g]= round(np.mean(results_by_gender[g]),3)
mean_acc_by_gender

{'pop': 0.827,
 'latin': 0.913,
 'electronic': 0.815,
 'hip-hop': 0.836,
 'rythm and blues': 0.843,
 'rock': 0.817,
 'indie': 0.796,
 'folk/traditional': 0.854,
 'classical and ost': 0.86,
 'miscellaneous': 0.732}

In [34]:
results_by_gender

{'pop': [0.8446428571428571,
  0.8315789473684209,
  0.8395604395604397,
  0.8204633204633205,
  0.8402597402597404,
  0.8047619047619047,
  0.8585714285714285,
  0.8371428571428572,
  0.8035714285714286,
  0.8021428571428573,
  0.8523809523809524,
  0.7857142857142857,
  0.8271428571428571,
  0.8653061224489795,
  0.8558441558441559,
  0.86,
  0.7755952380952381,
  nan,
  0.7999999999999999,
  0.8040816326530612,
  0.7789115646258503,
  0.8964285714285715,
  0.8571428571428571,
  0.807142857142857,
  0.8301020408163265,
  0.8285714285714286,
  0.8142857142857144,
  0.763265306122449,
  0.826530612244898,
  0.7464285714285714,
  0.8133333333333336,
  0.8375,
  0.8375,
  0.812987012987013,
  0.8821428571428571,
  0.8821428571428571,
  0.8482142857142857,
  0.8023809523809523,
  0.7839285714285714,
  0.9457142857142856,
  0.8809523809523809,
  0.780952380952381,
  nan,
  0.8297619047619048,
  0.8428571428571429,
  0.8428571428571429,
  0.8398496240601504,
  0.8732142857142857,
  0.873214

In [25]:
print("That's all folks!")

That's all folks!
