In [1]:
import os
import utils
import lightfm
import hueristics
import pickle

import sklearn.model_selection
import rectools.models
import rectools.metrics

import numpy as np
import pandas as pd
import tqdm.notebook as tqdm

from pandarallel import pandarallel
pandarallel.initialize()


from IPython.display import clear_output, HTML, display


RANDOM_STATE = 1337
NUM_JOBS = -1

os.environ['DIR'] = "/home/ml/softezza_ml/"
os.environ['DB_ENDPOINT'] = "apollo-api-staging-f82be878-d243-4113-8052-ef36565618e0.cpljy7lbflfq.eu-west-1.rds.amazonaws.com"
os.environ['DB_PORT'] = '3306'
os.environ['DB_USER'] = "admin"
os.environ['DB_PASSWORD'] = 'zsfZMSpS0SGz8gp203QJ4r3bqpVNxwmG'
os.environ['DB_NAME'] = "vapor"

DATA_DIR = os.path.join(os.environ['DIR'], 'data')
REPORTS_DIR = os.path.join(os.environ['DIR'], 'reports')

DATA_DIR, REPORTS_DIR

INFO: Pandarallel will run on 6 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


('/home/ml/softezza_ml/data', '/home/ml/softezza_ml/reports')

In [2]:
config = utils.DataConfig(
    split_strategy=utils.TimeSortSplit(num_interactions='all', splits=(.8, .2)),
    filter_strategy=[
        utils.MinNumInteractionsFilter(10, 500),
        utils.OnlyLastInteractionsFilter('user_id', 20)
    ],
    features_config=utils.FeaturesConfig(use_labels=False)
)

data = utils.load_data(config)

data.train_interactions.head()

Data after filter:
Len of train interactions with period [['2019-10-11T03:09:32.000000000'] / ['2023-09-14T06:04:32.000000000']] - 6769224
Len of test interactions with period [['2023-09-14T06:04:32.000000000'] / ['2023-10-23T12:10:54.000000000']] - 1692307
Num of uniq users 423917Num of uniq items 8260


Unnamed: 0,user_id,item_id,timestamp,weight,index
0,3518601,tt8201852,2023-09-14 06:04:32,0.947491,1692307
1,80783501,tt0455944,2023-09-14 06:04:30,0.261237,1692308
2,17678705,tt10366206,2023-09-14 06:04:28,0.908876,1692309
3,45173701,tt14308636,2023-09-14 06:04:27,0.01,1692310
4,52970501,tt0468569,2023-09-14 06:04:25,0.170943,1692311


## Grid-Search LightFM No-Features

### Общая проверка по параметрам 1

In [10]:
grids = {
    'lightfm': {
        'model': lightfm.LightFM,
        'grid': {
            'no_components': [50, 100, 200],
            'loss': ['warp'],
            'max_sampled': [10, 15, 20],
            'epochs': [1, 3, 5]
        }
    }
}

results = []

for label, params in grids.items():
    grid = sklearn.model_selection.ParameterGrid(params['grid'])
    train_dataset, _ = data.get_rectools_dataset()

    for train_index, p in enumerate(grid):
        print(f"Train {train_index+1}/{len(grid)}")
        
        epochs = p.pop('epochs')

        model = rectools.models.LightFMWrapperModel(params['model'](**p), epochs=epochs, num_threads=12)
        model.fit(train_dataset)
 
        recos = model.recommend(
            k=10,
            users=data.all_users,
            dataset=train_dataset,
            filter_viewed=True,
            add_rank_col=True,
        )

        metrics = rectools.metrics.calc_metrics(
            {
                'MAP@10': rectools.metrics.MAP(10),
                'Recall@10': rectools.metrics.Recall(10),
                'Siren@10': rectools.metrics.Serendipity(10),
                'MIUF@10': rectools.metrics.MeanInvUserFreq(10)
            },
            reco=recos,
            interactions=data.test_interactions,
            prev_interactions=data.train_interactions,
            catalog=data.all_items
        )
        metrics['PopInt@10'] = utils.PopularIntersect(10).calc(reco=recos, prev_interactions=data.train_interactions)
        metrics['RecallNoPop@10'] = utils.RecallNoPop(10).calc(reco=recos, interactions=data.test_interactions, prev_interactions=data.train_interactions)
        metrics['model'] = label
        metrics = {**metrics, **p, 'epochs': epochs}
        results.append(metrics)
        
        clear_output(wait=True)
        display(HTML(pd.DataFrame.from_records(results).fillna('').head(100).to_html()))


