Skip to content

Commit

Permalink
bug fixes with torch.linalg.solve() and default variance calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusMNoack committed Sep 14, 2022
1 parent 00da667 commit 8b98c33
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions fvgp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def __init__(
#######prepare variances##################
##########################################
if variances is None:
self.variances = np.ones((self.y_data.shape)) * abs(np.mean(self.y_data) / 100.0)
self.variances = np.ones((self.y_data.shape)) * (np.mean(abs(self.y_data)) / 100.0)
logger.warning("CAUTION: you have not provided data variances in fvGP, "
"they will be set to 1 percent of the data values!")
"they will be set to 1 percent of the |data values|!")
elif np.ndim(variances) == 2:
self.variances = variances[:,0]
elif np.ndim(variances) == 1:
Expand Down Expand Up @@ -198,7 +198,7 @@ def update_gp_data(
#######prepare variances##################
##########################################
if variances is None:
self.variances = np.ones((self.y_data.shape)) * abs(self.y_data / 100.0)
self.variances = np.ones((self.y_data.shape)) * (np.mean(abs(self.y_data)) / 100.0)
elif np.ndim(variances) == 2:
self.variances = variances[:,0]
elif np.ndim(variances) == 1:
Expand Down Expand Up @@ -525,7 +525,7 @@ def _optimize_log_likelihood(self,starting_hps,
radius = deflation_radius,
num_epochs = max_iter,
constraints = constraints)

print("gfdsdd: ", np.array(starting_hps).reshape(1,-1))
obj = opt.optimize(dask_client = dask_client, x0 = np.array(starting_hps).reshape(1,-1))
res = opt.get_final()
hyperparameters = res["x"][0]
Expand Down Expand Up @@ -731,9 +731,9 @@ def solve(self, A, b):
logger.error("torch.solve() on cpu did not work")
logger.error("reason: ", str(e))
#x, qr = torch.lstsq(b,A)
x, qr = torch.linalg.lstsq(A,b)
x, res, rank, s = torch.linalg.lstsq(A,b)
except Exception as e:
logger.error("torch.solve() and torch.lstsq() on cpu did not work; falling back to numpy.lstsq()")
logger.error("torch.solve() and torch.lstsq() on cpu did not work; falling back to numpy.linalg.lstsq()")
logger.error("reason: {}", str(e))
x,res,rank,s = np.linalg.lstsq(A.numpy(),b.numpy())
return x
Expand All @@ -750,7 +750,7 @@ def solve(self, A, b):
#x, qr = torch.lstsq(b,A)
x = torch.linalg.lstsq(A,b)
except Exception as e:
logger.error("torch.solve() and torch.lstsq() on gpu did not work; falling back to numpy.lstsq()")
logger.error("torch.solve() and torch.lstsq() on gpu did not work; falling back to numpy.linalg.lstsq()")
logger.error("reason: ", str(e))
x,res,rank,s = np.linalg.lstsq(A.numpy(),b.numpy())
return x
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ numpy
matplotlib
dask >= 2021.6.2
distributed >= 2021.6.2
hgdl == 2.0.2
hgdl == 2.0.3
notebook
plotly
loguru

0 comments on commit 8b98c33

Please sign in to comment.