In [2]:
import pickle
import os
import numpy as np

In [3]:
def unpickle_file(path, type_of_split, data, prefix, encoding):
    tmp_path = os.path.join(path, prefix + type_of_split + '.pkl')
    with open(tmp_path, 'rb') as file:
        data[type_of_split] = pickle.load(file, encoding=encoding)[type_of_split]
    return data

def load_pkl(path, type_of_split, encoding='ASCII'):
    tmp_path = os.path.join(path, type_of_split + '.pkl')
    with open(tmp_path, 'rb') as file:
        data = pickle.load(file, encoding=encoding)
    return data

def save_pkl(data, path, type_of_split, encoding='ASCII'):
    tmp_path = os.path.join(path, type_of_split + '.pkl')
    os.makedirs(path, exist_ok=True)
    with open(tmp_path, 'wb') as file:
        pickle.dump(data, file)

def preprocess_split(data: list[list[dict]]) -> list[list[dict]]:
    '''delete empty baskets, delete users whose history is <=1'''
    data_processed = []
    for user in data:
        if len(user) <= 1:
            continue
        user_history = []
        for basket in user:
            if basket['type_event'].sum() == 0.0:
                continue
            user_history.append(basket)
        if len(user_history) > 1:    
            data_processed.append(user_history)
    
    return data_processed

In [6]:
name = 'instacart'
path = f'tcmbn_data/{name}/split_1'
types_of_split = ['train', 'dev', 'test']
new_path = f'tcmbn_data/{name}_preprocessed/split_1'

In [7]:
for split in types_of_split:
    data = load_pkl(path, split)
    data[split] = preprocess_split(data[split])
    save_pkl(data, new_path, split)
    

## Data preparation for time information ablation study

In [20]:
def remove_time(data: dict):
    new_data = []
    for user in data:
        user_hist = []
        for i, basket in enumerate(user):
            basket['time_since_start'] = i + 1
            # basket['time_since_start'] = i
            basket['time_since_last_event'] = 0
            user_hist.append(basket)
        new_data.append(user_hist)
    return new_data
        
def remove_all_time(data: dict):
    new_data = []
    for user in data:
        user_hist = []
        for i, basket in enumerate(user):
            basket['time_since_start'] = 1
            basket['time_since_last_event'] = 0
            user_hist.append(basket)
        new_data.append(user_hist)
    return new_data

In [29]:
name = 'instacart_preprocessed'
path = f'tcmbn_data/{name}/split_1'
types_of_split = ['train', 'dev', 'test']
new_path = f'tcmbn_data/{name}_wo_time/split_1'

In [30]:
for split in types_of_split:
    data = load_pkl(path, split)
    data[split] = remove_time(data[split])
    save_pkl(data, new_path, split)
    

In [61]:
name = 'synthea_preprocessed'
path = f'tcmbn_data/{name}/split_1'
types_of_split = ['train', 'dev', 'test']
new_path = f'tcmbn_data/{name}_wo_all_time/split_1'

In [62]:
for split in types_of_split:
    data = load_pkl(path, split)
    data[split] = remove_all_time(data[split])
    save_pkl(data, new_path, split)
    