grid_data = pd.DataFrame.from_records(results).fillna('')
grid_data.to_csv('grid_report.csv', index=False)
clear_output(wait=True)

grid_data.sort_values(['RecallNoPop@10', 'PopInt@10', 'Recall@10', 'MAP@10'], ascending=[False, True, False, False]).head(100)

Unnamed: 0,Recall@10,MAP@10,MIUF@10,Siren@10,PopInt@10,RecallNoPop@10,model,loss,max_sampled,no_components,epochs
25,0.176187,0.076549,4.167758,9.4e-05,0.322088,0.048418,lightfm,warp,20,100,5
26,0.173279,0.073879,4.315791,0.000114,0.316138,0.048189,lightfm,warp,20,200,5
23,0.177518,0.075739,4.226113,0.000107,0.327849,0.047663,lightfm,warp,15,200,5
24,0.176016,0.076831,4.029857,7e-05,0.330862,0.046867,lightfm,warp,20,50,5
22,0.177327,0.076978,4.105724,8.6e-05,0.336176,0.046695,lightfm,warp,15,100,5
17,0.183993,0.080377,4.000955,8.6e-05,0.363181,0.046114,lightfm,warp,20,200,3
20,0.181496,0.078648,4.087813,9.9e-05,0.352831,0.046098,lightfm,warp,10,200,5
21,0.178822,0.078345,3.957804,6.6e-05,0.348019,0.04593,lightfm,warp,15,50,5
16,0.184208,0.080856,3.892584,6.8e-05,0.369999,0.0459,lightfm,warp,20,100,3
19,0.180666,0.07814,3.994206,7.9e-05,0.357836,0.045085,lightfm,warp,10,100,5


### Общая проверка по параметрам 2

In [13]:
grids = {
    'lightfm': {
        'model': lightfm.LightFM,
        'grid': {
            'no_components': [150],
            'loss': ['warp'],
            'max_sampled': [15, 20, 25],
            'epochs': [5, 6, 7]
        }
    }
}

results = []

for label, params in grids.items():
    grid = sklearn.model_selection.ParameterGrid(params['grid'])
    train_dataset, _ = data.get_rectools_dataset()

    for train_index, p in enumerate(grid):
        print(f"Train {train_index+1}/{len(grid)}")
        
        epochs = p.pop('epochs')

        model = rectools.models.LightFMWrapperModel(params['model'](**p), epochs=epochs, num_threads=12)
        model.fit(train_dataset)
 
        recos = model.recommend(
            k=10,
            users=data.all_users,
            dataset=train_dataset,
            filter_viewed=True,
            add_rank_col=True,
        )

        metrics = rectools.metrics.calc_metrics(
            {
                'MAP@10': rectools.metrics.MAP(10),
                'Recall@10': rectools.metrics.Recall(10),
                'Siren@10': rectools.metrics.Serendipity(10),
                'MIUF@10': rectools.metrics.MeanInvUserFreq(10)
            },
            reco=recos,
            interactions=data.test_interactions,
            prev_interactions=data.train_interactions,
            catalog=data.all_items
        )
        metrics['PopInt@10'] = utils.PopularIntersect(10).calc(reco=recos, prev_interactions=data.train_interactions)
        metrics['RecallNoPop@10'] = utils.RecallNoPop(10).calc(reco=recos, interactions=data.test_interactions, prev_interactions=data.train_interactions)
        metrics['model'] = label
        metrics = {**metrics, **p, 'epochs': epochs}
        results.append(metrics)
        
        clear_output(wait=True)
        display(HTML(pd.DataFrame.from_records(results).fillna('').head(100).to_html()))


grid_data = pd.DataFrame.from_records(results).fillna('')
grid_data.to_csv('grid_report.csv', index=False)
clear_output(wait=True)

grid_data.sort_values(['RecallNoPop@10', 'PopInt@10', 'Recall@10', 'MAP@10'], ascending=[False, True, False, False]).head(100)

