In [1]:
import sys
import time
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.autograd import grad
from torch import nn
from torch.functional import F

import torch.nn.utils.parametrize as parametrize
from torch.nn.utils.parametrizations import _Orthogonal

In [2]:
# User defined model class that inherits from Pytorch's neural network module
# NumComponents is the number of low rank components to use in the low rank
# part of the diagonal plus low rank covariance matrix.
# numCESamples is the number of stochastic samples of the cross entropy
# numNLLSamples is the number of stochastic samples of the negative log likelihood.
class Model(nn.Module):
    def __init__(self,trainX,testX,trainY,testY,numComponents,numCESamples,numNLLSamples,xiVal):
        super().__init__()
        
        # Input data
        self.X = trainX
        self.testX = testX
        self.Y = trainY
        self.testY = testY
        self.trainN = self.Y.size(dim=0)
        self.testN = self.testY.size(dim=0)
        self.nw = self.X.size(dim=1)
        self.nc = numComponents
        self.ns = numCESamples
        self.nsNLL = numNLLSamples
        
        # Regression parameter means
        betaInit = torch.squeeze(torch.distributions.Uniform(-0.0001,0.0001).sample((1,self.nw)).double())
        self.betaMu = nn.Parameter(betaInit)
        self.scaleSqMu = nn.Parameter(torch.tensor([math.log(0.000025)-0.5,-0.5],dtype=torch.double))
        self.xi = nn.Parameter(torch.tensor(xiVal,dtype=torch.double))
        
        # Diagonal component of parameter covariance
        self.paramVar = nn.Parameter(torch.ones((self.nw+2),dtype=torch.double)*0.0001)
        
        # Low rank covariance of all parameters
        self.V = nn.Parameter(torch.distributions.Uniform(-0.0001,0.0001).sample((self.nw+2,self.nc)).double())
        
        # Diagonal matrix of eigenvalues for low rank component
        self.C = nn.Parameter(torch.ones(self.nc,dtype=torch.double)*-10.0)
        
        # Apply orthogonal parametrization to V
        # Use householder reflections instead of the matrix exponential or Cayley map as householder
        # reflections are computationally cheaper
        parametrize.register_parametrization(self,"V", _Orthogonal(self.V, orthogonal_map="householder"))
        
        # Scale parameter priors
        self.lambdaPrior = torch.tensor(1.0,dtype=torch.double)
        self.regVarPrior = torch.tensor(1.0,dtype=torch.double)
        
        # Training metrics
        self.listTrainMSE = []
        self.listTestMSE = []
        self.listTrainR2 = []
        self.listTestR2 = []
        

# Samples the weights wrt to the q density
def sampleWeights(mod):
    diagC = torch.diag(torch.exp(mod.C))
    mod.paramCov = torch.diag(torch.exp(mod.paramVar)) + torch.matmul(torch.matmul(mod.V,diagC),torch.transpose(mod.V,0,1))
    lsCov = mod.paramCov[0:2,0:2]
     
    predWeights = 0.0
    for i in range(mod.ns):
        
        # Sample diagonal lambdaSq and sigmaSq
        mod.stdNormalDistNP = torch.distributions.MultivariateNormal(mod.scaleSqMu, covariance_matrix = lsCov)
        lsParams = mod.stdNormalDistNP.rsample()
        # Sample weights
        predBeta = (torch.unsqueeze(mod.betaMu,1) + torch.matmul(mod.paramCov[2:,0:2],
                    torch.matmul(torch.inverse(lsCov),torch.unsqueeze(lsParams,1) - torch.unsqueeze(mod.scaleSqMu,1))))
        scaleParams = torch.unsqueeze(torch.cat(((torch.unsqueeze(torch.tensor(1.0,dtype=torch.double),0)).clone(),
                      torch.exp(0.5*mod.xi*lsParams[0])*torch.ones((mod.nw-1))),0),1)
        predWeights = predWeights + torch.mul(predBeta,scaleParams)
        
    return (1.0/mod.ns)*predWeights


# Calculates the negative log likelihood
def NLL(mod,startBatch,endBatch):
    diagC = torch.diag(torch.exp(mod.C))
    diagC_sqrt = torch.diag(torch.sqrt(torch.exp(mod.C)))
    mod.paramCov = torch.diag(torch.exp(mod.paramVar)) + torch.matmul(torch.matmul(mod.V,diagC),torch.transpose(mod.V,0,1))
    lsCov = mod.paramCov[0:2,0:2]
    
