Skip to content

Commit

Permalink
Merge pull request #82 from jameschapman19/pytorchlightning
Browse files Browse the repository at this point in the history
Pytorchlightning
  • Loading branch information
jameschapman19 committed Nov 15, 2021
2 parents 316d0e9 + e103c73 commit 5cc2c94
Show file tree
Hide file tree
Showing 31 changed files with 788 additions and 305 deletions.
1 change: 0 additions & 1 deletion cca_zoo/data/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
self.dataset = datasets.KMNIST("../../data", train=train, download=True)

self.data = self.dataset.data
self.base_transform = transforms.ToTensor()
self.targets = self.dataset.targets
self.flatten = flatten

Expand Down
2 changes: 2 additions & 0 deletions cca_zoo/deepmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import cca_zoo.deepmodels.objectives
from ._dcca_base import _DCCA_base
from .dcca import DCCA
from .dcca_barlow_twins import BarlowTwins
from .dcca_noi import DCCA_NOI
from .dcca_sdl import DCCA_SDL
from .dccae import DCCAE
from .dtcca import DTCCA
from .dvcca import DVCCA
Expand Down
55 changes: 55 additions & 0 deletions cca_zoo/deepmodels/dcca_barlow_twins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Iterable

import torch

from cca_zoo.deepmodels import DCCA
from cca_zoo.deepmodels.architectures import BaseEncoder, Encoder


class BarlowTwins(DCCA):
"""
A class used to fit a Barlow Twins model.
Citation
--------
Zbontar, Jure, et al. "Barlow twins: Self-supervised learning via redundancy reduction." arXiv preprint arXiv:2103.03230 (2021).
Examples
--------
"""

def __init__(
self,
latent_dims: int,
encoders: Iterable[BaseEncoder] = [Encoder, Encoder],
lam=1,
):
"""
Constructor class for Barlow Twins
:param latent_dims: # latent dimensions
:param encoders: list of encoder networks
:param lam: weighting of off diagonal loss terms
"""
super().__init__(latent_dims=latent_dims, encoders=encoders)
self.lam = lam
self.bns = torch.nn.ModuleList(
[torch.nn.BatchNorm1d(latent_dims, affine=False) for _ in self.encoders]
)

def forward(self, *args):
z = []
for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)):
z.append(bn(encoder(args[i])))
return tuple(z)

def loss(self, *args):
z = self(*args)
cross_cov = z[0].T @ z[1] / (z[0].shape[0] - 1)
invariance = torch.mean(torch.pow(1 - torch.diag(cross_cov), 2))
covariance = torch.mean(
torch.triu(torch.pow(cross_cov, 2), diagonal=1)
) + torch.mean(torch.tril(torch.pow(cross_cov, 2), diagonal=-1))
return invariance + covariance
35 changes: 12 additions & 23 deletions cca_zoo/deepmodels/dcca_noi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-3,
eps: float = 1e-9,
shared_target: bool = False,
):
"""
Expand All @@ -45,7 +45,7 @@ def __init__(
self.eps = eps
self.rho = rho
self.shared_target = shared_target
self.mse = torch.nn.MSELoss()
self.mse = torch.nn.MSELoss(reduction='sum')
# Authors state that a final linear layer is an important part of their algorithmic implementation
self.linear_layers = torch.nn.ModuleList(
[
Expand All @@ -67,7 +67,7 @@ def forward(self, *args):
def loss(self, *args):
z = self(*args)
z_copy = [z_.detach().clone() for z_ in z]
self.update_covariances(*z_copy)
self._update_covariances(*z_copy)
covariance_inv = [
torch.linalg.inv(objectives.MatrixSquareRoot.apply(cov))
for cov in self.covs
Expand All @@ -76,25 +76,14 @@ def loss(self, *args):
loss = self.mse(z[0], preds[1]) + self.mse(z[1], preds[0])
return loss

def update_mean(self, *z):
batch_means = [torch.mean(z_, dim=0) for z_ in z]
if self.means is not None:
self.means = [
self.rho * self.means[i].detach() + (1 - self.rho) * batch_mean
for i, batch_mean in enumerate(batch_means)
]
else:
self.means = batch_means
z = [z_ - mean for (z_, mean) in zip(z, self.means)]
return z

def update_covariances(self, *z):
def _update_covariances(self, *z, train=True):
b = z[0].shape[0]
batch_covs = [self.N * z_.T @ z_ / b for z_ in z]
if self.covs is not None:
self.covs = [
self.rho * self.covs[i] + (1 - self.rho) * batch_cov
for i, batch_cov in enumerate(batch_covs)
]
else:
self.covs = batch_covs
if train:
if self.covs is not None:
self.covs = [
self.rho * self.covs[i] + (1 - self.rho) * batch_cov
for i, batch_cov in enumerate(batch_covs)
]
else:
self.covs = batch_covs
93 changes: 93 additions & 0 deletions cca_zoo/deepmodels/dcca_sdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
import torch.nn.functional as F

from cca_zoo.deepmodels import DCCA_NOI


class DCCA_SDL(DCCA_NOI):
"""
A class used to fit a Deep CCA by Stochastic Decorrelation model.
Citation
--------
Chang, Xiaobin, Tao Xiang, and Timothy M. Hospedales. "Scalable and effective deep CCA via soft decorrelation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.
Examples
--------
"""

def __init__(
self,
latent_dims: int,
N: int,
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-3,
shared_target: bool = False,
lam=0.5,
):
"""
Constructor class for DCCA
:param latent_dims: # latent dimensions
:param encoders: list of encoder networks
:param r: regularisation parameter of tracenorm CCA like ridge CCA
:param rho: covariance memory like DCCA non-linear orthogonal iterations paper
:param eps: epsilon used throughout
:param shared_target: not used
"""
super().__init__(
latent_dims=latent_dims,
N=N,
encoders=encoders,
r=r,
rho=rho,
eps=eps,
shared_target=shared_target,
)
self.cross_cov = None
self.lam = lam
self.bns = torch.nn.ModuleList(
[
torch.nn.BatchNorm1d(latent_dims, affine=False)
for _ in range(latent_dims)
]
)

def forward(self, *args):
z = []
for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)):
z.append(bn(encoder(args[i])))
return tuple(z)

def loss(self, *args):
z = self(*args)
self._update_covariances(*z, train=self.training)
SDL_loss = self._sdl_loss(self.covs)
l2_loss = F.mse_loss(z[0], z[1])
return l2_loss + self.lam * SDL_loss

def _sdl_loss(self, covs):
loss = 0
for cov in covs:
cov = cov
sgn = torch.sign(cov)
sgn.fill_diagonal_(0)
loss += torch.mean(cov * sgn)
return loss

def _update_covariances(self, *z, train=True):
batch_covs = [z_.T @ z_ for z_ in z]
if train:
if self.covs is not None:
self.c = self.rho * self.c + 1
self.covs = [
self.rho * self.covs[i].detach() + (1 - self.rho) * batch_cov
for i, batch_cov in enumerate(batch_covs)
]
else:
self.c = 1
self.covs = batch_covs
self.covs = [cov / self.c for cov in self.covs]
5 changes: 2 additions & 3 deletions cca_zoo/deepmodels/dccae.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def decode(self, *z):
"""
This method is used to decode from the latent space to the best prediction of the original views
:param args:
"""
recon = []
for i, decoder in enumerate(self.decoders):
Expand All @@ -72,11 +71,11 @@ def decode(self, *z):
def loss(self, *args):
z = self(*args)
recon = self.decode(*z)
recon_loss = self.recon_loss(args[: len(recon)], recon)
recon_loss = self._recon_loss(args[: len(recon)], recon)
return self.lam * recon_loss + self.objective.loss(*z)

@staticmethod
def recon_loss(x, recon):
def _recon_loss(x, recon):
recons = [
F.mse_loss(recon_, x_, reduction="mean") for recon_, x_ in zip(recon, x)
]
Expand Down
22 changes: 11 additions & 11 deletions cca_zoo/deepmodels/dvcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, *args, mle=True):
:return:
"""
# Used when we get reconstructions
mu, logvar = self.encode(*args)
mu, logvar = self._encode(*args)
if mle:
z = mu
else:
Expand All @@ -62,7 +62,7 @@ def forward(self, *args, mle=True):
if len(self.encoders) == 1:
z = z * len(args)
if self.private_encoders:
mu_p, logvar_p = self.encode_private(*args)
mu_p, logvar_p = self._encode_private(*args)
if mle:
z_p = mu_p
else:
Expand All @@ -71,7 +71,7 @@ def forward(self, *args, mle=True):
z = [torch.cat([z_] + z_p, dim=-1) for z_ in z]
return z

