In [7]:
from datetime import datetime
import gc
from pathlib import Path
from typing import Annotated
from warnings import simplefilter
import numpy as np
import pandas as pd
from typing import List
import pickle
import json

from napkinxc.datasets import load_dataset
from napkinxc.models import PLT, HSM, BR, OVR
from napkinxc.measures import precision_at_k

import polars as pl
from scipy.sparse import csr_matrix
import scipy.sparse as sp

In [8]:
def mean_binary_cross_entropy(predictuion, target):
        assert predictuion.shape==target.shape
        eps = 1e-8
        mbce = (-np.log(np.clip(predictuion, eps, 1 - eps)) * target \
               - np.log(np.clip(1 - predictuion, eps, 1 - eps)) * (1 - target))
        mbce = mbce.sum(axis=1).mean()
        return mbce

In [15]:
    %%time
    hexses_target_path = "data/hexses_target.lst"
    hexses_data_path = "data/hexses_data.lst"
    with open(hexses_target_path, "r") as f:
        hexses_target = [x.strip() for x in f.readlines()]       
    with open(hexses_data_path, "r") as f:
        hexses_data = [x.strip() for x in f.readlines()]
        
    targ2idx = {vac_id: idx for idx, vac_id in enumerate(hexses_target)}
    idx2targ = {idx: vac_id for vac_id, idx in targ2idx.items()}
        

    trans_train = pl.read_parquet("data/trans_train.parquet")
    trans_test = pl.read_parquet("data/trans_test.parquet")
    trans_train = trans_train.group_by("customer_id","h3_09","datetime_id","mcc_code"
                  ).agg( pl.col("count").sum().alias("count"), 
                         pl.col("sum").sum(), 
                         pl.col("min").min(), 
                         pl.col("max").max(), 
                         pl.col("count_distinct").max(), 
                        )
    trans_test = trans_test.group_by("customer_id","h3_09","datetime_id","mcc_code"
                  ).agg( pl.col("count").sum().alias("count"), 
                         pl.col("sum").sum(), 
                         pl.col("min").min(), 
                         pl.col("max").max(), 
                         pl.col("count_distinct").max(), 
                        )
    
    trans_train = trans_train.with_columns( pl.concat_str([pl.col("h3_09"), pl.col("datetime_id")]).alias("hmd") )
    trans_test = trans_test.with_columns( pl.concat_str([pl.col("h3_09"), pl.col("datetime_id")]).alias("hmd") )

    hmds  = (set(trans_train["hmd"].unique()).intersection( set(trans_test["hmd"].unique()) ) )

    ###################################################### tt train 
    tt = trans_train.group_by("customer_id","hmd"
                                    ).agg( pl.lit(1).alias("count") 
                                         )
    #tt = tt.with_columns( ( pl.col("count")/pl.col("count").sum().over( pl.col("customer_id") ) ).alias("count") )#.explode("h3_09")
    tt = tt.filter( pl.col("hmd").is_in(hmds) )
   
    unique_hmd1 = tt['hmd'].explode().unique().to_list()   
    hmd2idx1 = {vac_id: idx for idx, vac_id in enumerate(unique_hmd1)}
    idx2hmd1 = {idx: vac_id for vac_id, idx in hmd2idx1.items()}    
    
    unique_cu = tt['customer_id'].unique().to_list()
    cu2idx = {user_id: idx for idx, user_id in enumerate(unique_cu)}
    idx2cu = {idx: vac_id for vac_id, idx in cu2idx.items()}
    
    ################################################### count train
    pairs = tt.select(['customer_id', 'hmd', 'count'])
    users = pairs['customer_id'].replace(cu2idx).to_numpy()
    vacancies = pairs['hmd'].replace(hmd2idx1).to_numpy()
    preferences = pairs['count'].cast(pl.Float32).to_numpy()

    cutran_mat_1train = csr_matrix((preferences, (users, vacancies)), 
                              shape=(len(unique_cu), len(hmd2idx1)), 
                              dtype=np.float32)
    # target
    targe = pl.read_parquet("data/targe_train.parquet")
    targ = targe.with_columns( pl.col("h3_09").replace(targ2idx).cast(pl.Int16), pl.col("customer_id").replace(cu2idx).cast(pl.Int32))
    targ = targ.group_by("customer_id").agg( pl.col("h3_09").unique() ).sort("customer_id")["h3_09"].to_list()
    

