In [14]:
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
import polars as pl
import torch
from torch.utils.data import Dataset
from data_storage.DBManager import DBManager

In [111]:
class CBOWDataset(Dataset):
    """
    Selects tracks from the dataset on the SQLite database.
    """
    def __init__(self, n_playlists: int, context_size: int = 5):
        self.n_playlists = n_playlists
        self.context_size = context_size
        self._db = DBManager()
        self.cur = self._db.get_cursor()
        self.dataset = self._make_dataset()

    def _make_dataset(self) -> pl.DataFrame:
        """
        Returns polars dataframe containing:
            pid: playlist id
            pos: position of the track in the playlist
            track_idx: internal index of the track
        """
        query = f"""
            SELECT PT.track_uri, PT.pid, PT.pos
            FROM playlist_track PT
            WHERE PT.pid IN (SELECT pid FROM playlist LIMIT {self.n_playlists});
        """
        # Playlist tracks as dataframe
        df_pl_tracks = pl.read_database(query, self._db.get_connection())

        # Get a list of unique tracks
        self.track_vocab = df_pl_tracks["track_uri"].unique().to_list() + ["PAD"]
        self.track_2_idx = {track: idx for idx, track in enumerate(self.track_vocab)}
        self.idx_2_track = lambda idx: self.track_vocab[idx]

        # Map track_uri to track_idx and drop track_uri
        df_pl_tracks = df_pl_tracks.with_columns(
            track_idx=pl.col("track_uri").replace(
                self.track_2_idx,
                return_dtype=pl.UInt32
            )
        )
        df_pl_tracks = df_pl_tracks.drop("track_uri")
        return df_pl_tracks

    @property
    def n_tracks(self) -> int:
        return len(self.track_vocab)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Assumes idx is the index of the out-of-place track.
        Gets previous context_size tracks.
        """
        target = self.dataset["track_idx"][idx]

        # Get previous context_size tracks
        pid = self.dataset["pid"][idx]
        pos = self.dataset["pos"][idx]
        context = self.dataset.filter(
            (self.dataset["pid"] == pid) &
            (self.dataset["pos"] < pos)
        ).tail(self.context_size)["track_idx"].to_list()

        # Pad context if necessary
        if len(context) < self.context_size:
            pad_size = (self.context_size - len(context))
            context = [self.track_2_idx["PAD"]] * pad_size + context

        # Convert to pytorch tensors
        context = torch.tensor(context)
        # no need for one-hot encoding with NLLLoss
        return context, target

In [102]:
class Track2Vec(nn.Module):
    def __init__(self, num_tracks: int, embedding_dim: int, context_size: int):
        super(Track2Vec, self).__init__()
        self.model = nn.Sequential(
            nn.Embedding(num_tracks, embedding_dim),
            nn.Flatten(1),  # concatenate context
            nn.Linear(context_size * embedding_dim, 128),
            nn.LeakyReLU(),
            nn.Linear(128, num_tracks),
            nn.LogSoftmax(dim=1)
        )
    
    def forward(self, x):
        return self.model(x)

In [110]:
def train(model: Track2Vec, ds_train: CBOWDataset, epochs : int) -> None:
    print("Training CBOW model")
    criterion = nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    data_loader = DataLoader(ds_train, batch_size=32, shuffle=True)
    for epoch in range(1, epochs + 1):
        print(f"Epoch {epoch}")
        total_loss = 0
        for context, target in tqdm(data_loader):
            model.zero_grad()
            log_probs = model(context)
            loss = criterion(log_probs, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch}: Loss: {total_loss}")

In [18]:
CONTEXT_SIZE = 5
N_PLAYLISTS = 1_000
EMBEDDING_DIM = 250

In [108]:
ds = CBOWDataset(n_playlists=N_PLAYLISTS, context_size=CONTEXT_SIZE)

INFO:root:Attempting to connect to database.
INFO:root:Connected to database


In [119]:
ds.dataset.filter(ds.dataset['track_idx']==0)['pos'][0]

55

In [103]:
model = Track2Vec(
    num_tracks=ds.n_tracks,
    embedding_dim=EMBEDDING_DIM,
    context_size=CONTEXT_SIZE
)

In [112]:
train(model, ds, epochs=10)

Training CBOW model
Epoch 1


100%|██████████| 2086/2086 [04:19<00:00,  8.05it/s]


Epoch 1: Loss: 22037.149313926697
Epoch 2


100%|██████████| 2086/2086 [04:19<00:00,  8.04it/s]


Epoch 2: Loss: 20631.313497543335
Epoch 3


100%|██████████| 2086/2086 [04:12<00:00,  8.25it/s]


Epoch 3: Loss: 17898.65234708786
Epoch 4


100%|██████████| 2086/2086 [04:14<00:00,  8.19it/s]


Epoch 4: Loss: 13039.324611663818
Epoch 5


100%|██████████| 2086/2086 [04:12<00:00,  8.26it/s]


Epoch 5: Loss: 8195.068344116211
Epoch 6


100%|██████████| 2086/2086 [04:19<00:00,  8.03it/s]


Epoch 6: Loss: 4539.532862782478
Epoch 7


100%|██████████| 2086/2086 [04:12<00:00,  8.25it/s]


Epoch 7: Loss: 2219.0709552019835
Epoch 8


100%|██████████| 2086/2086 [04:10<00:00,  8.32it/s]


Epoch 8: Loss: 1214.2023377857986
Epoch 9


100%|██████████| 2086/2086 [04:12<00:00,  8.28it/s]


Epoch 9: Loss: 858.3454581960905
Epoch 10


100%|██████████| 2086/2086 [04:11<00:00,  8.30it/s]

Epoch 10: Loss: 707.6106652672161



