Skip to content

Commit

Permalink
Merge pull request #97 from jameschapman19/pytorchlightning
Browse files Browse the repository at this point in the history
Pytorchlightning
  • Loading branch information
jameschapman19 committed Nov 16, 2021
2 parents 159e680 + e72af45 commit a5523f3
Show file tree
Hide file tree
Showing 30 changed files with 355 additions and 285 deletions.
3 changes: 1 addition & 2 deletions cca_zoo/deepmodels/dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class DCCA(_DCCA_base):
"""
A class used to fit a DCCA model.
Citation
--------
:Citation:
Andrew, Galen, et al. "Deep canonical correlation analysis." International conference on machine learning. PMLR, 2013.
Expand Down
3 changes: 1 addition & 2 deletions cca_zoo/deepmodels/dcca_barlow_twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class BarlowTwins(DCCA):
"""
A class used to fit a Barlow Twins model.
Citation
--------
:Citation:
Zbontar, Jure, et al. "Barlow twins: Self-supervised learning via redundancy reduction." arXiv preprint arXiv:2103.03230 (2021).
Expand Down
3 changes: 1 addition & 2 deletions cca_zoo/deepmodels/dcca_noi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ class DCCA_NOI(DCCA):
A class used to fit a DCCA model by non-linear orthogonal iterations
Citation
--------
:Citation:
Wang, Weiran, et al. "Stochastic optimization for deep CCA via nonlinear orthogonal iterations." 2015 53rd Annual Allerton Conference on Communication, Control, and Computing (Allerton). IEEE, 2015.
Expand Down
3 changes: 1 addition & 2 deletions cca_zoo/deepmodels/dcca_sdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ class DCCA_SDL(DCCA_NOI):
"""
A class used to fit a Deep CCA by Stochastic Decorrelation model.
Citation
--------
:Citation:
Chang, Xiaobin, Tao Xiang, and Timothy M. Hospedales. "Scalable and effective deep CCA via soft decorrelation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.
Expand Down
3 changes: 1 addition & 2 deletions cca_zoo/deepmodels/dccae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class DCCAE(_DCCA_base):
"""
A class used to fit a DCCAE model.
Citation
--------
:Citation:
Wang, Weiran, et al. "On deep multi-view representation learning." International conference on machine learning. PMLR, 2015.
Expand Down
3 changes: 1 addition & 2 deletions cca_zoo/deepmodels/dtcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ class DTCCA(DCCA):
Is just a thin wrapper round DCCA with the DTCCA objective and a TCCA post-processing
Citation
--------
:Citation:
Wong, Hok Shing, et al. "Deep Tensor CCA for Multi-view Learning." IEEE Transactions on Big Data (2021).
Expand Down
5 changes: 3 additions & 2 deletions cca_zoo/deepmodels/dvcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ class DVCCA(_DCCA_base):
"""
A class used to fit a DVCCA model.
Citation
--------
:Citation:
Wang, Weiran, et al. "Deep variational canonical correlation analysis." arXiv preprint arXiv:1610.03454 (2016).
https: // arxiv.org / pdf / 1610.03454.pdf
https: // github.com / pytorch / examples / blob / master / vae / main.py
"""
Expand Down
3 changes: 1 addition & 2 deletions cca_zoo/deepmodels/splitae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ class SplitAE(_DCCA_base):
"""
A class used to fit a Split Autoencoder model.
Citation
--------
:Citation:
Ngiam, Jiquan, et al. "Multimodal deep learning." ICML. 2011.
Expand Down
35 changes: 28 additions & 7 deletions cca_zoo/models/gcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,20 @@


class GCCA(rCCA):
"""
r"""
A class used to fit GCCA model. For more than 2 views, GCCA optimizes the sum of correlations with a shared auxiliary vector
Citation
--------
:Maths:
.. math::
w_{opt}=\underset{w}{\mathrm{argmax}}\{ \sum_iw_i^TX_i^TT \}\\
\text{subject to:}
T^TT=1
:Citation:
Tenenhaus, Arthur, and Michel Tenenhaus. "Regularized generalized canonical correlation analysis." Psychometrika 76.2 (2011): 257.
Expand Down Expand Up @@ -99,11 +108,21 @@ def _solve_evp(self, views: Iterable[np.ndarray], C, D=None, **kwargs):


class KGCCA(GCCA):
"""
r"""
A class used to fit KGCCA model. For more than 2 views, KGCCA optimizes the sum of correlations with a shared auxiliary vector
Citation
--------
:Maths:
.. math::
w_{opt}=\underset{w}{\mathrm{argmax}}\{ \sum_i\alpha_i^TK_i^TT \}\\
\text{subject to:}
T^TT=1
:Citation:
Tenenhaus, Arthur, Cathy Philippe, and Vincent Frouin. "Kernel generalized canonical correlation analysis." Computational Statistics & Data Analysis 90 (2015): 114-131.
:Example:
Expand Down Expand Up @@ -135,6 +154,8 @@ def __init__(
kernel_params: Iterable[dict] = None,
):
"""
Constructor for PLS
:param latent_dims: number of latent dimensions to fit
:param scale: normalize variance in each column before fitting
:param centre: demean data by column before fitting (and before transforming out of sample
Expand Down Expand Up @@ -175,7 +196,7 @@ def _check_params(self):
)