#     K = (torch.eye(mod.nc) + torch.matmul(torch.matmul(torch.transpose(mod.V[0:2,:],0,1),
#          torch.diag(torch.pow(torch.exp(mod.paramVar[0:2]),-1.0))),torch.matmul(mod.V[0:2,:],diagC)))
    
    K = (torch.diag(torch.pow(torch.exp(mod.C), -1.0))
         + torch.matmul(torch.matmul(torch.transpose(mod.V[0:2,:],0,1),
                                     torch.diag(torch.pow(torch.exp(mod.paramVar[0:2]),-1.0))),mod.V[0:2,:]))
    
    LK = torch.linalg.cholesky(K)
    mod.stdNormalDistNP = torch.distributions.MultivariateNormal(mod.scaleSqMu,covariance_matrix = lsCov)
    
    nll = 0.0
    for i in range(mod.nsNLL):
        
        # Sample diagonal lambdaSq and sigmaSq
        lsParams = mod.stdNormalDistNP.rsample(sample_shape=torch.Size([endBatch-startBatch]))
        
        ### Calculate pre-activation distribution
        # calculate mean beta conditional on sampled scale parameters
        predBeta = (torch.unsqueeze(mod.betaMu,1) + torch.matmul(mod.paramCov[2:,0:2],
                    torch.matmul(torch.inverse(lsCov),torch.transpose(lsParams,0,1)
                 - torch.unsqueeze(mod.scaleSqMu,1).repeat(1,endBatch-startBatch))))
        scaleData = torch.cat(((torch.unsqueeze(mod.X[startBatch:endBatch,0],1)).clone(),
                    torch.unsqueeze(torch.exp(0.5*mod.xi*lsParams[:,0]),1)*mod.X[startBatch:endBatch,1:]),1)
        predMu = torch.sum(torch.mul(torch.transpose(predBeta,0,1),scaleData),dim=1)
        
#         Y = torch.linalg.solve(torch.transpose(LK,0,1),torch.matmul(torch.matmul(torch.transpose(mod.V[0:2,:],0,1),
#             torch.diag(torch.pow(torch.exp(mod.paramVar[0:2]),-1))),torch.matmul(torch.matmul(mod.V[0:2,:],diagC),
#             torch.matmul(torch.transpose(mod.V[2:,:],0,1),torch.transpose(scaleData,0,1)))))
        
        Y = torch.linalg.solve(torch.transpose(LK,0,1),
                               torch.matmul(torch.matmul(torch.matmul(torch.transpose(mod.V[0:2,:],0,1),
                                                                      torch.diag(torch.pow(torch.exp(mod.paramVar[0:2]),-1))),
                                                         torch.matmul(mod.V[0:2,:], diagC_sqrt)),
                                            torch.matmul(torch.matmul(diagC_sqrt, torch.transpose(mod.V[2:,:],0,1)),
                                                         torch.transpose(scaleData,0,1))))
        
        # Standard deviation calculation with updated low-rank component
        # Use sqrt(C) instead of C when calculating standard deviation of pre-activations
