In [1]:
from collections import defaultdict
import re
import pickle
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import time

In [3]:
with open('/home/instacart_data_all.pkl', 'rb') as f:
    instacart_all = pickle.load(f)
with open('/home/google_data_all.pkl', 'rb') as f:
    google_all = pickle.load(f)

In [4]:
def filter_dataset(dataset, excluded_categories=[]):
    filtered_dataset = []
    for entry in dataset:
        if not entry['category']:
            continue
        skip = False
        for c in entry['category']:
            if c in excluded_categories:
                skip = True
        if not skip:
            filtered_dataset.append(entry)
    return filtered_dataset

In [5]:
instacart = filter_dataset(instacart_all)
google = filter_dataset(google_all, ['home garden', 'home improvement tools', 'arts crafts party supplies',
                                     'automotive', 'musical instruments', 'books movies music',
                                     'electronics', 'apparel', 'office school supplies',])

In [6]:
cats = []
for entry in google:
    cats.append(entry['category'][0])
print(len(set(cats)))
Counter(cats).most_common()

8


[('sports outdoors', 169457),
 ('health beauty', 146447),
 ('pet supplies', 43175),
 ('toys games', 41703),
 ('household supplies', 28652),
 ('baby kids', 28288),
 ('grocery', 25151),
 ('travel luggage bags', 18561)]

In [34]:
num_cats = []
for entry in instacart:
    num_cats.append(len(entry['category']))

In [35]:
Counter(num_cats)

Counter({2: 322634, 1: 3524, 3: 4})

## Take data and tokenize

In [9]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
    'bert-base-multilingual-cased',
    do_lower_case=True)

In [18]:
def relative_pos(og_indices, new_indices):
    idxs = []
    for n in new_indices:
        idxs.append(og_indices.index(n))
    return idxs

def tokenize_dataset(dataset, max_level, msl_title=128, msl_cat=10):
    tokenization_by_category_level = []
    zero_indices = None
    for level in range(max_level):
        print('level:', level)
        titles_and_desc = []
        cats = []
        indices = []
        for idx, entry in enumerate(dataset):
            if not entry['category']:  # no categories
                continue
            if len(entry['category'])-1 < level:
                continue
            indices.append(idx)
            title = entry['title']
            desc = entry['description']
            category = entry['category'][level]
            titles_and_desc.append((title, desc))
            cats.append(category)
        if len(titles_and_desc) == 0:
            # Nothing with category in this level
            break
        if level > 0:
            relative_indices = relative_pos(zero_indices, indices)
            tokenized_titles_and_desc = {}
            for input_key, arr in tokenization_by_category_level[0][0].items():
                tokenized_titles_and_desc[input_key] = np.array(arr)[relative_indices]
        else:
            print(len(titles_and_desc), flush=True)
            zero_indices = indices
            start = time.time()
            tokenized_titles_and_desc = tokenizer.batch_encode_plus(
                titles_and_desc, max_length=msl_title, pad_to_max_length=True)
            print(time.time()-start)
        tokenized_cats = tokenizer.batch_encode_plus(cats, max_length=msl_cat, pad_to_max_length=True)
        tokenization_by_category_level.append((tokenized_titles_and_desc, tokenized_cats))
    return tokenization_by_category_level

In [21]:
tokenized_dataset = tokenize_dataset(instacart[:10000], max_level=6, msl_title=128, msl_cat=10)

level: 0
10000
20.750762462615967
level: 1
level: 2


In [None]:
instacart[:10]

In [24]:
for ik, arr in tokenized_dataset[1][1].items():
    print(ik)
    print(np.array(arr).shape)
    print(arr)
    if ik == 'input_ids':
        for ids in arr:
            print(tokenizer.convert_tokens_to_string(
                tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)))

