<a href="https://colab.research.google.com/github/jameschapman19/cca_zoo/blob/master/cca_zoo_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install cca-zoo

Requirement already up-to-date: cca-zoo in /usr/local/lib/python3.6/dist-packages (1.1.8)


In [None]:
# Imports
import numpy as np
from cca_zoo import wrappers
from cca_zoo import data
import itertools
import os
from cca_zoo.configuration import Config
import matplotlib.pyplot as plt
from torch.utils.data import Subset

# Load MNIST Data
os.chdir('..')
N = 1000
dataset = data.Noisy_MNIST_Dataset(mnist_type='FashionMNIST', train=True)
ids = np.arange(min(2 * N, len(dataset)))
np.random.shuffle(ids)
train_ids, val_ids = np.array_split(ids, 2)
val_dataset = Subset(dataset, val_ids)
train_dataset = Subset(dataset, train_ids)
test_dataset = data.Noisy_MNIST_Dataset(mnist_type='FashionMNIST', train=False)
test_ids = np.arange(min(N, len(test_dataset)))
np.random.shuffle(test_ids)
test_dataset = Subset(test_dataset, test_ids)
train_view_1, train_view_2, train_rotations, train_OH_labels, train_labels = train_dataset.dataset.to_numpy(
    train_dataset.indices)
val_view_1, val_view_2, val_rotations, val_OH_labels, val_labels = val_dataset.dataset.to_numpy(val_dataset.indices)
test_view_1, test_view_2, test_rotations, test_OH_labels, test_labels = test_dataset.dataset.to_numpy(
    test_dataset.indices)

# Settings

# The number of latent dimensions across models
latent_dims = 2
# The number of folds used for cross-validation/hyperparameter tuning
cv_folds = 5
# For running hyperparameter tuning in parallel (0 if not)
jobs = 2
# Number of iterations for iterative algorithms
max_iter = 10

# Canonical Correlation Analysis

In [None]:
"""
### Linear CCA via alternating least squares (can pass more than 2 views)
"""

# %%
linear_cca = wrappers.CCA_ALS(latent_dims=latent_dims)

linear_cca.fit(train_view_1, train_view_2)

linear_cca_results = np.stack(
    (linear_cca.train_correlations[0, 1], linear_cca.predict_corr(test_view_1, test_view_2)[0, 1]))

# Canonical Correlation Analysis and Partial Least Squares using scikit-learn implementations

These will likely be deprecated since our alternating least squares algorithm is very similar to the NIPALS algorithm used there. For the moment we keep to help test simple 2-view unregularized examples.

In [None]:
"""
### Linear CCA with scikit-learn (only permits 2 views)
"""

scikit_cca = wrappers.CCA_scikit(latent_dims=latent_dims)

scikit_cca.fit(train_view_1, train_view_2)

scikit_cca_results = np.stack(
    (scikit_cca.train_correlations[0, 1], scikit_cca.predict_corr(test_view_1, test_view_2)[0, 1]))

"""
### PLS with scikit-learn (only permits 2 views)
"""

# %%
pls = wrappers.PLS_scikit(latent_dims=latent_dims)

pls.fit(train_view_1, train_view_2)

pls_results = np.stack(
    (pls.train_correlations[0, 1], pls.predict_corr(test_view_1, test_view_2)[0, 1]))

# Extension to multiple views



In [None]:
"""
### (Regularized) Generalized CCA(can pass more than 2 views)
"""

gcca = wrappers.GCCA(latent_dims=latent_dims)
# small ammount of regularisation added since data is not full rank
params = {'c': [1, 1]}

gcca.fit(train_view_1, train_view_2, params=params) #Just pass more views to .fit()

gcca_results = np.stack((gcca.train_correlations[0, 1], gcca.predict_corr(test_view_1, test_view_2)[0, 1]))

"""
### (Regularized) Multiset CCA(can pass more than 2 views)
"""

