In [None]:
import numpy as np
import torch
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
def weighted_median(data, weights):
    index = torch.argsort(data)
    s_data = data[index]
    s_weights = weights[index]
    midpoint = 0.5 * s_weights.sum()
    
    cs_weights = torch.cumsum(s_weights,dim=0)
    idx = torch.where(cs_weights<=midpoint)[0]
    if len(idx) > 0:
        idx = idx[-1]
        if cs_weights[idx] == midpoint:
            w_median = torch.mean(s_data[idx:idx+2])
        else:
            w_median = s_data[idx+1]
        return w_median
    else:
        return s_data[0]

In [None]:
class LMLFM:
    def __init__(self,n,m,p,alpha = 1,dtype=torch.float,device='cuda'):
        self.n = n
        self.m = m
        self.p = p
        self.dtype = dtype
        self.device = device
        
        self.mu_i = torch.zeros(p).type(dtype).to(device)
        self.mu_o = torch.zeros(p).type(dtype).to(device)
        self.Theta_i = torch.zeros([n,p]).type(dtype).to(device)
        self.Theta_o = torch.zeros([n,p]).type(dtype).to(device)
        self.b_i = torch.ones(p).type(dtype).to(device)
        self.b_o = torch.ones(p).type(dtype).to(device)
        self.b_mu_0 = self.b_b_0 = torch.ones(1).type(dtype).to(device)
        self.alpha_0 = self.beta_0 = torch.ones(1).type(dtype).to(device)
        self.alpha = alpha
        
        # constant
        self.zeros = torch.zeros([1,self.p]).type(self.dtype).to(self.device)
    
    def update_alpha(self,y,y_hat):
        residual = y - y_hat
        self.alpha = (self.alpha_0 + len(y)/2 - 1) / (self.beta_0 + residual @ residual / 2)
    
    def update_b(self,I):
        if I:
            tmp = (self.n**2 + 4/self.b_b_0 * torch.norm(self.Theta_i-self.mu_i[None,:],p=1,dim=0)).sqrt()
            self.b_i = 2 * self.b_b_0 * (tmp - self.n)
        else:
            tmp = (self.n**2 + 4/self.b_b_0 * torch.norm(self.Theta_o-self.mu_o[None,:],p=1,dim=0)).sqrt()
            self.b_o = 2 * self.b_b_0 * (tmp - self.n)
    
    def update_mu(self,I):
        if I:
            dta = torch.cat([self.Theta_i,self.zeros],dim=0)
            for k in range(self.p):
                weights = torch.ones_like(dta[:,k]) * self.b_i[k]
                weights[-1] = self.b_mu_0
                self.mu_i[k] = weighted_median(dta[:,k],weights)
        else:
            dta = torch.cat([self.Theta_o,self.zeros],dim=0)
            for k in range(self.p):
                weights = torch.ones_like(dta[:,k]) * self.b_o[k]
                weights[-1] = self.b_mu_0
                self.mu_o[k] = weighted_median(dta[:,k],weights)
    
    def update_theta(self,X,y,target,indexes,I):
        if I:
            theta = self.Theta_i[target].reshape(-1)
            g = torch.zeros_like(y)
            for i in range(len(g)):
                g[i] = X[i] @ self.Theta_o[indexes[i]]
#             g = torch.diag(X @ self.Theta_o[indexes].t()) # ni
            h = X + self.Theta_o[indexes] # ni * p
            y_hat = h @ theta + g
            col = y - y_hat
        
            newv = torch.zeros_like(theta)
            for k in torch.randperm(self.p):
                if self.b_i[k] == 0:
                    newv[k] = self.mu_i[k]
                else:
                    bottom = h[:,k] @ h[:,k]
                    col += h[:,k] * theta[k]
                    C = h[:,k] @ col
                    Ccheck = C - bottom * self.mu_i[k]
                    sub = 1 / (self.alpha * self.b_i[k])
                    if Ccheck >= -sub and Ccheck <= sub:
                        newv[k] = self.mu_i[k]
                    elif Ccheck > sub:
                        newv[k] = (C - sub) / bottom
                    else:
                        newv[k] = (C + sub) / bottom
                    col -= h[:, k] * newv[k]
            self.Theta_i[target] = newv
            return h @ newv + g
        else:
            theta = self.Theta_o[target].reshape(-1)
            g = torch.zeros_like(y)
            for i in range(len(g)):
                g[i] = X[i] @ self.Theta_i[indexes[i]]
