In [None]:
import numpy as np
import pandas as pd
from sklearn import metrics
from numba import jit
import scipy as sp

class LMLFM:
    @jit()
    def __init__(self, num_user:int, num_item:int, num_feature:int, w = None, iterations:int = 20,fraction:float = 1,verbose:bool = True,
                 seed:int = 0,alpha:float = 1,alpha_beta0:int = None,beta_beta0:int = 1e10,w_update_order = None,reg:str = 'lasso',
                 epsilon:float = 1e-5,initb:float = 1,initial:bool = True,initialR:int = 5):
        self.num_user = num_user
        self.num_item = num_item
        self.num_feature = num_feature
        self.iterations = iterations
        self.fraction = fraction
        self.verbose = verbose
        self.dim = num_user + num_item
        self.initial = initial  # whether to use pilot runs
        self.initialR = initialR  # the iterations of pilot runs
        self.b_u = np.repeat(initb,num_feature)
        self.b_i = np.repeat(initb,num_feature)
        self.alpha_mu0 = -1e10
        self.beta_mu0 = 1e10
        self.alpha = alpha
        
        if alpha_beta0 is None:
            if reg == 'lasso':
                self.alpha_beta0 = 1e-4
            else:
                self.alpha_beta0 = 0
        self.beta_beta0 = beta_beta0
        self.epsilon = epsilon
        np.random.seed(seed)
        self.w_update_order = w_update_order
        if w is None:
            self.w = np.zeros([self.dim,num_feature])
