In [1]:
import datetime
import dill
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from zipfile import ZipFile
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from pandarallel import pandarallel
import warnings
import collections
import pickle
from scipy.sparse import csr_matrix
from scipy import sparse
warnings.simplefilter('ignore')

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 20)

pandarallel.initialize(progress_bar=True, nb_workers=20)

INFO: Pandarallel will run on 20 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


## Загрузим данные

In [7]:
# train упорядочен хронологически
# timespent: время залипания юзера на айтем в минутах (от 0 до 60)
# reaction: (1) - лайк, (-1) - дизлайк
train = pd.read_parquet('train.parquet.gzip')
train = train.reset_index().rename(columns={"index": "time"})
# train = train[train.timespent > 0]

print(f"Уникальных юзеров в interactions: {train['user_id'].nunique():_}")
print(f"Уникальных айтемов в interactions: {train['item_id'].nunique():_}")

train

Уникальных юзеров в interactions: 1_000_183
Уникальных айтемов в interactions: 227_606


Unnamed: 0,time,user_id,item_id,timespent,reaction
0,0,707536,67950,0,0
1,1,707536,151002,0,0
2,2,707536,134736,0,0
3,3,707536,196151,0,0
4,4,707536,94182,0,0
...,...,...,...,...,...
144440010,144440010,849764,80910,0,0
144440011,144440011,993316,132328,0,0
144440012,144440012,993316,186701,0,0
144440013,144440013,666981,81857,0,0


In [3]:
# в items_meta для каждого item_id его автор и эмбеддинг содержания
items_meta = pd.read_parquet('items_meta.parquet.gzip')
items_meta = items_meta[items_meta.item_id.isin(train.item_id.unique())]
items_meta[[f"feat_{i}" for i in range(312)]] = items_meta.embeddings.parallel_apply(pd.Series)
print("len rows: ", items_meta.shape[0])
print("Unique items: ", items_meta.item_id.nunique())
print("Unique writers: ", items_meta.source_id.nunique())
print("emb shape: ", items_meta.iloc[0].embeddings.shape)
items_meta.head()

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=11381), Label(value='0 / 11381')))…

len rows:  227606
Unique items:  227606
Unique writers:  24438
emb shape:  (312,)


