In [1]:
import os
import numpy as np
import pandas as pd
import modin.pandas as mpd
import random

from tqdm import tqdm
import torch
from torch import Tensor

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from torch_geometric.loader import LinkNeighborLoader

from torch_geometric.data import HeteroData
import torch_geometric.transforms as T

from torch_geometric.nn import GATConv, to_hetero
import torch.nn.functional as F


playlist_song = mpd.DataFrame(pd.read_parquet("data/gnn_playlists2songs.parquet"))
song_artist = mpd.DataFrame(pd.read_parquet("data/gnn_songs2artists.parquet"))
playlist_tag = mpd.DataFrame(pd.read_parquet("data/gnn_playlists2tags.parquet"))
playlist_dj = mpd.DataFrame(pd.read_parquet("data/gnn_playlists2djs.parquet"))

unique_playlist_id = np.unique(playlist_song['playlist_id'].values)
unique_playlist_id = mpd.DataFrame(data={
    'playlist_id': unique_playlist_id,
    'mappedID': pd.RangeIndex(len(unique_playlist_id)),
})
unique_song_id = np.unique(playlist_song['song_id'].values)
unique_song_id = mpd.DataFrame(data={
    'song_id': unique_song_id,
    'mappedID': pd.RangeIndex(len(unique_song_id)),
})
unique_artist = np.unique(song_artist['artist_id'].values)
unique_artist = mpd.DataFrame(data={
    'artist': unique_artist,
    'mappedID': pd.RangeIndex(len(unique_artist)),
})
unique_tag = np.unique(playlist_tag['tag_id'].values)
unique_tag = mpd.DataFrame(data={
    'tag': unique_tag,
    'mappedID': pd.RangeIndex(len(unique_tag)),
})
unique_dj = np.unique(playlist_dj['dj_id'].values)
unique_dj = mpd.DataFrame(data={
    'dj': unique_dj,
    'mappedID': pd.RangeIndex(len(unique_dj)),
})

def make_edge(edge_df, u0, u1):
    edge_df.drop_duplicates(inplace=True)
    ekey0, ekey1 = list(edge_df.columns)
    ukey0 = u0.columns[0]
    ukey1 = u1.columns[0]
    temp0 = mpd.merge(edge_df[ekey0], u0, left_on=ekey0, right_on=ukey0, how='left')
    temp0 = torch.from_numpy(temp0['mappedID'].values)
    temp1 = mpd.merge(edge_df[ekey1], u1, left_on=ekey1, right_on=ukey1, how='left')
    temp1 = torch.from_numpy(temp1['mappedID'].values)
    return torch.stack([temp0, temp1], dim=0)

edge_playlist_song = make_edge(playlist_song, unique_playlist_id, unique_song_id)
edge_song_artist = make_edge(song_artist, unique_song_id, unique_artist)
edge_playlist_tag = make_edge(playlist_tag, unique_playlist_id, unique_tag)
edge_playlist_dj = make_edge(playlist_dj, unique_playlist_id, unique_dj)

data = HeteroData()

data["playlist"].node_id = torch.arange(len(unique_playlist_id))
data["song"].node_id = torch.arange(len(unique_song_id))
data["artist"].node_id = torch.arange(len(unique_artist))
data["tag"].node_id = torch.arange(len(unique_tag))
data["dj"].node_id = torch.arange(len(unique_dj))

data["playlist", "playlist2song", "song"].edge_index = edge_playlist_song
data["song", "song2artist", "artist"].edge_index = edge_song_artist
data["playlist", "playlist2tag", "tag"].edge_index = edge_playlist_tag
data["playlist", "playlist2dj", "dj"].edge_index = edge_playlist_dj

data = T.ToUndirected()(data)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print(f"Device: '{device}'")

from torch_geometric.nn import GATConv, to_hetero
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, heads=8):
        super().__init__()
        self.conv1 = GATConv(hidden_channels, hidden_channels, heads, dropout=0.6, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=1, dropout=0.6, add_self_loops=False)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

    