#             self.w = np.random.random([self.dim,num_feature])
        else:
            self.w = w
        self.mu_u = np.zeros(num_feature)
        self.mu_i = np.zeros(num_feature)
        self.reg = reg  # the regularization. Currently only lasso and l2 are supported.
        
        self.alpha_y0  = 1
        self.beta_y0 = 1

        # for inverse gamma prior of b_v
        self.alpha_b = 1  # this hyper parameter is fixed
        self.beta_bu = 1 
        self.beta_bi = 1

        self.lossChain = [] # the train recording the normalized log of posterior density
        self.lastLoss = 0 # this is to prevent overshooting

        self.intervention = False
        self.userGroups = None
        self.itemGroups = None
        
    def copyFrom(self,model):
        """
        copy all the parameters from a pre-computed MixRFM model
        """
        self.w = np.copy(model.w)
        self.b_u = np.copy(model.b_u)
        self.b_i = np.copy(model.b_i)
        self.beta_bu = model.beta_bu
        self.beta_bi = model.beta_bi
        self.alpha = model.alpha
        self.mu_u = model.mu_u
        self.mu_i = model.mu_i
        
    @jit()
    def evaluate(self,test):
        """
        compute the normalized log of joint distribution, RMSE loss, and the r2 or auc score
        """
        itemGroups = test.groupby('iid')
        pred = None
        rating = None
        for name,itemG in itemGroups:
            ratings = np.array(itemG['rating'].values)
            if pred is None:
                pred = self.predict_block(itemG)
            else:
                pred = np.concatenate([pred,self.predict_block(itemG)])

            if rating is None:
                rating = ratings
            else:
                rating = np.concatenate([rating,ratings])
        
        likelihood,loss = self._log_fullPos(rating,pred)
        r2 = metrics.r2_score(rating,pred)
        return(likelihood,loss,r2)
        
    @jit()
    def _log_fullPos(self,rating,pred):
        likelihood_y = self.likelihood(rating,pred,'y')
        likelihood_thetau = self.likelihood(rating,pred,'thetau')
        likelihood_thetai = self.likelihood(rating,pred,'thetai')
        likelihood_bu = self.likelihood(rating,pred,'bu')
        likelihood_bi = self.likelihood(rating,pred,'bi')
        likelihood_alpha = self.likelihood(rating,pred,'alpha')
        likelihood_betau = 0
        likelihood_betai = 0
        full_pos = likelihood_y + likelihood_thetau + likelihood_thetai + likelihood_alpha + likelihood_bu + likelihood_bi + likelihood_betau + likelihood_betai
        sumloss = -(2 * likelihood_y - len(pred) * np.log(self.alpha)) / self.alpha
        loss = np.sqrt(sumloss / len(pred))
        return full_pos / (self.num_user * self.num_feature),loss
    
    @jit()
    def _likelihood_data(self,test):
        itemGroups = test.groupby('iid')
        pred = None
        rating = None
        for name,itemG in itemGroups:
            ratings = np.array(itemG['rating'].values)
            if pred is None:
                pred = self.predict_block(itemG)
            else:
                pred = np.concatenate([pred,self.predict_block(itemG)])
            if rating is None:
                rating = ratings
            else:
                rating = np.concatenate([rating,ratings])
        return self.likelihood(rating,pred,'y')
        
    @jit()
    def likelihood(self,rating,pred,flag):
        """
        compute the log likelihood of each component
        """
        if flag == 'y':
            col = rating-pred
            sumloss = col @ col
            likelihood_y = len(pred) * np.log(self.alpha) * 0.5 - self.alpha * 0.5 * sumloss
            return likelihood_y
        elif flag =='thetau':
            tmpsum = 0
            if self.reg == 'lasso':
                for u in range(self.num_user):
                    for k in range(self.num_feature):
                        tmpsum += (-np.log(2 * self.b_u[k])- np.abs(self.w[u,k] - self.mu_u[k]) / self.b_u[k])
            else:
                for u in range(self.num_user):
                    for k in range(self.num_feature):
                        tmpsum += (-0.5 * np.log(self.b_u[k])- 0.5 * (self.w[u,k] ** 2) / self.b_u[k])
            return tmpsum
        elif flag =='thetai':
            tmpsum = 0
            if self.reg == 'lasso':
                for u in range(self.num_item):
                    for k in range(self.num_feature):
                        tmpsum += (-np.log(2 * self.b_i[k]) -np.abs(self.w[u + self.num_user,k] - self.mu_i[k]) / self.b_i[k])
            else:
                for u in range(self.num_item):
                    for k in range(self.num_feature):
                        tmpsum += (-0.5 * np.log(self.b_i[k]) - 0.5 * (self.w[u + self.num_user,k] ** 2) / self.b_i[k])
            return tmpsum
        elif flag == 'bu':
            sum_log_bu = np.sum(np.log(self.b_u))
            return - (self.alpha_b + 1) * sum_log_bu - self.beta_bu * np.sum(1/self.b_u) + self.num_feature * self.alpha_b * np.log(self.beta_bu)
        elif flag =='bi':
            sum_log_bi = np.sum(np.log(self.b_i))
            return - (self.alpha_b + 1) * sum_log_bi - self.beta_bi * np.sum(1/self.b_i) + self.num_feature * self.alpha_b * np.log(self.beta_bi)
        elif flag == 'alpha':
            return (self.alpha_y0 - 1) * np.log(self.alpha) - self.beta_y0 * self.alpha
        elif flag == 'betau':
            return - self.precision_beta * 0.5 * (self.beta_bu - self.mu_beta) ** 2
        elif flag == 'betai':
            return - self.precision_beta * 0.5 * (self.beta_bi - self.mu_beta) ** 2
        else:
            return 0
            
    @jit()
    def predict_w(self,data,w):
        """
        prediction
        """
        pred = np.zeros(data.shape[0])
        i = 0
        for index, row in data.iterrows():
            pred[i] = w[row['uid']] @ w[row['iid']] + row['feature'] @ (w[row['uid']] + w[row['iid']])
            i += 1
        return pred

    @jit()
    def updateAlpha(self,data):
        if self.itemGroups is None:
            itemGroups = data.groupby('iid')
        else:
            itemGroups = self.itemGroups
        pred = None
        rating = None
        for iid,itemG in itemGroups:
            lv_i = self.w[iid,]
            users = itemG['uid'].values
            ratings = itemG['rating'].values
            featureMatrix = np.array(list(itemG['feature'])).reshape(len(users),self.num_feature)
            if pred is None:
                pred = self.predict_block(itemG)
            else:
                pred = np.concatenate([pred,self.predict_block(itemG)])

            if rating is None:
                rating = ratings
            else:
                rating = np.concatenate([rating,ratings])
        
        sumOfError = 0.5 * np.sum(np.square(rating-pred))
        self.alpha = (self.alpha_y0 + data.shape[0] * 0.5 - 1) / (self.beta_y0 + sumOfError)
        
    @jit()
    def updateMu(self,isUser):
        if self.reg =='lasso':
            if isUser:
                for i in range(self.num_feature):
                    tmp = np.median(self.w[:self.num_user,i])
                    if tmp < self.alpha_mu0:
                        self.mu_u[i] = self.alpha_mu0
                    elif tmp > self.beta_mu0:
                        self.mu_u[i] = self.beta_mu0
                    else:
                        self.mu_u[i] = tmp
            else:
                for i in range(self.num_feature):
                    tmp = np.median(self.w[self.num_user:,i])
                    if tmp < self.alpha_mu0:
                        self.mu_i[i] = self.alpha_mu0
                    elif tmp > self.beta_mu0:
                        self.mu_i[i] = self.beta_mu0
                    else:
                        self.mu_i[i] = tmp
    
    @jit()
    def updateB_invGamma(self,isUser):
        if self.reg == 'lasso':
            if isUser:
                bef = self.likelihood(None,None,'thetau') + self.likelihood(None,None,'bu')
                res = []
                for i in range(self.num_feature):
                    res.append((self.beta_bu + np.sum(np.abs(self.w[:self.num_user,i] - self.mu_u[i]))) / (self.alpha_b + self.num_user + 1))            
                self.b_u = np.array(res)

                aft = self.likelihood(None,None,'thetau') + self.likelihood(None,None,'bu')
                if aft < bef:
                    print('problem occur while updating b_u')
            else:
                bef = self.likelihood(None,None,'thetai') + self.likelihood(None,None,'bi')
                res = []
                for i in range(self.num_feature):
                    res.append((self.beta_bi + np.sum(np.abs(self.w[self.num_user:,i] - self.mu_i[i]))) / (self.alpha_b + self.num_item + 1))
                self.b_i = np.array(res)
                aft = self.likelihood(None,None,'thetai') + self.likelihood(None,None,'bi')
                if aft < bef:
                    print('problem occur while updating b_i')
        else:
            if isUser:
                bef = self.likelihood(None,None,'thetau') + self.likelihood(None,None,'bu')
                res = []
                for i in range(self.num_feature):
                    res.append((self.beta_bu + 0.5 * np.sum(np.square(self.w[:self.num_user,i]))) / (self.alpha_b + self.num_user/2 + 1))            
                self.b_u = np.array(res)

                aft = self.likelihood(None,None,'thetau') + self.likelihood(None,None,'bu')
                if aft < bef:
                    print('problem occur while updating b_u')
            else:
                bef = self.likelihood(None,None,'thetai') + self.likelihood(None,None,'bi')
                res = []
                for i in range(self.num_feature):
                    res.append((self.beta_bi + 0.5 * np.sum(np.square(self.w[self.num_user:,i]))) / (self.alpha_b + self.num_item/2 + 1))
                self.b_i = np.array(res)
                aft = self.likelihood(None,None,'thetai') + self.likelihood(None,None,'bi')
                if aft < bef:
                    print('problem occur while updating b_i')

    @jit()
    def updateBeta_b_uniform(self,isUser):
        minv = self.alpha_beta0
        maxv = self.beta_beta0
        if isUser:
            bef = self.likelihood(None,None,'bu')
            sumb = np.sum(1/self.b_u)
            tmp = self.num_feature * self.alpha_b / sumb
            if tmp >= minv and tmp <= maxv:
                self.beta_bu = tmp
            elif tmp < minv:
                self.beta_bu = minv
            else:
                self.beta_bu = maxv
            aft = self.likelihood(None,None,'bu')
            if aft < bef:
                print('problem occur while updating beta_u')
        else:
            bef = self.likelihood(None,None,'bi')
            sumb = np.sum(1/self.b_i)
            tmp = self.num_feature * self.alpha_b / sumb
            if tmp >= minv and tmp <= maxv:
                self.beta_bi = tmp
            elif tmp < minv:
                self.beta_bi = minv
            else:
                self.beta_bi = maxv
            aft = self.likelihood(None,None,'bi')
            if aft < bef:
                print('problem occur while updating beta_i')
    
    @jit()
    def latentVectorUpdate(self,featureMatrix, factorMatrix,b, v,ratings, residual,uid,mu):
        alpha = self.alpha
        h = featureMatrix + factorMatrix
        ht = h.T
        preds = h @ v + residual
        col = ratings - preds
        
        newv = np.zeros(len(v))
        # construct the indexes
        if self.w_update_order is None:
            indexes = np.arange(self.num_feature)
            np.random.shuffle(indexes)
        else:
            indexes = self.w_update_order
        for i in indexes:
            bottom = ht[i,:] @ h[:,i]
            if self.reg == 'lasso':
                col += h[:,i] * v[i]
                if bottom > 0:
                    C = ht[i,:] @ col
                    Ccheck = C - bottom * mu[i]
                    sub = 1/(alpha * b[i])
                    if Ccheck >= -sub and Ccheck <= sub:
                        newv[i] = mu[i]
                    elif Ccheck > sub:
                        newv[i] = (C - sub) / bottom
                    else:
                        newv[i] = (C + sub) / bottom
                    col -= h[:,i] * newv[i]
                else:  # degenerate condition. set theta as zero
                    newv[i] = v[i]
            else:
                col += h[:,i] * v[i]
                C = ht[i,:] @ col
                newv[i] = C / (bottom + 1/(b[i] * alpha ))
                col -= h[:,i] * newv[i]
        return np.array(newv)
    
    @jit()
    def initLatentVectorUpdate(self,featureMatrix, factorMatrix,b, v,ratings, residual,uid):
        alpha = self.alpha
        h = featureMatrix + factorMatrix
        ht = h.T
        preds = h @ v + residual
        col = ratings - preds
        
        newv = np.zeros(len(v))
        # construct the indexes
        if self.w_update_order is None:
            indexes = np.arange(self.num_feature)
            np.random.shuffle(indexes)
        else:
            indexes = self.w_update_order
        for i in indexes:
            bottom = ht[i,:] @ h[:,i]            
            col += h[:,i] * v[i]
            C = ht[i,:] @ col
            newv[i] = C / bottom
            col -= h[:,i] * newv[i]
        return np.array(newv)
    
    @jit()
    def userUpdateWithFeature(self,userG):
        uid = userG['uid'].values[0]
        lv_u = self.w[uid,]
        items = np.array(userG['iid'].values)
        ratings = userG['rating'].values
        rowLen = len(items)
        itemFactors = np.copy(self.w[items,])
        featureMatrix = np.array(list(userG['feature'])).reshape(rowLen,self.num_feature)
        residual = np.zeros(rowLen)
        for i in range(rowLen):
            residual[i] = itemFactors[i,].T @ featureMatrix[i,]
        
        self.w[uid,] = self.latentVectorUpdate(featureMatrix,itemFactors,self.b_u, lv_u, ratings, residual,uid,self.mu_u)
        
    @jit()
    def itemUpdateWithFeature(self,itemG):
        iid = itemG['iid'].values[0]
        lv_i = self.w[iid,]
        users = np.array(itemG['uid'].values)
        ratings = itemG['rating'].values
        rowLen = len(users)
        userFactors = self.w[users,]
        featureMatrix = np.array(list(itemG['feature'])).reshape(rowLen,self.num_feature)
        residual = np.zeros(rowLen)
        for i in range(rowLen):
            residual[i] = userFactors[i,] @ featureMatrix[i,]
        
        self.w[iid,] = self.latentVectorUpdate(featureMatrix,userFactors,self.b_i, lv_i, ratings, residual,iid,self.mu_i)
            
    @jit()
    def predict_block(self,itemG):
        iid = itemG['iid'].values[0]
        users = np.array(itemG['uid'].values)
        rowLen = len(users)
        userFactors = self.w[users,]
        featureMatrix = np.array(list(itemG['feature'])).reshape(len(users),self.num_feature)
        lv_i = self.w[iid,]
        residual = np.zeros(rowLen)
        for i in range(rowLen):
            residual[i] = userFactors[i,] @ featureMatrix[i,]
        pred = (userFactors + featureMatrix) @ lv_i + residual
        return pred
                
    def _printInfo(self,i,test,curEva):
        testL = (-1,-1,-1)
        if test is not None:
            testL = self.evaluate(test)
        print("%d    %.4f    %.4f    %.4f    %.4f    %.4f    %.4f" % (i,curEva[0],testL[0],curEva[1],testL[1],curEva[2],testL[2]))

    @jit()
    def fit(self,data,test,w = None):
        if w is not None:
            self.w = w
        if self.initial:
            self.initialization(data,test)
        
        curEva = self.evaluate(data)
        self.lossChain.append(curEva[0])
        self.lastLoss = curEva[0]
        if self.verbose:
            print("Iter    TrLL    TeLL   TrRMSE   TeRMSE   TrR2   TeR2")
            self._printInfo(0,test,curEva)
        
        i = 1
