In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

In [3]:
import skipgrammar.datasets as dset
from skipgrammar.models import NegativeSamplingLoss, SGNS

### Data

In [4]:
variant = 'lastfm-50'

In [5]:
%%time

lastfm_dataset = dset.LastFMUserItemDataset(variant)
print(f'Number of Total Listens: {len(lastfm_dataset.df):,}')
print(f'Number of Unique Artists: {lastfm_dataset.df.artist_name.nunique():,}')
lastfm_dataset.df.head(5)

Number of Total Listens: 169,555
Number of Unique Artists: 8,275
CPU times: user 1.03 s, sys: 246 ms, total: 1.27 s
Wall time: 947 ms


Unnamed: 0,user_id,timestamp,artist_id,artist_name,track_id,track_name,artist_cd,session_end,session_nbr,session_id
4,user_000001,2006-08-13 14:19:06+00:00,1cfbc7d1-299c-46e6-ba4c-1facb84ba435,Artful Dodger,120bb01c-03e4-465f-94a0-dce5e9fac711,What You Gonna Do?,478,0,1,user_000001-1
16,user_000001,2006-08-13 15:40:13+00:00,8522b9b6-b295-48d7-9a10-8618fb80beb8,Battles,523eaf59-8298-4b1c-9950-5864c5f4c1ff,Tras,664,1,2,user_000001-2
18,user_000001,2006-08-13 15:49:22+00:00,f9114439-1662-4415-b761-05a4170c9579,Boom Boom Satellites,099eaa23-3846-4670-a4d2-ca909b7b1f15,Moment I Count,933,0,2,user_000001-2
21,user_000001,2006-08-13 16:00:07+00:00,3a238c56-3790-4a6a-89af-4aa0c71fa732,José Padilla,1c061863-1d3e-4066-aa93-5c9ce0bf72f2,Solo,3551,0,2,user_000001-2
27,user_000001,2006-08-13 16:36:15+00:00,87225a21-c925-41cd-852f-be4b052d0824,Afx,c52348d2-dfc3-4754-9d02-f8b44cc5e9ec,Pwsteal.Ldpinch.D,174,1,3,user_000001-3


In [6]:
lastfm_dataset.df.groupby('user_id').size().describe()

count       50.000000
mean      3391.100000
std       3422.438588
min          6.000000
25%       1044.250000
50%       1942.500000
75%       4954.000000
max      14070.000000
dtype: float64

In [7]:
lastfm_dataloader = DataLoader(lastfm_dataset, batch_size=2)

In [8]:
%%time

for batch_num, (anchors, targets) in enumerate(lastfm_dataloader, start=0):
    print('batch', batch_num + 1, '| anchors:', len(anchors), anchors, ' | targets:', len(targets), targets)
    if batch_num == 4:
        break

batch 1 | anchors: 2 tensor([5800, 4881])  | targets: 2 tensor([5800, 4881])
batch 2 | anchors: 2 tensor([3063, 4389])  | targets: 2 tensor([3063, 4389])
batch 3 | anchors: 2 tensor([5416, 7436])  | targets: 2 tensor([7839, 7436])
batch 4 | anchors: 2 tensor([2870, 8089])  | targets: 2 tensor([5416, 8089])
batch 5 | anchors: 2 tensor([2155,  402])  | targets: 2 tensor([ 455, 1008])
CPU times: user 131 ms, sys: 11.3 ms, total: 143 ms
Wall time: 166 ms


In [9]:
model = SGNS(num_embeddings=lastfm_dataset.num_items, embedding_dim=10, nn_embedding_kwargs={'sparse': True})

In [10]:
anchor_embeddings, target_embeddings, negative_embeddings = model.forward(anchors, targets)

In [11]:
anchor_embeddings.size()

torch.Size([2, 10])

In [12]:
model.as_embedding(anchors[0].item()) == anchor_embeddings[0, :]

tensor([True, True, True, True, True, True, True, True, True, True])