In [1]:
data_map = {
    'Industrial_and_Scientific.json.gz': 'Scientific',
    'Musical_Instruments.json.gz': 'Instruments',
    'Prime_Pantry.json.gz': 'Pantry',
}
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Namespace(
    train_file = "train.json",
    dev_file = "val.json",
    test_file = "test.json",
    item2id_file = "smap.json",
    meta_file = "meta_data.json",
)

In [2]:
import json
from tqdm import tqdm
import gzip
from pathlib import Path

data_name = "Prime_Pantry.json.gz"
args.data_path =f"finetune_data/{data_map[data_name]}_ours"

In [3]:
from utils import load_data
train, val, test, item_meta_dict, item2id, id2item = load_data(args)
len(item_meta_dict)

4968

In [4]:
import os
umap = json.load(open(os.path.join(args.data_path, "umap.json")))
len(umap)

14180

In [5]:
train_item = set()

for train_session in tqdm(train.values()):
    for item in train_session:
        train_item.add(id2item[item])

val_item = set()
for val_session in tqdm(val.values()):
    for item in val_session:
        val_item.add(id2item[item])

len(train_item), len(val_item), len(id2item)

100%|██████████| 14180/14180 [00:00<00:00, 841126.74it/s]
100%|██████████| 14178/14178 [00:00<?, ?it/s]


(4961, 3977, 4968)

In [6]:
train_val_item = set.union(train_item, val_item)
len(train_val_item)

4968

# raw

In [7]:
def extract_meta_data(path):
    meta_data = dict()
    with gzip.open(path, "rt") as f:
        for line in tqdm(f):
            line = json.loads(line)
            attr_dict = dict()
            title = line["title"]

            if len(title) != 0:
                asin = line["asin"]
                attr_dict["title"] = title
                attr_dict["brand"] = line["brand"]
                attr_dict["category"] = " ".join(line["category"])
                
                meta_data[asin] = attr_dict

    return meta_data

In [8]:
raw_meta_dict = extract_meta_data(Path('data') / Path(f"meta_{data_name}"))
len(raw_meta_dict)

10813it [00:00, 54061.89it/s]


10812

In [9]:
from collections import defaultdict

raw_sequences = defaultdict(list)
total = gzip.open(Path('data') / Path(data_name), 'r')

for line in tqdm(total, ):
    data = json.loads(line)
    user = data['reviewerID']
    item = data['asin']
    time = data['unixReviewTime']
    
    if data['asin'] not in raw_meta_dict.keys():
        continue
    else:
        raw_sequences[user].append((item, time))

len(raw_sequences)

471614it [00:02, 157215.94it/s]


247640

In [10]:
# history는 seen, target은 unseen

new_item_sequences = dict()

for user, session in tqdm(raw_sequences.items()):
    session = sorted(session, key=lambda x: x[1])
    session = [item for item, _ in session]
    
    if len(session) < 4:
        continue
        
    if session[-1] in item_meta_dict:
        continue
    
    flag = False
    for item in session[:-1]:
        if item not in train_val_item:
            flag = True
            break
    if flag:
        continue
    
    new_item_sequences[user] = session

len(new_item_sequences), len(raw_sequences)
    

100%|██████████| 247640/247640 [00:00<00:00, 995313.62it/s]


(1904, 247640)

In [11]:
# TODO
# 1. update item meta dict
# 2. update item2id, id2item
# 4. update train, val, test

# 1. update item meta dict
new_item_meta_dict = item_meta_dict.copy()
print(f"meta dict len before update: {len(new_item_meta_dict)}")
for session in tqdm(new_item_sequences.values()):
    for item in session[:-1]:
        if item not in item_meta_dict:
            raise ValueError(f"{item} not in meta dict")
    if session[-1] in item_meta_dict:
        raise ValueError(f"{session[-1]} in meta dict")
    new_item_meta_dict[session[-1]] = raw_meta_dict[session[-1]]
    
print(f"meta dict len after update: {len(new_item_meta_dict)}")

meta dict len before update: 4968


100%|██████████| 1904/1904 [00:00<00:00, 589369.36it/s]

meta dict len after update: 6505





In [12]:
# 2. update item2id, id2item
print(f"item2id len before update: {len(item2id)}")

last_id = len(item2id)
for session in tqdm(new_item_sequences.values()):
    for item in session:
        if item not in item2id.keys():
            if last_id not in id2item.keys():
                id2item[last_id] = item
                item2id[item] = last_id
            else:
                print(f"{last_id} already in id2item")
                break
            last_id += 1
print(f"item2id len after update: {len(item2id)}")

item2id len before update: 4968


100%|██████████| 1904/1904 [00:00<?, ?it/s]

item2id len after update: 6505





In [13]:
# 2.5 update umap
print(f"umap before update: {len(umap)}")
new_umap = umap.copy()
last_id = len(umap)
for user in tqdm(new_item_sequences):
    if user not in umap.keys():
        assert last_id not in umap.values()
        new_umap[user] = last_id
        last_id += 1
print(f"umap after update: {len(new_umap)}")
        

umap before update: 14180


100%|██████████| 1904/1904 [00:00<00:00, 11164.73it/s]

umap after update: 15261





In [14]:
# 3. update train, val, test
train_dict = dict()
val_dict = dict()
test_dict = dict()

for user, session in tqdm(new_item_sequences.items()):
    user = new_umap[user]
    if len(session) < 3:
        train_dict[user] = []
    else:
        train_dict[user] = list(map(lambda x: item2id[x], session[:-2]))
    val_dict[user] = [item2id[session[-2]]]
    test_dict[user] = [item2id[session[-1]]]

len(train_dict), len(val_dict), len(test_dict)

100%|██████████| 1904/1904 [00:00<00:00, 209974.36it/s]


(1904, 1904, 1904)

In [15]:
output_path = Path('cold_data') / Path(f"{data_map[data_name]}_cold")
if not output_path.exists():
    output_path.mkdir(parents=True)

train_file = os.path.join(output_path, "train.json")
dev_file = os.path.join(output_path, "val.json")
test_file = os.path.join(output_path, "test.json")
umap_file = os.path.join(output_path, "umap.json")
smap_file = os.path.join(output_path, "smap.json")
meta_file = os.path.join(output_path, "meta_data.json")

In [16]:
json.dump(train_dict, open(train_file, "w", encoding='utf-8'), indent=1, ensure_ascii=False)
json.dump(val_dict, open(dev_file, "w", encoding='utf-8'), indent=1, ensure_ascii=False)
json.dump(test_dict, open(test_file, "w", encoding='utf-8'), indent=1, ensure_ascii=False)
json.dump(new_umap, open(umap_file, "w", encoding='utf-8'), indent=1, ensure_ascii=False)
json.dump(item2id, open(smap_file, "w", encoding='utf-8'), indent=1, ensure_ascii=False)
json.dump(new_item_meta_dict, open(meta_file, "w", encoding='utf-8'), indent=1, ensure_ascii=False)