In [1]:
import sys
import time
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.autograd import grad
from torch import nn
from torch.functional import F
import torch.nn.utils.parametrize as parametrize
from torch.nn.utils.parametrizations import _Orthogonal

In [2]:
# User defined model class that inherits from Pytorch's neural network module
# NumComponents is the number of low rank components to use in the low rank
# part of the diagonal plus low rank covariance matrix.
# numCESamples is the number of stochastic samples of the cross entropy
# numNLLSamples is the number of stochastic samples of the negative log likelihood.
class Model(nn.Module):
    def __init__(self,trainX,testX,trainY,testY,numComponents,numCESamples,numNLLSamples,xiVal):
        super().__init__()
        
        # Input data
        self.X = trainX
        self.testX = testX
        self.Y = trainY
        self.testY = testY
        self.trainN = self.Y.size(dim=0)
        self.testN = self.testY.size(dim=0)
        self.nw = self.X.size(dim=1)
        self.nc = numComponents
        self.ns = numCESamples
        self.nsNLL = numNLLSamples
        
        # Regression parameter means
        betaInit = torch.squeeze(torch.distributions.Uniform(-0.0001,0.0001).sample((1,self.nw)).double())
        self.betaMu = nn.Parameter(betaInit)
        self.scaleSqMu = nn.Parameter(torch.tensor([math.log(0.000025)-0.5,-0.5],dtype=torch.double))
        self.xi = nn.Parameter(torch.tensor(xiVal,dtype=torch.double))
        
        # Diagonal component of parameter covariance
        self.paramVar = nn.Parameter(torch.ones((self.nw+2),dtype=torch.double)*0.0001)
        
        # Low rank covariance of all parameters
        self.V = nn.Parameter(torch.distributions.Uniform(-0.0001,0.0001).sample((self.nw+2,self.nc)).double())
        
        # Diagonal matrix of eigenvalues for low rank component
        self.C = nn.Parameter(torch.ones(self.nc,dtype=torch.double)*-10.0)
        
        # Apply orthogonal parametrization to V
        # Use householder reflections instead of the matrix exponential or Cayley map as householder
        # reflections are computationally cheaper
        parametrize.register_parametrization(self,"V",_Orthogonal(self.V,orthogonal_map="householder"))
        
        # Scale parameter priors
        self.lambdaPrior = torch.tensor(1.0,dtype=torch.double)
        self.regVarPrior = torch.tensor(1.0,dtype=torch.double)
        
        # Training metrics
        self.listTrainMSE = []
        self.listTestMSE = []
        self.listTrainR2 = []
        self.listTestR2 = []
        

# Samples the weights wrt to the q density
def sampleWeights(mod):
    diagC = torch.diag(torch.exp(mod.C))
    mod.paramCov = torch.diag(torch.exp(mod.paramVar)) + torch.matmul(torch.matmul(mod.V,diagC),torch.transpose(mod.V,0,1))
    lsCov = mod.paramCov[0:2,0:2]
     
    predWeights = 0.0
    for i in range(mod.ns):
        
        # Sample diagonal lambdaSq and sigmaSq
        mod.stdNormalDistNP = torch.distributions.MultivariateNormal(mod.scaleSqMu,covariance_matrix = lsCov)
        lsParams = mod.stdNormalDistNP.rsample()
        
        # Sample weights
        predBeta = (torch.unsqueeze(mod.betaMu,1) + torch.matmul(mod.paramCov[2:,0:2],
                    torch.matmul(torch.inverse(lsCov),torch.unsqueeze(lsParams,1) - torch.unsqueeze(mod.scaleSqMu,1))))
        scaleParams = torch.unsqueeze(torch.cat(((torch.unsqueeze(torch.tensor(1.0,dtype=torch.double),0)).clone(),
                      torch.exp(0.5*mod.xi*lsParams[0])*torch.ones((mod.nw-1))),0),1)
        predWeights = predWeights + torch.mul(predBeta,scaleParams)
        
    return (1.0/mod.ns)*predWeights