#             g = torch.diag(X @ self.Theta_i[indexes].t()) # ni
            h = X + self.Theta_i[indexes] # ni * p
            y_hat = h @ theta + g
            col = y - y_hat
            
            newv = torch.zeros_like(theta)
            for k in torch.randperm(self.p):
                if self.b_o[k] == 0:
                    newv[k] = self.mu_o[k]
                else:
                    bottom = h[:,k] @ h[:,k]
                    col += h[:,k] * theta[k]
                    C = h[:,k] @ col
                    Ccheck = C - bottom * self.mu_o[k]
                    sub = 1 / (self.alpha * self.b_o[k])
                    if Ccheck >= -sub and Ccheck <= sub:
                        newv[k] = self.mu_o[k]
                    elif Ccheck > sub:
                        newv[k] = (C - sub) / bottom
                    else:
                        newv[k] = (C + sub) / bottom
                    col -= h[:, k] * newv[k]
            self.Theta_o[target] = newv
            return h @ newv + g
        
    def predict(self,X, target, indexes, I):
        if I:
            if target < self.n:
                theta = self.Theta_i[target].reshape(-1)
            else:
                theta = self.mu_i
            factors = torch.zeros_like(self.Theta_o[:len(indexes)])
            for i,each in enumerate(indexes):
                factors[i] = self.Theta_o[each] if each < self.m else self.mu_o
            g = torch.diag(X @ factors.t()) # ni
            h = X + factors # ni * p
            y_hat = h @ theta + g
        else:
            if target < self.m:
                theta = self.Theta_o[target].reshape(-1)
            else:
                theta = self.mu_o
            factors = torch.zeros_like(self.Theta_i[:len(indexes)])
            for i,each in enumerate(indexes):
                factors[i] = self.Theta_i[each] if each < self.n else self.mu_i
            g = torch.diag(X @ factors.t()) # ni
            h = X + factors # ni * p
            y_hat = h @ theta + g
        return y_hat
    
    def fixedEffect(self,rounding=2):
        mask = (self.b_i + self.b_o) == 0
        effects = self.mu_i + self.mu_o
        return np.round(effects.cpu().numpy(),rounding),mask
    
    def mapLoss(self, y, y_hat,eps=1e-6): # the larger the better, ascent property and convergence guarantee
        residual = y - y_hat
        sumloss = residual @ residual
        lly = 0.5 * len(y) * (self.alpha/2.51).log() - 0.5 * self.alpha * sumloss
#         pos = self.b_i > 0
#         llTheta_i = -(2*self.b_i[pos]).log().sum() * self.n - ((self.Theta_i[:,pos] - self.mu_i[None,pos]).abs() / self.b_i[None,pos]).sum()
        llTheta_i = -(2*self.b_i + eps).log().sum() * self.n - ((self.Theta_i - self.mu_i[None,:]).abs() / (self.b_i[None,:] + eps)).sum()
#         pos = self.b_o > 0
#         llTheta_o = -(2*self.b_o[pos]).log().sum() * self.m - ((self.Theta_o[:,pos] - self.mu_o[None,pos]).abs() / self.b_o[None,pos]).sum()
        llTheta_o = -(2*self.b_o + eps).log().sum() * self.m - ((self.Theta_o - self.mu_o[None,:]).abs() / (self.b_o[None,:] + eps)).sum()
        llmu_i = -self.mu_i.abs().sum() / self.b_mu_0
        llmu_o = -self.mu_o.abs().sum() / self.b_mu_0
        llb_i = -self.b_i.abs().sum() / self.b_b_0
        llb_o = -self.b_o.abs().sum() / self.b_b_0
        return (lly + llTheta_i + llTheta_o + llmu_i + llmu_o + llb_i + llb_o) / len(y)
    
    def bic(self, y, y_hat): # the smaller the better
        residual = y - y_hat
        sumloss = residual @ residual
        lly = 0.5 * len(y) * (self.alpha/2.51).log() - 0.5 * self.alpha * sumloss
        gt_zero = (self.Theta_i.abs()>0).sum() + (self.Theta_o.abs()>0).sum() + (self.mu_i.abs()>0).sum() \
        + (self.mu_o.abs()>0).sum() + (self.b_i>0).sum() + (self.b_o>0).sum()
        return -2 * lly + gt_zero * np.log(len(y))