Unnamed: 0,item_id,source_id,embeddings,feat_0,feat_1,feat_2,feat_3,feat_4,feat_5,feat_6,feat_7,feat_8,feat_9,feat_10,feat_11,feat_12,feat_13,feat_14,feat_15,feat_16,feat_17,feat_18,feat_19,feat_20,feat_21,feat_22,feat_23,feat_24,feat_25,feat_26,feat_27,feat_28,feat_29,feat_30,feat_31,feat_32,feat_33,feat_34,feat_35,feat_36,feat_37,feat_38,feat_39,feat_40,feat_41,feat_42,feat_43,feat_44,feat_45,feat_46,feat_47,feat_48,feat_49,feat_50,feat_51,feat_52,feat_53,feat_54,feat_55,feat_56,feat_57,feat_58,feat_59,feat_60,feat_61,feat_62,feat_63,feat_64,feat_65,feat_66,feat_67,feat_68,feat_69,feat_70,feat_71,feat_72,feat_73,feat_74,feat_75,feat_76,feat_77,feat_78,feat_79,feat_80,feat_81,feat_82,feat_83,feat_84,feat_85,feat_86,feat_87,feat_88,feat_89,feat_90,feat_91,feat_92,feat_93,feat_94,feat_95,feat_96,feat_97,feat_98,feat_99,feat_100,feat_101,feat_102,feat_103,feat_104,feat_105,feat_106,feat_107,feat_108,feat_109,feat_110,feat_111,feat_112,feat_113,feat_114,feat_115,feat_116,feat_117,feat_118,feat_119,feat_120,feat_121,feat_122,feat_123,feat_124,feat_125,feat_126,feat_127,feat_128,feat_129,feat_130,feat_131,feat_132,feat_133,feat_134,feat_135,feat_136,feat_137,feat_138,feat_139,feat_140,feat_141,feat_142,feat_143,feat_144,feat_145,feat_146,feat_147,feat_148,feat_149,feat_150,feat_151,feat_152,feat_153,feat_154,feat_155,feat_156,feat_157,feat_158,feat_159,feat_160,feat_161,feat_162,feat_163,feat_164,feat_165,feat_166,feat_167,feat_168,feat_169,feat_170,feat_171,feat_172,feat_173,feat_174,feat_175,feat_176,feat_177,feat_178,feat_179,feat_180,feat_181,feat_182,feat_183,feat_184,feat_185,feat_186,feat_187,feat_188,feat_189,feat_190,feat_191,feat_192,feat_193,feat_194,feat_195,feat_196,feat_197,feat_198,feat_199,feat_200,feat_201,feat_202,feat_203,feat_204,feat_205,feat_206,feat_207,feat_208,feat_209,feat_210,feat_211,feat_212,feat_213,feat_214,feat_215,feat_216,feat_217,feat_218,feat_219,feat_220,feat_221,feat_222,feat_223,feat_224,feat_225,feat_226,feat_227,feat_228,feat_229,feat_230,feat_231,feat_232,feat_233,feat_234,feat_235,feat_236,feat_237,feat_238,feat_239,feat_240,feat_241,feat_242,feat_243,feat_244,feat_245,feat_246,feat_247,feat_248,feat_249,feat_250,feat_251,feat_252,feat_253,feat_254,feat_255,feat_256,feat_257,feat_258,feat_259,feat_260,feat_261,feat_262,feat_263,feat_264,feat_265,feat_266,feat_267,feat_268,feat_269,feat_270,feat_271,feat_272,feat_273,feat_274,feat_275,feat_276,feat_277,feat_278,feat_279,feat_280,feat_281,feat_282,feat_283,feat_284,feat_285,feat_286,feat_287,feat_288,feat_289,feat_290,feat_291,feat_292,feat_293,feat_294,feat_295,feat_296,feat_297,feat_298,feat_299,feat_300,feat_301,feat_302,feat_303,feat_304,feat_305,feat_306,feat_307,feat_308,feat_309,feat_310,feat_311
0,0,7340,"[0.10458118, 0.0...",0.104581,0.04788,0.030944,-0.035116,-0.026452,-0.018017,0.016048,-0.045419,0.000311,0.093888,0.068062,-0.041069,0.062561,0.001897,-0.063457,-0.005675,0.015704,-0.061965,0.01497,-0.024888,0.056331,0.014982,0.047225,0.021182,0.079994,-0.021951,-0.057099,0.016034,-0.030449,0.024281,0.010529,-0.023831,0.011418,-0.067547,-0.024365,0.052528,-0.020069,0.004901,-0.036743,0.003144,-0.018055,0.248296,0.07359,0.014884,-0.042709,0.004142,-0.009092,-0.006609,0.0156,-0.00784,0.043871,-0.041044,0.020118,-0.052628,0.077389,-0.001171,0.018893,-0.041031,0.031918,-0.038071,0.003181,0.080917,-0.037467,0.020274,-0.024794,0.140124,0.085882,-0.037249,0.022786,0.001536,0.086229,-0.075313,-0.031434,-0.020224,0.009186,0.015094,0.06166,-0.044022,0.064016,0.019537,-0.03499,0.050758,0.036314,0.056562,0.006112,0.032237,-0.009762,0.030743,-0.014273,-0.00949,0.055373,-0.064723,0.011089,0.035945,-0.002951,0.01154,0.072136,0.028247,-0.043543,0.006763,-0.07649,-0.119404,-0.04722,-0.098469,-0.026963,-0.043232,0.052229,0.06321,0.077692,-0.02331,0.036534,0.002803,-0.058378,0.033859,0.001762,0.025262,-0.060648,0.101775,-0.031662,-0.045905,0.001397,-0.091876,0.014755,-0.011848,-0.019974,0.011138,0.037368,0.034023,0.012566,0.033387,0.046067,0.018949,-0.003952,-0.046605,0.03938,0.02903,0.018113,-0.013297,0.066368,-0.053902,0.014472,-0.009932,0.040023,-0.062347,0.087274,0.008949,0.076367,0.05259,0.024949,0.025704,0.009104,-0.016115,0.016704,0.026569,-0.060581,0.023069,-0.058094,0.049591,-0.06261,-0.039572,0.035351,0.106313,0.019917,-0.145175,0.026492,-0.029303,0.044426,-0.046923,0.002065,0.003374,0.024082,0.044541,0.006523,-0.003395,-0.028411,0.015958,0.007325,-0.002474,-0.086585,0.013133,-0.015525,0.061537,0.097969,-0.012246,-0.053275,0.084387,-0.007518,0.096921,-0.02751,0.025652,0.054984,-0.026778,-0.031838,0.018663,0.050039,-0.48321,0.02974,0.085101,-0.019348,-0.036927,-0.014241,-0.037933,-0.028382,-0.015813,-0.315715,-0.046879,0.069136,-0.053152,0.078296,0.021942,-0.050891,0.003087,0.023563,0.03518,0.009088,-0.063369,-0.03893,0.065241,0.065803,-0.025753,0.017802,0.040102,-0.017162,-0.043116,0.005302,0.041633,-0.111696,0.054126,-0.035951,-0.026965,-0.003268,-0.02262,0.012516,-0.017142,0.004693,0.103053,0.031467,-0.054088,0.044733,-0.017603,-0.012578,-0.061519,0.027163,0.016403,-0.009291,0.049876,-0.047903,-0.006431,-0.061078,-0.025195,0.000364,-0.055464,-0.075298,0.016098,0.037649,-0.039622,-0.015662,0.016056,0.022568,0.046898,0.036,0.039373,0.037384,0.018342,-0.014039,-0.000122,-0.032386,0.053249,0.105708,0.013821,0.01944,-0.032066,-0.02028,0.003047,0.012052,0.013756,-0.010913,-0.007669,0.018034,0.08405,0.047649,0.043563,-0.018764,0.027969,0.013103,0.038647,0.014058,-0.01357,0.013302,0.004822,-0.052609,0.018889,-0.010206,-0.002022,-0.033855,0.003221,-0.040477,-0.022297,-0.021932,0.054266,-0.060089,0.024747,0.006862,0.00103,-0.006966,0.016704,-0.063262,-0.02483,-0.035405,-0.036204,-0.012501,-0.051993
1,1,6284,"[0.035625108, -0...",0.035625,-0.039264,-0.033103,-0.049436,-0.032331,-0.078797,0.096703,0.014842,-0.062225,0.079789,0.1096,0.012535,-0.016996,0.014885,0.010518,0.044627,-0.013888,0.00242,-0.008272,0.009635,-0.004423,0.056131,0.005206,-0.017742,0.097094,-0.033391,-0.006142,-0.019154,0.05719,0.034036,0.022394,-0.046946,0.010017,-0.018133,-0.04183,-0.026363,0.011882,0.034107,0.011325,-0.050799,-0.033238,0.190952,0.007753,0.05536,-0.0914,0.010375,-0.074658,0.009025,-0.073327,-0.018201,0.065827,-0.026588,-0.021515,0.005007,0.001322,-0.022663,0.071353,-0.091256,0.062709,0.038625,0.02048,0.067012,0.012658,-0.006893,-0.027315,0.077231,0.024025,-0.009121,0.048105,0.07028,0.023217,-0.000484,-0.017492,0.007075,-0.036041,-0.083739,-0.013475,0.02694,0.059407,-0.024123,-0.04253,-0.026396,0.014753,0.014689,-0.010855,0.000931,-0.054533,0.004243,-0.001871,-0.02689,0.060603,-0.021824,-0.00862,-0.003259,-0.003833,-0.018202,0.020536,-0.079951,-0.045038,0.017007,-0.04745,-0.123132,-0.00137,-0.075383,0.050634,-0.010404,0.040572,0.00689,-0.025162,-0.043796,0.026795,-0.009475,-0.077041,0.021245,0.015781,0.038141,-0.009931,0.067241,-0.100557,-0.051747,-0.027246,-0.026531,0.069448,-0.004324,-0.054094,0.009454,0.017498,-0.012037,-0.038381,-0.013088,0.051876,-0.056812,0.031449,0.068826,0.016532,0.000102,0.024216,0.019697,-0.01138,-0.06161,0.047395,0.050405,0.083926,0.002224,0.057851,0.005653,0.092808,0.03546,0.044991,0.00524,0.036022,0.008969,0.015229,0.067129,0.021897,-0.001214,0.034102,0.042377,0.156365,-0.03473,6e-05,0.002131,0.04438,-0.149964,-0.078064,-0.038894,0.006546,-0.016623,-0.023486,-0.008292,-0.009094,0.02714,-0.027872,-0.049239,-0.024247,0.090044,0.074983,0.014869,-0.035614,-0.052032,0.006558,0.022982,0.046687,0.03717,0.022584,0.013601,0.014621,0.105019,0.043898,0.006147,0.025432,-0.019075,-0.073027,0.016581,0.129942,-0.45731,0.005167,0.035622,0.015882,-0.095243,-0.016727,-0.004454,0.023723,0.01843,-0.321819,0.008136,0.042833,-0.050847,0.064387,0.032128,-0.01505,-0.021448,-0.042778,0.059774,-0.048194,-0.059958,-0.020665,-0.007068,0.062741,0.014266,0.02404,0.054554,-0.057971,-0.008117,0.040581,0.041272,-0.047553,0.078273,0.019238,-0.065602,0.056987,-0.053488,0.015021,0.002688,-0.003708,0.015502,-0.052917,-0.032236,-0.01264,-0.072472,0.03327,0.004498,0.005389,0.005002,-0.023001,0.009674,-0.00341,-0.003682,-0.067879,0.010925,-0.121765,-0.066947,-0.019514,0.001729,-0.008663,-0.048593,-0.04213,0.045869,0.005386,0.03277,0.051503,0.031533,0.015302,0.058103,-0.017933,-0.07395,0.033768,0.035434,0.026163,0.033131,-0.024849,0.032894,-0.063769,0.001032,0.131275,0.062854,0.044675,0.011784,0.058894,-0.009838,-0.008196,0.023829,-0.024563,0.009418,0.038089,0.027128,0.107027,0.040551,-0.027338,0.092673,0.003481,0.054823,0.037511,-0.006672,-0.0616,-0.032384,0.070386,0.017031,-0.014194,0.040639,-0.021827,-0.014738,-0.00326,0.069096,-0.002862,0.049485,-0.081491,0.021737,0.034097,-0.016017,0.038909,-0.072126
2,2,12766,"[0.08418761, 0.0...",0.084188,0.006732,-0.003711,-0.020163,-0.02982,0.029411,0.017843,-0.01158,0.035843,0.078894,0.025562,0.045167,0.085365,0.056708,-0.0634,-0.016497,0.025046,-0.055683,-0.013206,-0.041707,0.005806,0.045827,-0.007802,0.034929,0.072549,0.016922,0.002034,0.028796,-0.029841,0.091459,0.050122,0.038738,-0.01688,-0.115528,-0.006674,-0.002313,0.017555,0.002022,-0.014358,-0.030713,0.051636,0.240008,0.019803,-0.015868,-0.026497,-0.034037,-0.035215,0.053387,-0.048209,0.045822,0.033669,0.037641,0.01074,0.032245,0.069093,0.029329,0.019683,-0.018839,-0.008873,-0.038841,-0.012251,0.121602,0.023051,0.000414,-0.020903,0.031036,0.092787,-0.035768,-0.030845,0.069676,0.06649,-0.072861,-0.012283,0.010152,-0.001587,-0.096491,-0.008677,-0.033257,0.019656,0.004105,-0.065065,-0.020688,0.028415,0.036351,0.016672,0.054702,0.013499,-0.064132,-0.029029,-0.073483,0.04098,-0.038132,0.109377,0.042458,0.016366,-0.002699,0.084556,-0.052831,-0.06458,0.010045,-0.017052,-0.065967,0.040914,-0.045249,0.054719,-0.053812,0.001542,0.043823,0.103153,0.02678,0.016293,0.016122,-0.081813,0.03903,0.025715,0.065226,-0.060537,0.021892,-0.019041,-0.02612,-0.04507,-0.049169,0.012669,0.006929,-0.075857,-0.029356,0.063172,0.050459,0.041876,0.074762,0.03678,-0.04065,-0.01367,0.029845,0.00496,-0.002323,0.018484,-0.067095,0.012363,0.040224,0.031962,0.017049,0.022777,-0.01328,0.03373,-0.038703,-0.016981,-0.011387,0.073266,0.061737,-0.013112,-0.010017,0.003697,0.009633,-0.056511,0.002955,-0.029552,0.006216,0.061964,-0.045629,0.057368,0.049045,0.037722,-0.161859,-0.017039,-0.043941,0.009616,-0.021188,-0.024246,0.014308,0.02184,0.036032,-0.020752,-0.009181,0.022839,0.078369,0.035759,-0.031123,-0.074861,0.008639,0.004115,0.010505,0.082908,0.079079,-0.066949,0.072011,-0.02072,0.08356,0.001247,0.013562,0.052212,-0.014523,-0.159691,0.023412,0.072151,-0.429264,0.027139,0.048605,0.013619,-0.053203,-0.015568,-0.051572,0.000311,-0.082107,-0.288873,-0.021212,-0.011882,-0.028436,0.06157,0.004802,-0.010853,-0.039439,0.006094,0.036278,-0.012372,-0.023462,-0.080887,0.09125,0.052256,-0.073001,-0.04886,-0.008154,0.001016,-0.025728,0.018782,0.026078,-0.067531,0.086866,-0.064242,-0.12607,0.010082,-0.048533,0.078178,-0.025484,-0.004861,0.073434,0.072934,0.046969,0.024856,-0.061359,-0.029461,-0.03215,0.044994,0.021193,0.07128,0.064211,-0.041286,0.020037,-0.069598,-0.087855,-0.070055,-0.029714,-0.064483,0.036734,-0.001978,-0.023552,0.022827,0.045512,0.038492,0.000201,0.077635,0.046231,0.034834,-0.017659,-0.083766,0.014131,-0.003189,-0.005182,0.072023,0.075449,-0.004739,0.010462,-0.038978,0.027168,-0.049389,0.010235,0.028414,-0.020022,0.056233,0.07875,-0.033912,0.053506,-0.057579,-0.010951,0.036815,-0.03584,0.035763,0.041067,-0.023126,-0.087749,0.009956,0.010306,-0.023607,-0.012206,-0.017982,-0.004706,-0.050456,-0.035429,-0.018668,0.024941,0.023879,-0.058734,0.035034,-0.001716,0.017571,0.036196,-0.054412,0.060726,0.010904,-0.030896,-0.000183,-0.006675
3,3,14734,"[0.049901545, 0....",0.049902,0.039079,-0.038907,-0.05343,-0.031497,-0.073507,-0.008374,-0.020376,-0.038801,0.089043,0.081095,-0.026807,0.051285,-0.014023,-0.039647,-0.007285,0.012667,-0.00396,0.01326,-0.028778,0.013529,0.012102,0.036266,-0.045854,0.11345,-0.024549,0.02102,0.007278,0.003986,0.021595,-0.020225,0.042651,-0.003088,-0.057861,-0.050431,0.011364,0.059311,-0.017664,-0.007379,-0.08728,-0.033992,0.206406,0.037101,0.041275,-0.01576,-0.043748,-0.010243,0.007684,-0.046845,-0.018688,-0.003585,0.012806,0.005318,-0.037617,0.096761,-0.001184,0.008935,-0.054072,0.078012,-0.006633,0.04123,0.038649,-0.00277,-0.013016,0.032157,0.084527,0.138936,-0.035247,0.044899,0.03651,0.08784,0.012557,-0.026222,0.034793,0.054172,-0.051727,0.035691,-0.049217,0.090429,0.056759,-0.033199,-0.055258,0.054649,0.013104,-0.060519,0.063874,-0.028052,0.030076,-0.055345,0.001767,0.112657,0.019715,-0.066405,0.028865,-0.013201,-0.036434,0.064304,0.004765,-0.024432,-0.032272,-0.062715,-0.086133,-0.01627,-0.091196,-0.066351,-0.014434,0.038514,0.025118,0.106087,0.028492,-0.014891,0.024162,-0.090544,-0.044541,0.06761,0.040144,-0.008114,0.051757,-0.050955,-0.00749,-0.041993,-0.068794,0.003912,0.016746,-0.063025,0.036356,0.042323,0.052323,0.026131,0.052951,0.003775,-0.005242,0.027763,0.00938,0.07091,0.002487,0.014252,0.023032,0.101078,-0.067809,0.017865,-0.009583,-0.016613,-0.063106,0.071412,0.01426,0.032991,-0.031419,-0.027962,0.006834,0.03095,-0.017352,0.00683,0.009532,-0.04044,0.066306,-0.010768,0.040014,-0.009754,-0.010987,0.068321,0.109898,0.05503,-0.171086,0.07461,-0.071838,0.005659,-0.01808,-0.011198,-0.009967,-0.004609,0.072834,0.010911,-0.024105,-0.008993,-0.005646,0.03531,0.016512,-0.053576,-0.027351,-0.012181,0.00619,0.067753,-0.011663,0.023814,0.072442,0.008365,0.041778,0.001221,0.069102,0.013399,-0.05411,-0.064719,-0.012616,0.037223,-0.481585,-0.006109,-0.008843,-0.042331,-0.011361,0.015441,0.056181,0.002322,-0.022931,-0.297464,-0.019227,0.033243,-0.053616,0.082671,0.0366,-0.020999,-0.010162,0.057793,0.034159,-0.061104,-0.016325,-0.043528,0.055538,-0.015938,-0.09448,0.03463,0.08373,-0.016026,-0.079596,-0.004517,0.08648,-0.034604,0.065511,0.016562,-0.072147,0.000668,-0.039242,0.050851,0.048343,0.032979,0.062207,0.054261,-0.020061,0.024392,-0.093156,-0.016549,-0.023793,-0.00474,-0.020433,0.026419,0.067596,-0.042491,-0.002954,-0.096592,-0.057602,-0.019893,-0.081006,-0.035186,0.019734,0.008858,0.003748,-0.044692,0.029002,0.021933,-0.025748,0.041826,-0.001239,0.03866,0.052488,-0.024225,-0.000388,-0.055913,0.064345,0.102119,0.024191,-0.006752,0.009,-0.030614,0.01022,0.023221,0.008823,-0.002955,-0.027848,0.018042,0.041241,0.048861,0.014176,-0.049921,-0.030757,0.043213,-0.008733,0.074278,0.054519,0.025424,0.030737,0.003729,0.04193,0.037013,0.017572,0.004484,0.024693,-0.012353,-0.010037,-0.022155,0.018482,-0.018713,0.035288,-0.025119,0.008545,-0.016361,0.017685,-0.057258,-0.022387,-0.011892,0.000888,0.01289,-0.074769
4,4,22557,"[0.09303163, 0.0...",0.093032,0.023448,0.002949,-0.017046,-0.031863,-0.058425,-0.009482,-0.046374,0.029164,0.036861,0.090054,0.007745,-0.005526,0.069329,-0.016935,0.031366,0.011966,0.00563,0.005025,-0.037645,0.016989,-0.02584,0.013938,-0.032739,0.082881,0.00708,0.108539,0.017993,-0.041932,0.046102,-0.065696,0.042008,0.057814,0.034774,-0.05215,-0.021379,-0.030924,-0.019908,0.016058,-0.049799,0.027549,0.247541,0.093748,-0.042712,-0.050753,-0.085753,-0.03166,-0.005282,-0.037682,0.006262,-0.002973,0.040347,0.031087,-0.002169,0.068369,-0.053547,0.022028,-0.031135,0.105393,-0.048299,0.045345,0.031209,0.043008,-0.012733,0.019669,0.032967,0.024931,-0.021692,0.018941,0.043737,0.010785,0.023405,-0.096807,-0.055594,0.039856,-0.053063,0.055711,-0.043998,0.149279,0.031125,-0.028854,0.018151,0.044385,0.042196,0.026718,0.084353,0.020515,-0.006334,0.013962,-0.022904,0.040387,0.016369,0.025532,0.005551,0.006564,-0.017633,0.007458,-0.052346,-0.01592,-0.006247,-0.03663,-0.121594,-0.019246,-0.036003,-0.041416,0.024777,-0.093635,0.062302,0.026781,0.041639,-0.068893,-0.038124,-0.013465,0.034889,0.074775,-0.031127,0.018981,0.044262,0.008135,-0.05613,-0.02201,-0.064458,0.003471,0.037965,-0.030719,-0.006875,0.059746,-0.014069,0.083908,-0.010158,-0.01757,0.023113,0.017518,0.008264,-0.02725,0.049902,0.01514,-0.066964,0.059306,-0.026803,0.075806,-0.001504,0.002046,-0.04675,0.017294,-0.041971,0.069029,0.003351,0.028508,-0.005002,0.022589,-0.010469,-0.007204,0.00504,-0.041037,0.038045,0.003571,0.005444,0.113088,0.027292,-0.027611,-0.021999,0.008878,-0.140081,-0.027287,-0.029526,-0.029091,0.005484,-0.012086,-0.043745,0.071023,0.02727,-0.04304,0.063394,-0.03512,-0.014757,-0.015466,0.03151,-0.018933,-0.038739,-0.071336,-0.056551,0.039746,-0.054927,-0.022977,0.08515,-0.0487,0.038497,0.016357,0.074343,0.050765,-0.006482,0.010613,0.062125,0.025087,-0.468937,0.015873,0.018334,-0.012314,-0.015417,-0.031032,-0.03263,-0.012015,0.009407,-0.27742,-0.022397,0.067781,0.014908,0.096681,0.061002,-0.01812,0.024039,0.023988,0.085014,0.013368,-0.093285,-0.023058,-0.027637,0.065397,-0.093397,0.011634,-0.026983,0.023789,-0.040002,0.066387,0.069868,-0.000946,0.023136,0.025201,-0.050381,-0.00197,-0.031445,-0.011741,0.116547,-0.027673,0.055339,0.021159,0.017749,0.049359,-0.051558,-0.046654,0.047815,-0.054313,-0.02577,-0.0142,0.040386,-0.043643,-0.044719,-0.006841,-0.02416,-0.133746,-0.045637,0.034884,0.04771,0.007517,0.019203,-0.028941,-0.003936,0.079009,-0.051103,0.01667,-0.012978,-0.015301,-0.021581,0.061304,-0.001375,0.023701,0.011408,0.056801,-0.010007,-0.031414,-0.038271,0.011284,0.007432,-0.002262,0.058402,0.013878,-0.01657,0.029469,0.072547,0.120254,-0.000486,-0.104221,-0.001464,0.072668,0.037428,0.07745,-0.00522,0.000895,0.080305,0.019935,0.065262,-0.014668,-0.006669,-0.05677,0.015141,-0.040743,0.048089,-0.044117,-0.027943,0.001521,0.006696,-0.005206,0.002089,0.020247,-0.014311,-0.048222,-0.068534,-0.063528,-0.018066,0.061141,-0.050887