# Calculate posterior means of regression coefficients accounting for partial centering
def compute_posterior_means(mod,num_samples):
    
    # Calculate parameter covariance
    diagC = torch.diag(torch.exp(model.C))
    paramCov = (torch.diag(torch.exp(model.paramVar))
                + torch.matmul(torch.matmul(model.V,diagC),torch.transpose(model.V,0,1)))
    
    # Covariance for scale parameters
    lsCov = paramCov[0:2,0:2]
    
    # Initialize storage for all samples
    all_samples = torch.zeros((num_samples,model.nw),dtype=torch.double)
    
    with torch.no_grad():
        for i in range(num_samples):
            # Sample scale parameters
            dist = torch.distributions.MultivariateNormal(model.scaleSqMu,covariance_matrix=lsCov)
            lsParams = dist.rsample()
            
            # Calculate conditional mean
            predBeta = (torch.unsqueeze(model.betaMu,1) + torch.matmul(paramCov[2:,0:2],
                        torch.matmul(torch.inverse(lsCov),torch.unsqueeze(lsParams,1) - torch.unsqueeze(model.scaleSqMu,1))))
            
            # Partially center the regression coefficients but not the intercept
            scaleParams = torch.unsqueeze(torch.cat((torch.unsqueeze(torch.tensor(1.0,dtype=torch.double),0),
                                                     torch.exp(0.5*model.xi*lsParams[0])*torch.ones((model.nw - 1))),0),1)
            
            # Calculate true regression coefficients accounting for partial centering
            weights = torch.mul(predBeta,scaleParams)
            all_samples[i,:] = weights.squeeze()
    
    # Compute summary statistics
    posterior_mean = all_samples.mean(dim=0,keepdim=True).T
    ci_lower = torch.quantile(all_samples,0.025,dim=0,keepdim=True).T
    ci_upper = torch.quantile(all_samples,0.975,dim=0,keepdim=True).T
    
    return posterior_mean,ci_lower,ci_upper


# Calculates the negative log likelihood
def NLL(mod,startBatch,endBatch):
    diagC = torch.diag(torch.exp(mod.C))
    diagC_sqrt = torch.diag(torch.sqrt(torch.exp(mod.C)))
    mod.paramCov = torch.diag(torch.exp(mod.paramVar)) + torch.matmul(torch.matmul(mod.V,diagC),torch.transpose(mod.V,0,1))
    lsCov = mod.paramCov[0:2,0:2]
    
    # Calculate K matrix to be re-used in the negative log likelihood calculation
    K = (torch.diag(torch.pow(torch.exp(mod.C),-1.0))
         + torch.matmul(torch.matmul(torch.transpose(mod.V[0:2,:],0,1),
                                     torch.diag(torch.pow(torch.exp(mod.paramVar[0:2]),-1.0))),mod.V[0:2,:]))
    LK = torch.linalg.cholesky(K)
    mod.stdNormalDistNP = torch.distributions.MultivariateNormal(mod.scaleSqMu,covariance_matrix = lsCov)
    
    # Calculate negative log likelihood for repeated samples from the q density
    nll = 0.0
    for i in range(mod.nsNLL):
        
        # Sample diagonal lambdaSq and sigmaSq
        lsParams = mod.stdNormalDistNP.rsample(sample_shape=torch.Size([endBatch-startBatch]))
        
        ### Calculate pre-activation distribution
        # Calculate mean beta conditional on sampled scale parameters
        predBeta = (torch.unsqueeze(mod.betaMu,1) + torch.matmul(mod.paramCov[2:,0:2],
                    torch.matmul(torch.inverse(lsCov),torch.transpose(lsParams,0,1)
                 - torch.unsqueeze(mod.scaleSqMu,1).repeat(1,endBatch-startBatch))))
        scaleData = torch.cat(((torch.unsqueeze(mod.X[startBatch:endBatch,0],1)).clone(),
                    torch.unsqueeze(torch.exp(0.5*mod.xi*lsParams[:,0]),1)*mod.X[startBatch:endBatch,1:]),1)
        predMu = torch.sum(torch.mul(torch.transpose(predBeta,0,1),scaleData),dim=1)
        
        Y = torch.linalg.solve(torch.transpose(LK,0,1),torch.matmul(torch.matmul(torch.matmul(torch.transpose(mod.V[0:2,:],0,1),
            torch.diag(torch.pow(torch.exp(mod.paramVar[0:2]),-1))),torch.matmul(mod.V[0:2,:],diagC_sqrt)),
            torch.matmul(torch.matmul(diagC_sqrt,torch.transpose(mod.V[2:,:],0,1)),torch.transpose(scaleData,0,1))))
        
        # Standard deviation calculation with updated low-rank component
        # Use sqrt(C) instead of C when calculating standard deviation of pre-activations
        term1 = torch.sum(torch.square(torch.matmul(torch.diag(torch.exp(0.5*mod.paramVar[2:])),
                                                    torch.transpose(scaleData,0,1))),dim=0)
        term2 = torch.sum(torch.square(torch.matmul(torch.matmul(diagC_sqrt,torch.transpose(mod.V[2:,:],0,1)),
                                                    torch.transpose(scaleData,0,1))),dim=0)
        temp3 = torch.matmul(torch.matmul(diagC_sqrt,torch.transpose(mod.V[2:,:],0,1)),
                             torch.transpose(scaleData,0,1))
        temp3 = torch.matmul(torch.matmul(mod.V[0:2,:],diagC_sqrt),temp3)
        temp3 = torch.matmul(torch.diag(torch.pow(torch.exp(mod.paramVar[0:2]),-0.5)),temp3)
        term3 = torch.sum(torch.square(temp3),dim=0)
        term4 = torch.sum(torch.square(Y),dim=0)
        sdB = torch.sqrt(term1 + term2 - term3 + term4)
        
        # Sample pre-activation and calculate negative log-likelihood
        mod.stdNormalDistB = torch.distributions.Normal(predMu,sdB)
        sampleB = mod.stdNormalDistB.rsample()
        trainResiduals = mod.Y[startBatch:endBatch,:] - torch.unsqueeze(sampleB,1)
        SE = torch.square(trainResiduals)
        # Calculate log likelihood
        nll = (nll + torch.sum(torch.mul(0.5*torch.unsqueeze(torch.exp(-1.0*lsParams[:,1]),1),SE)
            + 0.5*(math.log(2.0*math.pi) + torch.unsqueeze(lsParams[:,1],1))))
        
    return (1.0/mod.nsNLL)*nll