mcca = wrappers.MCCA(latent_dims=latent_dims)
# small ammount of regularisation added since data is not full rank
params = {'c': [0.5, 0.5]}

mcca.fit(train_view_1, train_view_2, params=params) #Just pass more views to .fit()

mcca_results = np.stack((mcca.train_correlations[0, 1], mcca.predict_corr(test_view_1, test_view_2)[0, 1]))

# Rgularised CCA solutions based on alternating minimisation/alternating least squares

We implement Witten's penalized matrix decomposition form of sparse CCA using 'pmd'

We implement Waaijenborg's penalized CCA using elastic net using 'elastic'

We implement Mai's sparse CCA using 'scca'

Furthermore, any of these methods can be extended to multiple views. Witten describes this method explicitly.

In [None]:
"""
### Sparse CCA (Penalized Matrix Decomposition) (can pass more than 2 views)
"""

# PMD
c1 = [1, 3, 7, 9]
c2 = [1, 3, 7, 9]
param_candidates = {'c': list(itertools.product(c1, c2))}

pmd = wrappers.CCA_ALS(latent_dims=latent_dims, method='pmd', tol=1e-5, max_iter=max_iter).gridsearch_fit(
    train_view_1,
    train_view_2,
    param_candidates=param_candidates,
    folds=cv_folds,
    verbose=True, jobs=jobs,
    plot=True)

pmd_results = np.stack((pmd.train_correlations[0, 1, :], pmd.predict_corr(test_view_1, test_view_2)[0, 1, :]))

"""
### Elastic CCA (can pass more than 2 views)
"""

# Elastic CCA
c1 = [0.0001, 0.001]
c2 = [0.0001, 0.001]
l1_1 = [0.01, 0.1]
l1_2 = [0.01, 0.1]
param_candidates = {'c': list(itertools.product(c1, c2)), 'l1_ratio': list(itertools.product(l1_1, l1_2))}

elastic = wrappers.CCA_ALS(latent_dims=latent_dims, method='elastic',
                                   max_iter=max_iter).gridsearch_fit(train_view_1,
                                                                    train_view_2,
                                                                    param_candidates=param_candidates,
                                                                    folds=cv_folds,
                                                                    verbose=True,
                                                                    jobs=jobs,
                                                                    plot=True)

elastic_results = np.stack(
    (elastic.train_correlations[0, 1, :], elastic.predict_corr(test_view_1, test_view_2)[0, 1, :]))

"""
### Sparse CCA (can pass more than 2 views)
"""

# Sparse CCA
c1 = [0.0001, 0.001]
c2 = [0.0001, 0.001]
param_candidates = {'c': list(itertools.product(c1, c2))}

scca = wrappers.CCA_ALS(latent_dims=latent_dims, method='scca', max_iter=max_iter).gridsearch_fit(
    train_view_1,
    train_view_2,
    param_candidates=param_candidates,
    folds=cv_folds,
    verbose=True,
    jobs=jobs, plot=True)

scca_results = np.stack(
    (scca.train_correlations[0, 1, :], scca.predict_corr(test_view_1, test_view_2)[0, 1, :]))

cross validation
number of folds:  5
Best score :  1.263031777484701
{'c': (9, 9)}
cross validation
number of folds:  5




Best score :  1.1217213066451142
{'c': (0.001, 0.001), 'l1_ratio': (0.1, 0.1)}
cross validation
number of folds:  5
Best score :  1.3778632445701178
{'c': (0.001, 0.0001)}


# Kernel CCA

In [None]:
"""
### Kernel CCA

Similarly, we can use kernel CCA methods with [method='kernel']

We can use different kernels and their associated parameters in a similar manner to before
- regularized linear kernel CCA: parameters :  'kernel'='linear', 0<'c'<1
- polynomial kernel CCA: parameters : 'kernel'='poly', 'degree', 0<'c'<1
- gaussian rbf kernel CCA: parameters : 'kernel'='gaussian', 'sigma', 0<'c'<1
"""
# %%
# r-kernel cca
c1 = [0.9, 0.99]
c2 = [0.9, 0.99]