Unnamed: 0,Recall@10,MAP@10,MIUF@10,Siren@10,PopInt@10,RecallNoPop@10,model,loss,max_sampled,no_components,epochs
5,0.170732,0.072919,4.447279,0.00012,0.292028,0.049849,lightfm,warp,25,150,6
8,0.167907,0.071413,4.551884,0.000126,0.280736,0.049716,lightfm,warp,25,150,7
7,0.171279,0.073304,4.469367,0.00012,0.29117,0.049346,lightfm,warp,20,150,7
4,0.171326,0.072835,4.370964,0.000113,0.302832,0.049056,lightfm,warp,20,150,6
2,0.174302,0.074258,4.324271,0.00011,0.308312,0.049034,lightfm,warp,25,150,5
1,0.176486,0.07599,4.265632,0.000106,0.319672,0.048793,lightfm,warp,20,150,5
6,0.172893,0.073595,4.372436,0.000113,0.304321,0.048625,lightfm,warp,15,150,7
3,0.173781,0.074005,4.277359,0.000107,0.318255,0.048131,lightfm,warp,15,150,6
0,0.178123,0.077711,4.177651,9.9e-05,0.331425,0.047631,lightfm,warp,15,150,5


### Проверка user- и item-alpha

In [14]:
grids = {
    'lightfm': {
        'model': lightfm.LightFM,
        'grid': {
            'no_components': [150],
            'loss': ['warp'],
            'max_sampled': [25],
            'epochs': [5, 6, 7],
        }
    }
}

results = []

for label, params in grids.items():
    grid = sklearn.model_selection.ParameterGrid(params['grid'])
    train_dataset, _ = data.get_rectools_dataset()

    for train_index, p in enumerate(grid):
        print(f"Train {train_index+1}/{len(grid)}")
        
        epochs = p.pop('epochs')

        model = rectools.models.LightFMWrapperModel(params['model'](**p), epochs=epochs, num_threads=12)
        model.fit(train_dataset)
 
        recos = model.recommend(
            k=10,
            users=data.all_users,
            dataset=train_dataset,
            filter_viewed=True,
            add_rank_col=True,
        )

        metrics = rectools.metrics.calc_metrics(
            {
                'MAP@10': rectools.metrics.MAP(10),
                'Recall@10': rectools.metrics.Recall(10),
                'Siren@10': rectools.metrics.Serendipity(10),
                'MIUF@10': rectools.metrics.MeanInvUserFreq(10)
            },
            reco=recos,
            interactions=data.test_interactions,
            prev_interactions=data.train_interactions,
            catalog=data.all_items
        )
        metrics['PopInt@10'] = utils.PopularIntersect(10).calc(reco=recos, prev_interactions=data.train_interactions)
        metrics['RecallNoPop@10'] = utils.RecallNoPop(10).calc(reco=recos, interactions=data.test_interactions, prev_interactions=data.train_interactions)
        metrics['model'] = label
        metrics = {**metrics, **p, 'epochs': epochs}
        results.append(metrics)
        
        clear_output(wait=True)
        display(HTML(pd.DataFrame.from_records(results).fillna('').head(100).to_html()))


grid_data = pd.DataFrame.from_records(results).fillna('')
grid_data.to_csv('grid_report.csv', index=False)
clear_output(wait=True)

grid_data.sort_values(['RecallNoPop@10', 'PopInt@10', 'Recall@10', 'MAP@10'], ascending=[False, True, False, False]).head(100)

Unnamed: 0,Recall@10,MAP@10,MIUF@10,Siren@10,PopInt@10,RecallNoPop@10,model,item_alpha,loss,max_sampled,no_components,user_alpha,epochs
4,0.236613,0.081773,2.603238,6e-06,0.813722,0.017378,lightfm,0.0001,warp,25,150,0.0001,6
0,0.236467,0.08667,2.614927,6e-06,0.813331,0.017266,lightfm,0.0001,warp,25,150,0.0001,5
8,0.232195,0.078894,2.774314,6e-06,0.801425,0.016671,lightfm,0.0001,warp,25,150,0.0001,7
5,0.008003,0.0021,11.376978,9e-06,0.01704275,0.000895,lightfm,0.0001,warp,25,150,0.001,6
9,0.007586,0.001918,11.900195,7e-06,0.01734048,0.000809,lightfm,0.0001,warp,25,150,0.001,7
1,0.008025,0.002016,11.334465,8e-06,0.01671981,0.000788,lightfm,0.0001,warp,25,150,0.001,5
3,0.000502,0.000131,11.263582,1.2e-05,0.0,0.000625,lightfm,0.001,warp,25,150,0.001,5
10,0.000337,7.7e-05,11.632719,7e-06,8.224473e-07,0.000423,lightfm,0.001,warp,25,150,0.0001,7
11,0.000266,7.1e-05,12.012013,8e-06,2.741491e-06,0.000328,lightfm,0.001,warp,25,150,0.001,7
7,0.000242,7.8e-05,11.799239,6e-06,2.467342e-06,0.000291,lightfm,0.001,warp,25,150,0.001,6