#         sdB = torch.sqrt(torch.sum(torch.square(torch.matmul(torch.diag(torch.exp(0.5*mod.paramVar[2:])),
#               torch.transpose(scaleData,0,1))),dim=0)
#             + torch.sum(torch.square(torch.matmul(torch.matmul(diagC_sqrt,torch.transpose(mod.V[2:,:],0,1)),
#               torch.transpose(scaleData,0,1))),dim=0) - torch.sum(torch.square(Y),dim=0))
        
        # Term 1: ||A_β^{1/2} · scaleData^T||²
        term1 = torch.sum(torch.square(torch.matmul(torch.diag(torch.exp(0.5*mod.paramVar[2:])),
                                                    torch.transpose(scaleData,0,1))),dim=0)

        # Term 2: ||Ṽ_β^T · scaleData^T||²
        term2 = torch.sum(torch.square(torch.matmul(torch.matmul(diagC_sqrt,torch.transpose(mod.V[2:,:],0,1)),
                                                    torch.transpose(scaleData,0,1))),dim=0)

        # Term 3: -||A_Λ^{-1/2} · Ṽ_Λ · Ṽ_β^T · scaleData^T||²
        temp3 = torch.matmul(torch.matmul(diagC_sqrt,torch.transpose(mod.V[2:,:],0,1)),
                             torch.transpose(scaleData,0,1))  # Ṽ_β^T · scaleData^T
        temp3 = torch.matmul(torch.matmul(mod.V[0:2,:],diagC_sqrt),temp3)  # Ṽ_Λ · Ṽ_β^T · scaleData^T
        temp3 = torch.matmul(torch.diag(torch.pow(torch.exp(mod.paramVar[0:2]),-0.5)),temp3)  # A_Λ^{-1/2} · ...
        term3 = torch.sum(torch.square(temp3),dim=0)

        # Term 4: +||Y||²
        term4 = torch.sum(torch.square(Y), dim=0)

        # Combine: note the signs!
        sdB = torch.sqrt(term1 + term2 - term3 + term4)
        
        # Sample pre-activation and calculate negative log-likelihood
        mod.stdNormalDistB = torch.distributions.Normal(predMu,sdB)
        sampleB = mod.stdNormalDistB.rsample()
        trainResiduals = mod.Y[startBatch:endBatch,:] - torch.unsqueeze(sampleB,1)
        SE = torch.square(trainResiduals)
        # Calculate log likelihood
        nll = (nll + torch.sum(torch.mul(0.5*torch.unsqueeze(torch.exp(-1.0*lsParams[:,1]),1),SE)
            + 0.5*(math.log(2.0*math.pi) + torch.unsqueeze(lsParams[:,1],1))))
        
    return (1.0/mod.nsNLL)*nll


# Calculates the cross entropy
def crossEntropy(mod):
        
    ## Sample all parameters together
    mod.stdNormalDistNS1 = torch.distributions.Normal(torch.zeros((mod.ns,mod.nw+2),dtype=torch.double),
                                                      torch.ones((mod.ns,mod.nw+2),dtype=torch.double))
    mod.stdNormalDistNS2 = torch.distributions.Normal(torch.zeros((mod.nc),dtype=torch.double),
                                                      torch.ones((mod.nc),dtype=torch.double))
    
    # Sample all parameters with C matrix - VECTORIZED
    diagC_sqrt = torch.sqrt(torch.exp(mod.C))

    # Sample diagonal component
    paramSample = torch.mul(torch.sqrt(torch.exp(mod.paramVar)), mod.stdNormalDistNS1.sample())

    # Sample low-rank component (vectorized)
    z2_samples = mod.stdNormalDistNS2.sample(sample_shape=torch.Size([mod.ns]))
    lowrank_component = torch.matmul(mod.V, torch.mul(diagC_sqrt.unsqueeze(1), z2_samples.T))
    paramSample = paramSample + lowrank_component.T
    
    ceVal = 0.0
    
    # Calculate log prior density of lambdaSq multiplied by Jacobian determinant
    sampledLambda = torch.exp(0.5*(mod.scaleSqMu[0] + paramSample[:,0]))
    ceVal = (ceVal + (-1.0/mod.ns)*torch.sum(torch.log(torch.mul(0.5*sampledLambda,
            ((2.0/(math.pi)))/(1.0 + torch.square(sampledLambda))))))
    
    # Calculate log prior density of sigmaSq multiplied by Jacobian determinant
    sampledSigma = torch.exp(mod.scaleSqMu[1] + paramSample[:,1])
    ceVal = (ceVal + (-1.0/mod.ns)*torch.sum(torch.log(torch.mul(sampledSigma,
            ((2.0/(math.pi)))/(1.0 + torch.square(sampledSigma))))))
    
    # Calculate log prior density of bias
    sampledBeta = mod.betaMu[0] + paramSample[:,2]
    ceVal = ceVal + (-1.0/mod.ns)*torch.sum(-0.5*torch.square(sampledBeta))
    ceVal = ceVal + (-1.0/mod.ns)*(-1.0*mod.ns)*0.5*math.log(2.0*math.pi)
    
    # Calculate log prior density of regression coefficients
    sampledLambda = torch.exp((1.0-mod.xi)*(mod.scaleSqMu[0] + paramSample[:,0]))
    sampledBeta = mod.betaMu[1:].repeat(mod.ns,1) + paramSample[:,3:]
    priorVar = torch.unsqueeze(sampledLambda*sampledSigma,1).repeat(1,mod.nw-1)
    ceVal = (ceVal + (-1.0/mod.ns)*torch.sum(torch.mul(-0.5*torch.square(sampledBeta),
             torch.pow(priorVar,-1.0))))
    ceVal = (ceVal + (-1.0/mod.ns)*torch.sum(-0.5*(mod.nw-1)*(math.log(2.0*math.pi)
          + torch.log(sampledLambda*sampledSigma))))
    
    return ceVal