param_candidates = {'kernel': ['linear'], 'c': list(itertools.product(c1, c2))}

kernel_reg = wrappers.KCCA(latent_dims=latent_dims).gridsearch_fit(train_view_1, train_view_2,
                                                                           folds=cv_folds,
                                                                           param_candidates=param_candidates,
                                                                           verbose=True, jobs=jobs,
                                                                           plot=True)
kernel_reg_results = np.stack((
    kernel_reg.train_correlations[0, 1, :],
    kernel_reg.predict_corr(test_view_1, test_view_2)[0, 1, :]))

# kernel cca (poly)
param_candidates = {'kernel': ['poly'], 'degree': [2, 3], 'c': list(itertools.product(c1, c2))}

kernel_poly = wrappers.KCCA(latent_dims=latent_dims).gridsearch_fit(train_view_1, train_view_2,
                                                                            folds=cv_folds,
                                                                            param_candidates=param_candidates,
                                                                            verbose=True, jobs=jobs,
                                                                            plot=True)

kernel_poly_results = np.stack((
    kernel_poly.train_correlations[0, 1, :],
    kernel_poly.predict_corr(test_view_1, test_view_2)[0, 1, :]))

# kernel cca (gaussian)
param_candidates = {'kernel': ['rbf'], 'sigma': [1e+1, 1e+2, 1e+3], 'c': list(itertools.product(c1, c2))}

kernel_gaussian = wrappers.KCCA(latent_dims=latent_dims).gridsearch_fit(train_view_1, train_view_2,
                                                                                folds=cv_folds,
                                                                                param_candidates=param_candidates,
                                                                                verbose=True, jobs=jobs,
                                                                                plot=True)

kernel_gaussian_results = np.stack((
    kernel_gaussian.train_correlations[0, 1, :],
    kernel_gaussian.predict_corr(test_view_1, test_view_2)[0, 1, :]))

cross validation
number of folds:  5
Best score :  1.4950788066772576
{'kernel': 'linear', 'c': (0.99, 0.99)}
cross validation
number of folds:  5
Best score :  1.269053452297602
{'kernel': 'poly', 'degree': 3, 'c': (0.9, 0.9)}
cross validation
number of folds:  5
Best score :  1.1206782413713658
{'kernel': 'rbf', 'sigma': 1000.0, 'c': (0.9, 0.9)}


  fig, axs = plt.subplots(1, n_uniques[-1], subplot_kw={'projection': '3d'})


# Deep CCA, Deep Generalized CCA & Deep Multiset CCA

In [None]:
"""
### Deep Learning

We also have deep CCA methods (and autoencoder variants)
- Deep CCA (DCCA)
- Deep Canonically Correlated Autoencoders (DCCAE)

We introduce a Config class from configuration.py. This contains a number of default settings for running DCCA.

"""
from cca_zoo import deepwrapper,objectives
# %%
# DCCA
cfg = Config()
cfg.epoch_num = 100

# hidden_layer_sizes are shown explicitly but these are also the defaults
dcca = deepwrapper.DeepWrapper(cfg)

dcca.fit(train_view_1, train_view_2)

dcca_results = np.stack((dcca.train_correlations, dcca.predict_corr(test_view_1, test_view_2)))

# DGCCA
cfg.loss_type = objectives.GCCA

# Note the different loss function
dgcca = deepwrapper.DeepWrapper(cfg)

dgcca.fit(train_view_1, train_view_2)

dgcca_results = np.stack((dgcca.train_correlations, dgcca.predict_corr(test_view_1, test_view_2)))

# DMCCA
cfg.loss_type = objectives.MCCA

# Note the different loss function
dmcca = deepwrapper.DeepWrapper(cfg)

dmcca.fit(train_view_1, train_view_2)