In [4]:
# candidates содержит item_id свежих кандидатов из которых нужно будет предсказать на тесте
candidates_df = pd.read_parquet('fresh_candidates.parquet.gzip')
print("len rows: ", candidates_df.shape[0])
print("Items for predict: ", candidates_df.item_id.nunique())
candidates_df.head()

len rows:  100000
Items for predict:  100000


Unnamed: 0,item_id
0,0
1,2
2,5
3,6
4,7


In [5]:
test_id = pd.read_parquet('test.parquet.gzip')
print("len rows: ", test_id.shape[0])
print("Test users: ", test_id.user_id.nunique())
test_id.head()

len rows:  200000
Test users:  200000


Unnamed: 0,user_id
0,7
1,8
2,9
3,11
4,18


In [8]:
#simple filter
train = train[train.user_id.isin(test_id.user_id)]
train = train[train.item_id.isin(candidates_df.item_id)]

print(train.shape)
train.head()

(16696198, 5)


Unnamed: 0,time,user_id,item_id,timespent,reaction
4940,4940,863936,72872,4,0
7366,7366,45679,161406,0,0
12790,12790,653187,92984,0,0
13038,13038,890059,72872,3,0
22176,22176,862054,70829,1,0


In [9]:
known_items = collections.defaultdict(list)
max_len = 0