# Calculates the entropy
def entropy(mod):
    diagC = torch.diag(torch.exp(mod.C))
    entVal = (0.5*(mod.nw+2) + 0.5*(mod.nw+2)*math.log(2.0*math.pi) + 0.5*torch.sum(mod.paramVar)
           + 0.5*torch.logdet(torch.eye(mod.nc) + torch.matmul(torch.matmul(torch.transpose(mod.V,0,1),
             torch.diag(torch.pow(torch.exp(mod.paramVar),-1.0))),torch.matmul(mod.V, diagC)))
           + 0.5*torch.sum(mod.C))
    return entVal

In [3]:
# Trains the linear model using Stochastic Variational Inference
# maxEpochs is the maximum number of training epochs to use
def trainModel(mod,opt,maxEpochs,batchSize,intervalToPrint):
    numBatches = math.floor(mod.Y.size(dim=0)/batchSize)
    trainN = mod.Y.size(dim=0)
    testN = mod.testY.size(dim=0)
    tol = 0.0
    nelboIdx = 0
    nelboList = []
    nllList = []
    klList = []
    nelboList.append(sys.float_info.max)
    nllList.append(sys.float_info.max)
    klList.append(sys.float_info.max)
    for epoch in range(maxEpochs):
        idx = torch.randperm(mod.Y.size(dim=0))
        mod.Y = mod.Y[idx].view(mod.Y.shape)
        mod.X = mod.X[idx].view(mod.X.shape)
        for batch in range(numBatches):
            opt.zero_grad()
            NLLval = (1.0/batchSize)*NLL(mod,batch*batchSize,(batch + 1)*batchSize)
            entVal = (1.0/trainN)*entropy(mod)
            ceVal = (1.0/trainN)*crossEntropy(mod)
            loss = NLLval + ceVal - entVal
            loss.backward()
            opt.step()
            with torch.no_grad():
                mod.xi.clamp_(min = 0.0)
                mod.xi.clamp_(max = 1.0)
        if (epoch % intervalToPrint == 0):
            with torch.no_grad():
                NLLval = NLL(mod,0,trainN)
                entVal = entropy(mod)
                ceVal = crossEntropy(mod)
                priorKL = ceVal - entVal
                trainNELBO = NLLval + priorKL
                nelboList.append(trainNELBO)
                nllList.append(NLLval)
                klList.append(priorKL)
                nelboIdx = nelboIdx + 1
                weights = sampleWeights(mod)
                predTrainY = torch.matmul(mod.X,weights)
                predTestY = torch.matmul(mod.testX,weights)
                trainMSE = torch.mean(torch.square(mod.Y - predTrainY))
                testMSE = torch.mean(torch.square(mod.testY - predTestY))
                
            print('Epoch {}, Neg Train ELBO {}, PriorKL {}, Train MSE {}, Test MSE {}'.format(epoch,trainNELBO,priorKL,trainMSE,testMSE))
            # Calculate marginal variances using diagonal plus low rank covariance structure
            marginalVar = torch.exp(mod.paramVar) + torch.sum(torch.square(mod.V)*torch.exp(mod.C).unsqueeze(0),dim=1)
            print('Lambda Sq: {:.4e}'.format(torch.exp(mod.scaleSqMu[0] + 0.5*marginalVar[0])))
            print('Sigma Sq: {:.4f}'.format(torch.exp(mod.scaleSqMu[1] + 0.5*marginalVar[1])))
            print(mod.xi)
            
    return mod, nelboList, nllList, klList

