In [40]:
import torch
import pandas as pd
import numpy as np

In [4]:
df = pd.read_csv('experiments.csv')

In [97]:
class TorchRecSys:

    def __init__(self):
        self.groups = []
        self.groups_keys = []
        self.n_groups = 0
        self.n_users = 0
        self.users = None

    def load_from_df(self, df):
        """
        Функция построения талбицы пользователей
        По сути обучение
        """
        self.groups = df['community_id'].unique()
        self.groups_keys = dict([[value, index] for index, value in enumerate(self.groups)])
        self.n_groups = len(self.groups)

        users = []
        self.users_index = df['customer_id'].unique()
        for user in self.users_index:
            user_vector = torch.zeros([self.n_groups])
            for line in df[df['customer_id'] == user].iterrows():
                user_vector[self.groups_keys[line[1]['community_id']]] +=1
            users.append(user_vector)

        self.users = torch.stack(users)
        self.n_users = len(self.users)

    def predict(self, target_user_id, n):
        """
        :param: target_user_id - id пользователя, он должен был быть в обучающей выборке
        :param: n - сколько каналов рекомендовать

        Функция возвращает топ-n каналов, на которые пользователь ещё не подписан
        """
        target_user = self.users[np.where(self.users_index == target_user_id)][0]
        return self.predict_for_user_vector(target_user, n)


    def predict_for_user_vector(self, target_user, n):
        sim = torch.nn.functional.cosine_similarity(torch.stack([target_user for _ in range(self.n_users)]), self.users)
        res = torch.mv(self.users.T, sim)
        channels_id = torch.topk(res * ((res/max(res)).ceil() - target_user), n).indices.tolist()
        return [self.groups[channel_id] for channel_id in channels_id]

    # def save_model(self, file_path):
    #     torch.save(self.users, file_path)

    # def load_model(self, file_path):
    #     self.users = torch.load(file_path)


SyntaxError: ignored

In [92]:
recsys = TorchRecSys()

In [93]:
recsys.load_from_df(df)

In [94]:
df['customer_id'].iloc[0]

'947224211267aefcc2e3e9c524fdf46ce329bc638e8bf1b7ce5a43f58d9dba320c133d58b3ca781fac348fa4beaec4e39b9717168b3df67a1d187a706523cbf2'

In [96]:
recsys.predict("947224211267aefcc2e3e9c524fdf46ce329bc638e8bf1b7ce5a43f58d9dba320c133d58b3ca781fac348fa4beaec4e39b9717168b3df67a1d187a706523cbf2", 7)

['4e035d2ac419b6c6ea89c6a4a2e8157177b3ecaf85df612d07596ba300f389712b6a9f3f1552ba1bdbc994b3219c6c12fb6ab029840f247a9520006cad2500de',
 '5bf72e0f30296cdf893cac3b9f679c809f9c491fd3f471e18d9b018b61e729d87d44c9885a5c0bbc064b0853f83f5f3478cb2cf6abc205085eed11baac23abed',
 '507eab83ef051545226393d21d65cd50b5b6d91de7a4221953f053ec05dc58b42d83a729aa246cbf698adf99cb86776eceaffdd23f22dd9d5978250e1451cda4',
 '4ea139a01a60ff0b309771b6d01b08a4f9ec33965f3fbfe84efc85ae889555d1e432ef0bb135d168be6dee4ff2eb57138b796d171a50d1243ba46c97a646d11c',
 '4e3418966b6fdd276b16b6f2c5349899687553f53047f27ed5e6b440d4b46cfa33416ca41f3445ae8869778d229e663e0dd574daa7e473a408390567397f6d7e',
 '3c46dd6baed8a58687695585666f41703907161a1ec6425de845732f263bf34b13531a50acb6a37d2586e3fe560ea4c7f16b61addb16e73fcf22e3dc8b10c934',
 '3bbc950620d9c02051d2040f4194852efa4b787f5e373c8a14bb2f7808553600756cde1466b43e72bf7fa199b15a510415050e1381f3ff678708bc1977ca210b']