In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

In [3]:
import skipgrammar.datasets as dset
from skipgrammar.models import NegativeSamplingLoss, SGNS

### Data

In [4]:
variant = 'lastfm-50'

In [11]:
%%time

lastfm_dataset = dset.LastFMUserItemDataset(variant)
print(f'Number of Total Listens: {len(lastfm_dataset.df):,}')
print(f'Number of Unique Artists: {lastfm_dataset.df.artist_name.nunique():,}')
lastfm_dataset.df.head(5)

Number of Total Listens: 765,399
Number of Unique Artists: 11,507
CPU times: user 1.19 s, sys: 131 ms, total: 1.32 s
Wall time: 1.32 s


Unnamed: 0,user_id,timestamp,artist_id,artist_name,track_id,track_name,artist_cd
0,user_000001,2006-08-13 13:59:20+00:00,09a114d9-7723-4e14-b524-379697f6d2b5,Plaid & Bob Jaroc,c4633ab1-e715-477f-8685-afa5f2058e42,The Launching Of Big Face,7375
1,user_000001,2006-08-13 14:03:29+00:00,09a114d9-7723-4e14-b524-379697f6d2b5,Plaid & Bob Jaroc,bc2765af-208c-44c5-b3b0-cf597a646660,Zn Zero,7375
2,user_000001,2006-08-13 14:10:43+00:00,09a114d9-7723-4e14-b524-379697f6d2b5,Plaid & Bob Jaroc,aa9c5a80-5cbe-42aa-a966-eb3cfa37d832,The Return Of Super Barrio - End Credits,7375
4,user_000001,2006-08-13 14:19:06+00:00,1cfbc7d1-299c-46e6-ba4c-1facb84ba435,Artful Dodger,120bb01c-03e4-465f-94a0-dce5e9fac711,What You Gonna Do?,683
5,user_000001,2006-08-13 14:23:03+00:00,6b77d8ef-c405-4846-9d5f-2b93e6533101,Rei Harakami,777ac51f-8ffc-4c44-92b6-a2c75cbc6915,Joy,7728


In [12]:
lastfm_dataset.df.groupby('user_id').size().describe()

count       50.000000
mean     15307.980000
std      17398.331646
min          9.000000
25%       3060.500000
50%      10451.500000
75%      20102.000000
max      74141.000000
dtype: float64

In [13]:
lastfm_dataloader = DataLoader(lastfm_dataset, batch_size=2)

In [14]:
%%time

for batch_num, (anchors, targets) in enumerate(lastfm_dataloader, start=0):
    print('batch', batch_num + 1, '| anchors:', len(anchors), anchors, ' | targets:', len(targets), targets)
    if batch_num == 4:
        break

batch 1 | anchors: 2 tensor([ 6766, 10022])  | targets: 2 tensor([9020, 7227])
batch 2 | anchors: 2 tensor([ 5575, 10022])  | targets: 2 tensor([ 490, 7227])
batch 3 | anchors: 2 tensor([3408, 3070])  | targets: 2 tensor([7594, 2789])
batch 4 | anchors: 2 tensor([1432, 4881])  | targets: 2 tensor([1432, 4881])
batch 5 | anchors: 2 tensor([2997,    9])  | targets: 2 tensor([2997, 7323])
CPU times: user 427 ms, sys: 69.6 ms, total: 496 ms
Wall time: 496 ms


In [15]:
model = SGNS(num_embeddings=lastfm_dataset.num_items, embedding_dim=10, nn_embedding_kwargs={'sparse': True})

In [16]:
anchor_embeddings, target_embeddings, negative_embeddings = model.forward(anchors, targets)

In [17]:
anchor_embeddings.size()

torch.Size([2, 10])

In [18]:
model.as_embedding(anchors[0].item()) == anchor_embeddings[0, :]

tensor([True, True, True, True, True, True, True, True, True, True])