tmp = train[train.item_id.isin(candidates_df.item_id.values)]
tmp = train[train.user_id.isin(test_id.user_id.values)]

for id, row in tqdm(tmp.iterrows(), total=tmp.shape[0]):
    known_items[row.user_id].append(row.item_id)
    l = len(known_items[row.user_id])
    if l > max_len:
        max_len = l

print(max_len)

  0%|          | 0/16696198 [00:00<?, ?it/s]

534


In [7]:
# Item feature matrix
item_feat_map = dict() # норм айди в айди си эс эр матрицы
for i, it_id in enumerate(items_meta.item_id.tolist()):
    item_feat_map[it_id] = i

item_feature_csr = csr_matrix(items_meta[[f"feat_{i}" for i in range(312)]].to_numpy())
item_feature_csr

In [8]:
%%time
# User feature matrix
n_last = 3
user_last_watches = train.groupby("user_id")["item_id"].apply(lambda x: list(x.iloc[-n_last:])).reset_index()
user_last_watches = user_last_watches.explode('item_id')
user_last_watches = user_last_watches.merge(items_meta[["item_id", "embeddings"]], on='item_id')
user_last_watches = pd.DataFrame(user_last_watches.groupby(['user_id'])['embeddings'].agg(np.mean)).reset_index()
user_last_watches[[f"feat_{i}" for i in range(312)]] = user_last_watches.embeddings.parallel_apply(pd.Series)
user_feature_csr = csr_matrix(user_last_watches[[f"feat_{i}" for i in range(312)]].to_numpy())
user_feature_csr

