In [1]:
import pandas as pd
from pathlib import Path
from datetime import datetime
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [2]:
BASE_DATE = datetime.strptime("2025-01-03", '%Y-%m-%d').timestamp()
print(BASE_DATE)

1735858800.0


In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cuda


In [6]:
tracks_raw_data = pd.read_json(Path('../data_v2/tracks.jsonl'), lines=True)
artists_raw_data = pd.read_json(Path('../data_v2/artists.jsonl'), lines=True)

# NORMALIZE DATES
tracks_raw_data["release_date"] = (pd.to_datetime(tracks_raw_data["release_date"], format='mixed').apply(lambda x: x.timestamp())).div(BASE_DATE)

# NORMALIZE DURATION
tracks_raw_data["duration_ms"] = tracks_raw_data["duration_ms"].div(tracks_raw_data["duration_ms"].max())

# NORMALIZE TEMPO
tracks_raw_data["tempo"] = tracks_raw_data["tempo"].div(tracks_raw_data["tempo"].max())

# EXPLICITE ENCODING
tracks_raw_data["explicit"] = tracks_raw_data["explicit"].apply(lambda x: [0, 1] if x else [1, 0])

# ARTISTS INJECTION
def couple_artist_to_track(artist_id: str):
    artist_index = np.where(artists_raw_data["id"] == artist_id)[0][0]
    return [artists_raw_data["genres"][artist_index], artists_raw_data["id_artist_hash"][artist_index]]

tracks_raw_data["geners"], tracks_raw_data["artists_hash"] = zip(*list(tracks_raw_data["id_artist"].apply(couple_artist_to_track).values))


In [9]:
class TracksDataset(Dataset):
    def __init__(self, tracks_data: pd.DataFrame):
        self.data = tracks_data

    def __len__(self):
        return len(self.data)

    def get_item(self, idx):
        return self.data.iloc[idx].values

    def __getitem__(self, idx):
        unpacked_data = []
        for data in self.data.iloc[idx].drop("id").drop("id_artist").values:
            if type(data) != list:
                unpacked_data.append(data)
            else:
                unpacked_data += data
        return unpacked_data

In [10]:
tracks_raw_dataset = TracksDataset(tracks_raw_data)
for i in range(10):
    print(tracks_raw_dataset[i+100])

[np.float64(0.51), np.float64(0.0244967669500308), 1, 0, np.float64(-0.11442958378872752), np.float64(0.665), np.float64(0.7010000000000001), np.int64(8), np.float64(0.1101166667), np.float64(0.0329), np.float64(0.622), np.float64(6.48e-05), np.float64(0.14100000000000001), np.float64(0.977), np.float64(0.6081490601956393), 0, 1, 0, 0, 0, 0, 0, 'rock', 'pop', np.int64(17241872)]
[np.float64(0.51), np.float64(0.030182575945486907), 1, 0, np.float64(-0.12328929058054722), np.float64(0.635), np.float64(0.656), np.int64(2), np.float64(0.1416666667), np.float64(0.0291), np.float64(0.389), np.float64(0.00127), np.float64(0.0828), np.float64(0.77), np.float64(0.610793324822012), 0, 1, 0, 0, 0, 0, 0, 'rock', np.int64(96828466)]
[np.float64(0.75), np.float64(0.04500033735751499), 1, 0, np.float64(-0.09456990395762604), np.float64(0.525), np.float64(0.216), np.int64(6), np.float64(0.22585000000000002), np.float64(0.030100000000000002), np.float64(0.837), np.float64(0.0), np.float64(0.107), np.fl