In [4]:
trainFeatures = pd.read_csv("trainImagingFeatures.csv", index_col = False)
trainFeatures = trainFeatures.drop(trainFeatures.columns[0], axis = 1)
trainFeatures = torch.tensor(trainFeatures.values, dtype = torch.double)
trainFeatures = torch.cat((torch.ones((trainFeatures.size(dim=0),1)),trainFeatures),1)

testFeatures = pd.read_csv("testImagingFeatures.csv", index_col = False)
testFeatures = testFeatures.drop(testFeatures.columns[0], axis = 1)
testFeatures = torch.tensor(testFeatures.values, dtype = torch.double)
testFeatures = torch.cat((torch.ones((testFeatures.size(dim=0),1)),testFeatures),1)

trainResponse = pd.read_csv("train_y_residualized.csv", index_col = False)
trainResponse = trainResponse.drop(trainResponse.columns[0], axis = 1)
trainResponse = torch.tensor(trainResponse.values, dtype = torch.double)

testResponse = pd.read_csv("test_y_residualized.csv", index_col = False)
testResponse = testResponse.drop(testResponse.columns[0], axis = 1)
testResponse = torch.tensor(testResponse.values, dtype = torch.double)

print(trainFeatures.shape)
print(testFeatures.shape)
print(trainResponse.shape)
print(testResponse.shape)

print(f"Training size: {trainFeatures.shape[0]}")
print(f"Testing size: {testFeatures.shape[0]}")

print(torch.std(trainResponse))
print(torch.std(testResponse))

torch.Size([2700, 1772])
torch.Size([300, 1772])
torch.Size([2700, 1])
torch.Size([300, 1])
Training size: 2700
Testing size: 300
tensor(1.1458, dtype=torch.float64)
tensor(1.1787, dtype=torch.float64)


In [5]:
intervalToPrint = 1
batchSize = 256
numComponents = 16
numCESamples = 128
numNLLSamples = 2
xiVal = 1.0

# initialize model
model = Model(trainFeatures,testFeatures,trainResponse,testResponse,numComponents,numCESamples,numNLLSamples,xiVal)

# first train at low batch size
maxEpochs = 50
optimizer = torch.optim.Adam(model.parameters(), lr = 0.025)
model, nelboList, nllList, klList = trainModel(model,optimizer,maxEpochs,batchSize,intervalToPrint)

# Save
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'nelboList': nelboList,
    'nllList': nllList,
    'klList': klList
}, 'model_and_metrics.pt')

Epoch 0, Neg Train ELBO 20288.542787008315, PriorKL 494.4406257456494, Train MSE 1.363782214889798, Test MSE 1.4929660811742518
Lambda Sq: 1.7747e-05
Sigma Sq: 1.1463
Parameter containing:
tensor(1., dtype=torch.float64, requires_grad=True)
Epoch 1, Neg Train ELBO 12180.761135040897, PriorKL 433.5314441017449, Train MSE 1.2856789693633868, Test MSE 1.3957823835584353
Lambda Sq: 1.3780e-05
Sigma Sq: 1.3030
Parameter containing:
tensor(1., dtype=torch.float64, requires_grad=True)
Epoch 2, Neg Train ELBO 9688.45430742906, PriorKL 343.35170319795657, Train MSE 1.2355012337433129, Test MSE 1.3685945697974877
Lambda Sq: 1.1587e-05
Sigma Sq: 1.4467
Parameter containing:
tensor(1., dtype=torch.float64, requires_grad=True)
Epoch 3, Neg Train ELBO 8037.961411497241, PriorKL 389.9760538401838, Train MSE 1.2250775996489203, Test MSE 1.390805841350111
Lambda Sq: 1.0216e-05
Sigma Sq: 1.5685
Parameter containing:
tensor(1., dtype=torch.float64, requires_grad=True)
Epoch 4, Neg Train ELBO 7389.8147776