In [9]:
# sparse.save_npz("user_feature_csr_3.npz", user_feature_csr)
user_feature_csr = sparse.load_npz("user_feature_csr_3.npz")

## Обучение LightFM

In [10]:
from lightfm.data import Dataset
from lightfm import LightFM
from lightfm.evaluation import precision_at_k, recall_at_k

In [11]:
dataset = Dataset()
dataset.fit(train['user_id'].unique(), train['item_id'].unique())

In [12]:
%%time
# matrix for training
interactions_matrix, weights_matrix = dataset.build_interactions(
    zip(*train[['user_id', 'item_id', 'timespent']].values.T)
)

weights_matrix_csr = weights_matrix.tocsr()

CPU times: user 43.4 s, sys: 521 ms, total: 43.9 s
Wall time: 43.9 s


In [13]:
# user / item mappings
lightfm_mapping = dataset.mapping()
lightfm_mapping = {
    'users_mapping': lightfm_mapping[0],
    'items_mapping': lightfm_mapping[2],
}

lightfm_mapping['users_inv_mapping'] = {v: k for k, v in lightfm_mapping['users_mapping'].items()}
lightfm_mapping['items_inv_mapping'] = {v: k for k, v in lightfm_mapping['items_mapping'].items()}

print(f"users_mapping amount: {len(lightfm_mapping['users_mapping'])}")
print(f"items_mapping amount: {len(lightfm_mapping['items_mapping'])}")