# Calculates the cross entropy
def crossEntropy(mod):
        
    ## Sample all parameters together
    mod.stdNormalDistNS1 = torch.distributions.Normal(torch.zeros((mod.ns,mod.nw+2),dtype=torch.double),
                                                      torch.ones((mod.ns,mod.nw+2),dtype=torch.double))
    mod.stdNormalDistNS2 = torch.distributions.Normal(torch.zeros((mod.nc),dtype=torch.double),
                                                      torch.ones((mod.nc),dtype=torch.double))
    
    # Sample diagonal component
    paramSample = torch.mul(torch.sqrt(torch.exp(mod.paramVar)),mod.stdNormalDistNS1.sample())

    # Sample low-rank component
    diagC_sqrt = torch.sqrt(torch.exp(mod.C))
    z2_samples = mod.stdNormalDistNS2.sample(sample_shape=torch.Size([mod.ns]))
    lowrank_component = torch.matmul(mod.V,torch.mul(diagC_sqrt.unsqueeze(1),z2_samples.T))
    paramSample = paramSample + lowrank_component.T
    
    ceVal = 0.0
    
    # Calculate log prior density of lambdaSq multiplied by Jacobian determinant
    sampledLambda = torch.exp(0.5*(mod.scaleSqMu[0] + paramSample[:,0]))
    ceVal = (ceVal + (-1.0/mod.ns)*torch.sum(torch.log(torch.mul(0.5*sampledLambda,
            ((2.0/(math.pi)))/(1.0 + torch.square(sampledLambda))))))
    
    # Calculate log prior density of sigmaSq multiplied by Jacobian determinant
    sampledSigma = torch.exp(mod.scaleSqMu[1] + paramSample[:,1])
    ceVal = (ceVal + (-1.0/mod.ns)*torch.sum(torch.log(torch.mul(sampledSigma,
            ((2.0/(math.pi)))/(1.0 + torch.square(sampledSigma))))))
    
    # Calculate log prior density of bias
    sampledBeta = mod.betaMu[0] + paramSample[:,2]
    ceVal = ceVal + (-1.0/mod.ns)*torch.sum(-0.5*torch.square(sampledBeta))
    ceVal = ceVal + (-1.0/mod.ns)*(-1.0*mod.ns)*0.5*math.log(2.0*math.pi)
    
    # Calculate log prior density of regression coefficients
    sampledLambda = torch.exp((1.0-mod.xi)*(mod.scaleSqMu[0] + paramSample[:,0]))
    sampledBeta = mod.betaMu[1:].repeat(mod.ns,1) + paramSample[:,3:]
    priorVar = torch.unsqueeze(sampledLambda*sampledSigma,1).repeat(1,mod.nw-1)
    ceVal = (ceVal + (-1.0/mod.ns)*torch.sum(torch.mul(-0.5*torch.square(sampledBeta),
             torch.pow(priorVar,-1.0))))
    ceVal = (ceVal + (-1.0/mod.ns)*torch.sum(-0.5*(mod.nw-1)*(math.log(2.0*math.pi)
          + torch.log(sampledLambda*sampledSigma))))
    
    return ceVal


