In [89]:
import pandas as pd
from pathlib import Path
from datetime import datetime
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [20]:
BASE_DATE = datetime.strptime("2025-01-03", '%Y-%m-%d').timestamp()
print(BASE_DATE)

1735858800.0


In [37]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cuda


In [120]:
tracks_raw_data = pd.read_json(Path('../data_v2/tracks.jsonl'), lines=True)
artists_raw_data = pd.read_json(Path('../data_v2/artists_2.jsonl'), lines=True)

# NORMALIZE DATES
tracks_raw_data["release_date"] = (pd.to_datetime(tracks_raw_data["release_date"], format='mixed').apply(lambda x: x.timestamp())).div(BASE_DATE)

# NORMALIZE DURATION
tracks_raw_data["duration_ms"] = tracks_raw_data["duration_ms"].div(tracks_raw_data["duration_ms"].max())

# NORMALIZE TEMPO
tracks_raw_data["tempo"] = tracks_raw_data["tempo"].div(tracks_raw_data["tempo"].max())

# EXPLICITE ENCODING
tracks_raw_data["explicit"] = tracks_raw_data["explicit"].apply(lambda x: [0, 1] if x else [1, 0])

# ARTISTS INJECTION
def couple_artist_to_track(artist_id: str):
    artist_index = np.where(artists_raw_data["id"] == artist_id)[0][0]
    return [artists_raw_data["genres"][artist_index], artists_raw_data["id_hash"][artist_index]]

tracks_raw_data["geners"], tracks_raw_data["artists_hash"] = zip(*list(tracks_raw_data["id_artist"].apply(couple_artist_to_track).values))


In [121]:
class TracksDataset(Dataset):
    def __init__(self, tracks_data: pd.DataFrame):
        self.data = tracks_data

    def __len__(self):
        return len(self.data)

    def get_item(self, idx):
        return self.data.iloc[idx].values

    def __getitem__(self, idx):
        unpacked_data = []
        for data in self.data.iloc[idx].drop("id").drop("name").drop("id_artist").values:
            if type(data) != list:
                unpacked_data.append(data)
            else:
                unpacked_data += data
        return unpacked_data

In [122]:
tracks_raw_dataset = TracksDataset(tracks_raw_data)
for i in range(10):
    print(tracks_raw_dataset[i+100])

[0.51, 0.0244967669500308, 1, 0, -0.11442958378872752, 0.665, 0.7010000000000001, 8, 0.1101166667, 0.0329, 0.622, 6.48e-05, 0.14100000000000001, 0.977, 0.6081490601956393, 0, 1, 0, 0, 0, 0, 0, 'rock', 'pop', 17241872]
[0.51, 0.030182575945486907, 1, 0, -0.12328929058054722, 0.635, 0.656, 2, 0.1416666667, 0.0291, 0.389, 0.00127, 0.0828, 0.77, 0.610793324822012, 0, 1, 0, 0, 0, 0, 0, 'rock', 96828466]
[0.75, 0.04500033735751499, 1, 0, -0.09456990395762604, 0.525, 0.216, 6, 0.22585000000000002, 0.030100000000000002, 0.837, 0.0, 0.107, 0.328, 0.4850589961789922, 0, 0, 1, 0, 0, 0, 0, 'rock', 75345512]
[0.73, 0.0325836391798766, 1, 0, -0.09541605572987849, 0.5730000000000001, 0.9390000000000001, 8, 0.10735000000000001, 0.08080000000000001, 0.493, 0.0, 0.0994, 0.963, 0.6241827541242805, 0, 0, 0, 0, 0, 0, 0, 'rock', 67110466]
[0.7000000000000001, 0.04182577887112894, 1, 0, -0.10910380498690331, 0.883, 0.464, 10, 0.1875166667, 0.0591, 0.279, 1.51e-05, 0.07440000000000001, 0.925, 0.50680829990140