Epoch 34, Neg Train ELBO 4834.790846381227, PriorKL 240.86625766548104, Train MSE 1.1156409387732793, Test MSE 1.2347283153127615
Lambda Sq: 3.4853e-06
Sigma Sq: 1.8677
Parameter containing:
tensor(1., dtype=torch.float64, requires_grad=True)
Epoch 35, Neg Train ELBO 4874.762756218592, PriorKL 222.91848758413198, Train MSE 1.1187860066730753, Test MSE 1.230432081177124
Lambda Sq: 3.4023e-06
Sigma Sq: 1.8519
Parameter containing:
tensor(1., dtype=torch.float64, requires_grad=True)
Epoch 36, Neg Train ELBO 4789.906200687135, PriorKL 215.2807604764639, Train MSE 1.1154293512732247, Test MSE 1.225597430866027
Lambda Sq: 3.3253e-06
Sigma Sq: 1.8345
Parameter containing:
tensor(1., dtype=torch.float64, requires_grad=True)
Epoch 37, Neg Train ELBO 4797.832420676283, PriorKL 225.18075870510393, Train MSE 1.1167834149629352, Test MSE 1.2280243556820076
Lambda Sq: 3.2489e-06
Sigma Sq: 1.8175
Parameter containing:
tensor(1., dtype=torch.float64, requires_grad=True)
Epoch 38, Neg Train ELBO 4754.1

In [8]:
intervalToPrint = 1
batchSize = 256
numComponents = 16
numCESamples = 128
numNLLSamples = 2
xiVal = 1.0

# Load and continue
checkpoint = torch.load('model_and_metrics.pt')
model = Model(trainFeatures, testFeatures, trainResponse, testResponse, 
              numComponents, numCESamples, numNLLSamples, xiVal)
model.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

nelboList = checkpoint['nelboList']
nllList = checkpoint['nllList']
klList = checkpoint['klList']

# continue training model
maxEpochs = 100
model, nelboList, nllList, klList = trainModel(model,optimizer,maxEpochs,batchSize,intervalToPrint)

# Save
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'nelboList': nelboList,
    'nllList': nllList,
    'klList': klList
}, 'model_and_metrics.pt')

Epoch 0, Neg Train ELBO 4147.082245518507, PriorKL 84.04382353108622, Train MSE 1.0978926517934993, Test MSE 1.2003623815941837
Lambda Sq: 4.4965e-07
Sigma Sq: 1.1802
Parameter containing:
tensor(0.9936, dtype=torch.float64, requires_grad=True)
Epoch 1, Neg Train ELBO 4131.212215905014, PriorKL 81.63837020289111, Train MSE 1.101508227322456, Test MSE 1.2068455531366562
Lambda Sq: 4.4975e-07
Sigma Sq: 1.1769
Parameter containing:
tensor(0.9969, dtype=torch.float64, requires_grad=True)
Epoch 2, Neg Train ELBO 4140.45716872624, PriorKL 84.84582111286409, Train MSE 1.1000410661468527, Test MSE 1.1969576987021988
Lambda Sq: 4.4652e-07
Sigma Sq: 1.1744
Parameter containing:
tensor(0.9929, dtype=torch.float64, requires_grad=True)
Epoch 3, Neg Train ELBO 4141.616212335785, PriorKL 83.82047883943324, Train MSE 1.1056672289015876, Test MSE 1.20926620129441
Lambda Sq: 4.4190e-07
Sigma Sq: 1.1790
Parameter containing:
tensor(0.9981, dtype=torch.float64, requires_grad=True)
Epoch 4, Neg Train ELBO 

Epoch 34, Neg Train ELBO 4145.788965278867, PriorKL 78.28092633097049, Train MSE 1.100824831286708, Test MSE 1.200830015701004
Lambda Sq: 3.8206e-07
Sigma Sq: 1.1714
Parameter containing:
tensor(0.9948, dtype=torch.float64, requires_grad=True)
Epoch 35, Neg Train ELBO 4135.328594891378, PriorKL 86.54782422943481, Train MSE 1.1068852046755453, Test MSE 1.2103106869325837
Lambda Sq: 3.7969e-07
Sigma Sq: 1.1750
Parameter containing:
tensor(1., dtype=torch.float64, requires_grad=True)
Epoch 36, Neg Train ELBO 4118.092805120903, PriorKL 79.46715916317862, Train MSE 1.10155609873059, Test MSE 1.2022439869178736
Lambda Sq: 3.7851e-07
Sigma Sq: 1.1696
Parameter containing:
tensor(0.9960, dtype=torch.float64, requires_grad=True)
Epoch 37, Neg Train ELBO 4116.12565418544, PriorKL 77.1464694055826, Train MSE 1.100730787612541, Test MSE 1.1989493110360439
Lambda Sq: 3.7709e-07
Sigma Sq: 1.1678
Parameter containing:
tensor(0.9954, dtype=torch.float64, requires_grad=True)
Epoch 38, Neg Train ELBO 41

