In [None]:
import numpy as np
import pandas as pd
from sklearn import metrics
from numba import jit
import scipy as sp

class LMLFM:
    @jit()
    def __init__(self, num_indv:int, num_obsr:int, num_feature:int, theta = None, iterations:int = 20,fraction:float = 1,verbose:bool = True,
                 seed:int = 0,alpha:float = 1,alpha_beta0:int = None,beta_beta0:int = 1e10,theta_update_order = None,
                 epsilon:float = 1e-5,initb:float = 1,pilot:bool = True,pilotMaxIter:int = 5):
        self.num_indv = num_indv
        self.num_obsr = num_obsr
        self.num_feature = num_feature
        self.iterations = iterations
        self.fraction = fraction
        self.verbose = verbose
        self.dim = num_indv + num_obsr
        self.pilot= pilot # whether to use pilot runs
        self.pilotMaxIter = pilotMaxIter  # the iterations of pilot runs
        self.b_i = np.repeat(initb,num_feature)
        self.b_o = np.repeat(initb,num_feature)
        self.alpha_mu0 = -1e10
        self.beta_mu0 = 1e10
        self.alpha = alpha
        
        if alpha_beta0 is None:
            self.alpha_beta0 = 1e-4
        self.beta_beta0 = beta_beta0
        self.epsilon = epsilon
        np.random.seed(seed)
        self.theta_update_order = theta_update_order
        if theta is None:
            self.theta = np.zeros([self.dim,num_feature])
        else:
            self.theta = theta
        self.mu_i = np.zeros(num_feature)
        self.mu_o = np.zeros(num_feature)
        
        self.alpha_y0  = 1
        self.beta_y0 = 1

        # for inverse gamma prior of b_v
        self.alpha_b = 1  # this hyper parameter is fixed
        self.beta_bi = 1 
        self.beta_bo = 1

        self.lastLoss = 0 # this is to prevent overshooting

        self.indvGroups = None
        self.obsrGroups = None
        
    def copyFrom(self,model):
        """
        copy all the parameters from a pre-computed MixRFM model
        """
        self.theta = np.copy(model.theta)
        self.b_i = np.copy(model.b_u)
        self.b_o = np.copy(model.b_i)
        self.beta_bi = model.beta_bi
        self.beta_bo = model.beta_bo
        self.alpha = model.alpha
        self.mu_i = model.mu_i
        self.mu_o = model.mu_o
        
    @jit()
    def evaluate(self,test):
        """
        compute the normalized log of joint distribition, RMSE loss, and the r2 or auc score
        """
        obsrGroups = test.groupby('oid')
        pred = None
        label = None
        for name,obsrG in obsrGroups:
            labels = np.array(obsrG['label'].values)
            if pred is None:
                pred = self.predict_block(obsrG)
            else:
                pred = np.concatenate([pred,self.predict_block(obsrG)])

            if label is None:
                label = labels
            else:
                label = np.concatenate([label,labels])
        
        likelihood,loss = self._log_fullPos(label,pred)
        r2 = metrics.r2_score(label,pred)
        return(likelihood,loss,r2)
        
    @jit()
    def _log_fullPos(self,label,pred):
        likelihood_y = self.likelihood(label,pred,'y')
        likelihood_thetai = self.likelihood(label,pred,'thetai')
        likelihood_thetao = self.likelihood(label,pred,'thetao')
        likelihood_bi = self.likelihood(label,pred,'bi')
        likelihood_bo = self.likelihood(label,pred,'bo')
        likelihood_alpha = self.likelihood(label,pred,'alpha')
        full_pos = likelihood_y + likelihood_thetai + likelihood_thetao + likelihood_alpha + likelihood_bi + likelihood_bo
        sumloss = -(2 * likelihood_y - len(pred) * np.log(self.alpha)) / self.alpha
        loss = np.sqrt(sumloss / len(pred))
        return full_pos / (self.num_indv * self.num_feature),loss
    
    @jit()
    def _likelihood_data(self,test):
        obsrGroups = test.groupby('oid')
        pred = None
        label = None
        for name,obsrG in obsrGroups:
            labels = np.array(obsrG['label'].values)
            if pred is None:
                pred = self.predict_block(obsrG)
            else:
                pred = np.concatenate([pred,self.predict_block(obsrG)])
            if label is None:
                label = labels
            else:
                label = np.concatenate([label,labels])
        return self.likelihood(label,pred,'y')
        
    @jit()
    def likelihood(self,label,pred,flag):
        """
        compute the log likelihood of each component
        """
        if flag == 'y':
            col = label-pred
            sumloss = col @ col
            likelihood_y = len(pred) * np.log(self.alpha) * 0.5 - self.alpha * 0.5 * sumloss
            return likelihood_y
        elif flag =='thetai':
            tmpsum = 0
            for u in range(self.num_indv):
                tmpsum += np.sum(-np.log(2 * self.b_i)- np.abs(self.theta[u] - self.mu_i) / self.b_i)
            return tmpsum
        elif flag =='thetao':
            tmpsum = 0
            for u in range(self.num_obsr):
                tmpsum += np.sum(-np.log(2 * self.b_o) -np.abs(self.theta[u + self.num_indv] - self.mu_o) / self.b_o)
            return tmpsum
        elif flag == 'bi':
            sum_log_bi = np.sum(np.log(self.b_i))
            return - (self.alpha_b + 1) * sum_log_bi - self.beta_bi * np.sum(1/self.b_i) + self.num_feature * self.alpha_b * np.log(self.beta_bi)
        elif flag =='bo':
            sum_log_bo = np.sum(np.log(self.b_o))
            return - (self.alpha_b + 1) * sum_log_bo - self.beta_bo * np.sum(1/self.b_o) + self.num_feature * self.alpha_b * np.log(self.beta_bo)
        elif flag == 'alpha':
            return (self.alpha_y0 - 1) * np.log(self.alpha) - self.beta_y0 * self.alpha
        elif flag == 'betau':
            return - self.precision_beta * 0.5 * (self.beta_bi - self.mu_beta) ** 2
        elif flag == 'betai':
            return - self.precision_beta * 0.5 * (self.beta_bo - self.mu_beta) ** 2
        else:
            return 0
            
    @jit()
    def predict_w(self,data,theta):
        """
        prediction
        """
        pred = np.zeros(data.shape[0])
        i = 0
        for index, row in data.iterrows():
            pred[i] = theta[row['iid']] @ theta[row['oid']] + row['feature'] @ (theta[row['iid']] + theta[row['oid']])
            i += 1
        return pred

    @jit()
    def updateAlpha(self,data):
        if self.obsrGroups is None:
            obsrGroups = data.groupby('oid')
        else:
            obsrGroups = self.obsrGroups
        pred = None
        label = None
        for oid,obsrG in obsrGroups:
            lv_i = self.theta[oid,]
            indvs = obsrG['iid'].values
            labels = obsrG['label'].values
            featureMatrix = np.array(list(obsrG['feature'])).reshape(len(indvs),self.num_feature)
            if pred is None:
                pred = self.predict_block(obsrG)
            else:
                pred = np.concatenate([pred,self.predict_block(obsrG)])

            if label is None:
                label = labels
            else:
                label = np.concatenate([label,labels])
        
        sumOfError = 0.5 * np.sum(np.square(label-pred))
        self.alpha = (self.alpha_y0 + data.shape[0] * 0.5 - 1) / (self.beta_y0 + sumOfError)
        
    @jit()
    def updateMu(self,isindv):
        if isindv:
            tmp = np.median(self.theta[:self.num_indv],axis=0)
            tmp[tmp < self.alpha_mu0] = self.alpha_mu0
            tmp[tmp > self.beta_mu0] = self.beta_mu0
            self.mu_i = tmp
        else:
            tmp = np.median(self.theta[self.num_indv:,:],axis=0)
            tmp[tmp < self.alpha_mu0] = self.alpha_mu0
            tmp[tmp > self.beta_mu0] = self.beta_mu0
            self.mu_o = tmp
    
    @jit()
    def updateB_invGamma(self,isindv):
        if isindv:
            self.b_i = (self.beta_bi + np.sum(np.abs(self.theta[:self.num_indv] - self.mu_i),axis=0)) / (self.alpha_b + self.num_indv + 1)
        else:
            self.b_o = (self.beta_bo + np.sum(np.abs(self.theta[self.num_indv:,:] - self.mu_o),axis=0)) / (self.alpha_b + self.num_obsr + 1)

    @jit()
    def updateBeta_b_uniform(self,isindv):
        minv = self.alpha_beta0
        maxv = self.beta_beta0
        if isindv:
            sumb = np.sum(1/self.b_i)
            tmp = self.num_feature * self.alpha_b / sumb
            if tmp >= minv and tmp <= maxv:
                self.beta_bi = tmp
            elif tmp < minv:
                self.beta_bi = minv
            else:
                self.beta_bi = maxv
        else:
            sumb = np.sum(1/self.b_o)
            tmp = self.num_feature * self.alpha_b / sumb
            if tmp >= minv and tmp <= maxv:
                self.beta_bo = tmp
            elif tmp < minv:
                self.beta_bo = minv
            else:
                self.beta_bo = maxv
    
    @jit()
    def latentVectorUpdate(self,featureMatrix, factorMatrix,b, v,labels, residual,iid,mu):
        alpha = self.alpha
        h = featureMatrix + factorMatrix
        ht = h.T
        preds = h @ v + residual
        col = labels - preds
        
        newv = np.zeros(len(v))
        # construct the indexes
        if self.theta_update_order is None:
            indexes = np.arange(self.num_feature)
            np.random.shuffle(indexes)
        else:
            indexes = self.theta_update_order
        for i in indexes:
            bottom = ht[i,:] @ h[:,i]
            col += h[:,i] * v[i]
            if bottom > 0:
                C = ht[i,:] @ col
                Ccheck = C - bottom * mu[i]
                sub = 1/(alpha * b[i])
                if Ccheck >= -sub and Ccheck <= sub:
                    newv[i] = mu[i]
                elif Ccheck > sub:
                    newv[i] = (C - sub) / bottom
                else:
                    newv[i] = (C + sub) / bottom
                col -= h[:,i] * newv[i]
            else:  # degenerate condition. set theta as zero
                newv[i] = v[i]
        return np.array(newv)
    
    @jit()
    def initLatentVectorUpdate(self,featureMatrix, factorMatrix,b, v,labels, residual,iid):
        alpha = self.alpha
        h = featureMatrix + factorMatrix
        ht = h.T
        preds = h @ v + residual
        col = labels - preds
        
        newv = np.zeros(len(v))
        # construct the indexes
        if self.theta_update_order is None:
            indexes = np.arange(self.num_feature)
            np.random.shuffle(indexes)
        else:
            indexes = self.theta_update_order
        for i in indexes:
            bottom = ht[i,:] @ h[:,i]            
            col += h[:,i] * v[i]
            C = ht[i,:] @ col
            newv[i] = C / bottom
            col -= h[:,i] * newv[i]
        return np.array(newv)
    
    @jit()
    def indvUpdateWithFeature(self,indvG):
        iid = indvG['iid'].values[0]
        lv_u = self.theta[iid,]
        obsrs = np.array(indvG['oid'].values)
        labels = indvG['label'].values
        rowLen = len(obsrs)
        obsrFactors = np.copy(self.theta[obsrs,])
        featureMatrix = np.array(list(indvG['feature'])).reshape(rowLen,self.num_feature)
        residual = np.zeros(rowLen)
        for i in range(rowLen):
            residual[i] = obsrFactors[i,].T @ featureMatrix[i,]
        
        self.theta[iid,] = self.latentVectorUpdate(featureMatrix,obsrFactors,self.b_i, lv_u, labels, residual,iid,self.mu_i)
        
    @jit()
    def obsrUpdateWithFeature(self,obsrG):
        oid = obsrG['oid'].values[0]
        lv_i = self.theta[oid,]
        indvs = np.array(obsrG['iid'].values)
        labels = obsrG['label'].values
        rowLen = len(indvs)
        indvFactors = self.theta[indvs,]
        featureMatrix = np.array(list(obsrG['feature'])).reshape(rowLen,self.num_feature)
        residual = np.zeros(rowLen)
        for i in range(rowLen):
            residual[i] = indvFactors[i,] @ featureMatrix[i,]
        
        self.theta[oid,] = self.latentVectorUpdate(featureMatrix,indvFactors,self.b_o, lv_i, labels, residual,oid,self.mu_o)
            
    @jit()
    def predict_block(self,obsrG):
        oid = obsrG['oid'].values[0]
        indvs = np.array(obsrG['iid'].values)
        rowLen = len(indvs)
        indvFactors = self.theta[indvs,]
        featureMatrix = np.array(list(obsrG['feature'])).reshape(len(indvs),self.num_feature)
        lv_i = self.theta[oid,]
        residual = np.zeros(rowLen)
        for i in range(rowLen):
            residual[i] = indvFactors[i,] @ featureMatrix[i,]
        pred = (indvFactors + featureMatrix) @ lv_i + residual
        return pred
                
    def _printInfo(self,i,test,curEva):
        testL = (-1,-1,-1)
        if test is not None:
            testL = self.evaluate(test)
        print("%d    %.4f    %.4f    %.4f    %.4f    %.4f    %.4f" % (i,curEva[0],testL[0],curEva[1],testL[1],curEva[2],testL[2]))

    @jit()
    def fit(self,data,test,theta = None):
        if theta is not None:
            self.theta = theta
        if self.pilot:
            self.initialization(data,test)
        
        curEva = self.evaluate(data)
        self.lastLoss = curEva[0]
        if self.verbose:
            print("Iter    TrLL    TeLL   TrRMSE   TeRMSE   TrR2   TeR2")
            self._printInfo(0,test,curEva)
        
        i = 1
