In [2]:
from collections import defaultdict
import re
import pickle
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import time

In [34]:
import torch 
torch.cuda.is_available()

True

In [4]:
from google.cloud import storage


def download_blob(bucket_name, source_blob_name, destination_file_name):
    """Downloads a blob from the bucket."""
    # bucket_name = "your-bucket-name"
    # source_blob_name = "storage-object-name"
    # destination_file_name = "local/path/to/file"

    storage_client = storage.Client()

    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(source_blob_name)
    blob.download_to_filename(destination_file_name)

    print(
        "Blob {} downloaded to {}.".format(
            source_blob_name, destination_file_name
        )
    )

In [8]:
download_blob("glisten", "datasets/google_rappi_category.pkl", '/home/sarahwooders_gmail_com/transformers/datasets/google_rappi_category.pkl')
download_blob("glisten", "datasets/instacart_rappi_category.pkl", '/home/sarahwooders_gmail_com/transformers/datasets/instacart_rappi_category.pkl')

Blob datasets/google_rappi_category.pkl downloaded to /home/sarahwooders_gmail_com/transformers/datasets/google_rappi_category.pkl.
Blob datasets/instacart_rappi_category.pkl downloaded to /home/sarahwooders_gmail_com/transformers/datasets/instacart_rappi_category.pkl.


In [9]:
with open('/home/sarahwooders_gmail_com/transformers/datasets/instacart_rappi_category.pkl', 'rb') as f:
    instacart_all = pickle.load(f)
with open('/home/sarahwooders_gmail_com/transformers/datasets/google_rappi_category.pkl', 'rb') as f:
    google_all = pickle.load(f)

In [10]:
instacart_all[0]

{'category': ['grilling out', 'condiments and supplies'],
 'original_category': 'Home > Publix > Grilling Out > Condiments and Supplies',
 'seller': 'Publix',
 'title': 'Kraft Singles American Cheese Slices (24 ct) from Publix - Instacart',
 'description': 'Kraft Singles American Slices feature the melty, great taste that you love with no artificial preservatives or flavors. This sliced cheese has a smooth, creamy texture and mild flavor. Pre-sliced for your convenience, this cheese melts beautifully over hot foods or in the oven. Kraft Singles slices of American cheese are made with quality ingredients, like fresh pasteurized milk, so you can feel good about feeding them to your family. Slide a slice of this deli style cheese into a grilled cheese sandwich, or melt one on top of a juicy burger. For optimum flavor, keep this 16 ounce package of 24 cheese slices refrigerated until use.',
 'ingredients': 'Milk, Cheddar Cheese (milk, Cheese Culture, Salt, Enzymes), Whey, Milk Protein Conc

In [62]:
# with open('/../rappi-data/instacart_data_all.pkl', 'rb') as f:
#     instacart_all = pickle.load(f)
# with open('/home/google_data_all.pkl', 'rb') as f:
#     google_all = pickle.load(f)

In [11]:
def filter_dataset(dataset, excluded_categories=[]):
    filtered_dataset = []
    for entry in dataset:
        if not entry['category']:
            continue
        skip = False
        for c in entry['category']:
            if c in excluded_categories:
                skip = True
        if not skip:
            filtered_dataset.append(entry)
    return filtered_dataset

In [12]:
instacart = filter_dataset(instacart_all)
google = filter_dataset(google_all, ['home garden', 'home improvement tools', 'arts crafts party supplies',
                                     'automotive', 'musical instruments', 'books movies music',
                                     'electronics', 'apparel', 'office school supplies',])

In [13]:
# rappi versions are same, except with category replaced by rappi_category
def get_rappi_version(dataset):
    rappi_version = []
    for entry in dataset:
        new_entry = dict(entry)
        new_entry['category'] = new_entry['rappi_category']
        rappi_version.append(new_entry)
    return rappi_version

In [14]:
instacart_rappi = get_rappi_version(instacart)
google_rappi = get_rappi_version(google)

In [15]:
cats = []
for entry in instacart_rappi:
    cats.append(entry['category'][0])
print(len(set(cats)))
Counter(cats).most_common()

13


[('food and drinks', 248986),
 ('care & beauty', 26865),
 ('home', 23713),
 ('babies', 11699),
 ('pets', 8406),
 ('cars', 2960),
 ('other', 1764),
 ('fashion', 685),
 ('hobbies', 681),
 ('office', 281),
 ('Other', 70),
 ('services', 38),
 ('dairy & eggs/  > eggs', 13)]

In [16]:
num_cats = []
for entry in instacart:
    num_cats.append(len(entry['category']))

In [17]:
Counter(num_cats)

Counter({2: 322633, 1: 3524, 3: 4})

## Take data and tokenize

In [18]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
    'bert-base-multilingual-cased',
    do_lower_case=True)

