In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import polars as pl
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
import pandas as pd
import pickle
import os
import sqlite3
from scipy.stats import spearmanr
import torch

In [3]:
df = pl.read_parquet('enc2_fast_cat_price.parquet')

df = df.with_columns(
    pl.col('visits').fill_null(0),
    pl.col('adds').fill_null(0),
    pl.col('rms').fill_null(0),
    pl.col('searchs').fill_null(0),
    pl.col('buys').fill_null(0),
    pl.col('kde_visits').fill_null([0] * 11),
    pl.col('kde_adds').fill_null([0] * 11),
    pl.col('kde_rms').fill_null([0] * 11),
    pl.col('kde_searchs').fill_null([0] * 11),
    pl.col('kde_buys').fill_null([0] * 11),
    pl.col('url_visit').fill_null([]),
    pl.col('sku_add').fill_null([]),
    pl.col('sku_rm').fill_null([]),
    pl.col('query_search').fill_null([0] * 64),
    pl.col('sku_buy').fill_null([]),
    pl.col('sku_add_cat').fill_null([]),
    pl.col('sku_rm_cat').fill_null([]),
    pl.col('sku_buy_cat').fill_null([]),
    pl.col('sku_add_price').fill_null([]),
    pl.col('sku_rm_price').fill_null([]),
    pl.col('sku_buy_price').fill_null([])
)

In [4]:
df = df.with_columns(pl.col('sku_buy_price').map_elements(lambda x: 0 if len(x) == 0 else sum(x)/len(x), return_dtype=pl.Float64).alias('price_buy'),
                    pl.col('sku_rm_price').map_elements(lambda x: 0 if len(x) == 0 else sum(x)/len(x), return_dtype=pl.Float64).alias('price_rm'),
                    pl.col('sku_add_price').map_elements(lambda x: 0 if len(x) == 0 else sum(x)/len(x), return_dtype=pl.Float64).alias('price_add'))

In [5]:
points = np.arange(11) / 10
masses = {}
for c in ['kde_buys', 'kde_visits', 'kde_adds', 'kde_rms', 'kde_searchs']:
    alias = 'mass_' + c.split('_')[1]
    density = np.asarray(df[c].to_list())
    mass = np.sum(density * points, axis=1) / np.clip(np.sum(density, axis=1), a_min=1e-8, a_max=None)
    masses[alias] = mass

df_mass = pl.DataFrame(data=masses)

df = pl.concat([df, df_mass], how='horizontal')

In [6]:
df.describe()

statistic,client_id,visits,adds,rms,searchs,buys,kde_visits,kde_adds,kde_rms,kde_searchs,kde_buys,url_visit,sku_add,sku_rm,query_search,sku_buy,sku_add_cat,sku_rm_cat,sku_buy_cat,sku_add_price,sku_rm_price,sku_buy_price,price_buy,price_rm,price_add,mass_buys,mass_visits,mass_adds,mass_rms,mass_searchs
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""count""",1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",499999.5,36.763237,2.707519,1.148122,4.46087,1.096716,,,,,,,,,,,,,,,,,29.047337,15.769935,31.259434,0.263441,0.458623,0.298378,0.160259,0.193171
"""std""",288675.278932,116.908754,9.765367,6.043549,20.112477,2.889363,,,,,,,,,,,,,,,,,33.110316,28.362211,33.331405,0.3203,0.30092,0.330088,0.288542,0.307538
"""min""",0.0,0.0,0.0,0.0,0.0,0.0,,,,,,,,,,,,,,,,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""25%""",250000.0,2.0,0.0,0.0,0.0,0.0,,,,,,,,,,,,,,,,,0.0,0.0,0.0,0.0,0.194966,0.0,0.0,0.0
"""50%""",500000.0,11.0,1.0,0.0,0.0,1.0,,,,,,,,,,,,,,,,,9.0,0.0,22.166667,0.068849,0.480763,0.153164,0.0,0.0
"""75%""",749999.0,34.0,2.0,1.0,2.0,1.0,,,,,,,,,,,,,,,,,60.0,23.0,61.0,0.532166,0.718896,0.597072,0.19811,0.375758
"""max""",999999.0,29676.0,1597.0,1145.0,3384.0,644.0,,,,,,,,,,,,,,,,,99.0,99.0,99.0,0.94799,0.94799,0.947987,0.947985,0.947988


In [7]:
clients = np.load('data/input/relevant_clients.npy')
buys = pl.read_parquet('data/product_buy.parquet')