users_mapping amount: 199985
items_mapping amount: 99812


In [15]:
lfm_model = LightFM(
    no_components=256, # размерность пространства
    learning_rate=0.01, 
    loss='warp', 
    # item_alpha=0.0, # регуляризация
    # user_alpha=0.0, # регуляризация
    max_sampled=10, # сколько негатив майнить
    random_state=42,
)

In [16]:
#  train model

num_epochs = 10

for _ in tqdm(range(num_epochs)):
    lfm_model.fit_partial(
        interactions=weights_matrix_csr,
        user_features=None, # np.float32 csr_matrix of shape [n_users, n_user_features]
        item_features=None, # np.float32 csr_matrix of shape [n_items, n_item_features]
        # sample_weight=None, # np.float32 coo_matrix of shape [n_users, n_items] по таймспент (учтено в interactions)
        epochs=1,
        num_threads=20,
        # verbose=False,
    )
    # test_precision = precision_at_k(
        # lfm_model, 
        # weights_matrix_csr_test, 
        # train_interactions=weights_matrix_csr, 
        # k=20, 
        # num_threads=20).mean()
    # print(test_precision)

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
# save model  
with open(f"lfm_model_itemFEAT_11jan.dill", 'wb') as f:
    dill.dump(lfm_model, f)

## inference

