Skip to content

Commit

Permalink
MAINT: Remove chain dot
Browse files Browse the repository at this point in the history
Remove chain dot in favor of using @
  • Loading branch information
bashtage committed Dec 5, 2019
1 parent 8e7d091 commit a2900db
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 87 deletions.
13 changes: 6 additions & 7 deletions statsmodels/regression/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from scipy import stats
from scipy import optimize

from statsmodels.tools.tools import chain_dot, pinv_extended
from statsmodels.tools.tools import pinv_extended
from statsmodels.tools.decorators import (cache_readonly,
cache_writable)
import statsmodels.base.model as base
Expand Down Expand Up @@ -1855,9 +1855,8 @@ def cov_HC2(self):
Heteroscedasticity robust covariance matrix. See HC2_se.
"""
# probably could be optimized
h = np.diag(chain_dot(self.model.wexog,
self.normalized_cov_params,
self.model.wexog.T))
wexog = self.model.wexog
h = np.diag(wexog @ self.normalized_cov_params @ wexog.T)
self.het_scale = self.wresid**2/(1-h)
cov_HC2 = self._HCCM(self.het_scale)
return cov_HC2
Expand All @@ -1867,8 +1866,8 @@ def cov_HC3(self):
"""
Heteroscedasticity robust covariance matrix. See HC3_se.
"""
h = np.diag(chain_dot(
self.model.wexog, self.normalized_cov_params, self.model.wexog.T))
wexog = self.model.wexog
h = np.diag(wexog @ self.normalized_cov_params @ wexog.T)
self.het_scale = (self.wresid / (1 - h))**2
cov_HC3 = self._HCCM(self.het_scale)
return cov_HC3
Expand Down Expand Up @@ -2084,7 +2083,7 @@ def compare_lm_test(self, restricted, demean=True, use_lr=False):
raise ValueError('Only nonrobust, HC, HAC and cluster are ' +
'currently connected')

lm_value = n * chain_dot(s, s_inv, s.T)
lm_value = n * (s @ s_inv @ s.T)
p_value = stats.chi2.sf(lm_value, df_diff)
return lm_value, p_value, df_diff

Expand Down
3 changes: 1 addition & 2 deletions statsmodels/regression/quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import scipy.stats as stats
from scipy.linalg import pinv
from scipy.stats import norm
from statsmodels.tools.tools import chain_dot
from statsmodels.tools.decorators import cache_readonly
from statsmodels.regression.linear_model import (RegressionModel,
RegressionResults,
Expand Down Expand Up @@ -206,7 +205,7 @@ def fit(self, q=.5, vcov='robust', kernel='epa', bandwidth='hsheather',
d = np.where(e > 0, (q/fhat0)**2, ((1-q)/fhat0)**2)
xtxi = pinv(np.dot(exog.T, exog))
xtdx = np.dot(exog.T * d[np.newaxis, :], exog)
vcov = chain_dot(xtxi, xtdx, xtxi)
vcov = xtxi @ xtdx @ xtxi
elif vcov == 'iid':
vcov = (1. / fhat0)**2 * q * (1 - q) * pinv(np.dot(exog.T, exog))
else:
Expand Down
7 changes: 0 additions & 7 deletions statsmodels/tools/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,6 @@ def test_pandas_const_df_prepend():
assert_equal(dta.var(0)[0], 0)


def test_chain_dot():
A = np.arange(1,13).reshape(3,4)
B = np.arange(3,15).reshape(4,3)
C = np.arange(5,8).reshape(3,1)
assert_equal(tools.chain_dot(A,B,C), np.array([[1820],[4300],[6780]]))


class TestNanDot(object):
@classmethod
def setup_class(cls):
Expand Down
29 changes: 0 additions & 29 deletions statsmodels/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
Utility functions models code
"""
from functools import reduce

import numpy as np
import numpy.lib.recfunctions as nprf
import pandas as pd
Expand Down Expand Up @@ -530,33 +528,6 @@ def unsqueeze(data, axis, oldshape):
return data.reshape(newshape)


def chain_dot(*arrs):
"""
Returns the dot product of the given matrices.
Parameters
----------
arrs: argument list of ndarray
Returns
-------
Dot product of all arguments.
Examples
--------
>>> import numpy as np
>>> from statsmodels.tools import chain_dot
>>> A = np.arange(1,13).reshape(3,4)
>>> B = np.arange(3,15).reshape(4,3)
>>> C = np.arange(5,8).reshape(3,1)
>>> chain_dot(A,B,C)
array([[1820],
[4300],
[6780]])
"""
return reduce(lambda x, y: np.dot(y, x), arrs[::-1])