In [None]:
# create dataset and dataloader
class LongitudinalData(Dataset):
    def __init__(self, X, y, iids, oids):
        self.X = X
        self.y = y
        self.iids = iids
        self.oids = oids
        
        indexes = np.arange(len(y))
        self.mapI = defaultdict(list)
        self.mapO = defaultdict(list)
        for ind,(i,o) in enumerate(zip(iids,oids)):
            self.mapI[i].append(ind)
            self.mapO[o].append(ind)

        self.n = len(self.mapI)
        self.m = len(self.mapO)
        
    def __getitem__(self, idx):
        if idx < self.n:
            related_indexes = np.array(self.mapI[idx])
            return {
                 'X':self.X[related_indexes,:],
                     'y':self.y[related_indexes],
                     'target':idx,
                     'indexes':self.oids[related_indexes],
                     'I':True,
                    }
        else:
            idx -= self.n
            related_indexes = np.array(self.mapO[idx])
            return {
                 'X':self.X[related_indexes,:],
                     'y':self.y[related_indexes],
                     'target':idx,
                     'indexes':self.iids[related_indexes],
                     'I':False,
                    }
            
    def __len__(self):
        return self.n + self.m

In [None]:
def train_lmlfm(lmlfm,train_loader,epochs):
    for i in range(epochs):
        y_cat = None
        y_hat_cat = None
        for cur,each in enumerate(train_loader):
            X = each['X'][0].type(dtype).to(device)
            y = each['y'][0].reshape(-1).type(dtype).to(device)
            target = each['target']
            indexes = each['indexes'].reshape(-1).type(torch.long).to(device)
            I = each['I']
            y_hat = lmlfm.update_theta(X,y,target,indexes,I)
            
            if cur == 0:
                y_cat = y
                y_hat_cat = y_hat
            else:
                y_cat = torch.cat([y_cat,y],dim=0)
                y_hat_cat = torch.cat([y_hat_cat,y_hat],dim=0)
        lmlfm.update_alpha(y_cat,y_hat_cat)
        lmlfm.update_mu(True)
        lmlfm.update_mu(False)
        lmlfm.update_b(True)
        lmlfm.update_b(False)
#         loss = lmlfm.mapLoss(y_cat, y_hat_cat)
#         print(f'epoch {i} finished! loss (+): {loss.item()}')
        if i % 5 == 0:
            loss = lmlfm.bic(y_cat,y_hat_cat)
            print(f'epoch {i} finished! loss (-): {loss.item()}')
            
    loss = lmlfm.bic(y_cat,y_hat_cat)
    print(f'training finished! loss (-): {loss.item()}')
    return lmlfm,loss.item()

In [None]:
from sklearn.metrics import r2_score
def test_lmlfm(lmlfm,test):
    y_cat = None
    y_hat_cat = None
    for cur,each in enumerate(train_loader):
        X = each['X'][0].type(dtype).to(device)
        y = each['y'][0].reshape(-1).type(dtype).to(device)
        target = each['target']
        indexes = each['indexes'].reshape(-1).type(torch.long).to(device)
        I = each['I']
        y_hat = lmlfm.predict(X,target,indexes,I)

        if cur == 0:
            y_cat = y
            y_hat_cat = y_hat
        else:
            y_cat = torch.cat([y_cat,y],dim=0)
            y_hat_cat = torch.cat([y_hat_cat,y_hat],dim=0)
    r2score = r2_score(y_cat.cpu().numpy(),y_hat_cat.cpu().numpy())
#     print(f'r2 score: {r2score}')
    return r2score,y_cat,y_hat_cat

In [None]:
def fp_fn(lmlfm,groundtruth):
    _,mask = lmlfm.fixedEffect()
    mask = ~mask
    gt = groundtruth != 0
    fp = 0
    fn = 0
    for h,y in zip(mask,gt):
        if h and not y:
            fp += 1
        elif not h and y:
            fn += 1
    print(f'f.p. => {fp}, f.n. => {fn}')
    return fp,fn