Skip to content

Commit

Permalink
Modify collocation.py. Solved the problem of using multiple outputs w…
Browse files Browse the repository at this point in the history
…ith the Scikit-learn regression methods. Local and scikit-learn methods have different u_hat types. The problem occurs in line 929 because u_hat is a list and cannot be transpose
  • Loading branch information
jp5000 committed Mar 24, 2016
1 parent 57ca660 commit b6e236c
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions chaospy/collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,18 +233,18 @@ def pcm_cc(func, order, dist_out, dist_in=None, acc=None,
# Examples
# --------
#
#
# Define function and distribution:
# >>> func = lambda z: -z[1]**2 + 0.1*z[0]
# >>> dist = cp.J(cp.Uniform(), cp.Uniform())
#
#
# Perform pcm:
# >>> q, x, w, y = cp.pcm_cc(func, 2, dist, acc=2, retall=1)
# >>> print cp.around(q, 10)
# -q1^2+0.1q0
# >>> print len(w)
# 9
#
#
# With Smolyak sparsegrid
# >>> q, x, w, y = cp.pcm_cc(func, 2, dist, acc=2, retall=1, sparse=1)
# >>> print cp.around(q, 10)
Expand Down Expand Up @@ -410,10 +410,10 @@ def pcm_gq(func, order, dist_out, dist_in=None, acc=None,
# --------
# Define function:
# >>> func = lambda z: z[1]*z[0]
#
#
# Define distribution:
# >>> dist = cp.J(cp.Normal(), cp.Normal())
#
#
# Perform pcm:
# >>> p, x, w, y = cp.pcm_gq(func, 2, dist, acc=3, retall=True)
# >>> print cp.around(p, 10)
Expand Down Expand Up @@ -525,13 +525,13 @@ def pcm_lr(func, order, dist_out, sample=None,
# Examples
# --------
#
#
# Define function:
# >>> func = lambda z: -z[1]**2 + 0.1*z[0]
#
#
# Define distribution:
# >>> dist = cp.J(cp.Normal(), cp.Normal())
#
#
# Perform pcm:
# >>> q, x, y = cp.pcm_lr(func, 2, dist, retall=True)
# >>> print cp.around(q, 10)
Expand Down Expand Up @@ -856,15 +856,17 @@ def fit_regression(P, x, u, rule="LS", retall=False, **kws):

# Local rules
if rule=="LS":
uhat = la.lstsq(Q, u)[0]
uhat = la.lstsq(Q, u)[0].T

elif rule=="T":
uhat, alphas = rlstsq(Q, u, kws.get("order",0),
kws.get("alpha", None), False, True)
uhat = uhat.T

elif rule=="TC":
uhat = rlstsq(Q, u, kws.get("order",0),
kws.get("alpha", None), True)
uhat = uhat.T

else:

Expand Down Expand Up @@ -926,7 +928,7 @@ def fit_regression(P, x, u, rule="LS", retall=False, **kws):

u = u.reshape(u.shape[0], *shape)

R = po.sum((P*uhat.T), -1)
R = po.sum((P*uhat), -1)
R = po.reshape(R, shape)

if retall==1:
Expand All @@ -946,7 +948,7 @@ def fit_lagrange(X, Y):

if len(X.shape) == 1:
X = X.reshape(1, X.size)

N, dim = X.shape

basis = []
Expand Down

0 comments on commit b6e236c

Please sign in to comment.