def nan_dot(A, B):
"""
Returns np.dot(left_matrix, right_matrix) with the convention that
Expand Down
31 changes: 13 additions & 18 deletions statsmodels/tsa/vector_ar/irf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@


from statsmodels.tools.decorators import cache_readonly
from statsmodels.tools.tools import chain_dot
#from statsmodels.tsa.api import VAR
import statsmodels.tsa.tsatools as tsa
import statsmodels.tsa.vector_ar.plotting as plotting
import statsmodels.tsa.vector_ar.util as util
Expand Down Expand Up @@ -280,7 +278,7 @@ def cov(self, orth=False):
covs[0] = np.zeros((self.neqs ** 2, self.neqs ** 2))
for i in range(1, self.periods + 1):
Gi = self.G[i - 1]
covs[i] = chain_dot(Gi, self.cov_a, Gi.T)
covs[i] = Gi @ self.cov_a @ Gi.T

return covs

Expand Down Expand Up @@ -572,10 +570,10 @@ def _orth_cov(self):
apiece = 0
else:
Ci = np.dot(PIk, self.G[i-1])
apiece = chain_dot(Ci, self.cov_a, Ci.T)
apiece = Ci @ self.cov_a @ Ci.T

Cibar = np.dot(np.kron(Ik, self.irfs[i]), H)
bpiece = chain_dot(Cibar, self.cov_sig, Cibar.T) / self.T
bpiece = (Cibar @ self.cov_sig @ Cibar.T) / self.T

# Lutkepohl typo, cov_sig correct
covs[i] = apiece + bpiece
Expand Down Expand Up @@ -613,18 +611,18 @@ def cum_effect_cov(self, orth=False):
apiece = 0
else:
Bn = np.dot(PIk, F)
apiece = chain_dot(Bn, self.cov_a, Bn.T)
apiece = Bn @ self.cov_a @ Bn.T

Bnbar = np.dot(np.kron(Ik, self.cum_effects[i]), self.H)
bpiece = chain_dot(Bnbar, self.cov_sig, Bnbar.T) / self.T
bpiece = (Bnbar @ self.cov_sig @ Bnbar.T) / self.T

covs[i] = apiece + bpiece
else:
if i == 0:
covs[i] = np.zeros((self.neqs**2, self.neqs**2))
continue

covs[i] = chain_dot(F, self.cov_a, F.T)
covs[i] = F @ self.cov_a @ F.T

return covs

Expand Down Expand Up @@ -652,10 +650,10 @@ def lr_effect_cov(self, orth=False):
Binf = np.dot(np.kron(self.P.T, np.eye(self.neqs)), Finfty)
Binfbar = np.dot(np.kron(Ik, lre), self.H)

return (chain_dot(Binf, self.cov_a, Binf.T) +
chain_dot(Binfbar, self.cov_sig, Binfbar.T))
return (Binf @ self.cov_a @ Binf.T +
Binfbar @ self.cov_sig @ Binfbar.T)
else:
return chain_dot(Finfty, self.cov_a, Finfty.T)
return Finfty @ self.cov_a @ Finfty.T

def stderr(self, orth=False):
return np.array([tsa.unvec(np.sqrt(np.diag(c)))
Expand All @@ -680,14 +678,11 @@ def H(self):
Kkk = tsa.commutation_matrix(k, k)
Ik = np.eye(k)

# B = chain_dot(Lk, np.eye(k**2) + commutation_matrix(k, k),
# np.kron(self.P, np.eye(k)), Lk.T)
# B = Lk @ (np.eye(k**2) + commutation_matrix(k, k)) @ \
# np.kron(self.P, np.eye(k)) @ Lk.T
# return Lk.T @ L.inv(B)

# return np.dot(Lk.T, L.inv(B))

B = chain_dot(Lk,
np.dot(np.kron(Ik, self.P), Kkk) + np.kron(self.P, Ik),
Lk.T)
B = Lk @ (np.kron(Ik, self.P) @ Kkk + np.kron(self.P, Ik)) @ Lk.T

return np.dot(Lk.T, L.inv(B))

Expand Down
25 changes: 12 additions & 13 deletions statsmodels/tsa/vector_ar/var_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from statsmodels.tools.decorators import cache_readonly, deprecated_alias
from statsmodels.tools.linalg import logdet_symm
from statsmodels.tools.sm_exceptions import OutputWarning
from statsmodels.tools.tools import chain_dot
from statsmodels.tsa.tsatools import vec, unvec, duplication_matrix
from statsmodels.tsa.vector_ar import output, plotting, util
from statsmodels.tsa.vector_ar.hypothesis_test_results import \
Expand Down Expand Up @@ -193,7 +192,7 @@ def forecast_cov(ma_coefs, sigma_u, steps):
for h in range(steps):
# Sigma(h) = Sigma(h-1) + Phi Sig_u Phi'
phi = ma_coefs[h]
var = chain_dot(phi, sigma_u, phi.T)
var = phi @ sigma_u @ phi.T
forc_covs[h] = prior = prior + var

return forc_covs
Expand Down Expand Up @@ -1079,7 +1078,7 @@ def mse(self, steps):
for h in range(steps):
# Sigma(h) = Sigma(h-1) + Phi Sig_u Phi'
phi = ma_coefs[h]
var = chain_dot(phi, self.sigma_u, phi.T)
var = phi @ self.sigma_u @ phi.T
forc_covs[h] = prior = prior + var

return forc_covs
Expand Down Expand Up @@ -1375,7 +1374,7 @@ def cov_ybar(self):
"""

