In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

In [3]:
from collections import Counter
import itertools
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from category_encoders import OrdinalEncoder

from src.utils.logger import logger
from src.utils.io_utils import load_model, save_model
from src.ml.data_loader_with_meta import Sequences, SequencesDataset
from torch.utils.data import DataLoader

In [4]:
PARAMS = {'dim': 128,
          'window': 5,
          'min_count': 1,
          'negative_samp': 5,
          'epochs': 10,
          'seed': 42}

In [5]:
dataset = 'electronics'

### Try sequence class

In [29]:
sequences = Sequences('../data/{}_sequences_samp.npy'.format(dataset), 
                      '../data/{}_edges_val_samp.csv'.format(dataset),
                      '../data/{}_meta.csv'.format(dataset))

2019-12-10 10:01:04,367 - Sequences loaded (length = 5,000)
2019-12-10 10:01:04,443 - Validation set loaded: (100000, 3)
2019-12-10 10:01:04,453 - Word frequency calculated
2019-12-10 10:01:04,488 - Adding val products to word2id, original size: 28695
2019-12-10 10:01:04,552 - Added val products to word2id, updated size: 133050
2019-12-10 10:01:04,557 - No. of unique tokens: 133050
2019-12-10 10:01:05,709 - Model saved to model/word2id
2019-12-10 10:01:06,966 - Model saved to model/id2word
2019-12-10 10:01:06,967 - Word2Id and Id2Word created and saved
2019-12-10 10:01:10,836 - No. of rows in meta before filter by word2id: 498196
2019-12-10 10:01:11,174 - No. of rows in meta after filter by word2id: 79566
2019-12-10 10:01:13,267 - Model saved to model/encoder
2019-12-10 10:01:15,703 - Embedding dimensions: OrderedDict([('product', 133050), ('asin', 79567), ('category_lvl_3', 55), ('brand', 74)])
2019-12-10 10:01:16,435 - Model saved to model/meta_dict
2019-12-10 10:01:16,462 - Convert 

### Verify meta working correctly

In [30]:
sequences.meta_dict

