Skip to content

Commit

Permalink
Merge pull request #8 from jdtuck/refactor
Browse files Browse the repository at this point in the history
Refactor to multiple files
  • Loading branch information
dfrancom committed May 8, 2024
2 parents 0a95e50 + 801dea4 commit 9fdc9de
Show file tree
Hide file tree
Showing 9 changed files with 812 additions and 167 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/Build.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: Build

on: [push]
on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

jobs:
build:
Expand All @@ -9,8 +13,8 @@ jobs:
continue-on-error: ${{ matrix.os == 'windows-latest' }}
strategy:
matrix:
os: [ubuntu-20.04, macos-latest, windows-latest]
python-version: [3.6, 3.7, 3.8, 3.9, '3.10']
os: [ubuntu-22.04, macos-latest, windows-latest]
python-version: [3.7, 3.8, 3.9, '3.10']

steps:
- uses: actions/checkout@v3
Expand All @@ -21,6 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install .
pip install pytest
- name: Test with pytest
Expand Down
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# pyBASS 0.3.1.9999
* added sobol for Basis

# pyBASS 0.3.1

# pyBASS 0.3.2
* initial version of package
160 changes: 13 additions & 147 deletions pyBASS/pyBASS.py → pyBASS/BASS.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,143 +18,12 @@

import numpy as np
import scipy as sp
from scipy import stats
import matplotlib.pyplot as plt
from itertools import combinations, chain
from scipy.special import comb
import pyBASS.utils as uf
from collections import namedtuple
#from pathos.multiprocessing import ProcessingPool as Pool
from multiprocessing import Pool
import time


def abline(slope, intercept):
"""Plot a line from slope and intercept"""
axes = plt.gca()
x_vals = np.array(axes.get_xlim())
y_vals = intercept + slope * x_vals
plt.plot(x_vals, y_vals, '--', color='red')


pos = lambda a: (abs(a) + a) / 2 # same as max(0,a)


def const(signs, knots):
"""Get max value of BASS basis function, assuming 0-1 range of inputs"""
cc = np.prod(((signs + 1) / 2 - signs * knots))
if cc == 0:
return 1
return cc


def makeBasis(signs, vs, knots, xdata):
"""Make basis function using continuous variables"""
cc = const(signs, knots)
temp1 = pos(signs * (xdata[:, vs] - knots))
if len(signs) == 1:
return temp1 / cc
temp2 = np.prod(temp1, axis=1) / cc
return temp2


def normalize(x, bounds):
"""Normalize to 0-1 scale"""
return (x - bounds[:, 0]) / (bounds[:, 1] - bounds[:, 0])


def unnormalize(z, bounds):
"""Inverse of normalize"""
return z * (bounds[:, 1] - bounds[:, 0]) + bounds[:, 0]


def comb_index(n, k):
"""Get all combinations of indices from 0:n of length k"""
# https://stackoverflow.com/questions/16003217/n-d-version-of-itertools-combinations-in-numpy
count = comb(n, k, exact=True)
index = np.fromiter(chain.from_iterable(combinations(range(n), k)),
int, count=count * k)
return index.reshape(-1, k)


def dmwnchBass(z_vec, vars_use):
"""Multivariate Walenius' noncentral hypergeometric density function with some variables fixed"""
with np.errstate(divide='ignore'):
alpha = z_vec[vars_use - 1] / sum(np.delete(z_vec, vars_use))
j = len(alpha)
ss = 1 + (-1) ** j * 1 / (sum(alpha) + 1)
for i in range(j - 1):
idx = comb_index(j, i + 1)
temp = alpha[idx]
ss = ss + (-1) ** (i + 1) * sum(1 / (temp.sum(axis=1) + 1))
return ss


Qf = namedtuple('Qf', 'R bhat qf')

def getQf(XtX, Xty):
"""Get the quadratic form y'X solve(X'X) X'y, as well as least squares beta and cholesky of X'X"""
try:
R = sp.linalg.cholesky(XtX, lower=False) # might be a better way to do this with sp.linalg.cho_factor
except np.linalg.LinAlgError as e:
return None
dr = np.diag(R)
if len(dr) > 1:
if max(dr[1:]) / min(dr) > 1e3:
return None
bhat = sp.linalg.solve_triangular(R, sp.linalg.solve_triangular(R, Xty, trans=1))
qf = np.dot(bhat, Xty)
return Qf(R, bhat, qf)