Ainv = scipy.linalg.inv(np.eye(self.neqs) - self.coefs.sum(0))
return chain_dot(Ainv, self.sigma_u, Ainv.T)
return Ainv @ self.sigma_u @ Ainv.T

# ------------------------------------------------------------
# Estimation-related things
Expand Down Expand Up @@ -1403,7 +1402,7 @@ def _cov_sigma(self):
D_Kinv = np.linalg.pinv(D_K)

sigxsig = np.kron(self.sigma_u, self.sigma_u)
return 2 * chain_dot(D_Kinv, sigxsig, D_Kinv.T)
return 2 * D_Kinv @ sigxsig @ D_Kinv.T

@cache_readonly
def llf(self):
Expand Down Expand Up @@ -1623,7 +1622,7 @@ def _omega_forc_cov(self, steps):

# memoize powers of B for speedup
# TODO: see if can memoize better
# TODO: much lower-hanging fruit in caching `np.trace` and `chain_dot` below.
# TODO: much lower-hanging fruit in caching `np.trace` below.
B = self._bmat_forc_cov()
_B = {}

Expand All @@ -1647,8 +1646,8 @@ def bpow(i):
for j in range(h):
Bi = bpow(h - 1 - i)
Bj = bpow(h - 1 - j)
mult = np.trace(chain_dot(Bi.T, Ginv, Bj, G))
om += mult * chain_dot(phis[i], sig_u, phis[j].T)
mult = np.trace(Bi.T @ Ginv @ Bj @ G)
om += mult * phis[i] @ sig_u @ phis[j].T
omegas[h-1] = om

return omegas
Expand Down Expand Up @@ -1819,10 +1818,10 @@ def test_causality(self, caused, causing=None, kind='f', signif=0.05):

# Lütkepohl 3.6.5
Cb = np.dot(C, vec(self.params.T))
middle = scipy.linalg.inv(chain_dot(C, self.cov_params(), C.T))
middle = scipy.linalg.inv(C @ self.cov_params() @ C.T)

# wald statistic
lam_wald = statistic = chain_dot(Cb, middle, Cb)
lam_wald = statistic = Cb @ middle @ Cb

if kind.lower() == 'wald':
df = num_restr
Expand Down Expand Up @@ -1940,9 +1939,9 @@ def test_inst_causality(self, causing, signif=0.05):
Cs = np.dot(C, vech_sigma_u)
d = np.linalg.pinv(duplication_matrix(k))
Cd = np.dot(C, d)
middle = scipy.linalg.inv(chain_dot(Cd, np.kron(sigma_u, sigma_u), Cd.T)) / 2
middle = scipy.linalg.inv(Cd @ np.kron(sigma_u, sigma_u) @ Cd.T) / 2

wald_statistic = t * chain_dot(Cs.T, middle, Cs)
wald_statistic = t * (Cs.T @ middle @ Cs)
df = num_restr
dist = stats.chi2(df)

Expand Down Expand Up @@ -1982,7 +1981,7 @@ def test_whiteness(self, nlags=10, signif=0.05, adjusted=False):
cov0_inv = scipy.linalg.inv(acov_list[0])
for t in range(1, nlags+1):
ct = acov_list[t]
to_add = np.trace(chain_dot(ct.T, cov0_inv, ct, cov0_inv))
to_add = np.trace(ct.T @ cov0_inv @ ct @ cov0_inv)
if adjusted:
to_add /= (self.nobs - t)
statistic += to_add
Expand Down
20 changes: 9 additions & 11 deletions statsmodels/tsa/vector_ar/vecm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from statsmodels.iolib.table import SimpleTable
from statsmodels.tools.decorators import cache_readonly
from statsmodels.tools.sm_exceptions import HypothesisTestWarning
from statsmodels.tools.tools import chain_dot
from statsmodels.tsa.tsatools import duplication_matrix, vec, lagmat