# Calculates the entropy
def entropy(mod):
    diagC = torch.diag(torch.exp(mod.C))
    entVal = (0.5*(mod.nw+2) + 0.5*(mod.nw+2)*math.log(2.0*math.pi) + 0.5*torch.sum(mod.paramVar)
           + 0.5*torch.logdet(torch.eye(mod.nc) + torch.matmul(torch.matmul(torch.transpose(mod.V,0,1),
             torch.diag(torch.pow(torch.exp(mod.paramVar),-1.0))),torch.matmul(mod.V,diagC)))
           + 0.5*torch.sum(mod.C))
    return entVal

In [3]:
# Trains the linear model using Stochastic Variational Inference
# maxEpochs is the maximum number of training epochs to use
def trainModel(mod,opt,maxEpochs,batchSize,intervalToPrint):
    numBatches = math.floor(mod.Y.size(dim=0)/batchSize)
    trainN = mod.Y.size(dim=0)
    testN = mod.testY.size(dim=0)
    tol = 0.0
    nelboIdx = 0
    nelboList = []
    nllList = []
    klList = []
    nelboList.append(sys.float_info.max)
    nllList.append(sys.float_info.max)
    klList.append(sys.float_info.max)
    for epoch in range(maxEpochs):
        idx = torch.randperm(mod.Y.size(dim=0))
        mod.Y = mod.Y[idx].view(mod.Y.shape)
        mod.X = mod.X[idx].view(mod.X.shape)
        for batch in range(numBatches):
            opt.zero_grad()
            NLLval = (1.0/batchSize)*NLL(mod,batch*batchSize,(batch + 1)*batchSize)
            entVal = (1.0/trainN)*entropy(mod)
            ceVal = (1.0/trainN)*crossEntropy(mod)
            loss = NLLval + ceVal - entVal
            loss.backward()
            opt.step()
            with torch.no_grad():
                # Prior parameterization must be between 0 and 1 inclusive
                mod.xi.clamp_(min = 0.0)
                mod.xi.clamp_(max = 1.0)
        if (epoch % intervalToPrint == 0):
            with torch.no_grad():
                NLLval = NLL(mod,0,trainN)
                entVal = entropy(mod)
                ceVal = crossEntropy(mod)
                priorKL = ceVal - entVal
                trainNELBO = NLLval + priorKL
                nelboList.append(trainNELBO)
                nllList.append(NLLval)
                klList.append(priorKL)
                nelboIdx = nelboIdx + 1
                weights = sampleWeights(mod)
                predTrainY = torch.matmul(mod.X,weights)
                predTestY = torch.matmul(mod.testX,weights)
                trainMSE = torch.mean(torch.square(mod.Y - predTrainY))
                testMSE = torch.mean(torch.square(mod.testY - predTestY))
                
            print('Epoch {}, Neg Train ELBO {}, PriorKL {}, Train MSE {}, Test MSE {}'.format(epoch,trainNELBO,priorKL,trainMSE,testMSE))
            # Calculate marginal variances using diagonal plus low rank covariance structure
            marginalVar = torch.exp(mod.paramVar) + torch.sum(torch.square(mod.V)*torch.exp(mod.C).unsqueeze(0),dim=1)
            print('Lambda Sq: {:.4e}'.format(torch.exp(mod.scaleSqMu[0] + 0.5*marginalVar[0])))
            print('Sigma Sq: {:.4f}'.format(torch.exp(mod.scaleSqMu[1] + 0.5*marginalVar[1])))
            print(mod.xi)
            
    return mod,nelboList,nllList,klList

