In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
DATASET = "ml-1m"
# DATASET = "ml-20m"
SAVE_PATH = "/content/drive/MyDrive/Capstone/Dataset"
SEED = 2107

COLD_ITEMS_PROPORTION = 0.3
TEST_WARM_INTERACTIONs_PROPORTION = 0.1
IMAGE_FEATURES = 512

In [3]:
!mkdir {SAVE_PATH}/{DATASET}

In [4]:
!pip install -U sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-2.7.0-py3-none-any.whl (171 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/171.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m122.9/171.5 kB[0m [31m3.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (

In [5]:
!wget https://files.grouplens.org/datasets/movielens/{DATASET}.zip
# !wget https://datasets.imdbws.com/title.basics.tsv.gz
# !wget https://datasets.imdbws.com/title.crew.tsv.gz
# !wget https://datasets.imdbws.com/title.principals.tsv.gz
# !wget https://datasets.imdbws.com/name.basics.tsv.gz

--2024-05-14 02:32:05--  https://files.grouplens.org/datasets/movielens/ml-1m.zip
Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152
Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5917549 (5.6M) [application/zip]
Saving to: ‘ml-1m.zip’


2024-05-14 02:32:06 (11.9 MB/s) - ‘ml-1m.zip’ saved [5917549/5917549]



In [6]:
!unzip /content/{DATASET}.zip
# !gzip -d title.basics.tsv.gz
# !gzip -d title.crew.tsv.gz
# !gzip -d title.principals.tsv.gz
# !gzip -d name.basics.tsv.gz

Archive:  /content/ml-1m.zip
   creating: ml-1m/
  inflating: ml-1m/movies.dat        
  inflating: ml-1m/ratings.dat       
  inflating: ml-1m/README            
  inflating: ml-1m/users.dat         


In [7]:
import os
import gc
import re
import gzip
import json
import numpy as np
import pandas as pd
import torch
import requests
import urllib
import random
from tqdm import tqdm
from PIL import Image
from io import BytesIO
from bs4 import BeautifulSoup
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from sentence_transformers import SentenceTransformer
from torchvision.models import efficientnet_v2_m
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
TEXT_FEATURES = 384
IMAGE_FEATURES = 4096
TEXT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

#  Process Raw Data

In [9]:
df_movies = pd.read_csv(f"{DATASET}/movies.dat", sep="::", encoding='latin-1', header=None)
df_ratings = pd.read_csv(f"{DATASET}/ratings.dat", sep="::", header=None)
df_users = pd.read_csv(f"{DATASET}/users.dat", sep="::", header=None)

  df_movies = pd.read_csv(f"{DATASET}/movies.dat", sep="::", encoding='latin-1', header=None)
  df_ratings = pd.read_csv(f"{DATASET}/ratings.dat", sep="::", header=None)
  df_users = pd.read_csv(f"{DATASET}/users.dat", sep="::", header=None)


In [10]:
df_movies

Unnamed: 0,0,1,2
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama
4,5,Father of the Bride Part II (1995),Comedy
...,...,...,...
3878,3948,Meet the Parents (2000),Comedy
3879,3949,Requiem for a Dream (2000),Drama
3880,3950,Tigerland (2000),Drama
3881,3951,Two Family House (2000),Drama


In [11]:
df_users

Unnamed: 0,0,1,2,3,4
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,02460
4,5,M,25,20,55455
...,...,...,...,...,...
6035,6036,F,25,15,32603
6036,6037,F,45,1,76006
6037,6038,F,56,1,14706
6038,6039,F,45,0,01060


In [12]:
df_ratings

Unnamed: 0,0,1,2,3
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291
...,...,...,...,...
1000204,6040,1091,1,956716541
1000205,6040,1094,5,956704887
1000206,6040,562,5,956704746
1000207,6040,1096,4,956715648


In [13]:
def combine_text(row):
    context = ""
    context += row[1] + ' ' + row[2]
    context = context.strip()
    context = re.sub(r'\s+', ' ', context)
    return context

In [14]:
n_items = len(df_movies)
items_id = {}
for id, row in df_movies.iterrows():
    items_id[row[0]] = id

contexts = [""] * n_items
for id, row in df_movies.iterrows():
    data = row.tolist()
    try:
        contexts[items_id[data[0]]] = combine_text(data)
    except:
        print(data)

print(n_items)
print(len(items_id))

3883
3883


In [15]:
n_users = len(df_users)
users_id = {}
for id, row in df_users.iterrows():
    users_id[row[0]] = id

print(n_users)
print(len(users_id))

6040
6040


In [16]:
items_users_set = {}
users_items_set = {}

users_cnt = [0] * n_users
interactions = []

error_cnt = 0
for id, row in df_ratings.iterrows():
    reviewer_id, prod_id, overall, timestamp = row.tolist()

    try:
        user_id = users_id[reviewer_id]
        item_id = items_id[prod_id]

        if item_id not in items_users_set:
            items_users_set[item_id] = set()
        items_users_set[item_id].add(user_id)

        if user_id not in users_items_set:
            users_items_set[user_id] = set()
        users_items_set[user_id].add(item_id)

        interactions.append((user_id, item_id))

        users_cnt[user_id] += 1
    except:
        error_cnt += 1

print(len(interactions))
print(n_users, n_items)

1000209
6040 3883


In [17]:
print(error_cnt)

0


## Filter user have greater or equal to 5 interactions

In [18]:
users_cnt = np.array(users_cnt)
print(users_cnt)

[ 53 129  51 ...  20 123 341]


In [19]:
print(np.sum(users_cnt))

1000209


In [20]:
users_greater_than_5_cnt = np.where(users_cnt >= 5)[0]
set_users_greater_than_5_cnt = set(users_greater_than_5_cnt)
print(np.sum(users_cnt[users_greater_than_5_cnt]))
print(len(users_greater_than_5_cnt))
print(len(set_users_greater_than_5_cnt))

1000209
6040
6040


# Split Data

In [21]:
warm_items, cold_items = train_test_split(np.arange(n_items), test_size=COLD_ITEMS_PROPORTION, random_state=SEED, shuffle=True)
val_cold_items, test_cold_items = train_test_split(cold_items, test_size=0.6, random_state=SEED, shuffle=True)

set_warm_items, set_cold_items = set(warm_items), set(cold_items)
set_val_cold_items, set_test_cold_items = set(val_cold_items), set(test_cold_items)

print(len(set_warm_items), len(set_val_cold_items), len(set_test_cold_items))
assert len(set_warm_items) + len(set_val_cold_items) + len(set_test_cold_items) == n_items

2718 466 699


In [22]:
print(type(set_cold_items))
print(set_cold_items)

print(type(set_warm_items))
print(set_warm_items)

assert len(set_cold_items.intersection(set_warm_items)) == 0

<class 'set'>
{2050, 3, 2052, 2053, 5, 6, 2056, 2057, 2058, 2060, 2061, 14, 13, 16, 17, 2064, 21, 22, 2069, 25, 27, 37, 38, 2087, 2088, 43, 2095, 2097, 52, 2100, 2102, 2106, 2108, 61, 63, 2112, 2113, 64, 2116, 69, 2118, 73, 75, 76, 2125, 2127, 82, 2130, 2136, 88, 90, 91, 2139, 2141, 100, 2151, 103, 2152, 104, 107, 2156, 2158, 112, 2161, 2163, 116, 2164, 118, 2169, 122, 125, 2173, 2175, 2176, 2180, 134, 135, 2184, 141, 2190, 145, 2194, 146, 149, 2198, 2200, 153, 154, 152, 157, 158, 2205, 2208, 2206, 2211, 164, 165, 2215, 171, 2219, 172, 174, 2223, 177, 2226, 2230, 184, 185, 2233, 2235, 188, 187, 186, 2239, 2237, 189, 194, 2247, 2251, 204, 2252, 2254, 2256, 209, 2258, 210, 212, 213, 214, 215, 216, 2265, 2257, 219, 2268, 2267, 222, 2271, 2272, 2273, 2270, 224, 223, 226, 2279, 2280, 233, 2281, 232, 236, 237, 2285, 2287, 2288, 241, 2290, 242, 2293, 245, 248, 2297, 253, 254, 2303, 256, 257, 2302, 2304, 2309, 2310, 263, 265, 266, 268, 269, 270, 2316, 274, 2322, 2325, 2326, 281, 2329, 283, 168

In [23]:
warm_interactions = [interaction for interaction in interactions if interaction[1] in set_warm_items]
train_warm_interactions, val_warm_interactions = train_test_split(warm_interactions, test_size=TEST_WARM_INTERACTIONs_PROPORTION, random_state=SEED, shuffle=True)
val_warm_interactions, test_warm_interactions = train_test_split(val_warm_interactions, test_size=0.5, random_state=SEED, shuffle=True)
print(len(train_warm_interactions), len(val_warm_interactions), len(test_warm_interactions))

val_cold_interactions = [interaction for interaction in interactions if interaction[1] in set_val_cold_items]
test_cold_interactions = [interaction for interaction in interactions if interaction[1] in set_test_cold_items]
print(len(val_cold_interactions), len(test_cold_interactions))

val_interactions = val_warm_interactions + val_cold_interactions
test_interactions = test_warm_interactions + test_cold_interactions
print(len(val_interactions), len(test_interactions))

assert len(val_cold_interactions) + len(test_cold_interactions) + len(warm_interactions) == len(interactions)

641713 35651 35651
102753 184441
138404 220092


# Item Features Vector


In [24]:
each_movie_genre = list(map(lambda x: x.split('|'), df_movies[2]))
movie_genres = set(sum(each_movie_genre, []))
movie_genres_id = {movie_genre: id for id, movie_genre in enumerate(movie_genres)}
print(len(movie_genres))
print(movie_genres_id)

18
{'Fantasy': 0, 'Horror': 1, 'Mystery': 2, "Children's": 3, 'Sci-Fi': 4, 'Adventure': 5, 'Comedy': 6, 'Drama': 7, 'Action': 8, 'Documentary': 9, 'Western': 10, 'Crime': 11, 'Thriller': 12, 'Musical': 13, 'Romance': 14, 'Animation': 15, 'Film-Noir': 16, 'War': 17}


In [25]:
onehot_features = np.full((n_items, len(movie_genres)), len(movie_genres))
print(onehot_features.shape)
for id, row in df_movies.iterrows():
    movie_id, genres = items_id[row[0]], row[2].split('|')
    for genre in genres:
        onehot_features[movie_id, movie_genres_id[genre]] = movie_genres_id[genre]
print(onehot_features[:10])

(3883, 18)
[[18 18 18  3 18 18  6 18 18 18 18 18 18 18 18 15 18 18]
 [ 0 18 18  3 18  5 18 18 18 18 18 18 18 18 18 18 18 18]
 [18 18 18 18 18 18  6 18 18 18 18 18 18 18 14 18 18 18]
 [18 18 18 18 18 18  6  7 18 18 18 18 18 18 18 18 18 18]
 [18 18 18 18 18 18  6 18 18 18 18 18 18 18 18 18 18 18]
 [18 18 18 18 18 18 18 18  8 18 18 11 12 18 18 18 18 18]
 [18 18 18 18 18 18  6 18 18 18 18 18 18 18 14 18 18 18]
 [18 18 18  3 18  5 18 18 18 18 18 18 18 18 18 18 18 18]
 [18 18 18 18 18 18 18 18  8 18 18 18 18 18 18 18 18 18]
 [18 18 18 18 18  5 18 18  8 18 18 18 12 18 18 18 18 18]]


# Find negative user and negative

## Negative Items

In [26]:
for item_id in range(n_items):
    if item_id not in items_users_set:
        items_users_set[item_id] = set()

In [27]:
negative_items = {}
for item_id in range(n_items):
    for neg_item_id in range(n_items):
        if item_id == neg_item_id or len(items_users_set[item_id].intersection(items_users_set[neg_item_id])) > 0:
            continue
        if item_id not in negative_items:
            negative_items[item_id] = set()
        negative_items[item_id].add(neg_item_id)

In [28]:
mn = 1e9
for item_id in range(n_items):
    mn = min(mn, len(negative_items[item_id]))
print(mn)
for id in range(10):
    print(id, len(negative_items[id]))

223
0 310
1 363
2 370
3 507
4 441
5 339
6 421
7 795
8 847
9 356


In [29]:
print(negative_items[0])

{50, 2130, 2144, 2145, 2147, 2148, 2149, 2151, 2153, 2155, 107, 2156, 2157, 2159, 2160, 113, 2161, 2166, 125, 128, 131, 136, 137, 140, 141, 2201, 2205, 2208, 2230, 2239, 2250, 223, 281, 282, 283, 2369, 2411, 2415, 2420, 391, 2439, 395, 396, 397, 398, 399, 2478, 2495, 2515, 2519, 2522, 2526, 2532, 2534, 2535, 2539, 526, 541, 2611, 2615, 572, 575, 578, 580, 2629, 597, 600, 603, 616, 618, 619, 620, 624, 2673, 631, 637, 638, 639, 645, 648, 649, 652, 654, 660, 666, 669, 670, 672, 676, 677, 684, 690, 692, 704, 2752, 712, 714, 2763, 718, 721, 2769, 729, 730, 739, 742, 748, 753, 758, 760, 762, 763, 767, 779, 781, 782, 784, 785, 787, 2841, 802, 806, 808, 811, 814, 815, 816, 832, 834, 2885, 2888, 2889, 844, 845, 846, 848, 857, 860, 861, 862, 2911, 867, 878, 882, 883, 2940, 2954, 2990, 963, 3011, 967, 3015, 971, 977, 988, 3054, 3057, 1026, 1032, 1038, 1048, 1051, 3101, 3103, 1056, 1059, 3122, 3124, 3126, 3133, 1090, 1092, 1093, 1094, 3140, 3143, 1099, 3147, 1102, 3151, 1106, 3157, 3158, 3160, 316

## Negative User for Training Interactions

In [30]:
set_all_users = set(range(n_users))
negative_users = {}
for item_id in range(n_items):
    negative_users[item_id] = set_all_users.difference(items_users_set[item_id])
for item_id in range(n_items):
    negative_users[item_id] = list(negative_users[item_id])

In [31]:
warm_interactions_negative_users = []
for id, interaction in enumerate(warm_interactions):
    user_id = interaction[0]
    item_id = interaction[1]

    random.seed(SEED + id)
    neg_user_id = random.choice(negative_users[item_id])

    warm_interactions_negative_users.append([user_id, item_id, neg_user_id])

print(len(warm_interactions_negative_users))
print(warm_interactions_negative_users[:10])

713015
[[0, 1176, 366], [0, 655, 2342], [0, 902, 3616], [0, 3339, 1203], [0, 2286, 3960], [0, 1267, 853], [0, 2735, 102], [0, 590, 2300], [0, 907, 3994], [0, 591, 5673]]


In [32]:
print(len(warm_interactions))
print(len(warm_interactions_negative_users))
print(random.choice(warm_interactions_negative_users))

713015
713015
[2063, 1224, 5394]


# Extract Image Features Vector

In [33]:
img_features_all_movies = np.load(os.path.join(SAVE_PATH, "img_features_all_movies.npy"), allow_pickle=True)
print(img_features_all_movies.shape)

(24425, 4098)


In [34]:
IMAGE_FEATURES = img_features_all_movies.shape[1] - 2
print(IMAGE_FEATURES)

4096


In [35]:
img_features = np.zeros((n_items, IMAGE_FEATURES))
cnt = 0
for row in img_features_all_movies:
    if row[0] not in items_id:
        continue
    if row[0] == 1054:
        print(row[1])
    img_features[items_id[row[0]]] = row[2:]
    cnt += 1
print(img_features.shape)

49471.0
(3883, 4096)


In [36]:
df_movies[df_movies[0] == 1054]

Unnamed: 0,0,1,2
1040,1054,Get on the Bus (1996),Drama


In [37]:
df_users[df_users[0] == 1292]

Unnamed: 0,0,1,2,3,4
1291,1292,M,25,7,94904


# Extract Text Features Vector


In [38]:
BATCH_SIZE = 128
model = SentenceTransformer(TEXT_MODEL_NAME, device=DEVICE)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [39]:
text_features = np.zeros((n_items, TEXT_FEATURES))
for i in tqdm(range(0, n_items, BATCH_SIZE)):
    s, e = i, min(n_items, i + BATCH_SIZE)
    text_features[s:e] = model.encode(contexts[s:e], batch_size=e-s)

100%|██████████| 31/31 [00:02<00:00, 15.28it/s]


In [40]:
print(contexts[0:10])
print(text_features[0:10])

["Toy Story (1995) Animation|Children's|Comedy", "Jumanji (1995) Adventure|Children's|Fantasy", 'Grumpier Old Men (1995) Comedy|Romance', 'Waiting to Exhale (1995) Comedy|Drama', 'Father of the Bride Part II (1995) Comedy', 'Heat (1995) Action|Crime|Thriller', 'Sabrina (1995) Comedy|Romance', "Tom and Huck (1995) Adventure|Children's", 'Sudden Death (1995) Action', 'GoldenEye (1995) Action|Adventure|Thriller']
[[-0.08772578  0.03191983  0.0387254  ...  0.00568977  0.0635334
   0.08347662]
 [-0.05130553  0.1168138   0.0357864  ...  0.01607435 -0.12300505
   0.02443848]
 [-0.09299542 -0.01534201 -0.02820371 ... -0.00981735  0.05475814
  -0.02729915]
 ...
 [-0.15326194  0.05590968 -0.00471486 ...  0.08766425 -0.03980549
   0.05817764]
 [-0.06021889  0.01161089  0.00457912 ... -0.00667278 -0.01517241
   0.05044143]
 [-0.04592343 -0.01904068  0.01369785 ... -0.0037353  -0.05472192
   0.02981756]]


In [41]:
x = 1
y = 5
sim = np.dot(text_features[x], text_features[y]) / (np.linalg.norm(text_features[x]) * np.linalg.norm(text_features[y]))
print(contexts[x])
print(contexts[y])
print(sim)

Jumanji (1995) Adventure|Children's|Fantasy
Heat (1995) Action|Crime|Thriller
0.3167491067873524


# Save Data

In [42]:
metadata = {}
metadata["n_users"] = n_users
metadata["n_items"] = n_items
metadata["n_text_features"] = TEXT_FEATURES
metadata["n_image_features"] = IMAGE_FEATURES

metadata["n_warm_items"] = len(warm_items)
metadata["n_cold_items"] = len(cold_items)
metadata["n_val_cold_items"] = len(val_cold_items)
metadata["n_test_cold_items"] = len(test_cold_items)

metadata["n_train_interactions"] = len(train_warm_interactions)
metadata["n_val_interactions"] = len(val_interactions)
metadata["n_test_interactions"] = len(test_interactions)

metadata["n_val_warm_interactions"] = len(val_warm_interactions)
metadata["n_val_cold_interactions"] = len(val_cold_interactions)
metadata["n_test_warm_interactions"] = len(test_warm_interactions)
metadata["n_test_cold_interactions"] = len(test_cold_interactions)

np.save(os.path.join(SAVE_PATH, DATASET, "metadata.npy"), metadata)
print(metadata)

{'n_users': 6040, 'n_items': 3883, 'n_text_features': 384, 'n_image_features': 4096, 'n_warm_items': 2718, 'n_cold_items': 1165, 'n_val_cold_items': 466, 'n_test_cold_items': 699, 'n_train_interactions': 641713, 'n_val_interactions': 138404, 'n_test_interactions': 220092, 'n_val_warm_interactions': 35651, 'n_val_cold_interactions': 102753, 'n_test_warm_interactions': 35651, 'n_test_cold_interactions': 184441}


In [43]:
np.save(os.path.join(SAVE_PATH, DATASET, "train_all_warm_interactions.npy"), warm_interactions)
np.save(os.path.join(SAVE_PATH, DATASET, "train_interactions.npy"), train_warm_interactions)
np.save(os.path.join(SAVE_PATH, DATASET, "train_all_warm_interactions_negative_users.npy"), warm_interactions_negative_users)

np.save(os.path.join(SAVE_PATH, DATASET, "val_interactions.npy"), val_interactions)
np.save(os.path.join(SAVE_PATH, DATASET, "val_warm_interactions.npy"), val_warm_interactions)
np.save(os.path.join(SAVE_PATH, DATASET, "val_cold_interactions.npy"), val_cold_interactions)

np.save(os.path.join(SAVE_PATH, DATASET, "test_interactions.npy"), test_interactions)
np.save(os.path.join(SAVE_PATH, DATASET, "test_warm_interactions.npy"), test_warm_interactions)
np.save(os.path.join(SAVE_PATH, DATASET, "test_cold_interactions.npy"), test_cold_interactions)

np.save(os.path.join(SAVE_PATH, DATASET, "warm_items.npy"), set_warm_items)
np.save(os.path.join(SAVE_PATH, DATASET, "cold_items.npy"), set_cold_items)
np.save(os.path.join(SAVE_PATH, DATASET, "val_cold_items.npy"), set_val_cold_items)
np.save(os.path.join(SAVE_PATH, DATASET, "test_cold_items.npy"), set_test_cold_items)

In [44]:
np.save(os.path.join(SAVE_PATH, DATASET, "negative_items.npy"), negative_items)

In [45]:
np.save(os.path.join(SAVE_PATH, DATASET, "onehot_features.npy"), onehot_features)

In [46]:
np.save(os.path.join(SAVE_PATH, DATASET, "t_features.npy"), text_features)

In [47]:
np.save(os.path.join(SAVE_PATH, DATASET, "v_features.npy"), img_features)

In [48]:
test = np.load(os.path.join(SAVE_PATH, DATASET, "train_all_warm_interactions.npy"), allow_pickle=True)
print(test.shape)
assert test.shape[0] == len(train_warm_interactions) + len(val_warm_interactions) + len(test_warm_interactions)

(713015, 2)
