In [None]:
from scipy.linalg import cho_factor,cho_solve,cholesky,inv,solve_triangular
import copy

    
class kernel_SE:
    def __init__(self,length=1,k0=1):
        "The kernel as the squared exponential kernel or a gaussian distribution"
        self.length=length
        self.k0=k0
        
    def matrix(self,x_data1,x_data2,dis_m=None):
        "Make the kernel matrix from two sets of data"
        if dis_m is None:
            dist2=np.sum(x_data1**2,axis=1).reshape(-1,1)+np.sum(x_data2**2,axis=1)-2*np.dot(x_data1, x_data2.T)
        else:
            dist2=np.sum(dis_m**2,axis=0)
        return self.k0*np.exp(-0.5*dist2/(self.length**2))
    

class GP_reg:
    def __init__(self,kernel=kernel_SE(),noise=0.0,yp=0):
        """The Gaussian process that uses a kernel. 
        It can predict values if a training set and prediction features are given"""
        self.kernel=copy.deepcopy(kernel)
        self.noise=noise
        self.yp=yp
        
    def train(self,x_train,y_train):
        "Train the Gaussian process with training set"
        self.x_train=np.copy(x_train)
        y_tr=y_train-self.yp
        KXX=self.kernel.matrix(x_train,x_train)
        KXX=KXX+self.noise*np.identity(len(KXX))
        self.L,self.low=cho_factor(KXX)
        self.coef=cho_solve((self.L,self.low),y_tr,check_finite=False)
        #self.KXX_inv=np.linalg.inv(KXX)
        #self.coef=np.matmul(self.KXX_inv,(y_train-self.yp).reshape(-1,1))
        return self
        
    def predict(self,x_test):
        "Predict values with features"
        KQX=self.kernel.matrix(x_test,self.x_train)
        return self.yp+np.matmul(KQX,self.coef)
    
    def uncertainty(self,x_test):
        #KQQ=self.kernel.matrix(x_test,x_test)
        KQX=self.kernel.matrix(x_test,self.x_train)
        var=self.kernel.k0-np.einsum('ij,ji->i',KQX,cho_solve((self.L,self.low),KQX.T,check_finite=False))
        return np.sqrt(var)
        #return np.sqrt(np.diagonal(KQQ)-np.einsum('ij,ji->i',KQX,np.matmul(self.KXX_inv,KQX.T)))
        #return np.sqrt(np.diagonal(KQQ)-np.diagonal(np.matmul(KQX,np.matmul(self.KXX_inv,KQX.T)))) 
    
    def error(self,x_test,y_test):
        "Calculate the squared error"
        return np.sum((y_test-self.predict(x_test))**2)
    
    def int_error(self,x_test,y_test):
        return np.trapz((y_test.flatten()-self.predict(x_test).flatten())**2,x_test)
    
    
    def lml(self,theta,x_train,y_train,dis_m=None):
        length,noise,k0=theta
        self.kernel.length=length
        self.kernel.k0=k0
        self.noise=noise
        y_tr=(y_train-self.yp).reshape(-1,1)
        C=self.kernel.matrix(x_train,x_train,dis_m=dis_m)+self.noise*np.identity(len(y_tr))
        #C_inv=np.linalg.inv(C)
        #coef=np.matmul(C_inv,y_tr)
        L,low=cho_factor(C)
        coef=cho_solve((L,low),y_train,check_finite=False)
        return -0.5*(np.matmul(y_tr.flatten(),coef)+2*np.sum(np.log(np.diagonal(L)))+len(y_train)*np.log(2*np.pi)).item(0)
        #if np.linalg.slogdet(C)[0]<0:
        #    print(np.linalg.slogdet(C))
        #return -0.5*(np.matmul(y_tr.flatten(),coef)+np.linalg.slogdet(C)[1]+len(y_train)*np.log(2*np.pi)).item(0)
    
    def nmml(self,theta,x_train,y_train,dis_m=None):
        lml=self.lml(theta,x_train,y_train,dis_m=dis_m)
        return -np.exp(lml/len(y_train))
    
    def optimize(self,x_train,y_train,maxfun=2000):
        theta=[self.kernel.length,self.noise,self.kernel.k0]
        sol=scipy.optimize.dual_annealing(self.nmml,x0=theta,bounds=np.array([[1e-3,1e3],[1e-5,1e2],[1e-3,1e3]]),maxfun=maxfun,args=(x_train,y_train,None))
        theta_o=sol.x
        self.kernel.length=theta_o[0]
        self.kernel.k0=theta_o[-1]
        self.noise=theta_o[1]
        self.train(x_train,y_train)
        return theta_o
    
    def update(self,length=None,k0=None,noise=None,yp=None):
        if length is not None:
            self.kernel.length=length
        if k0 is not None:
            self.kernel.k0=k0
        if noise is not None:
            self.noise=noise
        if yp is not None:
            self.yp=yp
        return self
        