In [17]:
# фильруем нужные айтемы и юзеры из горячих (которые были в обучении)
warm_users = []
cold_users = []
for user in test_id.user_id.tolist():
    if user in lightfm_mapping['users_mapping'].keys():
        warm_users.append(user)
    else:
        cold_users.append(user)

print("warm users: ", len(warm_users))
print("cold users: ", len(cold_users))

warm users:  199985
cold users:  15


In [18]:
warm_items = []
cold_items = []
for item in candidates_df.item_id.tolist():
    if item in lightfm_mapping['items_mapping'].keys():
        warm_items.append(item)
    else:
        cold_items.append(item)

print("Warm items: ", len(warm_items))
print("Cold items: ", len(cold_items))

Warm items:  99812
Cold items:  188


In [19]:
candidates = pd.DataFrame({
    'user_id': warm_users
})

candidates.head(3)

Unnamed: 0,user_id
0,7
1,8
2,9


In [20]:
def generate_lightfm_recs_mapper(model, item_ids, known_items,
                                 user_features, item_features, N, 
                                 user_mapping, item_inv_mapping, 
                                 num_threads=1):
    def _recs_mapper(user): # обычный юзер
        user_id = user_mapping[user] # псевдо юзер
        # user_id = user
        recs = model.predict(user_id, # Для какого юзера будем рекомендовать
                             item_ids, # массив айтемов из обучения
                             user_features=user_features, # np.float32 csr_matrix of shape [n_users, n_user_features], optional
                             item_features=item_features,  # np.float32 csr_matrix of shape [n_items, n_item_features], optional
                             num_threads=num_threads)
        
        additional_N = len(known_items[user]) if user in known_items else 0
        total_N = N + additional_N
        top_cols = np.argpartition(recs, -np.arange(total_N))[-total_N:][::-1]
        
        final_recs = [item_inv_mapping[pseudo_id_mapper[item]] for item in top_cols]
        if additional_N > 0:
            filter_items = known_items[user]
            final_recs = [item for item in final_recs if item not in filter_items]
        return final_recs[:N]
    return _recs_mapper