## Grid-Search LightFM With-Features

In [3]:
def feature2columns(fstr: str) -> list:
    if fstr == 'device':
        return ['unknown', 'android', 'ios']
    
    if fstr == 'account_type':
        return ['account_type_-1.0', 'account_type_6.0', 'account_type_3.0', 'account_type_1.0', 'account_type_12.0', 'account_type_0.0']
    
    if fstr == 'year':
        return ['-1980', '2000-2010', '2010-2020', '1980-2000', '+2020']
    
    if fstr == 'rating':
        return [ '6.0-8.0', '8.0+', '-6.0']

    if fstr == 'genres':
        return ['Sci-Fi', 'Adventure', 'Action', 'Comedy', 'Crime', 'Romance', 'Fantasy', 'Thriller', 'Mystery', 'Drama']

    if fstr == 'time':
        return ['short', 'normal', 'long']

    if fstr == 'MPPA':
        return ['R', 'PG-13', 'TV-MA', 'TV-14', 'N']

    return [fstr]


def filter_cols(fcols: list, features: pd.DataFrame):
    if 'user_id' in features.columns:
        return list(set(features.columns).intersection(set(fcols))) + ['user_id',]
    else:
        return list(set(features.columns).intersection(set(fcols))) + ['item_id',]

### Проверка с одной фичей 1

In [5]:
grids = {
    'lightfm': {
        'model': lightfm.LightFM,
        'grid': {
            'no_components': [125, 150, 175],
            'loss': ['warp'],
            'max_sampled': [22, 25, 27],
            'epochs': [6, 7],
            'with_feature': ['device', 'account_type', 'year', 'rating', 'genres', 'time', 'MPPA', 'lifetime']
        }
    }
}

results = []

for label, params in grids.items():
    grid = sklearn.model_selection.ParameterGrid(params['grid'])

    for train_index, p in enumerate(grid):
        print(f"Train {train_index+1}/{len(grid)}")
        
        epochs = p.pop('epochs')
        feature_str = p.pop('with_feature')
        feature_cols = feature2columns(feature_str)

        user_features = data.user_features[filter_cols(feature_cols, data.user_features)]
        item_features = data.item_features[filter_cols(feature_cols, data.item_features)]

        train_dataset = data.get_rectools_dataset(item_features, user_features)

        model = rectools.models.LightFMWrapperModel(params['model'](**p), epochs=epochs, num_threads=12)
        model.fit(train_dataset)
 
        recos = model.recommend(
            k=10,
            users=train_dataset.user_id_map.external_ids,
            dataset=train_dataset,
            filter_viewed=True,
            add_rank_col=True,
        )

        metrics = rectools.metrics.calc_metrics(
            {
                'MAP@10': rectools.metrics.MAP(10),
                'Recall@10': rectools.metrics.Recall(10),
                'Siren@10': rectools.metrics.Serendipity(10),
                'MIUF@10': rectools.metrics.MeanInvUserFreq(10)
            },
            reco=recos,
            interactions=data.test_interactions,
            prev_interactions=data.train_interactions,
            catalog=data.all_items
        )
        metrics['PopInt@10'] = utils.PopularIntersect(10).calc(reco=recos, prev_interactions=data.train_interactions)
        metrics['RecallNoPop@10'] = utils.RecallNoPop(10).calc(reco=recos, interactions=data.test_interactions, prev_interactions=data.train_interactions)
        metrics['model'] = label
        metrics = {**metrics, **p, 'epochs': epochs, 'with_feature': feature_str}
        results.append(metrics)
        
        clear_output(wait=True)
        display(HTML(pd.DataFrame.from_records(results).fillna('').head(100).to_html()))


grid_data = pd.DataFrame.from_records(results).fillna('')
grid_data.to_csv('lightfm_grid_report.csv', index=False)
clear_output(wait=True)

grid_data.sort_values(['RecallNoPop@10', 'PopInt@10', 'Recall@10', 'MAP@10'], ascending=[False, True, False, False]).head(100)