dmcca_results = np.stack((dmcca.train_correlations, dmcca.predict_corr(test_view_1, test_view_2)))

total parameters:  402948
====> Epoch: 1 Average train loss: -0.1573
====> Epoch: 1 Average val loss: -0.6672
Min loss -0.67
====> Epoch: 2 Average train loss: -0.8956
====> Epoch: 2 Average val loss: -1.1338
Min loss -1.13
====> Epoch: 3 Average train loss: -1.3275
====> Epoch: 3 Average val loss: -1.2245
Min loss -1.22
====> Epoch: 4 Average train loss: -1.4298
====> Epoch: 4 Average val loss: -1.2658
Min loss -1.27
====> Epoch: 5 Average train loss: -1.5004
====> Epoch: 5 Average val loss: -1.2973
Min loss -1.30
====> Epoch: 6 Average train loss: -1.5585
====> Epoch: 6 Average val loss: -1.3205
Min loss -1.32
====> Epoch: 7 Average train loss: -1.6067
====> Epoch: 7 Average val loss: -1.3305
Min loss -1.33
====> Epoch: 8 Average train loss: -1.6449
====> Epoch: 8 Average val loss: -1.3334
Min loss -1.33
====> Epoch: 9 Average train loss: -1.6779
====> Epoch: 9 Average val loss: -1.3323
====> Epoch: 10 Average train loss: -1.7053
====> Epoch: 10 Average val loss: -1.3362
Min loss -1.

  plt.figure()


====> Epoch: 1 Average train loss: -0.2987
====> Epoch: 1 Average val loss: -1.0767
Min loss -1.08
====> Epoch: 2 Average train loss: -1.2313
====> Epoch: 2 Average val loss: -1.2231
Min loss -1.22
====> Epoch: 3 Average train loss: -1.4267
====> Epoch: 3 Average val loss: -1.2763
Min loss -1.28
====> Epoch: 4 Average train loss: -1.5150
====> Epoch: 4 Average val loss: -1.3096
Min loss -1.31
====> Epoch: 5 Average train loss: -1.5763
====> Epoch: 5 Average val loss: -1.3333
Min loss -1.33
====> Epoch: 6 Average train loss: -1.6258
====> Epoch: 6 Average val loss: -1.3512
Min loss -1.35
====> Epoch: 7 Average train loss: -1.6694
====> Epoch: 7 Average val loss: -1.3607
Min loss -1.36
====> Epoch: 8 Average train loss: -1.7083
====> Epoch: 8 Average val loss: -1.3578
====> Epoch: 9 Average train loss: -1.7380
====> Epoch: 9 Average val loss: -1.3492
====> Epoch: 10 Average train loss: -1.7621
====> Epoch: 10 Average val loss: -1.3477
====> Epoch: 11 Average train loss: -1.7884
====> Epo

  plt.figure()


total parameters:  402948
====> Epoch: 1 Average train loss: -0.1679
====> Epoch: 1 Average val loss: -0.7035
Min loss -0.70
====> Epoch: 2 Average train loss: -0.7201
====> Epoch: 2 Average val loss: -1.1391
Min loss -1.14
====> Epoch: 3 Average train loss: -1.3139
====> Epoch: 3 Average val loss: -1.2081
Min loss -1.21
====> Epoch: 4 Average train loss: -1.4056
====> Epoch: 4 Average val loss: -1.2373
Min loss -1.24
====> Epoch: 5 Average train loss: -1.4574
====> Epoch: 5 Average val loss: -1.2524
Min loss -1.25
====> Epoch: 6 Average train loss: -1.4938
====> Epoch: 6 Average val loss: -1.2643
Min loss -1.26
====> Epoch: 7 Average train loss: -1.5232
====> Epoch: 7 Average val loss: -1.2754
Min loss -1.28
====> Epoch: 8 Average train loss: -1.5500
====> Epoch: 8 Average val loss: -1.2859
Min loss -1.29
====> Epoch: 9 Average train loss: -1.5750
====> Epoch: 9 Average val loss: -1.2957
Min loss -1.30
====> Epoch: 10 Average train loss: -1.5991
====> Epoch: 10 Average val loss: -1.30

  plt.figure()