class Classifier(torch.nn.Module):
    def forward(self, x_playlist: Tensor, x_song: Tensor, edge_label_index: Tensor) -> Tensor:
        edge_feat_playlist = x_playlist[edge_label_index[0]]
        edge_feat_song = x_song[edge_label_index[1]]
        return (edge_feat_playlist * edge_feat_song).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.playlist_emb = torch.nn.Embedding(data["playlist"].num_nodes, hidden_channels)
        self.song_emb = torch.nn.Embedding(data["song"].num_nodes, hidden_channels)
        self.tag_emb = torch.nn.Embedding(data["tag"].num_nodes, hidden_channels)
        self.dj_emb = torch.nn.Embedding(data["dj"].num_nodes, hidden_channels)
        self.artist_emb = torch.nn.Embedding(data["artist"].num_nodes, hidden_channels)
        self.gnn = GNN(hidden_channels)
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        self.classifier = Classifier()

    def forward(self, data: HeteroData) -> Tensor:
        x_dict = {
          "playlist": self.playlist_emb(data["playlist"].node_id),
          "song": self.song_emb(data["song"].node_id),
          "artist": self.artist_emb(data["artist"].node_id),
          "tag": self.tag_emb(data["tag"].node_id),
          "dj": self.dj_emb(data["dj"].node_id),
        } 

        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["playlist"],
            x_dict["song"],
            data["playlist", "playlist2song", "song"].edge_label_index,
        )

        return pred

        
model = Model(hidden_channels=64)
batch_size = 512
epochs = 1000

In [None]:
pas = pd.read_csv("playlists.csv")['playlist_id'].iloc[:500]
model_name = "model/model_20250206_0/epoch_0982__f1_0874.pth"
save_table_name = "dev-ai-project-357507.leo_melon_simple_mode.gnn_playlists_5nodes"

In [1]:
# model = Model(hidden_channels=64)
old_model = torch.load(model_name, weights_only=False)
state_dict = old_model.state_dict()

model = Model(hidden_channels=64)
model = model.to(device)
model.load_state_dict(state_dict)
_ = model.eval()

2025-02-17 03:30:09,896	INFO worker.py:1786 -- Started a local Ray instance.


Device: 'cuda'




In [31]:
from google.cloud import bigquery

# Initialize a BigQuery client
client = bigquery.Client()

# Define your table ID
table_id = save_table_name

for pa in tqdm(pas):
    pa_id = unique_playlist_id.loc[unique_playlist_id['playlist_id'] == pa]['mappedID'].item()

    songsin = edge_p2s[1][torch.where(edge_p2s[0] == pa_id)]

    all_songs = torch.arange(len(model.song_emb.weight))
    target_songs = all_songs[~torch.isin(all_songs, songsin)]
    playlist_song_pairs = torch.cartesian_prod(torch.tensor([pa_id]), target_songs)

    # Step 3: Create a LinkNeighborLoader
    link_loader = LinkNeighborLoader(
        data=data,
        num_neighbors=[20, 10],  # Number of neighbors to sample at each hop
        edge_label_index=(("playlist", "p2s", "song"), playlist_song_pairs.T),
        batch_size=batch_size,  # Number of pairs per batch
        shuffle=False,  # Shuffle the data for better training
        num_workers=0,
    )
    for sampled_data in link_loader:

        sampled_data.to(device)
        preds = model(sampled_data).detach().cpu().numpy()
        edge_label_index = sampled_data["playlist", "p2s", "song"].edge_label_index.cpu().numpy()
        playlists = sampled_data['playlist'].node_id[edge_label_index[0]].cpu().numpy()
        songs = sampled_data['song'].node_id[edge_label_index[1]].cpu().numpy()
        temp_song_ids = unique_song_id.loc[songs].song_id
        rows_to_insert = [
            {"source_playlist_id": int(pa), 
             "song_id": int(tsi),
             "prediction": float(p)} for tsi, p in zip(temp_song_ids, preds) if p > 3]  
        if len(rows_to_insert) > 0:
            errors = client.insert_rows_json(table_id, rows_to_insert)
            if errors:
                print("Errors:", errors)

  0%|          | 5/9899 [01:18<43:15:07, 15.74s/it]


KeyboardInterrupt: 

*** SIGTERM received at time=1739768163 on cpu 10 ***
PC: @     0x7f4827d6fdf6  (unknown)  epoll_wait
    @     0x7f4827faa140  (unknown)  (unknown)
[2025-02-17 04:56:03,209 E 156075 156075] logging.cc:440: *** SIGTERM received at time=1739768163 on cpu 10 ***
[2025-02-17 04:56:03,209 E 156075 156075] logging.cc:440: PC: @     0x7f4827d6fdf6  (unknown)  epoll_wait
[2025-02-17 04:56:03,210 E 156075 156075] logging.cc:440:     @     0x7f4827faa140  (unknown)  (unknown)
