-
Notifications
You must be signed in to change notification settings - Fork 38
/
_dmcca.py
62 lines (51 loc) · 1.8 KB
/
_dmcca.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import List
import torch
from ._dcca import DCCA
from .._utils import inv_sqrtm
class DMCCA(DCCA):
"""
A class used to fit a DMCCA model.
Is just a thin wrapper round DCCA with the DMCCA objective
References
----------
"""
def loss(
self,
representations: List[torch.Tensor],
independent_representations: List[torch.Tensor] = None,
):
latent_dims = representations[0].shape[1]
representations = [
representation - representation.mean(dim=0)
for representation in representations
]
A = self.A(representations)
B = self.B(representations)
A += B
R = inv_sqrtm(B, self.eps)
C_whitened = R @ A @ R.T
eigvals = torch.linalg.eigvalsh(C_whitened)
idx = torch.argsort(eigvals, descending=True)
eigvals = eigvals[idx[:latent_dims]]
eigvals = torch.nn.LeakyReLU()(eigvals[torch.gt(eigvals, 0)])
corr = eigvals.sum()
return {"objective": -corr}
def A(self, representations: List[torch.Tensor]):
"""Calculate cross-covariance matrix."""
all_views = torch.cat(representations, dim=1)
A = torch.cov(all_views.T)
A = A - torch.block_diag(
*[torch.cov(representation.T) for representation in representations]
)
return A / len(representations)
def B(self, representations: List[torch.Tensor]):
"""Calculate block covariance matrix."""
B = torch.block_diag(
*[
(1 - self.eps) * torch.cov(representation.T)
+ self.eps
* torch.eye(representation.shape[1], device=representation.device)
for representation in representations
]
)
return B / len(representations)