In [31]:
from tqdm import tqdm 

def relative_pos(og_indices, new_indices):
    idxs = []
    for n in new_indices:
        idxs.append(og_indices.index(n))
    return idxs

def tokenize_dataset(dataset, max_level, msl_title=128, msl_cat=10):
    tokenization_by_category_level = []
    zero_indices = None
    for level in range(max_level):
        print('tokenizer level:', level)
        print('dataset size', len(dataset))
        titles_and_desc = []
        cats = []
        indices = []
        for idx, entry in tqdm(enumerate(dataset)):
            if not entry['category']:  # no categories
                continue
            if len(entry['category'])-1 < level:
                continue
            indices.append(idx)
            title = entry['title']
            desc = entry['description']
            category = entry['category'][level]
            titles_and_desc.append((title, desc))
            cats.append(category)
        if len(titles_and_desc) == 0:
            # Nothing with category in this level
            break
        if level > 0:
            relative_indices = relative_pos(zero_indices, indices)
            tokenized_titles_and_desc = {}
            for input_key, arr in tokenization_by_category_level[0][0].items():
                tokenized_titles_and_desc[input_key] = np.array(arr)[relative_indices]
        else:
            print(len(titles_and_desc), flush=True)
            zero_indices = indices
            start = time.time()
            tokenized_titles_and_desc = tokenizer.batch_encode_plus(
                titles_and_desc, max_length=msl_title, pad_to_max_length=True)
            print(time.time()-start)
        tokenized_cats = tokenizer.batch_encode_plus(cats, max_length=msl_cat, pad_to_max_length=True)
        tokenization_by_category_level.append((tokenized_titles_and_desc, tokenized_cats))
    return tokenization_by_category_level

In [29]:
tokenized_dataset = tokenize_dataset(instacart[:10000], max_level=6, msl_title=128, msl_cat=10)

10000it [00:00, 583376.78it/s]

level: 0
10000





20.810821771621704


10000it [00:00, 752112.18it/s]

level: 1



10000it [00:00, 1596552.85it/s]

level: 2





In [21]:
instacart[:10]