Unnamed: 0,Recall@10,MAP@10,MIUF@10,Siren@10,PopInt@10,RecallNoPop@10,model,loss,max_sampled,no_components,epochs,with_feature
7,0.134157,0.060028,4.085855,0.000159,0.325782,0.064682,lightfm,warp,22,125,6,lifetime
55,0.133237,0.060117,4.149332,0.000166,0.315199,0.064254,lightfm,warp,27,125,6,lifetime
127,0.131231,0.058236,4.273753,0.000189,0.300931,0.063966,lightfm,warp,27,125,7,lifetime
23,0.131436,0.058123,4.249661,0.000187,0.307612,0.063671,lightfm,warp,22,175,6,lifetime
31,0.134345,0.060430,4.080820,0.000158,0.328111,0.063590,lightfm,warp,25,125,6,lifetime
...,...,...,...,...,...,...,...,...,...,...,...,...
133,0.124263,0.053525,4.542418,0.000220,0.283970,0.059918,lightfm,warp,27,150,7,time
61,0.126602,0.054954,4.433134,0.000206,0.296487,0.059874,lightfm,warp,27,150,6,time
70,0.126256,0.055218,4.493920,0.000199,0.297645,0.059868,lightfm,warp,27,175,6,MPPA
123,0.124641,0.054673,4.477716,0.000204,0.286029,0.059847,lightfm,warp,27,125,7,rating


In [10]:
(
    grid_data.sort_values(['RecallNoPop@10', 'PopInt@10', 'Recall@10', 'MAP@10'], ascending=[False, True, False, False])
        .head(50)
        .style.text_gradient(
            axis=0,
            cmap='PiYG',
            subset=['Recall@10', 'MAP@10', 'MIUF@10', 'Siren@10', 'PopInt@10', 'RecallNoPop@10']
        )
)

Unnamed: 0,Recall@10,MAP@10,MIUF@10,Siren@10,PopInt@10,RecallNoPop@10,model,loss,max_sampled,no_components,epochs,with_feature
7,0.134157,0.060028,4.085855,0.000159,0.325782,0.064682,lightfm,warp,22,125,6,lifetime
55,0.133237,0.060117,4.149332,0.000166,0.315199,0.064254,lightfm,warp,27,125,6,lifetime
127,0.131231,0.058236,4.273753,0.000189,0.300931,0.063966,lightfm,warp,27,125,7,lifetime
23,0.131436,0.058123,4.249661,0.000187,0.307612,0.063671,lightfm,warp,22,175,6,lifetime
31,0.134345,0.06043,4.08082,0.000158,0.328111,0.06359,lightfm,warp,25,125,6,lifetime
79,0.132331,0.058992,4.200548,0.000175,0.314591,0.063506,lightfm,warp,22,125,7,lifetime
63,0.132072,0.059077,4.247295,0.00018,0.307119,0.063477,lightfm,warp,27,150,6,lifetime
71,0.130095,0.057681,4.335527,0.000198,0.298057,0.063377,lightfm,warp,27,175,6,lifetime
119,0.128726,0.056937,4.422664,0.000211,0.290346,0.063292,lightfm,warp,25,175,7,lifetime
103,0.132354,0.059235,4.210314,0.000178,0.312706,0.063253,lightfm,warp,25,125,7,lifetime


### Проверка с одной фичей 2

In [4]:
grids = {
    'lightfm': {
        'model': lightfm.LightFM,
        'grid': {
            'no_components': [125, 175, 200],
            'loss': ['warp'],
            'max_sampled': [22, 25, 27],
            'epochs': [6, 7],
        }
    }
}

results = []