import statsmodels.tsa.base.tsa_model as tsbase
Expand Down Expand Up @@ -401,7 +400,7 @@ def _sij(delta_x, delta_y_1_T, y_lag1):
s11_ = inv(_mat_sqrt(s11))
# p. 295:
s01_s11_ = np.dot(s01, s11_)
eig = np.linalg.eig(chain_dot(s01_s11_.T, inv(s00), s01_s11_))
eig = np.linalg.eig(s01_s11_.T @ inv(s00) @ s01_s11_)
lambd = eig[0]
v = eig[1]
# reorder eig_vals to make them decreasing (and order eig_vecs accordingly)
Expand Down Expand Up @@ -1323,7 +1322,7 @@ def _cov_sigma(self):
d = duplication_matrix(self.neqs)
d_K_plus = np.linalg.pinv(d)
# compare p. 93, 297 Lutkepohl (2005)
return 2 * chain_dot(d_K_plus, np.kron(sigma_u, sigma_u), d_K_plus.T)
return 2 * (d_K_plus @ np.kron(sigma_u, sigma_u) @ d_K_plus.T)

@cache_readonly
def cov_params_default(self): # p.296 (7.2.21)
Expand Down Expand Up @@ -1409,8 +1408,7 @@ def stderr_coint(self):
mat1 = np.kron(mat1.T, np.identity(r))
det = self.det_coef_coint.shape[0]
mat2 = np.kron(np.identity(self.neqs-r+det),
inv(chain_dot(
self.alpha.T, inv(self.sigma_u), self.alpha)))
inv(self.alpha.T @ inv(self.sigma_u) @ self.alpha))
first_rows = np.zeros((r, r))
last_rows_1d = np.sqrt(np.diag(mat1.dot(mat2)))
last_rows = last_rows_1d.reshape((self.neqs-r+det, r),
Expand Down Expand Up @@ -1575,7 +1573,7 @@ def cov_var_repr(self):
#
# w_eye = np.kron(w, np.identity(K))
#
# return chain_dot(w_eye.T, self.cov_params_default, w_eye)
# return w_eye.T @ self.cov_params_default @ w_eye

if self.k_ar - 1 == 0:
return self.cov_params_wo_det
Expand All @@ -1594,8 +1592,8 @@ def cov_var_repr(self):
start_col:start_col+2*self.neqs**2] = hstack((-eye, eye))
# for A_p:
vecm_var_transformation[-self.neqs**2:, -self.neqs**2:] = -eye
return chain_dot(vecm_var_transformation, self.cov_params_wo_det,
vecm_var_transformation.T)
vvt = vecm_var_transformation
return vvt @ self.cov_params_wo_det @ vvt.T

def ma_rep(self, maxn=10):
return ma_rep(self.var_rep, maxn)
Expand Down Expand Up @@ -1885,9 +1883,9 @@ def test_granger_causality(self, caused, causing=None, signif=0.05):
# same results as the reference software JMulTi.
sigma_u = var_results.sigma_u * (t-k*p-num_det_terms) / t
sig_alpha_min_p = t * np.kron(x_x_11, sigma_u) # k**2*(p-1)xk**2*(p-1)
middle = inv(chain_dot(C, sig_alpha_min_p, C.T))
middle = inv(C @ sig_alpha_min_p @ C.T)

wald_statistic = t * chain_dot(Ca.T, middle, Ca)
wald_statistic = t * (Ca.T @ middle @ Ca)
f_statistic = wald_statistic / num_restr
df = (num_restr, k * var_results.df_resid)
f_distribution = scipy.stats.f(*df)
Expand Down Expand Up @@ -2050,7 +2048,7 @@ def test_whiteness(self, nlags=10, signif=0.05, adjusted=False):
c0_inv = np.real(c0_inv)
for t in range(1, nlags+1):
ct = acov_list[t]
to_add = np.trace(chain_dot(ct.T, c0_inv, ct, c0_inv))
to_add = np.trace(ct.T @ c0_inv @ ct @ c0_inv)
if adjusted:
to_add /= (self.nobs - t)
statistic += to_add
Expand Down

0 comments on commit a2900db

Please sign in to comment.