# Using different model architectures

In [None]:
"""
### Convolutional Deep Learning

We can vary the encoder architecture from the default fcn to encoder/decoder based on the brainnetcnn architecture or a simple cnn
"""
from cca_zoo import deep_models
# %%
cfg = Config()
cfg.epoch_num = 100
cfg.encoder_models = [deep_models.CNNEncoder, deep_models.CNNEncoder]
cfg.encoder_args = [{'channels': [3, 3]}, {'channels': [3, 3]}]
# to change the models used change the cfg.encoder_models. We implement a CNN_Encoder and CNN_decoder as well
# as some based on brainnet architecture in cca_zoo.deep_models. Equally you could pass your own encoder/decoder models

dcca_conv = deepwrapper.DeepWrapper(cfg)

dcca_conv.fit(train_view_1.reshape((-1, 1, 28, 28)), train_view_2.reshape((-1, 1, 28, 28)))

dcca_conv_results = np.stack((dcca_conv.train_correlations, dcca_conv.predict_corr(test_view_1.reshape((-1, 1, 28, 28)),
                                                                                   test_view_2.reshape(
                                                                                       (-1, 1, 28, 28)))))

total parameters:  1204
====> Epoch: 1 Average train loss: -0.1742
====> Epoch: 1 Average val loss: -0.1891
Min loss -0.19
====> Epoch: 2 Average train loss: -0.2741
====> Epoch: 2 Average val loss: -0.2190
Min loss -0.22
====> Epoch: 3 Average train loss: -0.3481
====> Epoch: 3 Average val loss: -0.2434
Min loss -0.24
====> Epoch: 4 Average train loss: -0.4097
====> Epoch: 4 Average val loss: -0.2683
Min loss -0.27
====> Epoch: 5 Average train loss: -0.4734
====> Epoch: 5 Average val loss: -0.2936
Min loss -0.29
====> Epoch: 6 Average train loss: -0.5331
====> Epoch: 6 Average val loss: -0.3175
Min loss -0.32
====> Epoch: 7 Average train loss: -0.5839
====> Epoch: 7 Average val loss: -0.3442
Min loss -0.34
====> Epoch: 8 Average train loss: -0.6289
====> Epoch: 8 Average val loss: -0.3711
Min loss -0.37
====> Epoch: 9 Average train loss: -0.6714
====> Epoch: 9 Average val loss: -0.4000
Min loss -0.40
====> Epoch: 10 Average train loss: -0.7146
====> Epoch: 10 Average val loss: -0.4321

  plt.figure()


# Deep Variational CCA

In [None]:
"""
### Deep Variational Learning
Finally we have Deep Variational CCA methods.
- Deep Variational CCA (DVCCA)
- Deep Variational CCA - private (DVVCA_p)

These are both implemented by the DVCCA class with private=True/False and both_encoders=True/False. If both_encoders,
the encoder to the shared information Q(z_shared|x) is modelled for both x_1 and x_2 whereas if both_encoders is false
it is modelled for x_1 as in the paper
"""
from cca_zoo import dvcca
# %%
# DVCCA (technically bi-DVCCA)
cfg = Config()
cfg.method = dvcca.DVCCA
cfg.epoch_num = 100
dvcca = deepwrapper.DeepWrapper(cfg)

dvcca.fit(train_view_1, train_view_2)

dvcca_results = np.stack((dvcca.train_correlations, dvcca.predict_corr(test_view_1, test_view_2)))

# DVCCA_private (technically bi-DVCCA_private)
# switch private=False default to private=True
cfg.private = True

dvcca_p = deepwrapper.DeepWrapper(cfg)

