In [1]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

In [5]:
from collections import Counter
import itertools

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.utils.logger import logger

In [3]:
PARAMS = {'dim': 128,
          'window': 5,
          'min_count': 1,
          'negative_samp': 5,
          'epochs': 10,
          'seed': 42}

### Load data

In [4]:
sequences = np.load('../data/books_sequences_sample.npy')
sequences = sequences.tolist()

In [11]:
val = pd.read_csv('../data/books_edges_val_samp.csv')

In [None]:
# seq_samp = sequences[:1000]
# seq_samp = np.array(seq_samp)
# np.save('../data/books_sequences_sample.npy', seq_samp)

### Negative sampling

In [13]:
def get_word_freq(sequences):
    # Flatten list
    seq_flat = list(itertools.chain.from_iterable(sequences))
    
    # Get word frequency
    word_freq = Counter(seq_flat)
    
    return word_freq

In [14]:
word_freq = get_word_freq(sequences)

In [15]:
def get_mapping_dicts(word_freq):
    word2id = dict()
    id2word = dict()

    wid = 0
    for w, c in word_freq.items():
        word2id[w] = wid
        id2word[wid] = w
        wid += 1
        
    return word2id, id2word

In [16]:
word2id, id2word = get_mapping_dicts(word_freq)

In [17]:
len(word2id)

7757

### Add validation data into word2id

In [21]:
val_product_set = set(val['product1'].values).union(set(val['product2'].values))

In [22]:
len(val_product_set)

179397

In [29]:
wid = max(word2id.values()) + 1

In [30]:
for w in val_product_set:
    if w not in word2id:
        word2id[w] = wid
        id2word[wid] = w
        wid += 1

In [31]:
len(word2id)

185207

In [None]:
def convert_sequence_to_id(sequences, word2id):
    return np.vectorize(word2id.get)(sequences)

In [None]:
sequences = convert_sequence_to_id(sequences, word2id)

In [None]:
sequences

In [None]:
def convert_word_freq_to_id(word_freq, word2id):
    return {word2id[k]: v for k, v in word_freq.items()}

In [None]:
word_freq = convert_word_freq_to_id(word_freq, word2id)

In [None]:
def get_discard_probs(sequences, word_freq, sample=0.001):
    
    # Convert to array
    word_freq = np.array(list(word_freq.items()), dtype=np.float64)
    
    # Convert to probabilities
    word_freq[:, 1] = word_freq[:, 1] / word_freq[:, 1].sum()
    
    # Perform subsampling
    # http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/
    word_freq[:, 1] = (np.sqrt(word_freq[:, 1]/sample) + 1)*(sample/word_freq[:, 1])
    
    # Get dict
    discard_probs = {int(k):v for k, v in word_freq.tolist()}
    
    return discard_probs

In [None]:
discard_probs = get_discard_probs(sequences, word_freq)

In [None]:
discard_probs[9]

In [None]:
def get_negative_sample_table(sequences, word_freq, power=0.75):
    sample_table_size = 1e6
    
    # Convert to array
    word_freq = np.array(list(word_freq.items()), dtype=np.float64)
    
    # Adjust by power 
    word_freq[:, 1] = word_freq[:, 1] ** power
    
    # Get probabilities
    word_freq_sum = word_freq[:, 1].sum()
    word_freq[:, 1] = word_freq[:, 1] / word_freq_sum
    
    # Multiply probabilities by sample table size
    word_freq[:, 1] = np.round(word_freq[:, 1] * sample_table_size)
    
    # Convert to int
    word_freq = word_freq.astype(int).tolist()
    
    # Create sample table
    sample_table = [[tup[0]] * tup[1] for tup in word_freq]
    sample_table = np.array(list(itertools.chain.from_iterable(sample_table)))
    np.random.shuffle(sample_table)
    
    return sample_table

In [None]:
neg_table = get_negative_sample_table(sequences, word_freq)

In [None]:
def get_negative_samples(sample_size=5):
    negative_idx = 0
    # Get a batch from the shuffled table
    neg_sample = neg_table[negative_idx:negative_idx + sample_size]
    
    # Update negative index
    negative_idx = (negative_idx + sample_size) % len(neg_table)
    
    # Check if batch insufficient
    if len(neg_sample) != sample_size:
        neg_sample = np.concatenate((neg_sample, neg_samples[:negative_idx]))
        
    return neg_sample