In [8]:
buys = buys.filter(pl.col('client_id').is_in(set(clients)))

In [9]:
df.filter(pl.col('buys') > 0).describe()

statistic,client_id,visits,adds,rms,searchs,buys,kde_visits,kde_adds,kde_rms,kde_searchs,kde_buys,url_visit,sku_add,sku_rm,query_search,sku_buy,sku_add_cat,sku_rm_cat,sku_buy_cat,sku_add_price,sku_rm_price,sku_buy_price,price_buy,price_rm,price_add,mass_buys,mass_visits,mass_adds,mass_rms,mass_searchs
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""count""",510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0,510971.0
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",499889.265857,50.143593,3.895321,1.653626,6.231007,2.146337,,,,,,,,,,,,,,,,,56.847329,18.016703,36.994913,0.515569,0.399931,0.359835,0.191306,0.213084
"""std""",288606.515506,145.039547,12.821275,8.002104,25.832602,3.753066,,,,,,,,,,,,,,,,,23.772945,29.389455,32.778649,0.266061,0.318315,0.330652,0.307445,0.318119
"""min""",1.0,0.0,0.0,0.0,0.0,1.0,,,,,,,,,,,,,,,,,0.0,0.0,0.0,0.052014,0.0,0.0,0.0,0.0
"""25%""",249976.0,0.0,0.0,0.0,0.0,1.0,,,,,,,,,,,,,,,,,39.0,0.0,0.0,0.283735,0.0,0.0,0.0,0.0
"""50%""",499265.0,17.0,1.0,0.0,0.0,1.0,,,,,,,,,,,,,,,,,59.0,0.0,37.5,0.524185,0.4246,0.333922,0.0,0.0
"""75%""",749780.0,49.0,3.0,1.0,3.0,2.0,,,,,,,,,,,,,,,,,76.0,34.0,66.0,0.746678,0.679975,0.662614,0.380899,0.4672
"""max""",999995.0,29676.0,1597.0,1145.0,3384.0,644.0,,,,,,,,,,,,,,,,,99.0,99.0,99.0,0.94799,0.947938,0.947958,0.947938,0.947978


In [10]:
embs = np.load('analysis/full.npy')

In [11]:
df_embs = pl.from_numpy(embs, schema=['x', 'y'])

In [12]:
df_full = pl.concat([df, df_embs], how="horizontal")

In [13]:
from collections import Counter

df_part = df_full.select(['client_id', 'sku_buy_cat'])

interactions = []
values = []

for x in df_part.iter_rows():
    idx = x[0]
    adds = Counter()
    for v in x[1]:
        interactions.append([idx, v])
        values.append(1)


In [14]:
import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Example dummy data
user_item_pairs = interactions
ratings = values

# Convert to tensors
user_ids = torch.tensor([u for u, _ in user_item_pairs], dtype=torch.long)
item_ids = torch.tensor([i for _, i in user_item_pairs], dtype=torch.long)
ratings = torch.tensor(ratings, dtype=torch.float).to(device)

# Hyperparameters
num_users = 1_000_0000 #max(user_ids).item() + 1
num_items = 6443 + 1 #max(item_ids).item() + 1


embedding_dim = 32
lr = 0.001 #to test it waw 0.01
epochs = 100

# Matrix Factorization Model
class MatrixFactorization(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super().__init__()
        self.user_embed = nn.Embedding(num_users, embedding_dim)
        self.item_embed = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_vecs = self.user_embed(user_ids)
        item_vecs = self.item_embed(item_ids)
        preds = (user_vecs * item_vecs).sum(dim=1)  # Dot product
        return preds

model = MatrixFactorization(num_users, num_items, embedding_dim).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

print('Starting...')
# Training loop
idx = np.arange(user_ids.shape[0])

bs = 4000
for epoch in tqdm(range(epochs)):
    model.train()
    np.random.shuffle(idx)
    total_loss = 0
    for i in range(0, idx.shape[0], bs):
        u, i, r = user_ids[idx[i:i+bs]].to(device), item_ids[idx[i:i+bs]].to(device), ratings[idx[i:i+bs]].to(device)
        u_n = torch.randint(0, num_users, size=(bs,), device=device).long()
        i_n = torch.randint(0, num_items, size=(bs,), device=device).long()
        r_n = torch.zeros_like(i_n, device=device)
        u = torch.concat((u, u_n), dim=0)
        i = torch.concat((i, i_n), dim=0)
        r = torch.concat((r, r_n), dim=0)
        preds = model(u, i)
        loss = criterion(preds, r)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(range(0, idx.shape[0], bs)):.4f}")
        
del idx

Starting...


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 10, Loss: 8.9988
Epoch 20, Loss: 2.8767
Epoch 30, Loss: 0.7607
Epoch 40, Loss: 0.2267
Epoch 50, Loss: 0.1190
Epoch 60, Loss: 0.0769
Epoch 70, Loss: 0.0540
Epoch 80, Loss: 0.0401
Epoch 90, Loss: 0.0311
Epoch 100, Loss: 0.0251


In [15]:
model.to('cpu')

MatrixFactorization(
  (user_embed): Embedding(10000000, 32)
  (item_embed): Embedding(6444, 32)
)

In [16]:
embs_c = model.item_embed.weight.detach().numpy()
embs_u = model.user_embed.weight.detach().numpy()

In [17]:
embs = np.load('Models/EMA_OneCycleLarge/embeddings.npy')

In [18]:
from scipy.optimize import linear_sum_assignment


def analysis_1(df_full, embs, sample_size=1_000_000, alternative='two-sided'):
    dfs = df_full.select(['client_id']).to_pandas()
    
    df_1 = dfs.sample(sample_size, random_state=42, replace=True).copy()
    df_2 = dfs.sample(sample_size, random_state=1024, replace=True).copy()
    x = df_1['client_id'].to_list()
    y = df_2['client_id'].to_list()
    s = np.sum(embs[x, :] * embs[y, :], axis=1)
    u = np.sum(embs_u[x, :] * embs_u[y, :], axis=1)
    print(spearmanr(u, s, alternative=alternative))

def analysis(df_full, embs, action='add', sample_size=1_000_000, alternative='two-sided'):
    sku, cat =  f'sku_{action}', f'sku_{action}_cat'
    dfs = df_full.filter(pl.col(f'{action}s') > 0).select(['client_id', sku, cat]).to_pandas()
    print(len(dfs))
    df_1 = dfs.sample(sample_size, random_state=42, replace=True).copy()
    df_2 = dfs.sample(sample_size, random_state=1024, replace=True).copy()
    
    sets_1 = []
    for x in tqdm(df_1.itertuples(), total=len(df_1)):
        x_c = list(x.__getattribute__(cat))
        sets_1.append((x.client_id, x_c))
    
    sets_2 = []
    for x in tqdm(df_2.itertuples(), total=len(df_2)):
        x_c = list(x.__getattribute__(cat))
        sets_2.append((x.client_id, x_c))
    
    sims = set()
    for i in tqdm(range(len(df_1))):
        x_id, x_c = sets_1[i]
        y_id, y_c = sets_2[i]
        if x_id == y_id:
            continue
        if x_id < y_id:
            x_id, y_id = y_id, x_id
        #s = np.mean(np.max(embs_cat[x_c] @ embs_cat[y_c].T, axis=1)) + np.mean(np.max(embs_cat[x_c] @ embs_cat[y_c].T, axis=0))
        mat = embs_c[x_c] @ embs_c[y_c].T
        row_ind, col_ind = linear_sum_assignment(-mat)
        s = mat[row_ind, col_ind].sum()
        sims.add((x_id, y_id, s))
    
    data = list(sims)
    x = [d[0] for d in data]
    y = [d[1] for d in data]
    c = [d[2] for d in data]
    s = np.sum(embs[x, :] * embs[y, :], axis=1)
    u = np.sum(embs_u[x, :] * embs_u[y, :], axis=1)
    print(spearmanr(c, s, alternative=alternative))
    print(spearmanr(u, s, alternative=alternative))
    print(len(data))

In [19]:
analysis_1(df_full, embs)
analysis_1(df_full, embs[:,:1024])
analysis_1(df_full, embs[:, 1024:])

SignificanceResult(statistic=0.0033087039605329935, pvalue=0.0009372723990262921)
SignificanceResult(statistic=0.007793345961656196, pvalue=6.5201536791437706e-15)
SignificanceResult(statistic=-0.007001227301745633, pvalue=2.5359038524982805e-12)


In [20]:
analysis(df_full, embs, action='buy')

510971


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.04461362592320651, pvalue=0.0)
SignificanceResult(statistic=0.011134859173805606, pvalue=8.460108791246006e-29)
999993


In [21]:
print('full')

full


In [22]:
print('buy')
analysis(df_full, embs, action='buy')
print('add')
analysis(df_full, embs, action='add')
print('rm')
analysis(df_full, embs, action='rm')

buy
510971


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.04461362592320651, pvalue=0.0)
SignificanceResult(statistic=0.011134859173805606, pvalue=8.460108791246006e-29)
999993
add
557093


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.13399380687282447, pvalue=0.0)
SignificanceResult(statistic=0.019801024907761717, pvalue=2.811396716635924e-87)
999994
rm
289432


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.1366799524205363, pvalue=0.0)
SignificanceResult(statistic=0.03454076718537825, pvalue=1.381932497687098e-261)
999992


In [23]:
print('1c')

1c


In [24]:
print('buy')
analysis(df_full, embs[:, :1024], action='buy')
print('add')
analysis(df_full, embs[:, :1024], action='add')
print('rm')
analysis(df_full, embs[:, :1024], action='rm')

buy
510971


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.06498378026698967, pvalue=0.0)
SignificanceResult(statistic=0.01263733103889208, pvalue=1.3069933857907169e-36)
999993
add
557093


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.16655750092284746, pvalue=0.0)
SignificanceResult(statistic=0.022476427348290563, pvalue=6.636663637413625e-112)
999994
rm
289432


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.13739591574646817, pvalue=0.0)
SignificanceResult(statistic=0.03495362933969878, pvalue=7.898828040177763e-268)
999992


In [25]:
print('ema')

ema


In [26]:
print('buy')
analysis(df_full, embs[:, 1024:], action='buy')
print('add')
analysis(df_full, embs[:, 1024:], action='add')
print('rm')
analysis(df_full, embs[:, 1024:], action='rm')

buy
510971


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=-0.02513945110165957, pvalue=1.6727614120039385e-139)
SignificanceResult(statistic=0.0015354770574672243, pvalue=0.12466825647553213)
999993
add
557093


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=-0.009012469739000384, pvalue=2.0120159136757018e-19)
SignificanceResult(statistic=0.004634690191506849, pvalue=3.574634704127889e-06)
999994
rm
289432


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.07253475567953865, pvalue=0.0)
SignificanceResult(statistic=0.01893497486113153, pvalue=5.698787275617193e-80)
999992


In [27]:
print('full')
print('buy')
analysis(df_full, embs, action='buy', alternative='greater')
print('add')
analysis(df_full, embs, action='add', alternative='greater')
print('rm')
analysis(df_full, embs, action='rm', alternative='greater')

print('1c')
print('buy')
analysis(df_full, embs[:, :1024], action='buy', alternative='greater')
print('add')
analysis(df_full, embs[:, :1024], action='add', alternative='greater')
print('rm')
analysis(df_full, embs[:, :1024], action='rm', alternative='greater')

print('ema')
print('buy')
analysis(df_full, embs[:, 1024:], action='buy', alternative='greater')
print('add')
analysis(df_full, embs[:, 1024:], action='add', alternative='greater')
print('rm')
analysis(df_full, embs[:, 1024:], action='rm', alternative='greater')

full
buy
510971


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.04461362592320651, pvalue=0.0)
SignificanceResult(statistic=0.011134859173805606, pvalue=4.230054395623003e-29)
999993
add
557093


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.13399380687282447, pvalue=0.0)
SignificanceResult(statistic=0.019801024907761717, pvalue=1.405698358317962e-87)
999994
rm
289432


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.1366799524205363, pvalue=0.0)
SignificanceResult(statistic=0.03454076718537825, pvalue=6.90966248843549e-262)
999992
1c
buy
510971


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.06498378026698967, pvalue=0.0)
SignificanceResult(statistic=0.01263733103889208, pvalue=6.5349669289535845e-37)
999993
add
557093


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.16655750092284746, pvalue=0.0)
SignificanceResult(statistic=0.022476427348290563, pvalue=3.3183318187068126e-112)
999994
rm
289432


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.13739591574646817, pvalue=0.0)
SignificanceResult(statistic=0.03495362933969878, pvalue=3.9494140200888815e-268)
999992
ema
buy
510971


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=-0.02513945110165957, pvalue=1.0)
SignificanceResult(statistic=0.0015354770574672243, pvalue=0.062334128237766066)
999993
add
557093


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=-0.009012469739000384, pvalue=1.0)
SignificanceResult(statistic=0.004634690191506849, pvalue=1.7873173520639445e-06)
999994
rm
289432


  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

SignificanceResult(statistic=0.07253475567953865, pvalue=0.0)
SignificanceResult(statistic=0.01893497486113153, pvalue=2.8493936378085965e-80)
999992