[{'category': ['grilling out', 'condiments and supplies'],
  'original_category': 'Home > Publix > Grilling Out > Condiments and Supplies',
  'seller': 'Publix',
  'title': 'Kraft Singles American Cheese Slices (24 ct) from Publix - Instacart',
  'description': 'Kraft Singles American Slices feature the melty, great taste that you love with no artificial preservatives or flavors. This sliced cheese has a smooth, creamy texture and mild flavor. Pre-sliced for your convenience, this cheese melts beautifully over hot foods or in the oven. Kraft Singles slices of American cheese are made with quality ingredients, like fresh pasteurized milk, so you can feel good about feeding them to your family. Slide a slice of this deli style cheese into a grilled cheese sandwich, or melt one on top of a juicy burger. For optimum flavor, keep this 16 ounce package of 24 cheese slices refrigerated until use.',
  'ingredients': 'Milk, Cheddar Cheese (milk, Cheese Culture, Salt, Enzymes), Whey, Milk Protei

In [23]:
for ik, arr in tokenized_dataset[1][1].items():
    print(ik)
    print(np.array(arr).shape)
    print(arr)
    if ik == 'input_ids':
        for ids in arr:
            print(tokenizer.convert_tokens_to_string(
                tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)))

input_ids
(9653, 10)
[[101, 10173, 45094, 24384, 10111, 49963, 102, 0, 0, 0], [101, 62432, 10162, 108193, 14724, 14273, 10111, 47353, 102, 0], [101, 10680, 102, 0, 0, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 10680, 102, 0, 0, 0, 0, 0, 0, 0], [101, 32650, 36269, 11945, 102, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 15263, 102, 0, 0, 0, 0, 0, 0, 0], [101, 10680, 102, 0, 0, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 10347, 45918, 13156, 102, 0, 0, 0, 0, 0], [101, 67622, 73768, 102, 0, 0, 0, 0, 0, 0], [101, 58768, 103189, 10336, 30521, 10347, 45918, 13156, 102, 0], [101, 67622, 73768, 102, 0, 0, 0, 0, 0, 0], [101, 31084, 73768, 102, 0, 0, 0, 0, 0, 0], [101, 31084, 73768, 102, 0, 0, 0, 0, 0, 0], [101, 67622, 73768, 102, 0, 0, 0, 0, 0, 0], [101, 31084, 73768, 102, 0, 

hot dogs bacon and sausage
hot dogs bacon and sausage
hot dogs bacon and sausage
hot dogs bacon and sausage
hot dogs bacon and sausage
hot dogs bacon and sausage
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
specialty cheeses
specialty cheeses
specialty cheeses
specialty cheeses
specialty cheeses
specialty cheeses
specialty cheeses
specialty cheeses
specialty cheeses
specialty cheeses
specialty cheeses
specialty cheeses
lunch meat
lunch meat
packaged meat
lunch meat
lunch meat
lunch meat
lunch meat
lunch meat
lunch meat
lunch meat
lunch meat
lunch meat
prepared meals
prepared meals
prepared meals
dips
food ( limited time )
prepared meals
prepared meals
prepared meals
prepared meals
prepared meals
fresh dips and tapenades
fresh dips and tapenades
fresh dips and tapenades
fresh dips and tapenades
fresh dips and tapenades
fresh dips and tapenades
fresh dips and tapen

hair tools and brushes
hair tools and brushes
hair tools and brushes
hair tools and brushes
hair tools and brushes
hair tools and brushes
hair tools and brushes
hair tools and brushes
face makeup
face makeup
face makeup
face makeup
face makeup
face makeup
face makeup
face makeup
face makeup
face makeup
face makeup
face makeup
eye makeup
eye makeup
eye makeup
eye makeup
eye makeup
eye makeup
eye makeup
eye makeup
eye makeup
eye makeup
eye makeup
eye makeup
lip makeup
lip makeup
lip makeup
lip makeup
lip makeup
lip makeup
lip makeup
lip makeup
lip makeup
lip makeup
lip makeup
lip makeup
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
makeup brushes and applicators
beauty tools and accessories
beauty 

nuts
nuts
soup
soup
soup
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
pasta and pasta sauce
instant foods
instant foods
instant foods
instant foods
instant foods
instant foods
instant foods
instant foods
instant foods
instant foods
instant foods
instant foods
members mark
members mark
members mark
members mark
members mark
members mark
members mark
prepared
members mark
members mark
members mark
members mark
fresh fruits
fresh fruits
fresh fruits
fresh fruits
fresh fruits
fresh fruits
fresh fruits
fresh fruits
fresh fruits
fresh fruits
fresh fruits
fresh fruits
fresh vegetables
fresh vegetables
fresh vegetables
fresh vegetables
fresh vegetables
fresh vegetables
fresh produce
fresh vegetables
fresh vegetables
fresh vegetables
fresh vegetables
fresh produce
packaged vegetables and fruits
packag

soft drinks
soft drinks
soft drinks
soft drinks
soft drinks
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
water seltzer and sparkling water
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cocoa and drink mixes
cereal
cereal
cereal
cereal
cereal
cereal
cereal
cereal
cereal
cereal
cereal
cereal
hot cereal and pancake mix
hot cereal and pancake mix
hot cereal and pancake mix
hot cereal and pancake mix
hot cereal and pancake mix
hot cereal and pancake mix
hot cereal and pancake

ice cream and ice
ice cream and ice
ice cream and ice
ice cream and ice
ice cream and ice
ice cream and ice
ice cream and ice
facial care
facial care
facial care
facial care
facial care
facial care
facial care
facial care
facial care
facial care
facial care
facial care
deodorants
deodorants
deodorants
deodorants
deodorants
deodorants
deodorants
deodorants
deodorants
deodorants
deodorants
deodorants
body lotions and soap
body lotions and soap
body lotions and soap
body lotions and soap
body lotions and soap
body lotions and soap
body lotions and soap
body lotions and soap
body lotions and soap
body lotions and soap
body lotions and soap
body lotions and soap
shave needs
shave needs
shave needs
shave needs
shave needs
shave needs
shave needs
shave needs
shave needs
shave needs
shave needs
shave needs
oral hygiene
oral hygiene
oral hygiene
oral hygiene
oral hygiene
oral hygiene
oral hygiene
oral hygiene
oral hygiene
oral hygiene
oral hygiene
oral hygiene
eye and ear care
eye and ear care


other creams and cheeses
other creams and cheeses
other creams and cheeses
other creams and cheeses
other creams and cheeses
other creams and cheeses
other creams and cheeses
other creams and cheeses
other creams and cheeses
soy and lactose free
soy and lactose free
soy and lactose free
soy and lactose free
soy and lactose free
soy and lactose free
soy and lactose free
soy and lactose free
soy and lactose free
soy and lactose free
soy and lactose free
soy and lactose free
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
refrigerated pudding and desse
canned fruit and applesauce
canned fruit and applesauce
canned fruit and applesauce
canned fruit and applesauce
canned fruit and applesauce
canned frui

dips
dips
dips
meat counter
meat counter
meat counter
meat counter
meat counter
meat counter
meat counter
meat counter
meat counter
meat counter
meat counter
meat counter
poultry counter
poultry counter
poultry counter
poultry counter
poultry counter
poultry counter
poultry counter
poultry counter
poultry counter
poultry counter
poultry counter
poultry counter
seafood counter
seafood counter
fresh vegetables
seafood counter
seafood counter
seafood counter
seafood counter
seafood counter
seafood counter
seafood counter
seafood counter
seafood counter
packaged meat
packaged meat
packaged meat
packaged meat
packaged meat
packaged meat
packaged meat
packaged meat
packaged meat
packaged meat
packaged meat
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged poultry
packaged seafood
packaged seafood
packaged seafood
packaged seafood
packaged seafood
p

frozen pizza
frozen pizza
frozen pizza
frozen pizza
frozen pizza
frozen pizza
doughs gelatins and bake
frozen breads and dough
frozen breads and dough
frozen breads and dough
frozen breads and dough
frozen breads and dough
frozen breads and dough
frozen breads and dough
frozen breads and dough
frozen breads and dough
frozen produce
frozen produce
frozen produce
frozen produce
frozen produce
frozen produce
frozen produce
frozen produce
frozen produce
frozen produce
canned and jarred vegeta
frozen meat and seafood
frozen meat and seafood
frozen meat and seafood
frozen meat and seafood
frozen meat and seafood
frozen meat and seafood
frozen meat and seafood
frozen meat and seafood
frozen meat and seafood
frozen meat and seafood
frozen vegan and vegetarian
frozen vegan and vegetarian
frozen vegan and vegetarian
frozen vegan and vegetarian
frozen vegan and vegetarian
frozen vegan and vegetarian
frozen vegan and vegetarian
frozen vegan and vegetarian
frozen vegan and vegetarian
frozen vegan a

In [26]:
rootdir = '../rappi-data/tokenized'
assert os.path.exists(rootdir)

In [35]:
file_pattern = "{dataset}_{split}_{level}_{msl_title}_{msl_cat}_{input_key}"
lens = {}
msl_title = 128
msl_cat = 10
dataset_to_obj = {'instacart': instacart, 'google': google,
                  'instacart-rappi': instacart_rappi,
                  'google-rappi': google_rappi}

for dataset in ['google', 'instacart-rappi', 'google-rappi']:
    tokenized_dataset = tokenize_dataset(dataset_to_obj[dataset],
                                         max_level=7, msl_title=msl_title, msl_cat=msl_cat)
    print(dataset, flush=True)
    print("got dataset")
    for level in range(len(tokenized_dataset)):
        print('level:', level, flush=True)
        titles, cats = tokenized_dataset[level]
        split_to_arr = {'title-and-desc': titles, 'category': cats}
        for split in ['title-and-desc', 'category']:
            tokenized_output = split_to_arr[split]
            for input_key, arr in tokenized_output.items():
                filename = os.path.join(rootdir,
                                        file_pattern.format(
                                            dataset=dataset,
                                            split=split,
                                            level=level,
                                            msl_title=msl_title,
                                            msl_cat=msl_cat,
                                            input_key=input_key))
                arr = np.array(arr)
                fp = np.memmap(filename, dtype='int64', mode='w+',
                              shape=arr.shape)
                lens[(dataset, split, level)] = arr.shape
                fp[:] = arr[:]
                del fp

56224it [00:00, 562208.37it/s]

tokenizer level: 0
dataset size 298454


298454it [00:00, 600921.59it/s]

298454





735.9970207214355


78290it [00:00, 782893.28it/s]

tokenizer level: 1
dataset size 298454


298454it [00:00, 812087.97it/s]
78648it [00:00, 786477.00it/s]

tokenizer level: 2
dataset size 298454


298454it [00:00, 813828.64it/s]
79226it [00:00, 792130.44it/s]

tokenizer level: 3
dataset size 298454


298454it [00:00, 833587.03it/s]
118283it [00:00, 1182822.67it/s]

tokenizer level: 4
dataset size 298454


298454it [00:00, 1147929.52it/s]
298454it [00:00, 1618105.42it/s]

tokenizer level: 5
dataset size 298454



298454it [00:00, 1719436.90it/s]

tokenizer level: 6
dataset size 298454





google
got dataset
level: 0
level: 1
level: 2
level: 3
level: 4
level: 5
level: 6


0it [00:00, ?it/s]

tokenizer level: 0
dataset size 326161


326161it [00:00, 634531.89it/s]

326161





690.3850283622742


80835it [00:00, 808344.99it/s]

tokenizer level: 1
dataset size 326161


326161it [00:00, 820515.97it/s]
84882it [00:00, 848818.79it/s]

tokenizer level: 2
dataset size 326161


326161it [00:00, 872864.85it/s]
108112it [00:00, 1081118.45it/s]

tokenizer level: 3
dataset size 326161


326161it [00:00, 1114617.83it/s]
326161it [00:00, 1823817.82it/s]

tokenizer level: 4
dataset size 326161
instacart-rappi





got dataset
level: 0
level: 1
level: 2
level: 3


81680it [00:00, 816794.94it/s]

tokenizer level: 0
dataset size 298454


298454it [00:00, 830954.24it/s]

298454





730.2772514820099


83257it [00:00, 832568.81it/s]

tokenizer level: 1
dataset size 298454


298454it [00:00, 847056.23it/s]
84874it [00:00, 848633.57it/s]

tokenizer level: 2
dataset size 298454


298454it [00:00, 862160.34it/s]
89021it [00:00, 890204.48it/s]

tokenizer level: 3
dataset size 298454


298454it [00:00, 892828.88it/s]
298454it [00:00, 1937295.61it/s]

tokenizer level: 4
dataset size 298454





google-rappi
got dataset
level: 0
level: 1
level: 2
level: 3


In [36]:
with open(os.path.join(rootdir, 'lens_{}_{}.pkl'.format(msl_title, msl_cat)), 'wb') as f:
    # Keep lens which is needed for loading a memmap
    pickle.dump(lens, f)

# Dataloader

In [37]:
import torch
import itertools
from collections import Counter, OrderedDict, defaultdict
from torch.utils.data import DataLoader

In [38]:
class IterableTitles(torch.utils.data.IterableDataset):
    def __init__(self, root_dir, dataset, level, msl_title, msl_cat):
        super(IterableTitles).__init__()
        file_pattern = "{dataset}_{{split}}_{level}_{msl_title}_{msl_cat}_{{input_key}}"
        self.datafile_pattern = os.path.join(
            root_dir,
            file_pattern.format(dataset=dataset, level=level, msl_title=msl_title, msl_cat=msl_cat))
        len_filename = 'lens_{}_{}.pkl'.format(msl_title, msl_cat)
        with open(os.path.join(root_dir, len_filename), 'rb') as f:
            len_dict = pickle.load(f)
        self.len_dict = len_dict
        assert self.len_dict[(dataset, 'title-and-desc', level)][0] == self.len_dict[(dataset, 'category', level)][0]
        self.length = self.len_dict[(dataset, 'title-and-desc', level)][0]
        self.dataset = dataset
        self.level = level

    def __len__(self):
        return self.length

    def __iter__(self):
        split_input_key_to_memmap = {}
        for split in ['title-and-desc', 'category']:
            for input_key in ['input_ids',
                              'attention_mask',
                              'token_type_ids']:
                datafile = self.datafile_pattern.format(split=split, input_key=input_key)
                fp = np.memmap(
                    datafile, dtype='int64', mode='r+',
                    shape=self.len_dict[(self.dataset, split, self.level)])
                split_input_key_to_memmap[(split, input_key)] = fp
        while True:
            i = random.choice(range(self.length))
            yield (
                split_input_key_to_memmap[('title-and-desc', 'input_ids')][i],
                split_input_key_to_memmap[('title-and-desc', 'token_type_ids')][i],
                split_input_key_to_memmap[('title-and-desc', 'attention_mask')][i],
                split_input_key_to_memmap[('category', 'input_ids')][i],
                split_input_key_to_memmap[('category', 'token_type_ids')][i],
                split_input_key_to_memmap[('category', 'attention_mask')][i],
            )


class MultiStreamDataLoader:

    def __init__(self, root_dir, msl_title, msl_cat, batch_size):
        len_filename = 'lens_{}_{}.pkl'.format(msl_title, msl_cat)
        with open(os.path.join(root_dir, len_filename), 'rb') as f:
            len_dict = pickle.load(f)
        print(len_dict)
        self.len_dict = dict([(k, v[0]) for k,v in len_dict.items()])  # in len_dict shape is stored
        self.batch_size = batch_size
        self.total_samples = sum(self.len_dict.values())//2  # title and category => duplicate
        self.dataset_lvl_to_iter = {}
        for dataset in ['instacart']:
            for key in self.len_dict:
                if key[0] == dataset and key[1] == 'title-and-desc':
                    level = key[2]
                    dataset_iter = IterableTitles(root_dir, dataset, level, msl_title, msl_cat)
                    self.dataset_lvl_to_iter[(dataset, level)] = iter(DataLoader(dataset_iter, batch_size=None))

    def __len__(self):
        return self.total_samples//self.batch_size
    

    def __iter__(self):
        dataset_keys = list(self.len_dict.keys())
        print(self.dataset_lvl_to_iter.keys())
        while True:
            buffer = []
            labels = []
            key_choices = random.choices(dataset_keys, weights=list(self.len_dict.values()), k=self.batch_size)
            for kc in key_choices:
                key = (kc[0], kc[2])
                if key[0] == 'google':
                    continue
                buffer.extend(
                    [next(self.dataset_lvl_to_iter[key])]
                )
            yield (torch.stack([b[0] for b in buffer]),
                   torch.stack([b[1] for b in buffer]),
                   torch.stack([b[2] for b in buffer]),
                   torch.stack([b[3] for b in buffer]),
                   torch.stack([b[4] for b in buffer]),
                   torch.stack([b[5] for b in buffer]),
                  )

In [39]:
print("hi")

hi


In [44]:
rootdir = '../rappi-data/tokenized'
assert os.path.exists(rootdir)

In [45]:
ds = MultiStreamDataLoader(root_dir=rootdir,
                           msl_title=msl_title,
                           msl_cat=msl_cat,
                           batch_size=64)

{('google', 'title-and-desc', 0): (298454, 128), ('google', 'category', 0): (298454, 10), ('google', 'title-and-desc', 1): (298454, 128), ('google', 'category', 1): (298454, 10), ('google', 'title-and-desc', 2): (297432, 128), ('google', 'category', 2): (297432, 10), ('google', 'title-and-desc', 3): (281753, 128), ('google', 'category', 3): (281753, 10), ('google', 'title-and-desc', 4): (132557, 128), ('google', 'category', 4): (132557, 10), ('google', 'title-and-desc', 5): (20540, 128), ('google', 'category', 5): (20540, 10), ('google', 'title-and-desc', 6): (3019, 128), ('google', 'category', 6): (3019, 10), ('instacart-rappi', 'title-and-desc', 0): (326161, 128), ('instacart-rappi', 'category', 0): (326161, 10), ('instacart-rappi', 'title-and-desc', 1): (310969, 128), ('instacart-rappi', 'category', 1): (310969, 10), ('instacart-rappi', 'title-and-desc', 2): (278850, 128), ('instacart-rappi', 'category', 2): (278850, 10), ('instacart-rappi', 'title-and-desc', 3): (156964, 128), ('in

In [46]:
for i in ds:
    print(i)
    break

dict_keys([])


KeyError: ('instacart-rappi', 0)

In [47]:
for batch_num, batch in enumerate(ds):
    if batch_num == 1:
        break
    input_ids_title = batch[0]
    input_ids_cat = batch[3]
    print(input_ids_title.shape)
    for _, ids in enumerate(input_ids_title):
        print(
            tokenizer.convert_tokens_to_string(
                tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)),
        )
    print("---")
    for _, ids in enumerate(input_ids_cat):
        print(
            tokenizer.convert_tokens_to_string(
                tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)),
        )
    print("=====")

dict_keys([])


KeyError: ('google-rappi', 1)