#         params = np.zeros(self.theta.shape)
        while i <= self.iterations:
            if not self.verbose:
                print("iteration: {0}".format(i))
            self.update(data,i)
            curEva = self.evaluate(data)
            if self.verbose:
                self._printInfo(i,test,curEva)
            if abs(curEva[0] - self.lastLoss) <= self.epsilon:
                print('convergence!')
                break;
            self.lastLoss = curEva[0]
            i += 1
        if i > self.iterations:
            print("finish!")
        return self.theta
    
    @jit()
    def temporalPopulationEffects(self):
        mU = np.mean(self.theta[:self.num_indv,],axis=0)
        return self.theta[self.num_indv:,] + mU
    
    @jit()
    def subjectEffects(self):
        mO = np.mean(self.theta[self.num_indv:,],axis=0)
        return self.theta[:self.num_indv,] + mO
    
    @jit()
    def overallEffects(self):
        return np.mean(self.theta[:self.num_indv,],axis =0) + np.mean(self.theta[self.num_indv:,],axis=0)
    
    @jit()
    def fixedEffects(self):
        return self.mu_i + self.mu_o
    
    @jit()
    def update(self,data,cur):
        if self.fraction == 1:
            train = data
            if cur == 1:
                self.indvGroups = train.groupby('iid')
                self.obsrGroups = train.groupby('oid')
        else:
            train = data.sample(frac = self.fraction,replace = False)
            self.indvGroups = train.groupby('iid')
            self.obsrGroups = train.groupby('oid')
            
        self.indvGroups.apply(self.indvUpdateWithFeature)
        self.obsrGroups.apply(self.obsrUpdateWithFeature)
        self.updateMu(True)
        self.updateMu(False)
        self.updateAlpha(train)
        self.updateB_invGamma(True)
        self.updateB_invGamma(False)
        self.updateBeta_b_uniform(True)
        self.updateBeta_b_uniform(False)
            
    @jit()
    def initialization(self,data,test):
        print('initalization ...')
        self.indvGroups = data.groupby('iid')
        self.obsrGroups = data.groupby('oid')
        print("Iter    TrLL    TeLL   TrRMSE   TeRMSE   TrR2   TeR2")
        curEva = self.evaluate(data)
        self._printInfo(0,test,curEva)
        
        for cur in range(self.pilotMaxIter):
            for u,indvG in self.indvGroups:
                self.indvUpdateWithFeature(indvG)
            for i,obsrG in self.obsrGroups:
                self.obsrUpdateWithFeature(obsrG)
                
            curEva = self.evaluate(data)
            self._printInfo(cur + 1,test,curEva)
            
        # update alpha
        self.updateMu(True)
        self.updateMu(False)
        
        self.updateAlpha(data)
        
        self.updateB_invGamma(True)
        self.updateB_invGamma(False)

        self.updateBeta_b_uniform(True)
        self.updateBeta_b_uniform(False)
        print('finish initialization...')