# кол-во кандидатов 
top_N = 20

# вспомогательные данные 
# all_cols = list(lightfm_mapping['items_mapping'].values()) # какие айтемы будем рекомендовать

pseudo_item_ids = [lightfm_mapping['items_mapping'][item] for item in warm_items]
pseudo_id_mapper = dict()
for i, item in enumerate(pseudo_item_ids):
    pseudo_id_mapper[i] = item

mapper = generate_lightfm_recs_mapper(
    lfm_model, 
    item_ids=pseudo_item_ids, 
    known_items=known_items, # тут можно добавить айтемы из трейна, чтобы они не повторялись в предсказаниях
    N=top_N,
    user_features=None, 
    item_features=None, 
    user_mapping=lightfm_mapping['users_mapping'],
    item_inv_mapping=lightfm_mapping['items_inv_mapping'],
    num_threads=1
)

In [21]:
%%time
# генерируем предказания
candidates['predictions'] = candidates['user_id'].parallel_map(mapper)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=10000), Label(value='0 / 10000')))…

CPU times: user 22min 14s, sys: 1min 55s, total: 24min 10s
Wall time: 1h 21min 24s


In [22]:
candidates

Unnamed: 0,user_id,predictions
0,7,"[120767, 225411,..."
1,8,"[43635, 210210, ..."
2,9,"[223491, 56204, ..."
3,11,"[168651, 43635, ..."
4,18,"[87797, 120767, ..."
...,...,...
199980,1000160,"[155049, 58272, ..."
199981,1000165,"[158392, 1537, 8..."
199982,1000166,"[14760, 183888, ..."
199983,1000168,"[76464, 117035, ..."


In [23]:
cold_pred = pd.DataFrame({'user_id': cold_users})
most_views = [158392, 1537, 155049, 126834, 151406, 168651, 84951, 106246, 93615, 225171, 23623, 32641, 183888, 205187, 197397, 119238, 221911, 117035, 149746, 43635]
cold_pred['predictions'] = cold_pred.user_id.apply(lambda x: most_views)
cold_pred

Unnamed: 0,user_id,predictions
0,14910,"[158392, 1537, 1..."
1,31656,"[158392, 1537, 1..."
2,33612,"[158392, 1537, 1..."
3,78730,"[158392, 1537, 1..."
4,88598,"[158392, 1537, 1..."
5,105397,"[158392, 1537, 1..."
6,177374,"[158392, 1537, 1..."
7,234422,"[158392, 1537, 1..."
8,361346,"[158392, 1537, 1..."
9,366677,"[158392, 1537, 1..."


In [24]:
final = pd.concat([cold_pred, candidates])
final

Unnamed: 0,user_id,predictions
0,14910,"[158392, 1537, 1..."
1,31656,"[158392, 1537, 1..."
2,33612,"[158392, 1537, 1..."
3,78730,"[158392, 1537, 1..."
4,88598,"[158392, 1537, 1..."
...,...,...
199980,1000160,"[155049, 58272, ..."
199981,1000165,"[158392, 1537, 8..."
199982,1000166,"[14760, 183888, ..."
199983,1000168,"[76464, 117035, ..."


In [25]:
final.to_parquet('lightfm_sub_12jan_need.parquet.gzip', compression='gzip')