def logProbChangeMod(n_int, vars_use, I_vec, z_vec, p, maxInt):
"""Get reversibility factor for RJMCMC acceptance ratio, and also prior"""
if n_int == 1:
out = (np.log(I_vec[n_int - 1]) - np.log(2 * p) # proposal
+ np.log(2 * p) + np.log(maxInt))
else:
x = np.zeros(p)
x[vars_use] = 1
lprob_vars_noReplace = np.log(dmwnchBass(z_vec, vars_use))
out = (np.log(I_vec[n_int - 1]) + lprob_vars_noReplace - n_int * np.log(2) # proposal
+ n_int * np.log(2) + np.log(comb(p, n_int)) + np.log(maxInt)) # prior
return out


CandidateBasis = namedtuple('CandidateBasis', 'basis n_int signs vs knots lbmcmp')


def genCandBasis(maxInt, I_vec, z_vec, p, xdata):
"""Generate a candidate basis for birth step, as well as the RJMCMC reversibility factor and prior"""
n_int = int(np.random.choice(range(maxInt), p=I_vec) + 1)
signs = np.random.choice([-1, 1], size=n_int, replace=True)
# knots = np.random.rand(n_int)
knots = np.zeros(n_int)
if n_int == 1:
vs = np.random.choice(p)
knots = np.random.choice(xdata[:, vs], size=1)
else:
vs = np.sort(np.random.choice(p, size=n_int, p=z_vec, replace=False))
for i in range(n_int):
knots[i] = np.random.choice(xdata[:, vs[i]], size=1)

basis = makeBasis(signs, vs, knots, xdata)
lbmcmp = logProbChangeMod(n_int, vs, I_vec, z_vec, p, maxInt)
return CandidateBasis(basis, n_int, signs, vs, knots, lbmcmp)


BasisChange = namedtuple('BasisChange', 'basis signs vs knots')


def genBasisChange(knots, signs, vs, tochange_int, xdata):
"""Generate a condidate basis for change step"""
knots_cand = knots.copy()
signs_cand = signs.copy()
signs_cand[tochange_int] = np.random.choice([-1, 1], size=1)
knots_cand[tochange_int] = np.random.choice(xdata[:, vs[tochange_int]], size=1) # np.random.rand(1)
basis = makeBasis(signs_cand, vs, knots_cand, xdata)
return BasisChange(basis, signs_cand, vs, knots_cand)


class BassPrior:
"""Structure to store prior"""
def __init__(self, maxInt, maxBasis, npart, g1, g2, s2_lower, h1, h2, a_tau, b_tau, w1, w2):
Expand All @@ -181,7 +50,7 @@ def __init__(self, xx, y):
self.ssy = sum(y * y)
self.n, self.p = xx.shape
self.bounds = np.column_stack([xx.min(0), xx.max(0)])
self.xx = normalize(self.xx_orig, self.bounds)
self.xx = uf.normalize(self.xx_orig, self.bounds)
return


Expand Down Expand Up @@ -251,7 +120,7 @@ def update(self):
if move_type == 1:
## BIRTH step

cand = genCandBasis(self.prior.maxInt, self.I_vec, self.z_vec, self.data.p, self.data.xx)
cand = uf.genCandBasis(self.prior.maxInt, self.I_vec, self.z_vec, self.data.p, self.data.xx)

if (cand.basis > 0).sum() < self.prior.npart: # if proposed basis function has too few non-zero entries, dont change the state
return
Expand All @@ -265,7 +134,7 @@ def update(self):
self.XtX[self.nc, 0:(self.nc)] = Xta
self.XtX[self.nc, self.nc] = ata

qf_cand = getQf(self.XtX[0:(self.nc + 1), 0:(self.nc + 1)], self.Xty[0:(self.nc + 1)])
qf_cand = uf.getQf(self.XtX[0:(self.nc + 1), 0:(self.nc + 1)], self.Xty[0:(self.nc + 1)])

fullRank = qf_cand != None
if not fullRank:
Expand Down Expand Up @@ -304,7 +173,7 @@ def update(self):
ind = list(range(self.nc))
del ind[tokill_ind + 1]

qf_cand = getQf(self.XtX[np.ix_(ind, ind)], self.Xty[ind])
qf_cand = uf.getQf(self.XtX[np.ix_(ind, ind)], self.Xty[ind])

fullRank = qf_cand != None
if not fullRank:
Expand All @@ -319,7 +188,7 @@ def update(self):

z_vec = z_star / sum(z_star)