{39732: [1, 1, 1],
 132825: [2, 2, 1],
 911: [3, 3, 1],
 3923: [4, 1, 1],
 852: [5, 1, 1],
 841: [6, 1, 1],
 11901: [7, 2, 1],
 81558: [8, 4, 1],
 39378: [9, 5, 1],
 60973: [10, 6, 1],
 11376: [11, 1, 1],
 11318: [12, 1, 1],
 101754: [13, 1, 1],
 3935: [14, 7, 1],
 11883: [15, 7, 1],
 871: [16, 7, 1],
 3987: [17, 7, 1],
 11900: [18, 1, 1],
 3927: [19, 7, 1],
 1534: [20, 7, 1],
 11331: [21, 1, 1],
 42488: [22, 8, 1],
 41277: [23, 9, 1],
 46224: [24, 2, 1],
 125831: [25, 2, 1],
 3929: [26, 1, 1],
 51694: [27, 2, 1],
 29523: [28, 10, 2],
 97530: [29, 10, 1],
 117494: [30, 11, 1],
 11569: [31, 12, 1],
 125001: [32, 10, 1],
 22930: [33, 3, 1],
 48434: [34, 10, 1],
 6470: [35, 13, 1],
 31138: [36, 14, 1],
 128798: [37, 15, 1],
 97917: [38, 13, 1],
 90378: [39, 3, 1],
 102077: [40, 16, 1],
 51868: [41, 12, 1],
 126310: [42, 6, 1],
 24864: [43, 17, 1],
 46876: [44, 18, 1],
 121350: [45, 1, 1],
 90968: [46, 10, 2],
 101419: [47, 15, 1],
 32467: [48, 13, 3],
 108558: [49, 1, 1],
 78556: [50, 13,

In [90]:
asin = sequences.id2word[97530]
print(asin)

5342765439


In [91]:
sequences.meta[sequences.meta['asin'] == asin]

Unnamed: 0,asin,description,categories,title,price,related,brand,category_lvl_1,category_lvl_2,category_lvl_3,category_lvl_4,productid
144,5342765439,"hand-crafted crystal flash drive(8g),works ver...","['electronics', 'computers & accessories', 'da...",mickey logo crystal usb flash drive with neckl...,18,1,MISC,electronics,computers & accessories,data storage,usb flash drives,5342765439


In [92]:
sequences.meta[sequences.meta['asin'] == asin].asin

144    5342765439
Name: asin, dtype: object

In [93]:
sequences.encoder.category_mapping[0]['mapping'][sequences.meta[sequences.meta['asin'] == asin].asin]

asin
5342765439    29
dtype: int64

In [94]:
# 1 = Misc, 2 = covers, 7 = tablets, 10 = data storage
sequences.encoder.category_mapping[1]['mapping'][sequences.meta[sequences.meta['asin'] == asin].category_lvl_3]

category_lvl_3
data storage    10
dtype: int64

In [95]:
# 1 = Misc
sequences.encoder.category_mapping[1]['mapping'][sequences.meta[sequences.meta['asin'] == asin].brand]

brand
MISC    1
dtype: int64

In [96]:
# 1 = Misc
sequences.encoder.category_mapping[1]['mapping']

MISC                                      1
covers                                    2
touch screen tablet accessories           3
car electronics                           4
audio & video accessories                 5
computer components                       6
tablets                                   7
NA_VALUE                                  8
blu-ray players & recorders               9
data storage                             10
lighting & studio                        11
video surveillance                       12
cables & accessories                     13
gps system accessories                   14
mp3 players & accessories                15
pdas, handhelds & accessories            16
external components                      17
accessories                              18
laptop & netbook computer accessories    19
vehicle electronics accessories          20
blank media                              21
telephone accessories                    22
cord management                 

In [None]:
pairs = sequences.get_pairs(0)
pairs

In [None]:
neg_samples = []
for center, context in pairs:
    neg_samples.append(sequences.get_negative_samples(context))
neg_samples[:5]

In [None]:
seq_dset = SequencesDataset(sequences)

In [None]:
for i, batch in enumerate(seq_dset):
    logger.info(batch)
    if i > 3:
        break

In [None]:
center = [pair[0] for pair in batch[0]]
context = [pair[1] for pair in batch[0]]
neg_context = batch[1]

In [None]:
seq_dloader = DataLoader(dataset=seq_dset, batch_size=2, shuffle=False, collate_fn=seq_dset.collate)

In [None]:
for i, batches in enumerate(seq_dloader):
    centers, contexts, neg_contexts = batches
    if i == 0:
        break

In [None]:
batches[0]

In [None]:
batches[1]

In [None]:
batches[2]

### Init skipgram

In [None]:
sequences.emb_sizes

In [None]:
emb_dim = 128

In [None]:
center_embeddings = nn.ModuleList()
for k, v in sequences.emb_sizes.items():
    center_embeddings.append(nn.Embedding(v, emb_dim, sparse=True))
    
context_embeddings = nn.ModuleList()
for k, v in sequences.emb_sizes.items():
    context_embeddings.append(nn.Embedding(v, emb_dim, sparse=True))

In [None]:
emb_range = 0.5 / emb_dim

In [None]:
for emb in center_embeddings:
    emb.weight.data.uniform_(-emb_range, emb_range)
    
for emb in context_embeddings:
    emb.weight.data.uniform_(0, 0)

In [None]:
x.num_embeddings

In [None]:
context_embeddings[0]

In [None]:
center_embeddings

In [None]:
centers = torch.LongTensor(batch[0][5][0])
centers

In [None]:
centers

In [None]:
emb_center = []

In [None]:
for i, center in enumerate(centers):
    logger.info('i: {}, center: {}'.format(i, center))
    emb_center.append(center_embeddings[i](center))

In [None]:
torch.mean(torch.stack(emb_center), axis=0)

In [None]:
torch.stack(emb_center)