In [4]:
trainFeatures = pd.read_csv("trainImagingFeatures.csv", index_col = False)
trainFeatures = trainFeatures.drop(trainFeatures.columns[0], axis = 1)
trainFeatures = torch.tensor(trainFeatures.values, dtype = torch.double)
trainFeatures = torch.cat((torch.ones((trainFeatures.size(dim=0),1)),trainFeatures),1)

testFeatures = pd.read_csv("testImagingFeatures.csv", index_col = False)
testFeatures = testFeatures.drop(testFeatures.columns[0], axis = 1)
testFeatures = torch.tensor(testFeatures.values, dtype = torch.double)
testFeatures = torch.cat((torch.ones((testFeatures.size(dim=0),1)),testFeatures),1)

trainResponse = pd.read_csv("train_y_residualized.csv", index_col = False)
trainResponse = trainResponse.drop(trainResponse.columns[0], axis = 1)
trainResponse = torch.tensor(trainResponse.values, dtype = torch.double)

testResponse = pd.read_csv("test_y_residualized.csv", index_col = False)
testResponse = testResponse.drop(testResponse.columns[0], axis = 1)
testResponse = torch.tensor(testResponse.values, dtype = torch.double)

print(trainFeatures.shape)
print(testFeatures.shape)
print(trainResponse.shape)
print(testResponse.shape)

print(f"Training size: {trainFeatures.shape[0]}")
print(f"Testing size: {testFeatures.shape[0]}")

print(torch.std(trainResponse))
print(torch.std(testResponse))

torch.Size([2700, 1772])
torch.Size([300, 1772])
torch.Size([2700, 1])
torch.Size([300, 1])
Training size: 2700
Testing size: 300
tensor(1.1458, dtype=torch.float64)
tensor(1.1787, dtype=torch.float64)


In [None]:
intervalToPrint = 10
batchSize = 256
numComponents = 16
numCESamples = 128
numNLLSamples = 2
xiVal = 1.0

# initialize model
model = Model(trainFeatures,testFeatures,trainResponse,testResponse,numComponents,numCESamples,numNLLSamples,xiVal)

# first train at low batch size
maxEpochs = 50
optimizer = torch.optim.Adam(model.parameters(),lr = 0.025)
model,nelboList,nllList,klList = trainModel(model,optimizer,maxEpochs,batchSize,intervalToPrint)

# Save
torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'nelboList': nelboList,
            'nllList': nllList,
            'klList': klList},'model_and_metrics.pt')

In [None]:
intervalToPrint = 10
batchSize = 512
numComponents = 16
numCESamples = 128
numNLLSamples = 2
xiVal = 1.0

# Load model
checkpoint = torch.load('model_and_metrics.pt')
model = Model(trainFeatures,testFeatures,trainResponse,testResponse,
              numComponents,numCESamples,numNLLSamples,xiVal)
model.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.Adam(model.parameters(),lr = 0.005)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

nelboList = checkpoint['nelboList']
nllList = checkpoint['nllList']
klList = checkpoint['klList']

# Continue training model
maxEpochs = 60
model,nelboList,nllList,klList = trainModel(model,optimizer,maxEpochs,batchSize,intervalToPrint)

# Save
torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'nelboList': nelboList,
            'nllList': nllList,
            'klList': klList},'model_and_metrics.pt')

In [5]:
intervalToPrint = 10
batchSize = 512
numComponents = 16
numCESamples = 128
numNLLSamples = 2
xiVal = 1.0

# Load model
checkpoint = torch.load('model_and_metrics.pt')
model = Model(trainFeatures,testFeatures,trainResponse,testResponse,
              numComponents,numCESamples,numNLLSamples,xiVal)
model.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.Adam(model.parameters(),lr = 0.005)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Sample the true regression parameters (accounting for partial centering) using the q density
posterior_mean,ci_lower,ci_upper = compute_posterior_means(model,num_samples=4000)

# Convert to numpy
mean_np = posterior_mean.squeeze().cpu().numpy()
lower_np = ci_lower.squeeze().cpu().numpy()
upper_np = ci_upper.squeeze().cpu().numpy()

# Save posterior means
np.savetxt('posterior_mean_svi.csv',mean_np,delimiter=',')
np.savetxt('posterior_ci_lower_svi.csv',lower_np,delimiter=',')
np.savetxt('posterior_ci_upper_svi.csv',upper_np,delimiter=',')