lbmcmp = logProbChangeMod(self.n_int[tokill_ind], self.vs[tokill_ind, 0:self.n_int[tokill_ind]], I_vec,
lbmcmp = uf.logProbChangeMod(self.n_int[tokill_ind], self.vs[tokill_ind, 0:self.n_int[tokill_ind]], I_vec,
z_vec, self.data.p, self.prior.maxInt)

alpha = .5 / self.s2 * (qf_cand.qf - self.qf) / (1 + self.tau) - np.log(self.lam) + np.log(self.nbasis) + np.log(
Expand Down Expand Up @@ -371,7 +240,7 @@ def update(self):
tochange_basis = np.random.choice(self.nbasis)
tochange_int = np.random.choice(self.n_int[tochange_basis])

cand = genBasisChange(self.knots[tochange_basis, 0:self.n_int[tochange_basis]],
cand = uf.genBasisChange(self.knots[tochange_basis, 0:self.n_int[tochange_basis]],
self.signs[tochange_basis, 0:self.n_int[tochange_basis]],
self.vs[tochange_basis, 0:self.n_int[tochange_basis]], tochange_int, self.data.xx)

Expand All @@ -391,7 +260,7 @@ def update(self):
Xty_cand = self.Xty[0:self.nc].copy()
Xty_cand[tochange_basis + 1] = aty

qf_cand = getQf(XtX_cand, Xty_cand)
qf_cand = uf.getQf(XtX_cand, Xty_cand)

fullRank = qf_cand != None
if not fullRank:
Expand Down Expand Up @@ -507,7 +376,7 @@ def plot(self):
ax = fig.add_subplot(2, 2, 3)
yhat = self.predict(self.data.xx_orig).mean(axis=0) # posterior predictive mean
plt.scatter(self.data.y, yhat)
abline(1, 0)
uf.abline(1, 0)
plt.xlabel("observed")
plt.ylabel("posterior prediction")

Expand Down Expand Up @@ -543,8 +412,8 @@ def makeBasisMatrix(self, model_ind, X):
mat[:, 0] = 1
for m in range(nb):
ind = list(range(self.samples.n_int[model_ind, m]))
mat[:, m + 1] = makeBasis(self.samples.signs[model_ind, m, ind], self.samples.vs[model_ind, m, ind],
self.samples.knots[model_ind, m, ind], X).reshape(n)
mat[:, m + 1] = uf.makeBasis(self.samples.signs[model_ind, m, ind], self.samples.vs[model_ind, m, ind],
self.samples.knots[model_ind, m, ind], X).reshape(n)
return mat


Expand All @@ -563,7 +432,7 @@ def predict(self, X, mcmc_use=None, nugget=False):
if X.ndim == 1:
X = X[None, :]

Xs = normalize(X, self.data.bounds)
Xs = uf.normalize(X, self.data.bounds)
if np.any(mcmc_use == None):
mcmc_use = np.array(range(self.nstore))
out = np.zeros([len(mcmc_use), len(Xs)])
Expand Down Expand Up @@ -778,7 +647,7 @@ def plot(self):
ax = fig.add_subplot(2, 2, 3)
yhat = self.predict(self.bm_list[0].data.xx_orig).mean(axis=0) # posterior predictive mean
plt.scatter(self.y, yhat)
abline(1, 0)
uf.abline(1, 0)
plt.xlabel("observed")
plt.ylabel("posterior prediction")

Expand Down Expand Up @@ -901,6 +770,3 @@ def bassPCA(xx, y, npc=None, percVar=99.9, ncores=1, center=True, scale=False, *
print('\rStarting bassPCA with {:d} components, using {:d} cores.'.format(npc, ncores))

return BassBasis(xx, y, basis, newy, setup.y_mean, setup.y_sd, trunc_error, ncores, **kwargs)



22 changes: 21 additions & 1 deletion pyBASS/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,21 @@
from .pyBASS import *
"""
A python package for Bayesian Adaptive Spline Surfaces
"""
__all__ = ["utils", "sobol", "BASS"]

__version__ = "0.3.2"

import sys

if sys.version_info[0] == 3 and sys.version_info[1] < 6:
raise ImportError("Python Version 3.6 or above is required for pyBASS.")
else: # Python 3
pass
# Here we can also check for specific Python 3 versions, if needed

del sys

from .BASS import *
from .utils import *
from .sobol import *
Loading

0 comments on commit 9fdc9de

Please sign in to comment.