In [None]:
neg_samples = get_negative_samples()

### Get pairs (with subsampling)

In [None]:
sequences

In [None]:
sequence = sequences[1]
sequence

In [None]:
# Works on per sequence
def get_pairs(idx, window=5):
    pairs = []
    sequence = sequences[idx]
    
    for center_idx, node in enumerate(sequence):
        for i in range(-window, window+1):
            context_idx = center_idx + i
            if context_idx > 0 and context_idx < len(sequence) and node != sequence[context_idx] and np.random.rand() < discard_probs[sequence[context_idx]]:
                pairs.append((node, sequence[context_idx]))
            
    return pairs

In [None]:
def get_negative_samples(context, sample_size=5):
    negative_idx = 0
    while True:
        # Get a batch from the shuffled table
        neg_sample = neg_table[negative_idx:negative_idx + sample_size]

        # Update negative index
        negative_idx = (negative_idx + sample_size) % len(neg_table)

        # Check if batch insufficient
        if len(neg_sample) != sample_size:
            neg_sample = np.concatenate((neg_sample, neg_samples[:negative_idx]))
        
        if not context in neg_sample:
            return neg_sample

In [None]:
get_negative_samples(7726)

In [None]:
neg_samples = []
for center, context in pairs:
    neg_samples.append(get_negative_samples(context))

In [None]:
# # Works on batch
# def get_pairs(sequences, window=5):
#     pairs = []
#     window = PARAMS['window']

#     for sequence in sequences:
#         for center_idx, node in enumerate(sequence):
#             for i in range(-window, window+1):
#                 context_idx = center_idx + i
#                 if context_idx > 0 and context_idx < len(sequence) and node != sequence[context_idx] and np.random.rand() < discard_probs[sequence[context_idx]]:
#                     pairs.append((node, sequence[context_idx]))
                    
#     return pairs

In [None]:
# pairs = get_pairs(sequences, PARAMS['window'])
# logger.info('Len of pairs: {:,}'.format(len(pairs)))

### Try sequence class

In [32]:
from src.ml.data_loader import Sequences, SequencesDataset
from torch.utils.data import DataLoader

In [33]:
sequences = Sequences('../data/books_sequences_sample.npy', '../data/books_edges_val_samp.csv')

2019-12-05 09:52:14,426 - Sequences loaded (length = 1,000)
2019-12-05 09:52:14,513 - Validation set loaded: (100000, 3)
2019-12-05 09:52:14,516 - Word frequency calculated
2019-12-05 09:52:14,566 - Adding val products to mapping dict, original dict size: 7757
2019-12-05 09:52:14,680 - Added val products to mapping dict, updated dict size: 185207
2019-12-05 09:52:16,946 - Model saved to model/word2id
2019-12-05 09:52:19,193 - Model saved to model/id2word
2019-12-05 09:52:19,194 - Word2Id and Id2Word created and saved
2019-12-05 09:52:19,202 - Convert sequence and wordfreq to ID
2019-12-05 09:52:19,212 - Discard probability calculated
2019-12-05 09:52:20,673 - Negative sample table created


In [None]:
pairs = sequences.get_pairs(2)

In [None]:
pairs

In [None]:
neg_samples = []
for center, context in pairs:
    neg_samples.append(sequences.get_negative_samples(context))
neg_samples[:5]

In [5]:
seq_dset = SequencesDataset(sequences)

In [None]:
for i, batch in enumerate(seq_dset):
    logger.info(batch)
    if i > 3:
        break

In [None]:
center = [pair[0] for pair in batch[0]]
context = [pair[1] for pair in batch[0]]
neg_context = batch[1]

In [6]:
seq_dloader = DataLoader(dataset=seq_dset, batch_size=2, shuffle=False, collate_fn=seq_dset.collate)

In [13]:
for i, batches in enumerate(seq_dloader):
    centers, contexts, neg_contexts = batches
    if i == 0:
        break

In [19]:
batches[0].to('cpu')