for label, params in grids.items():
    grid = sklearn.model_selection.ParameterGrid(params['grid'])

    feature_cols = feature2columns('lifetime')
    user_features = data.user_features[filter_cols(feature_cols, data.user_features)]
    item_features = data.item_features[filter_cols(feature_cols, data.item_features)]
    train_dataset = data.get_rectools_dataset(item_features, user_features)

    for train_index, p in enumerate(grid):
        print(f"Train {train_index+1}/{len(grid)}")
        
        epochs = p.pop('epochs')
        model = rectools.models.LightFMWrapperModel(params['model'](**p), epochs=epochs, num_threads=12)
        model.fit(train_dataset)
 
        recos = model.recommend(
            k=10,
            users=train_dataset.user_id_map.external_ids,
            dataset=train_dataset,
            filter_viewed=True,
            add_rank_col=True,
        )

        metrics = rectools.metrics.calc_metrics(
            {
                'MAP@10': rectools.metrics.MAP(10),
                'Recall@10': rectools.metrics.Recall(10),
                'Siren@10': rectools.metrics.Serendipity(10),
                'MIUF@10': rectools.metrics.MeanInvUserFreq(10)
            },
            reco=recos,
            interactions=data.test_interactions,
            prev_interactions=data.train_interactions,
            catalog=data.all_items
        )
        metrics['PopInt@10'] = utils.PopularIntersect(10).calc(reco=recos, prev_interactions=data.train_interactions)
        metrics['RecallNoPop@10'] = utils.RecallNoPop(10).calc(reco=recos, interactions=data.test_interactions, prev_interactions=data.train_interactions)
        metrics['model'] = label
        metrics = {**metrics, **p, 'epochs': epochs, 'with_feature': 'lifetime'}
        results.append(metrics)
        
        clear_output(wait=True)
        display(HTML(pd.DataFrame.from_records(results).fillna('').head(100).to_html()))


grid_data = pd.DataFrame.from_records(results).fillna('')
grid_data.to_csv('lightfm_grid_report_2.csv', index=False)
clear_output(wait=True)

(
    grid_data.sort_values(['RecallNoPop@10', 'PopInt@10', 'Recall@10', 'MAP@10'], ascending=[False, True, False, False])
        .head(50)
        .style.text_gradient(
            axis=0,
            cmap='PiYG',
            subset=['Recall@10', 'MAP@10', 'MIUF@10', 'Siren@10', 'PopInt@10', 'RecallNoPop@10']
        )
)

Unnamed: 0,Recall@10,MAP@10,MIUF@10,Siren@10,PopInt@10,RecallNoPop@10,model,loss,max_sampled,no_components,epochs,with_feature
6,0.131547,0.058467,4.277105,0.000167,0.300746,0.062004,lightfm,warp,27,125,6,lifetime
3,0.129646,0.05746,4.3022,0.000171,0.292967,0.06192,lightfm,warp,25,125,6,lifetime
5,0.126615,0.05526,4.521738,0.000209,0.281923,0.061771,lightfm,warp,25,200,6,lifetime
7,0.127824,0.056147,4.477814,0.000201,0.2803,0.061683,lightfm,warp,27,175,6,lifetime
8,0.125952,0.054545,4.551806,0.000209,0.2737,0.06163,lightfm,warp,27,200,6,lifetime
15,0.128684,0.057182,4.398405,0.000186,0.285121,0.061604,lightfm,warp,27,125,7,lifetime
4,0.128688,0.056548,4.453019,0.000196,0.28771,0.061312,lightfm,warp,25,175,6,lifetime
0,0.130723,0.058617,4.224777,0.00016,0.308431,0.061092,lightfm,warp,22,125,6,lifetime
9,0.127875,0.056176,4.366917,0.000179,0.290262,0.061001,lightfm,warp,22,125,7,lifetime
12,0.129375,0.057004,4.379948,0.000183,0.292768,0.060914,lightfm,warp,25,125,7,lifetime


## Сохранение лучших моделей

In [14]:
def format_name(_p: pd.Series):
    return str(_p.to_dict()).replace("'", '').replace(' ', '_').replace('}', '').replace('{', '').replace(',', '')

best_models = [6, 3, 5, 7, 8, 17]
best_params = grid_data.iloc[best_models][['loss', 'max_sampled', 'no_components', 'epochs', 'with_feature']].iterrows()

feature_cols = feature2columns('lifetime')
user_features = data.user_features[filter_cols(feature_cols, data.user_features)]
item_features = data.item_features[filter_cols(feature_cols, data.item_features)]
train_dataset = data.get_rectools_dataset(item_features, user_features)

for _, p in best_params:
    epochs = p.pop('epochs')
    feature = p.pop('with_feature')

    model = rectools.models.LightFMWrapperModel(lightfm.LightFM(**p), epochs=epochs, num_threads=12)
    model.fit(train_dataset)
    p['epochs'] = epochs
    p['with_feature'] = feature

    with open(os.path.join("/home/ml/softezza_ml/models/lightfm", f"""lightfm_{format_name(p)}.pickle"""), mode='xb') as f:
        pickle.dump(model, f)

grid_data.iloc[best_models].to_csv(os.path.join('/home/ml/softezza_ml/models/lightfm', 'meta.csv'));