input_ids
(10, 10)
[[101, 10173, 45094, 24384, 10111, 49963, 102, 0, 0, 0], [101, 62432, 10162, 108193, 14724, 14273, 47353, 102, 0, 0], [101, 10680, 102, 0, 0, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 10680, 102, 0, 0, 0, 0, 0, 0, 0], [101, 32650, 36269, 11945, 102, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0]]
condiments and supplies
packaged vegetables fruits
red
beverages
beverages
beverages
red
sparkling
beverages
beverages
token_type_ids
(10, 10)
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
attention_mask
(10, 10)
[[1, 1,

In [32]:
rootdir = '/home/transformers-public/rappi-data-updated/'
assert os.path.exists(rootdir)

In [33]:
file_pattern = "{dataset}_{split}_{level}_{msl_title}_{msl_cat}_{input_key}"
lens = {}
msl_title = 128
msl_cat = 10
dataset_to_obj = {'instacart': instacart, 'google': google}
for dataset in ['instacart']:
    tokenized_dataset = tokenize_dataset(dataset_to_obj[dataset],
                                         max_level=7, msl_title=msl_title, msl_cat=msl_cat)
    print(dataset, flush=True)
    print("got dataset")
    for level in range(len(tokenized_dataset)):
        print('level:', level, flush=True)
        titles, cats = tokenized_dataset[level]
        split_to_arr = {'title-and-desc': titles, 'category': cats}
        for split in ['title-and-desc', 'category']:
            tokenized_output = split_to_arr[split]
            for input_key, arr in tokenized_output.items():
                filename = os.path.join(rootdir,
                                        file_pattern.format(
                                            dataset=dataset,
                                            split=split,
                                            level=level,
                                            msl_title=msl_title,
                                            msl_cat=msl_cat,
                                            input_key=input_key))
                arr = np.array(arr)
                fp = np.memmap(filename, dtype='int64', mode='w+',
                              shape=arr.shape)
                lens[(dataset, split, level)] = arr.shape
                fp[:] = arr[:]
                del fp

level: 0
326162
680.991087436676
level: 1
level: 2
level: 3
instacart
got dataset
level: 0
level: 1
level: 2


In [45]:
with open(os.path.join(rootdir, 'lens_{}_{}.pkl'.format(msl_title, msl_cat)), 'wb') as f:
    # Keep lens which is needed for loading a memmap
    pickle.dump(lens, f)

# Dataloader

In [37]:
import torch
import itertools
from collections import Counter, OrderedDict, defaultdict
from torch.utils.data import DataLoader

In [38]:
class IterableTitles(torch.utils.data.IterableDataset):
    def __init__(self, root_dir, dataset, level, msl_title, msl_cat):
        super(IterableTitles).__init__()
        file_pattern = "{dataset}_{{split}}_{level}_{msl_title}_{msl_cat}_{{input_key}}"
        self.datafile_pattern = os.path.join(
            root_dir,
            file_pattern.format(dataset=dataset, level=level, msl_title=msl_title, msl_cat=msl_cat))
        len_filename = 'lens_{}_{}.pkl'.format(msl_title, msl_cat)
        with open(os.path.join(root_dir, len_filename), 'rb') as f:
            len_dict = pickle.load(f)
        self.len_dict = len_dict
        assert self.len_dict[(dataset, 'title-and-desc', level)][0] == self.len_dict[(dataset, 'category', level)][0]
        self.length = self.len_dict[(dataset, 'title-and-desc', level)][0]
        self.dataset = dataset
        self.level = level

    def __len__(self):
        return self.length

    def __iter__(self):
        split_input_key_to_memmap = {}
        for split in ['title-and-desc', 'category']:
            for input_key in ['input_ids',
                              'attention_mask',
                              'token_type_ids']:
                datafile = self.datafile_pattern.format(split=split, input_key=input_key)
                fp = np.memmap(
                    datafile, dtype='int64', mode='r+',
                    shape=self.len_dict[(self.dataset, split, self.level)])
                split_input_key_to_memmap[(split, input_key)] = fp
        while True:
            i = random.choice(range(self.length))
            yield (
                split_input_key_to_memmap[('title-and-desc', 'input_ids')][i],
                split_input_key_to_memmap[('title-and-desc', 'token_type_ids')][i],
                split_input_key_to_memmap[('title-and-desc', 'attention_mask')][i],
                split_input_key_to_memmap[('category', 'input_ids')][i],
                split_input_key_to_memmap[('category', 'token_type_ids')][i],
                split_input_key_to_memmap[('category', 'attention_mask')][i],
            )


class MultiStreamDataLoader:

    def __init__(self, root_dir, msl_title, msl_cat, batch_size):
        len_filename = 'lens_{}_{}.pkl'.format(msl_title, msl_cat)
        with open(os.path.join(root_dir, len_filename), 'rb') as f:
            len_dict = pickle.load(f)
        print(len_dict)
        self.len_dict = dict([(k, v[0]) for k,v in len_dict.items()])  # in len_dict shape is stored
        self.batch_size = batch_size
        self.total_samples = sum(self.len_dict.values())//2  # title and category => duplicate
        self.dataset_lvl_to_iter = {}
        for dataset in ['instacart']:
            for key in self.len_dict:
                if key[0] == dataset and key[1] == 'title-and-desc':
                    level = key[2]
                    dataset_iter = IterableTitles(root_dir, dataset, level, msl_title, msl_cat)
                    self.dataset_lvl_to_iter[(dataset, level)] = iter(DataLoader(dataset_iter, batch_size=None))

    def __len__(self):
        return self.total_samples//self.batch_size
    

    def __iter__(self):
        dataset_keys = list(self.len_dict.keys())
        print(self.dataset_lvl_to_iter.keys())
        while True:
            buffer = []
            labels = []
            key_choices = random.choices(dataset_keys, weights=list(self.len_dict.values()), k=self.batch_size)
            for kc in key_choices:
                key = (kc[0], kc[2])
                if key[0] == 'google':
                    continue
                buffer.extend(
                    [next(self.dataset_lvl_to_iter[key])]
                )
            yield (torch.stack([b[0] for b in buffer]),
                   torch.stack([b[1] for b in buffer]),
                   torch.stack([b[2] for b in buffer]),
                   torch.stack([b[3] for b in buffer]),
                   torch.stack([b[4] for b in buffer]),
                   torch.stack([b[5] for b in buffer]),
                  )

In [101]:
print("hi")

hi


In [43]:
rootdir = '/home/transformers-public/rappi-data-updated'
assert os.path.exists(rootdir)

In [46]:
ds = MultiStreamDataLoader(root_dir=rootdir,
                           msl_title=msl_title,
                           msl_cat=msl_cat,
                           batch_size=64)

{('instacart', 'title-and-desc', 0): (326162, 128), ('instacart', 'category', 0): (326162, 10), ('instacart', 'title-and-desc', 1): (322638, 128), ('instacart', 'category', 1): (322638, 10), ('instacart', 'title-and-desc', 2): (4, 128), ('instacart', 'category', 2): (4, 10)}


In [47]:
for i in ds:
    print(i)
    break

dict_keys([('instacart', 0), ('instacart', 1), ('instacart', 2)])
(tensor([[  101, 46484, 35045,  ..., 30231, 10263,   102],
        [  101, 29177, 10245,  ...,     0,     0,     0],
        [  101, 41924, 49085,  ...,   118, 13961,   102],
        ...,
        [  101, 12541, 15797,  ..., 36796, 40564,   102],
        [  101, 13009, 11195,  ..., 10271,   118,   102],
        [  101, 14772, 50513,  ...,     0,     0,     0]]), tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 0, 0, 0]]), tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]]), tensor([[   101,  10127,  10116,    102,      0,      0,      0,      0,      0,
              0],
        [   101,  10127,  10116, 

In [48]:
for batch_num, batch in enumerate(ds):
    if batch_num == 1:
        break
    input_ids_title = batch[0]
    input_ids_cat = batch[3]
    print(input_ids_title.shape)
    for _, ids in enumerate(input_ids_title):
        print(
            tokenizer.convert_tokens_to_string(
                tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)),
        )
    print("---")
    for _, ids in enumerate(input_ids_cat):
        print(
            tokenizer.convert_tokens_to_string(
                tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)),
        )
    print("=====")

dict_keys([('instacart', 0), ('instacart', 1), ('instacart', 2)])
torch.Size([64, 128])
ghirardelli chocolate double chocolate premium brownie mix ( 2 . 3 oz ) from lucky - instacart about one minute to prepare .
red bull energy drink ( 16 fl oz ) from kings food markets - instacart red bulla® energy drink . vitalizes body and mind . made with taurine . recognized worldwide by top athletes , busy professionals , college students , and travelers on long journeys . lightly carbonated .
somersaults snack co sunflower seed crunchy bites - cinnamon ( 6 oz ) from natural grocers - instacart somersaultsa® sunflower seed crunchy bites cinnamon . 5g protein . 3g fiber . nut tree . vegan . sweet baked goodness . non - gmo - project - verified . nongmoproject . org . certified vegan , vegan . org . net wt 6 oz ( 170 g ) .
ken ' s steakhouse dressing ranch ( 16 fl oz ) from piggly wiggly ga - instacart ken ' sa® steak house dressing ranch . free from gluten . satisfaction guaranteed .
white onion 