tensor([ 0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  3,  3,
         3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,
         5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,
         8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  0,  0,  0,  0,  0, 10, 10,
        10, 10, 10, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 13, 13,
        13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15,
        15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 18,
        18, 18, 18, 18])

In [None]:
batches = [([(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (2, 1), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (3, 1), (3, 2), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (4, 1), (4, 2), (4, 3), (4, 5), (4, 6), (4, 7), (4, 8), (4, 9), (5, 1), (5, 2), (5, 3), (5, 4), (5, 6), (5, 7), (5, 8), (5, 9), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (6, 7), (6, 8), (6, 9), (7, 2), (7, 3), (7, 4), (7, 5), (7, 6), (7, 8), (7, 9), (8, 3), (8, 4), (8, 5), (8, 6), (8, 7), (8, 9), (9, 4), (9, 5), (9, 6), (9, 7), (9, 8)], [np.array([6911, 7062, 1107, 1246,  677]), np.array([ 697, 2655, 4380, 3183, 6465]), np.array([3425, 7452, 2766, 5655, 3064]), np.array([2274, 1321,  337, 7428, 5732]), np.array([2333, 1719, 5978, 6495, 7674]), np.array([1285, 1558, 2029, 2436, 1561]), np.array([5698, 6331, 2001,  761, 2794]), np.array([6849, 1593, 5863, 3042, 6586]), np.array([4262, 5396, 5854,  963, 6818]), np.array([3197, 1698, 3044, 7672, 2789]), np.array([3875, 2656, 2602,  499, 5877]), np.array([ 479, 1770, 7674, 2934, 1814]), np.array([4175,  750, 2026, 3953, 5545]), np.array([4971, 5452, 4985, 6671, 2247]), np.array([6880, 1743, 6300, 7239, 3404]), np.array([3321, 3110, 6075,  866, 6466]), np.array([4208, 6886, 6723, 2720, 4556]), np.array([6243, 2416,  389, 7655,  109]), np.array([2788, 6081, 3591, 5346, 7028]), np.array([7160, 7055, 5543,   64, 2285]), np.array([7196, 5277, 3923, 4398, 2139]), np.array([5860,  969, 7271, 5856, 1801]), np.array([3752, 2161, 4823, 1817, 7126]), np.array([3388, 6590, 5993,  817, 6017]), np.array([6088, 6972, 5963, 5691, 4207]), np.array([6633, 5380, 5708, 4590, 3690]), np.array([1214, 1494, 3179, 4345, 3285]), np.array([ 460, 5675, 4791, 6728, 3241]), np.array([6035, 6050, 1021, 1581, 7406]), np.array([6684, 3291,  479, 2506, 2648]), np.array([7062, 2731, 7442, 6461, 3165]), np.array([6506, 2869, 4135,  270, 1400]), np.array([6433, 2757, 3367, 4317, 5401]), np.array([2161, 5163, 4694, 4700, 1168]), np.array([4382,  232, 5843, 2715, 6432]), np.array([ 246, 5236, 1957, 2142, 1255]), np.array([ 410, 6860, 3555, 6867, 7416]), np.array([4552,  810, 3043, 3214,  249]), np.array([7294, 3650, 3357, 7208, 5200]), np.array([3559, 5036, 1550, 1672, 5395]), np.array([5339, 6610, 5293, 2438, 2199]), np.array([6107, 3755, 2167,  698, 2596]), np.array([6934, 2932, 1570, 3895, 3359]), np.array([2044, 6838, 1183, 4118, 7576]), np.array([6532, 2135, 1801, 7232, 7252]), np.array([3964, 5793, 5463, 4373, 6860]), np.array([4733, 7739, 5158,  326, 1178]), np.array([6363,  484, 7726, 5215,  998]), np.array([1514, 6849, 4165, 1207, 2644]), np.array([3551, 3886, 3811,  389, 2897]), np.array([1197, 4981, 7400,  693, 5228]), np.array([4644, 5649, 7351, 1144, 1885]), np.array([7295, 2781, 7547, 3249, 4184]), np.array([4703, 1633, 5622, 1158, 1053]), np.array([5140, 6627, 5539, 2980, 3438]), np.array([7103, 3022, 7239, 1877, 4911]), np.array([5372, 2445, 2805, 6498, 5147]), np.array([3033, 7609, 1753, 1338, 4162]), np.array([1841, 1940, 5864, 6642, 1151]), np.array([3953,  439, 2519, 2288, 5552]), np.array([2959, 7729, 3274, 4970, 5465]), np.array([4289, 7229, 4241, 6115, 6679]), np.array([2953, 2573, 6647,  247, 1080]), np.array([5861, 7712,  844, 7455,  891]), np.array([  11, 5957, 2488, 4967, 7544])]), ([(0, 10), (0, 11), (0, 12), (0, 13), (0, 14), (10, 11), (10, 12), (10, 13), (10, 14), (10, 15), (11, 10), (11, 12), (11, 13), (11, 14), (11, 15), (11, 16), (12, 10), (12, 11), (12, 13), (12, 14), (12, 15), (12, 16), (12, 17), (13, 10), (13, 11), (13, 12), (13, 14), (13, 15), (13, 16), (13, 17), (13, 18), (14, 10), (14, 11), (14, 12), (14, 13), (14, 15), (14, 16), (14, 17), (14, 18), (15, 10), (15, 11), (15, 12), (15, 13), (15, 14), (15, 16), (15, 17), (15, 18), (16, 11), (16, 12), (16, 13), (16, 14), (16, 15), (16, 17), (16, 18), (17, 12), (17, 13), (17, 14), (17, 15), (17, 16), (17, 18), (18, 13), (18, 14), (18, 15), (18, 16), (18, 17)], [np.array([5931, 6670, 1539, 6831,  229]), np.array([3534,  335, 2857, 3945, 1002]), np.array([3745, 7254, 3647, 7435, 1460]), np.array([7551, 3893, 3030, 7056, 4245]), np.array([2238, 1801, 1849, 4584, 5860]), np.array([5364, 6752,  136, 1999, 2197]), np.array([4928, 6525, 6690, 7324, 6559]), np.array([ 767, 1452, 3152, 6733, 5854]), np.array([4527,  443, 2060, 6631, 5419]), np.array([6324, 5067, 7077, 3987, 2648]), np.array([2765,  803, 4050, 7003, 7702]), np.array([6372, 5535, 6030, 4797, 3045]), np.array([6293, 3248, 3969, 6346, 4515]), np.array([1610, 2189, 2687, 7000, 6292]), np.array([4584, 6160, 5294, 7285, 4823]), np.array([3523, 3370, 1040, 3238, 4736]), np.array([6732, 1642, 6894, 1556, 3084]), np.array([1493, 1147,   36, 3993, 1290]), np.array([5404, 3555, 4335, 5448,  993]), np.array([ 413, 1589, 3720, 4410, 4651]), np.array([2392, 5986,   30, 5717, 1325]), np.array([ 417, 5694, 7550, 1830, 2186]), np.array([4957, 5000, 1134, 3309, 7673]), np.array([5162, 3574, 3039, 4348, 1725]), np.array([2778, 4587, 1237, 5632, 1791]), np.array([1911,  262, 5795,  976, 3314]), np.array([2529, 1150, 5177, 2350, 2817]), np.array([   2,   19, 2200, 4321,  623]), np.array([ 840, 3369, 4327, 3600, 3226]), np.array([5020, 4182, 2828, 6970, 3940]), np.array([4463, 6296, 3938, 3258, 7131]), np.array([1628, 6441, 3357, 3881, 2199]), np.array([3675, 1636, 4676, 6875, 6635]), np.array([3187, 6134, 5877, 2543, 6084]), np.array([ 168,  998, 1801, 5181, 6515]), np.array([7674, 5657, 5426, 4797, 1791]), np.array([3236, 6282, 2594, 6574, 4973]), np.array([5953, 6162,   43, 3405, 2965]), np.array([2561, 6764, 6886, 2977, 3234]), np.array([3458, 3621, 7560, 5445, 6401]), np.array([3629, 2915, 3767, 1209, 2609]), np.array([2246, 5444,  418, 1005, 1104]), np.array([2389, 4028, 1055, 1917, 1974]), np.array([ 783, 3555, 2129, 3516,  338]), np.array([2128, 7568, 5877, 5035, 5485]), np.array([1087, 6531,  534, 5801, 4291]), np.array([4858, 5600, 3209, 3577, 2157]), np.array([2923,  419,  974, 1755, 3089]), np.array([6218,  971, 5900, 1899, 4525]), np.array([5791, 4764, 5024, 1967, 5440]), np.array([6739, 6139, 4461, 3399,  236]), np.array([ 538, 6594, 4339, 6017, 5228]), np.array([4373, 3011, 6073, 3369, 3602]), np.array([6711, 1964, 3377, 5045,   79]), np.array([ 840, 2335, 3257, 1618, 5952]), np.array([7183, 7152, 5047, 7391, 5233]), np.array([4816, 5445, 4579, 1345,  307]), np.array([3333,   21,  417, 7208,  833]), np.array([1883,  141, 1779, 3602,  501]), np.array([4128, 4665, 3510, 5313, 5394]), np.array([7725, 2627, 7262, 2070, 1826]), np.array([5696, 4085, 1396, 2880, 6592]), np.array([2041, 5380, 6268, 7261, 4661]), np.array([6412, 5223, 6400, 3328, 4911]), np.array([7543, 2502, 2974, 6212, 7199])])]

In [None]:
pairs_batch = [batch[0] for batch in batches]
neg_contexts_batch = [batch[1] for batch in batches]

In [None]:
pairs_batch = list(itertools.chain.from_iterable(pairs_batch))
neg_contexts = list(itertools.chain.from_iterable(neg_contexts_batch))

In [None]:
centers = [center for center, _ in pairs_batch]
contexts = [context for _, context in pairs_batch]

In [None]:
len(centers)

In [None]:
len(contexts)

In [None]:
len(neg_contexts)

In [None]:
[neg_context for neg_context in neg_contexts_batch for neg in neg_contexts]

In [None]:
pairs = batches[0][0]
neg_contexts = batches[0][1]

In [None]:
pairs

In [None]:
batches[0]

In [None]:
pairs = batches[0]
neg_contexts = batches[1]

In [None]:
pairs

In [None]:
neg_contexts

In [None]:
centers = [center for pair in pairs for center in pair[0]]
contexts = [context for pair in pairs for context in pair[1]]
negs = [negs for neg_context in neg_contexts for negs in neg_context]

In [None]:
len(centers)

In [None]:
len(contexts)

In [None]:
len(negs)

In [None]:
centers

In [None]:
contexts

In [None]:
batch[1]

In [None]:
sequences.pairs[:10]

In [None]:
centers

In [None]:
contexts

In [None]:
neg_contexts

In [None]:
pairs, neg_samples = batch

In [None]:
pairs[0]

In [None]:
batch[0][1]

In [None]:
batch[1]

### Build Dataset

In [None]:
emb = nn.Embedding(2000000, 4, sparse=True)

In [None]:
def get_len():
    return sequences.n_pairs

In [None]:
get_len()

In [None]:
idx = 0
batch_size = 5
neg_sample_size = 3

In [None]:
pairs = sequences.pairs[idx:idx+batch_size]
logger.info('Sequence batch ({}): {}'.format(idx, pairs))
idx += window_size

In [None]:
sequences.pairs[1]

In [None]:
neg_samples = sequences.get_negative_samples(neg_sample_size)
neg_samples

In [None]:
batch = (pairs, neg_samples)

In [None]:
pairs, neg_contexts = batch

In [None]:
centers = [center for center, _ in pairs]
centers

In [None]:
contexts = [context for _, context in pairs]
contexts

In [None]:
neg_contexts

In [None]:
center = [center for center, context in pairs]
center = torch.LongTensor(center)

In [None]:
neg_context = torch.LongTensor(neg_samples)

In [None]:
emb_center = emb(center)

In [None]:
emb_neg = emb(neg_context)

In [None]:
emb_center

In [None]:
emb_neg

In [None]:
neg_score = torch.bmm(emb_neg, emb_center.unsqueeze(2)).squeeze()
neg_score

In [None]:
neg_score = torch.clamp(neg_score, max=10, min=-10)
neg_score

In [None]:
neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)
neg_score