#         params = np.zeros(self.w.shape)
        while i <= self.iterations:
            if not self.verbose:
                print("iteration: {0}".format(i))
            self.update(data,i)
            curEva = self.evaluate(data)
            self.lossChain.append(curEva[0])
            if self.verbose:
                self._printInfo(i,test,curEva)
            if abs(curEva[0] - self.lastLoss) <= self.epsilon:
                print('convergence!')
                break;
            self.lastLoss = curEva[0]
            i += 1
        if i > self.iterations:
            print("finish!")
        return self.w
    
    @jit()
    def temporalPopulationEffects(self):
        mU = np.mean(self.w[:self.num_user,],axis=0)
        return self.w[self.num_user:,] + mU
    
    @jit()
    def subjectEffects(self):
        mO = np.mean(self.w[self.num_user:,],axis=0)
        return self.w[:self.num_user,] + mO
    
    @jit()
    def overallEffects(self):
        return np.mean(self.w[:self.num_user,],axis =0) + np.mean(self.w[self.num_user:,],axis=0)
    
    @jit()
    def fixedEffects(self):
        return self.mu_u + self.mu_i
    
    @jit()
    def update(self,data,cur):
        if self.fraction == 1:
            train = data
            if cur == 1:
                self.userGroups = train.groupby('uid')
                self.itemGroups = train.groupby('iid')
        else:
            train = data.sample(frac = self.fraction,replace = False)
            self.userGroups = train.groupby('uid')
            self.itemGroups = train.groupby('iid')
            
        if self.intervention:
            self.userGroups.apply(self.userUpdateWithFeature)
            self.itemGroups.apply(self.itemUpdateWithFeature)
        else:            
            self.userGroups.apply(self.userUpdateWithFeature)
            self.itemGroups.apply(self.itemUpdateWithFeature)
            self.updateMu(True)
            self.updateMu(False)
            self.updateAlpha(train)
            self.updateB_invGamma(True)
            self.updateB_invGamma(False)
            self.updateBeta_b_uniform(True)
            self.updateBeta_b_uniform(False)
            
    @jit()
    def initialization(self,data,test):
        print('initalization ...')
        self.userGroups = data.groupby('uid')
        self.itemGroups = data.groupby('iid')
        print("Iter    TrLL    TeLL   TrRMSE   TeRMSE   TrR2   TeR2")
        curEva = self.evaluate(data)
        self._printInfo(0,test,curEva)
        
        for cur in range(self.initialR):
            for u,userG in self.userGroups:
                self.userUpdateWithFeature(userG)
            for i,itemG in self.itemGroups:
                self.itemUpdateWithFeature(itemG)
                
            curEva = self.evaluate(data)
            self._printInfo(cur + 1,test,curEva)
            
        # update alpha
        self.updateMu(True)
        self.updateMu(False)
        
        self.updateAlpha(data)
        
        self.updateB_invGamma(True)
        self.updateB_invGamma(False)

        self.updateBeta_b_uniform(True)
        self.updateBeta_b_uniform(False)
        print('finish initialization... {0}'.format(self.alpha))