In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append(os.path.dirname(os.getcwd()))

In [3]:
import itertools
from collections import Counter

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

from src.utils.logger import logger

In [4]:
dataset = 'electronics'

### Import edges

In [5]:
df = pd.read_csv('../data/{}_edges_train_samp.csv'.format(dataset))
val = pd.read_csv('../data/{}_edges_val_samp.csv'.format(dataset))

In [6]:
product_set = set(df['product1'].tolist() + df['product2'].tolist() + val['product1'].tolist() + val['product2'].tolist())

In [7]:
def get_mapping_dicts(product_set):
    word2id = dict()
    id2word = dict()

    wid = 0
    for w in product_set:
        word2id[w] = wid
        id2word[wid] = w
        wid += 1

    return word2id, id2word

In [8]:
word2id, id2word = get_mapping_dicts(product_set)
logger.info('Len of word2id: {:,}'.format(len(word2id)))

2019-12-13 10:07:55,570 - Len of word2id: 128,802


In [9]:
def get_product_id(x):
    return word2id.get(x, -1)

In [10]:
word2id_func = np.vectorize(get_product_id)

In [11]:
df['product1_id'] = word2id_func(df['product1'])
df['product2_id'] = word2id_func(df['product2'])

In [14]:
pair_df = df[['product1_id', 'product2_id', 'weight']].copy().values

In [15]:
pair_df

array([[7.85400e+03, 4.11320e+04, 1.00000e+00],
       [1.57460e+04, 5.25720e+04, 1.00000e+00],
       [1.21606e+05, 6.80420e+04, 2.00000e+00],
       ...,
       [8.40210e+04, 1.14213e+05, 1.00000e+00],
       [1.95740e+04, 1.07128e+05, 1.00000e+00],
       [1.27100e+04, 1.16752e+05, 5.00000e-01]])

### Get negative samples

In [None]:
def get_word_freq(pair_df):
    product_counts = list(itertools.chain.from_iterable(pair_df))
    word_freq = Counter(product_counts)
    return word_freq

In [None]:
word_freq = get_word_freq(pair_df[:, :2].astype(int))

In [None]:
NEGATIVE_SAMPLE_TABLE_SIZE = 100000

In [None]:
def get_negative_sample_table(word_freq, power=0.75) -> np.array:
    """
    Returns a table (size = NEGATIVE_SAMPLE_TABLE_SIZE) of negative samples which can be selected via indexing.

    Args:
        power:

    Returns:

    """
    # Convert to array
    word_freq = np.array(list(word_freq.items()), dtype=np.float64)

    # Adjust by power
    word_freq[:, 1] = word_freq[:, 1] ** power

    # Get probabilities
    word_freq_sum = word_freq[:, 1].sum()
    word_freq[:, 1] = word_freq[:, 1] / word_freq_sum

    # Multiply probabilities by sample table size
    word_freq[:, 1] = np.round(word_freq[:, 1] * NEGATIVE_SAMPLE_TABLE_SIZE)

    # Convert to int
    word_freq = word_freq.astype(int).tolist()

    # Create sample table
    sample_table = [[tup[0]] * tup[1] for tup in word_freq]
    sample_table = np.array(list(itertools.chain.from_iterable(sample_table)))
    np.random.shuffle(sample_table)

    return sample_table

In [None]:
neg_table = get_negative_sample_table(word_freq)

In [None]:
negative_idx = 0
def get_negative_samples(context, sample_size=5) -> np.array:
    """
    Returns a list of negative samples, where len = sample_size.

    Args:
        sample_size:

    Returns:

    """
    negative_idx = 0
    while True:
        # Get a batch from the shuffled table
        neg_sample = neg_table[negative_idx:negative_idx + sample_size]

        # Update negative index
        negative_idx = (negative_idx + sample_size) % len(neg_table)

        # Check if batch insufficient
        if len(neg_sample) != sample_size:
            neg_sample = np.concatenate((neg_sample, neg_table[:negative_idx]))

        # Check if context in negative sample
        if not context in neg_sample:
            return neg_sample

In [None]:
get_negative_samples(121656)

### Testing

In [34]:
from src.ml.data_loader_edges import Edges, EdgesDataset

In [35]:
dataset = 'electronics'
edges = Edges(edge_path='../data/{}_edges_train_samp.csv'.format(dataset),
              val_path='../data/{}_edges_val_samp.csv'.format(dataset))

2019-12-13 10:10:20,678 - Edges loaded (length = 9,999)
2019-12-13 10:10:20,756 - Validation set loaded: (100000, 3)
2019-12-13 10:10:20,838 - No. of unique tokens: 128802
2019-12-13 10:10:22,034 - Model saved to model/word2id_edge
2019-12-13 10:10:23,217 - Model saved to model/id2word_edge
2019-12-13 10:10:23,217 - Word2Id and Id2Word created and saved
2019-12-13 10:10:23,237 - Edges prepared


In [20]:
len(edges.neg_table)

9996217

In [21]:
edges.get_negative_samples(1)

array([ 57910, 123297,  72649, 123626,  19100])

In [22]:
pair = edges.edges[4]

In [23]:
context = pair[1]

In [24]:
pair

array([45651, 96516,     2])

In [25]:
context

96516

In [36]:
dataset = EdgesDataset(edges)

In [37]:
dataloader = DataLoader(dataset, batch_size=5, collate_fn=dataset.collate_continuous)

In [38]:
for i, batches in enumerate(dataloader):
    if i > 2:
        break
    logger.info('i: {}, batches: {}'.format(i, batches))

2019-12-13 10:10:28,521 - i: 0, batches: (tensor([  7854,   7854,   7854,   7854,   7854,   7854,  15746,  15746,  15746,
         15746,  15746,  15746, 121606, 121606, 121606, 121606, 121606, 121606,
         98156,  98156,  98156,  98156,  98156,  98156,  45651,  45651,  45651,
         45651,  45651,  45651]), tensor([ 41132,  76249,  79665,  44770,  44310,  49603,  52572, 117187,  31609,
         71384,  12211,   4100,  68042,  67788,  86698,  18276,   7869,  47684,
        112014, 124112,  13020,  71960,  70662,  25855,  96516,  92665,  54844,
         22578, 112309,  95797]), tensor([1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0.]))
2019-12-13 10:10:28,526 - i: 1, batches: (tensor([ 87057,  87057,  87057,  87057,  87057,  87057,  48254,  48254,  48254,
         48254,  48254,  48254, 122944, 122944, 122944, 122944, 122944, 122944,
         66545,  66545,  66545,  66545,  66545,  66545, 121354, 121354

In [44]:
pd.Series(edges.edges[:, 2]).value_counts()

1.0    7706
2.0    1299
0.5     504
2.2     199
3.2     122
1.5     113
2.5      18
4.4      12
1.2      11
2.7       5
3.7       4
3.0       4
1.7       1
4.2       1
dtype: int64

In [106]:
batches[0]

tensor([287628, 287628, 287628, 287628, 287628, 287628, 394067, 394067, 394067,
        394067, 394067, 394067,  97662,  97662,  97662,  97662,  97662,  97662,
        306502, 306502, 306502, 306502, 306502, 306502, 385155, 385155, 385155,
        385155, 385155, 385155])