def _get_kernel(self, view, X, Y=None):
if callable(self.kernel):
if callable(self.kernel[view]):
params = self.kernel_params[view] or {}
else:
params = {
Expand Down
61 changes: 5 additions & 56 deletions cca_zoo/models/innerloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,14 @@ def __init__(
self,
max_iter: int = 100,
tol: float = 1e-5,
generalized: bool = False,
initialization: str = "unregularized",
random_state=None,
):
"""
:param max_iter: maximum number of iterations to perform if tol is not reached
:param tol: tolerance value used for stopping criteria
:param generalized: use an auxiliary variable to
:param initialization: initialise the optimisation with either the 'unregularized' (CCA/PLS) solution, or a 'random' initialisation
"""
self.generalized = generalized
self.initialization = initialization
self.max_iter = max_iter
self.tol = tol
Expand Down Expand Up @@ -124,14 +121,12 @@ def __init__(
self,
max_iter: int = 100,
tol=1e-5,
generalized: bool = False,
initialization: str = "unregularized",
random_state=None,
):
super().__init__(
max_iter=max_iter,
tol=tol,
generalized=generalized,
initialization=initialization,
random_state=random_state,
)
Expand Down Expand Up @@ -179,7 +174,6 @@ def __init__(
self,
max_iter: int = 100,
tol=1e-5,
generalized: bool = False,
initialization: str = "unregularized",
c=None,
positive=None,
Expand All @@ -188,7 +182,6 @@ def __init__(
super().__init__(
max_iter=max_iter,
tol=tol,
generalized=generalized,
initialization=initialization,
random_state=random_state,
)
Expand Down Expand Up @@ -242,15 +235,13 @@ def __init__(
self,
max_iter: int = 100,
tol=1e-5,
generalized: bool = False,
initialization: str = "unregularized",
c=None,
random_state=None,
):
super().__init__(
max_iter=max_iter,
tol=tol,
generalized=generalized,
initialization=initialization,
random_state=random_state,
)
Expand Down Expand Up @@ -283,27 +274,25 @@ def __init__(
self,
max_iter: int = 100,
tol=1e-5,
generalized: bool = False,
initialization: str = "unregularized",
c=None,
l1_ratio=None,
constrained=False,
maxvar=True,
stochastic=True,
positive=None,
random_state=None,
):
super().__init__(
max_iter=max_iter,
tol=tol,
generalized=generalized,
initialization=initialization,
random_state=random_state,
)
self.stochastic = stochastic
self.constrained = constrained
self.c = c
self.l1_ratio = l1_ratio
self.positive = positive
self.maxvar = maxvar

def _check_params(self):
self.c = _process_parameter("c", self.c, 0, len(self.views))
Expand All @@ -313,8 +302,6 @@ def _check_params(self):
self.positive = _process_parameter(
"positive", self.positive, False, len(self.views)
)
if self.constrained:
self.gamma = np.zeros(len(self.views))
self.regressors = []
for alpha, l1_ratio, positive in zip(self.c, self.l1_ratio, self.positive):
if self.stochastic:
Expand Down Expand Up @@ -398,14 +385,12 @@ def _update_view(self, view_index: int):
:param view_index: index of view being updated
:return: updated weights
"""
if self.generalized:
if self.maxvar:
target = self.scores.mean(axis=0)
target /= np.linalg.norm(target)
else:
target = self.scores[view_index - 1]
if self.constrained:
self._elastic_solver_constrained(self.views[view_index], target, view_index)
else:
self._elastic_solver(self.views[view_index], target, view_index)
self._elastic_solver(self.views[view_index], target, view_index)
_check_converged_weights(self.weights[view_index], view_index)
self.scores[view_index] = self.views[view_index] @ self.weights[view_index]

Expand All @@ -418,35 +403,6 @@ def _elastic_solver(self, X, y, view_index):
self.views[view_index] @ self.weights[view_index]
) / np.sqrt(self.n)

@ignore_warnings(category=ConvergenceWarning)
def _elastic_solver_constrained(self, X, y, view_index):
converged = False
min_ = -1
max_ = 1
previous = self.gamma[view_index]
previous_val = None
i = 0
while not converged:
i += 1
coef = (
self.regressors[view_index]
.fit(
np.sqrt(self.gamma[view_index] + 1) * X,
y.ravel() / np.sqrt(self.gamma[view_index] + 1),
)
.coef_
)
current_val = 1 - (np.linalg.norm(X @ coef) ** 2) / self.n
self.gamma[view_index], previous, min_, max_ = _bin_search(
self.gamma[view_index], previous, current_val, previous_val, min_, max_
)
previous_val = current_val
if np.abs(current_val) < 1e-5:
converged = True
elif np.abs(max_ - min_) < 1e-30 or i == 50:
converged = True
self.weights[view_index] = coef

def _objective(self):
views = len(self.views)
c = np.array(self.c)
Expand All @@ -455,7 +411,6 @@ def _objective(self):
l2 = c * (1 - ratio)
total_objective = 0
for i in range(views):
# TODO this looks like it could be tidied up. In particular can we make the generalized objective correspond to the 2 view
target = self.scores.mean(axis=0)
objective = (
views
Expand All @@ -480,7 +435,6 @@ def __init__(
self,
max_iter: int = 100,
tol=1e-5,
generalized: bool = False,
initialization: str = "unregularized",
mu=None,
lam=None,
Expand All @@ -491,7 +445,6 @@ def __init__(
super().__init__(
max_iter=max_iter,
tol=tol,
generalized=generalized,
initialization=initialization,
random_state=random_state,
)
Expand Down Expand Up @@ -601,7 +554,6 @@ def __init__(
self,
max_iter: int = 100,
tol=1e-5,
generalized: bool = False,
initialization: str = "unregularized",
c=None,
regularisation="l0",
Expand All @@ -612,7 +564,6 @@ def __init__(
super().__init__(
max_iter=max_iter,
tol=tol,
generalized=generalized,
initialization=initialization,
random_state=random_state,
)
Expand Down Expand Up @@ -664,7 +615,6 @@ def __init__(
self,
max_iter: int = 100,
tol=1e-20,
generalized: bool = False,
initialization: str = "unregularized",
regularisation="l0",
c=None,
Expand All @@ -675,7 +625,6 @@ def __init__(
super().__init__(
max_iter=max_iter,
tol=tol,
generalized=generalized,
initialization=initialization,
random_state=random_state,
)
Expand Down
Loading

0 comments on commit a5523f3

Please sign in to comment.