ValueError: Expected parameter scale (Tensor of shape (256,)) of distribution Normal(loc: torch.Size([256]), scale: torch.Size([256])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([0.2134, 0.0637, 0.0341, 0.1736, 0.2258, 0.1781, 0.1566, 0.2153, 0.0917,
        0.1952, 0.0949, 0.1218, 0.1079, 0.2080, 0.1523, 0.1277, 0.2200,    nan,
        0.2126, 0.2286, 0.1045, 0.2360, 0.1011, 0.1604, 0.0187, 0.2087, 0.1234,
        0.0924,    nan, 0.2394, 0.0407, 0.1377, 0.1048, 0.0410, 0.1345, 0.2012,
        0.2275, 0.1445, 0.1743, 0.2055, 0.0665, 0.1865, 0.1091, 0.1397, 0.1399,
        0.1407, 0.2157, 0.1841, 0.0830, 0.1611, 0.1287, 0.1384, 0.2148, 0.1534,
        0.1204, 0.1672, 0.2103, 0.0426, 0.2361, 0.1973, 0.1717, 0.1208, 0.0600,
        0.1750, 0.1595, 0.2302, 0.0804, 0.1941, 0.2685, 0.1439, 0.0784, 0.2066,
        0.0713, 0.1413, 0.1231, 0.1986, 0.0732, 0.1241, 0.1308, 0.0226, 0.1724,
        0.2342, 0.1574, 0.2620, 0.1446, 0.1685, 0.1530, 0.1332, 0.1932, 0.0914,
           nan, 0.1821, 0.1310, 0.1863, 0.2052, 0.1728, 0.2074, 0.1846, 0.2012,
        0.1619, 0.1337, 0.1600, 0.2266,    nan, 0.1329, 0.0535, 0.1121, 0.1667,
        0.1346,    nan, 0.1055, 0.1644, 0.1445, 0.1899, 0.2365, 0.1826, 0.1031,
        0.2797, 0.1497, 0.1590, 0.1280, 0.1009, 0.1942, 0.2357, 0.1192, 0.2061,
        0.1877, 0.1850, 0.1029, 0.1110, 0.1785, 0.1302, 0.0653, 0.2315, 0.0826,
        0.1567, 0.1726, 0.1715, 0.1803, 0.1448, 0.1306, 0.1397, 0.1089, 0.1047,
        0.3595, 0.2233, 0.2161, 0.0677, 0.1382, 0.2094, 0.1988, 0.0703, 0.2300,
        0.1929, 0.1190, 0.1484, 0.0714, 0.1244, 0.1868, 0.2348, 0.1648, 0.1576,
        0.1304, 0.0359, 0.0986, 0.2454, 0.2426, 0.1331, 0.1546, 0.1324, 0.1312,
        0.1698, 0.1367, 0.2104, 0.1810, 0.1358, 0.1384, 0.2697, 0.1132, 0.2068,
        0.1645, 0.1450, 0.1610, 0.2327, 0.1521, 0.0828, 0.0906, 0.0863, 0.0459,
        0.1490, 0.1958, 0.1794, 0.1723, 0.0848, 0.0497, 0.2804, 0.1378, 0.1416,
        0.2038, 0.2458, 0.0446, 0.1720, 0.1728, 0.1674, 0.1136, 0.1161, 0.1526,
        0.1388, 0.0884, 0.1511, 0.0650, 0.1406, 0.2239, 0.2252, 0.0342, 0.0354,
        0.1485, 0.2456, 0.0758, 0.2696, 0.2156, 0.1619, 0.0793, 0.0484, 0.1702,
        0.1024, 0.0826, 0.1937, 0.3427, 0.1640, 0.1009, 0.1984, 0.1622, 0.1500,
        0.1467, 0.1895, 0.0737,    nan, 0.0991, 0.0895, 0.1135, 0.1347, 0.0948,
        0.2136, 0.1338, 0.1944, 0.0848, 0.1518, 0.2205, 0.0768, 0.1096, 0.3377,
        0.2302, 0.0646, 0.1649, 0.1875], dtype=torch.float64,
       grad_fn=<SqrtBackward0>)