dvcca_p.fit(train_view_1, train_view_2)

dvcca_p_results = np.stack((dvcca_p.train_correlations, dvcca_p.predict_corr(test_view_1, test_view_2)))

total parameters:  809516
====> Epoch: 1 Average train loss: 1105.8721
====> Epoch: 1 Average val loss: 1079.6569
Min loss 1079.66
====> Epoch: 2 Average train loss: 1079.5342
====> Epoch: 2 Average val loss: 1062.3976
Min loss 1062.40
====> Epoch: 3 Average train loss: 1061.5380
====> Epoch: 3 Average val loss: 1040.3821
Min loss 1040.38
====> Epoch: 4 Average train loss: 1039.0693
====> Epoch: 4 Average val loss: 1013.6394
Min loss 1013.64
====> Epoch: 5 Average train loss: 1011.4689
====> Epoch: 5 Average val loss: 990.8416
Min loss 990.84
====> Epoch: 6 Average train loss: 986.9456
====> Epoch: 6 Average val loss: 968.4504
Min loss 968.45
====> Epoch: 7 Average train loss: 962.8571
====> Epoch: 7 Average val loss: 948.5939
Min loss 948.59
====> Epoch: 8 Average train loss: 942.2664
====> Epoch: 8 Average val loss: 933.8574
Min loss 933.86
====> Epoch: 9 Average train loss: 927.0457
====> Epoch: 9 Average val loss: 921.9867
Min loss 921.99
====> Epoch: 10 Average train loss: 914.415

  plt.figure()


total parameters:  1216568
====> Epoch: 1 Average train loss: 1101.5437
====> Epoch: 1 Average val loss: 1078.9376
Min loss 1078.94
====> Epoch: 2 Average train loss: 1079.0117
====> Epoch: 2 Average val loss: 1059.9119
Min loss 1059.91
====> Epoch: 3 Average train loss: 1060.0469
====> Epoch: 3 Average val loss: 1032.5344
Min loss 1032.53
====> Epoch: 4 Average train loss: 1033.2188
====> Epoch: 4 Average val loss: 1005.0295
Min loss 1005.03
====> Epoch: 5 Average train loss: 1006.3236
====> Epoch: 5 Average val loss: 980.5747
Min loss 980.57
====> Epoch: 6 Average train loss: 981.8994
====> Epoch: 6 Average val loss: 957.4884
Min loss 957.49
====> Epoch: 7 Average train loss: 958.9430
====> Epoch: 7 Average val loss: 938.8829
Min loss 938.88
====> Epoch: 8 Average train loss: 940.0383
====> Epoch: 8 Average val loss: 924.3047
Min loss 924.30
====> Epoch: 9 Average train loss: 925.4579
====> Epoch: 9 Average val loss: 912.7220
Min loss 912.72
====> Epoch: 10 Average train loss: 913.45

  plt.figure()


# Generate Some Plots

In [None]:
"""
### Make results plot to compare methods
"""
# %%

all_results = np.stack(
    [linear_cca_results, scikit_cca_results, gcca_results, mcca_results, pls_results, scca_results, pmd_results,
     elastic_results,
     kernel_reg_results,
     kernel_poly_results,
     kernel_gaussian_results, dcca_results, dgcca_results,dmcca_results, dcca_conv_results,dvcca_results,dvcca_p_results],
    axis=0)
all_labels = ['linear', 'scikit', 'gcca', 'mcca', 'pls', 'pmd', 'elastic', 'scca', 'linear kernel', 'polynomial kernel',
              'gaussian kernel', 'deep CCA', 'deep generalized CCA','deep multiset CCA', 'deep convolutional cca',
              'deep variational CCA','deep variational CCA (private)']

from cca_zoo import plot_utils
plot_utils.plot_results(all_results, all_labels)
plt.show()

  fig, ax = plt.subplots()
  plt.figure()
  plt.figure()