########################################################################################################

    plt = PLT("eurlex-model2")

    plt.fit(cutran_mat_1train, targ, )
    Y_pred = plt.predict(cutran_mat_1train, top_k=250)
    print("precision_at_k на train ",precision_at_k(targ, Y_pred, k=2))    

    predi = plt.predict_proba(cutran_mat_1train, top_k=250)
    cus=[]
    h3s=[]
    prs=[]
    icu=0
    for cu in predi:
        for h3 in cu:
            cus.append(icu)
            h3s.append(h3[0])
            prs.append(h3[1])
        icu+=1
    dfpred = pd.DataFrame( csr_matrix((prs, (cus, h3s))).toarray(), columns=idx2targ.values() )#.reset_index().rename(columns= {"index":"customer_id"})     
    for col in hexses_target:
        if not col in dfpred.columns:
            print("dfpred",col)
            dfpred[col]=0.0011
    cus=[]
    h3s=[]
    prs=[]
    icu=0
    for cu in targ:
        for h3 in cu:
            cus.append(icu)
            h3s.append(h3)
            prs.append(1)
        icu+=1
    dftrue = pd.DataFrame( csr_matrix((prs, (cus, h3s))).toarray(), columns=idx2targ.values() )
    
    print("MBCE на train", mean_binary_cross_entropy(dfpred, dftrue) )    
    
    #############################################################################################
    tt = trans_test.group_by("customer_id","hmd"
                                    ).agg( pl.lit(1).alias("count") 
                                         )
    #tt = tt.with_columns( ( pl.col("count")/pl.col("count").sum().over( pl.col("customer_id") ) ).alias("count") )#.explode("h3_09")
    tt = tt.filter( pl.col("hmd").is_in(hmds) )
    
    unique_cu = tt['customer_id'].unique().to_list()
    cu2idx = {user_id: idx for idx, user_id in enumerate(unique_cu)}
    idx2cu = {idx: vac_id for vac_id, idx in cu2idx.items()}
    ##### count
    pairs = tt.select(['customer_id', 'hmd', 'count'])
    users = pairs['customer_id'].replace(cu2idx).to_numpy()
    vacancies = pairs['hmd'].replace(hmd2idx1).to_numpy()
    preferences = pairs['count'].cast(pl.Float32).to_numpy()

    cutran_mat_1test = csr_matrix((preferences, (users, vacancies)), 
                              shape=(len(unique_cu), len(hmd2idx1)), 
                              dtype=np.float32)
    # target
    targe = pl.read_parquet("data/targe_test.parquet")
    targ = targe.with_columns( pl.col("h3_09").replace(targ2idx).cast(pl.Int16), pl.col("customer_id").replace(cu2idx).cast(pl.Int32))
    targ = targ.group_by("customer_id").agg( pl.col("h3_09").unique() ).sort("customer_id")["h3_09"].to_list()
    
    Y_pred = plt.predict(cutran_mat_1test, top_k=250)
    print("precision_at_k на test", precision_at_k(targ, Y_pred, k=2))    
    
    ###############################################################################################
    predi = plt.predict_proba(cutran_mat_1test, top_k=1657)
    cus=[]
    h3s=[]
    prs=[]
    icu=0
    for cu in predi:
        for h3 in cu:
            cus.append(icu)
            h3s.append(h3[0])
            prs.append(h3[1])
        icu+=1
    dfpred = pd.DataFrame( csr_matrix((prs, (cus, h3s))).toarray(), columns=idx2targ.values() )#.reset_index().rename(columns= {"index":"customer_id"})     
    for col in hexses_target:
        if not col in dfpred.columns:
            print("dfpred",col)
            dfpred[col]=0.0011
    
    cus=[]
    h3s=[]
    prs=[]
    icu=0
    for cu in targ:
        for h3 in cu:
            cus.append(icu)
            h3s.append(h3)
            prs.append(1)
        icu+=1
    dftrue = pd.DataFrame( csr_matrix((prs, (cus, h3s))).toarray(), columns=idx2targ.values() )
    
    print("MBCE на тесте", mean_binary_cross_entropy(dfpred, dftrue) )    
    
    dfpred.to_parquet("output_path.pq")
    print(dfpred)
    

precision_at_k на train  [0.84556363 0.69350698]
MBCE на train 4.114155831700172
precision_at_k на test [0.44160489 0.33838588]
MBCE на тесте 10.193271182092792
       8911818610bffff  89118195133ffff  8911819513bffff  891181b2827ffff  \
0             0.000026     4.912327e-06         0.000001         0.000002   
1             0.003753     1.054190e-05         0.001152         0.001350   
2             0.001951     5.936617e-05         0.000061         0.000027   
3             0.000028     5.280091e-07         0.000004         0.000004   
4             0.000496     2.917616e-04         0.000034         0.000114   
...                ...              ...              ...              ...   
34664         0.002942     3.423307e-04         0.000014         0.000020   
34665         0.000528     4.052042e-05         0.000013         0.000057   
34666         0.000813     4.902435e-05         0.000010         0.000015   
34667         0.002257     2.651611e-05         0.000122         0.00