def encode(self, *args):
def _encode(self, *args):
"""
:param args:
:return:
Expand All @@ -84,7 +84,7 @@ def encode(self, *args):
logvar.append(logvar_i)
return mu, logvar

def encode_private(self, *args):
def _encode_private(self, *args):
"""
:param args:
:return:
Expand All @@ -97,14 +97,14 @@ def encode_private(self, *args):
logvar.append(logvar_i)
return mu, logvar

def decode(self, z):
def _decode(self, z):
"""
:param z:
:return:
"""
x = []
for i, decoder in enumerate(self.decoders):
x_i = decoder(z)
x_i = F.sigmoid(decoder(z))
x.append(x_i)
return x

Expand All @@ -114,16 +114,16 @@ def recon(self, *args):
:return:
"""
z = self(*args)
return [self.decode(z_i) for z_i in z][0]
return [self._decode(z_i) for z_i in z][0]

def loss(self, *args):
"""
:param args:
:return:
"""
mus, logvars = self.encode(*args)
mus, logvars = self._encode(*args)
if self.private_encoders:
mus_p, logvars_p = self.encode_private(*args)
mus_p, logvars_p = self._encode_private(*args)
losses = [
self.vcca_private_loss(
*args, mu=mu, logvar=logvar, mu_p=mu_p, logvar_p=logvar_p
Expand All @@ -150,7 +150,7 @@ def vcca_loss(self, *args, mu, logvar):
kl = torch.mean(
-0.5 * torch.sum(1 + logvar - logvar.exp() - mu.pow(2), dim=1), dim=0
)
recons = self.decode(z)
recons = self._decode(z)
bces = torch.stack(
[
F.binary_cross_entropy(recon, arg, reduction="sum") / batch_n
Expand Down Expand Up @@ -185,7 +185,7 @@ def vcca_private_loss(self, *args, mu, logvar, mu_p, logvar_p):
-0.5 * torch.sum(1 + logvar - logvar.exp() - mu.pow(2), dim=1), dim=0
)
z_combined = torch.cat([z, z_p], dim=-1)
recon = self.decode(z_combined)
recon = self._decode(z_combined)
bces = torch.stack(
[
F.binary_cross_entropy(recon[i], args[i], reduction="sum") / batch_n
Expand Down
4 changes: 3 additions & 1 deletion cca_zoo/deepmodels/splitae.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,7 @@ def loss(self, *args):

@staticmethod
def recon_loss(x, recon):
recons = [F.mse_loss(recon[i], x[i], reduction="mean") for i in range(len(recon))]
recons = [
F.mse_loss(recon[i], x[i], reduction="mean") for i in range(len(recon))
]
return torch.stack(recons).sum(dim=0)
Loading

0 comments on commit 5cc2c94

Please sign in to comment.