In [79]:
import pandas as pd
import os.path as osp
import numpy as np
from sklearn.metrics import pairwise_distances
import matplotlib.image as mpimg
import matplotlib.pyplot as plt


In [80]:
data_dir = "/data/amazon_dataset/"
category = "Clothing_Shoes_and_Jewelry"

df_w_labels_path = "../outputs/process_label_Clothing_Shoes_and_Jewelry_20211019_120707/df_w_labels.pkl"

base_path = "../outputs/train_recommender_Clothing_Shoes_and_Jewelry_20211019_092258"
cf_vector_path = osp.join(base_path, "cf_df.pkl")


In [81]:
# Load cf vectors
cf_df = pd.read_pickle(cf_vector_path)

# Load df with labels
label_df = pd.read_pickle(df_w_labels_path)

# Load datalaoder
df = pd.merge(label_df, cf_df, on=["asin"], how="inner")


In [82]:
# Pairwise distances
cf_vectors = np.array(df["embs"].tolist())
dists = pairwise_distances(cf_vectors)
np.fill_diagonal(dists, np.inf)


In [103]:
def most_frequent(List):
    return max(set(List), key = List.count)

import itertools
labels = list(itertools.chain.from_iterable(df["merged_labels"].tolist()))
label_most_freq = [most_frequent(labels)]
print(label_most_freq)


['Clothing, Shoes & Jewelry + Women']


In [104]:
def calc_num_intersection(lst1, lst2):
    return len(set(lst1).intersection(lst2))


# Check how much closest items share the same label
metric_accumulate, metric_accumulate_baseline = 0, 0
no_label = 0
for anchor_idx in range(len(dists)):
    anchor_row = df.iloc[anchor_idx]

    pos_idx = dists[anchor_idx].argmin()
    dist = dists[anchor_idx][pos_idx]
    pos_row = df.iloc[pos_idx]

    label1 = anchor_row["merged_labels"]
    label2 = pos_row["merged_labels"]

    num_inter = calc_num_intersection(label1, label2)
    num_total = len(label1) + len(label2) + np.finfo("float").eps

    metric = 2 * num_inter / num_total
    metric_accumulate += metric

    num_inter = calc_num_intersection(label1, label_most_freq)
    num_total = len(label1) + len(label_most_freq) + np.finfo("float").eps
    metric = 2 * num_inter / num_total
    metric_accumulate_baseline += metric

    if len(label1) == 0:
        no_label += 1


print(
    "2*intersection/All [CF Most_pop]=[{:.2f} {:.2f}] {}".format(
        metric_accumulate/len(dists), metric_accumulate_baseline/len(dists),no_label )
)


2*intersection/All [CF Most_pop]=[0.37 0.27] 394


# Visaulize Anchor and Postivie

In [None]:
num_anchors = 10
anchor_idxs = np.random.randint(0, len(df), num_anchors)


In [None]:
def format_label(merged_labels):
    return "\n".join(
        [
            label.replace("Clothing, Shoes & Jewelry +", f"{i}.") + "\n"
            for i, label in enumerate(merged_labels)
        ]
    )


# Output Images
fig, axs = plt.subplots(
    len(anchor_idxs), 2, figsize=(20, 4 * len(anchor_idxs)), facecolor="white"
)
for i, anchor_idx in enumerate(anchor_idxs):
    anchor_row = df.iloc[anchor_idx]

    pos_idx = dists[anchor_idx].argmin()
    dist = dists[anchor_idx][pos_idx]
    pos_row = df.iloc[pos_idx]

    # Read Images
    anchor_img = mpimg.imread(anchor_row["img_path"])
    pos_img = mpimg.imread(pos_row["img_path"])

    ax = axs[i, 0]
    ax.imshow(anchor_img)
    ax.set_xticks([])
    ax.set_yticks([])

    label = format_label(anchor_row["merged_labels"])
    ax.set_title(label)
    ax.set_ylabel(f"dist={dist:.2f}")

    ax = axs[i, 1]
    ax.imshow(pos_img)

    label = format_label(pos_row["merged_labels"])
    ax.set_title(label)
    ax.set_xticks([])
    ax.set_yticks([])


plt.tight_layout()
plt.show()
