Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Orthogonal Procrustes Distance + similarity->distance #11

Merged
merged 5 commits into from
Oct 29, 2021
Merged

Conversation

moskomule
Copy link
Owner

#10

@brando90
Copy link
Contributor

warning this pull request makes anatome fail my sanity checks e.g. when D is really large (much larger than # data points) sim should be 1.0 since there is a lot of power for the linear model to correlate the two data sets.

@brando90
Copy link
Contributor

Code to do sanity check

@brando90
Copy link
Contributor

#%%
"""
The similarity of the same network should always be 1.0 on same input.
"""
import torch
import torch.nn as nn

import uutils.torch_uu
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_identity_one_layer_linear_model

print('--- Sanity check: sCCA = 1.0 when using same net twice with same input. --')

Din: int = 10
Dout: int = Din
B: int = 2000
mdl1: nn.Module = get_named_identity_one_layer_linear_model(D=Din)
mdl2: nn.Module = mdl1
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

# - ends up comparing two matrices of size [B, Dout], on same data, on same model
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)

print(f'Should be very very close to 1.0: {sim=}')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0))

#%%
"""
Reproducing: How many data points: https://github.com/google/svcca/blob/master/tutorials/001_Introduction.ipynb

As n increases, the cca sim should decrease until it converges to the true max linear correlation in the data.
This is because when D is small it's easy to correlate via Xw, Yw since there are less equations (m data) than unknown (D features).
Similarly, the similarity decreases because the more data there is, the more variation has to be captured and thus the less
correlation there will be.
This is correct because 1/4*E[|| Xw - Yw||^2]^2 is proportional the pearson's correlation (assuming Xw, Yw is standardized).

"""
from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

import uutils.plot as uulot

print('\n--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 '
      '(even if nets/matrices are different)')
B: int = 10
Dout: int = 300
mdl1: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

# - get sim for B << D e.g. [B=10, D=300] easy to "fit", to many degrees of freedom
X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
# mdl1(X) : [B, Dout] = [B, B] [B, Dout]
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} (since we have many features to match the two Xw1, Yw2).')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
# assert(approx_equal(sim, 1.0))

print('\n-- Santity: just makes sure that when low data is present sim is high and afterwards (as n->infty) sim (CCA) '
      'converges to the "true" cca value (eventually)')
# data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000]
data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000]
# data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000, 50_000, 100_000]
# data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000, 10_000]
sims: list[float] = []
for b in data_sizes:
    X: torch.Tensor = uutils.torch_uu.get_identity_data(b)
    mdl1: nn.Module = get_named_one_layer_random_linear_model(b, Dout)
    mdl2: nn.Module = get_named_one_layer_random_linear_model(b, Dout)
    # print(f'{b=}')
    sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
    # print(f'{sim=}')
    sims.append(sim)

print(f'{sims=}')
uulot.plot(x=data_sizes, y=sims, xlabel='number of data points (n)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='ndata_vs_svcca_sim', title='Features (D) vs Sim (SVCCA)', x_hline=Dout, x_hline_label=f'B=D={Dout}')

#%%

from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

from uutils.plot import plot, save_to_desktop
import uutils.plot as uuplot

B: int = 10  # [101, 200, 500, 1000, 2000, 5000, 10000]
Din: int = B
Dout: int = 300
mdl1: nn.Module = get_named_one_layer_random_linear_model(Din, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(Din, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)

print(f'Should be very very close to 1.0: {sim=}')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')

# data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000, 50_000]
B: int = 300
D_feature_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000]
sims: list[float] = []
for d in D_feature_sizes:
    X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
    mdl1: nn.Module = get_named_one_layer_random_linear_model(B, d)
    mdl2: nn.Module = get_named_one_layer_random_linear_model(B, d)
    sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
    # print(f'{d=}, {sim=}')
    sims.append(sim)

print(f'{sims=}')
uuplot.plot(x=D_feature_sizes, y=sims, xlabel='number of features/size of dimension (D)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='D_vs_sim_svcca', title='Features (D) vs Sim (SVCCA)', x_hline=B, x_hline_label=f'B=D={B}')
# uuplot.plot(x=D_feature_sizes, y=sims, xlabel='number of features/size of dimension (D)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='D_vs_sim', title='Features (D) vs Sim (SVCCA)')
```

@brando90
Copy link
Contributor

plots you should get:
ndata_vs_svcca_sim
D_vs_sim_svcca

@brando90
Copy link
Contributor

Something in this pull request breaks anatome...

@moskomule
Copy link
Owner Author

Something in this pull request breaks anatome...

Thank you for reporting. I introduced torch.linalg.svd etc, but it may work different from torch.svd.

@brando90
Copy link
Contributor

Something in this pull request breaks anatome...

Thank you for reporting. I introduced torch.linalg.svd etc, but it may work different from torch.svd.

for now I will use the current version... If I have time I will test or try to fix the new one, but I'm also busy hehehe :) hope this helps though, at least the santity check should allow us to chat obvious bugs.

@brando90
Copy link
Contributor

I see other differences too like:

_matrix_normalize

instead of

_zero_mean

@brando90
Copy link
Contributor

brando90 commented Oct 26, 2021

I don't think you need to _matrix_normalize for CCA (idk for the others). The formula already has it:

max_{a, b} a^T X^T Y b / (a^TXa) (b^T Y b)

it only assumes centering. In short, pearson-correlation already normalizes in the demoninator.

Though, I don't think this should have made a difference.

Centering is for sure needed since the product only gives the covariance when things are centered.

@brando90
Copy link
Contributor

brando90 commented Oct 26, 2021

if I may suggest this implementation for OPD - since it re-uses other code you already wrote:

def orthogonal_procrustes_distance(x: Tensor,
                                   y: Tensor,
                                   ) -> Tensor:
    """ Orthogonal Procrustes distance used in Ding+21.
    Returns in dist interval [0, 1].

    Note:
        -  for a raw representation A we first subtract the mean value from each column, then divide
    by the Frobenius norm, to produce the normalized representation A* , used in all our dissimilarity computation.
        - see uutils.torch_uu.orthogonal_procrustes_distance to see my implementation
    Args:
        x: input tensor of Shape DxH
        y: input tensor of Shape DxW
    Returns:
    """
    _check_shape_equal(x, y, 0)

    # frobenius_norm = partial(torch.linalg.norm, ord="fro")
    nuclear_norm = partial(torch.linalg.norm, ord="nuc")

    x = _matrix_normalize(x, dim=0)
    y = _matrix_normalize(y, dim=0)
    # x = _zero_mean(x, dim=0)
    # x /= frobenius_norm(x)
    # y = _zero_mean(y, dim=0)
    # y /= frobenius_norm(y)
    # frobenius_norm(x) = 1, frobenius_norm(y) = 1
    # 0.5*d_proc(x, y)
    # - note this already outputs it between [0, 1] e.g. it's not 2 - 2 nuclear_norm(<x1, x2>)
    return 1 - nuclear_norm(x.t() @ y)

@brando90
Copy link
Contributor

brando90 commented Oct 26, 2021

I can confirm that normalizing by the forbenius norm breaks one of the CCA santity checks:

normalizing by the forbenius norm breaks the sanity check when D is really large for cca.

See:

--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 (even if nets/matrices are different)
Should be very very close to 1.0: sim=0.9341215491294861 (since we have many features to match the two Xw1, Yw2).
Is it close to 1.0? False

Code:

from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

import uutils.plot as uulot

print('\n--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 '
      '(even if nets/matrices are different)')
B: int = 10
Dout: int = 100
mdl1: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

# - get sim for B << D e.g. [B=10, D=300] easy to "fit", to many degrees of freedom
X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
# mdl1(X) : [B, Dout] = [B, B] [B, Dout]
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} (since we have many features to match the two Xw1, Yw2).')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
# assert(approx_equal(sim, 1.0))

but this didn't break the plots, surprisingly.

@brando90
Copy link
Contributor

Something in this pull request breaks anatome...

Thank you for reporting. I introduced torch.linalg.svd etc, but it may work different from torch.svd.

No that is not it (according to my sanity checks above that ran with U, S, Vh = torch.linalg.svd(input, full_matrices=False). I think it the division of the Frobenius norm for cca. It might be nice to figure out which need that. Afaik, only orthogonal Procrustes needs that.

@brando90
Copy link
Contributor

ok found the bug!

You need to do the centering correct because * binds stronger than -. So normalization is as follows:

def _matrix_normalize(input: Tensor,
                      dim: int
                      ) -> Tensor:
    """
    Center and normalize according to the forbenius norm (not the standard deviation).

    Warning: this does not create standardized random variables in a random vectors.

    Note: careful with this, it makes CCA behave in unexpected ways
    :param input:
    :param dim:
    :return:
    """
    from torch.linalg import norm
    return (input - input.mean(dim=dim, keepdim=True)) / norm(input, 'fro')

or even better reuse your _zerp_mean.

My sanity checks (for all metrics pass) now:

/Users/brando/anaconda3/envs/metalearning/bin/python /Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py --cmd-line --multiproc --qt-support=auto --client 127.0.0.1 --port 55456 --file /Users/brando/ultimate-utils/tutorials_for_myself/anatome_pg/sanity_checks_anatome.py
Connected to pydev debugger (build 212.5080.64)
--- Sanity check: sCCA = 1.0 when using same net twice with same input. --
Should be very very close to 1.0: sim=1.000000238418579 (cxa_dist_type='svcca')
Is it close to 1.0? True
Should be very very close to 1.0: sim=0.9999998807907104 (cxa_dist_type='pwcca')
Is it close to 1.0? True
Should be very very close to 1.0: sim=1.0 (cxa_dist_type='lincka')
Is it close to 1.0? True
Should be very very close to 1.0: sim=0.9997346997261047 (cxa_dist_type='opd')
Is it close to 1.0? True
--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 (even if nets/matrices are different)
Should be very very close to 1.0: sim=1.000000238418579 (since we have many features to match the two Xw1, Yw2).
Is it close to 1.0? True
-- Santity: just makes sure that when low data is present sim is high and afterwards (as n->infty) sim (CCA) converges to the "true" cca value (eventually)
sims=[0.9999998807907104, 0.9999999403953552, 0.9801047444343567, 0.9169793725013733, 0.6231850981712341, 0.3799371123313904, 0.2702748775482178, 0.18766719102859497, 0.11999624967575073, 0.08386451005935669]
Should be very very close to 1.0: sim=1.0
Is it close to 1.0? True
sims=[0.2898038625717163, 0.44516634941101074, 0.6200690865516663, 0.9168117046356201, 0.9173185229301453, 0.9742245674133301, 0.9898524284362793, 0.9903322458267212, 0.9898055791854858, 0.98990398645401, 0.9907135367393494]
import sys; print('Python %s on %s' % (sys.version, sys.p

@brando90
Copy link
Contributor

final sanity check code:

#%%
"""
The similarity of the same network should always be 1.0 on same input.
"""
import torch
import torch.nn as nn

import uutils.torch_uu
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_identity_one_layer_linear_model

print('--- Sanity check: sCCA = 1.0 when using same net twice with same input. --')

Din: int = 10
Dout: int = Din
B: int = 2000
mdl1: nn.Module = get_named_identity_one_layer_linear_model(D=Din)
mdl2: nn.Module = mdl1
layer_name = 'fc0'

# - ends up comparing two matrices of size [B, Dout], on same data, on same model
cxa_dist_type = 'svcca'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0)), f'Sim should be close to 1.0 but got {sim=}'

cxa_dist_type = 'pwcca'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0)), f'Sim should be close to 1.0 but got {sim=}'

cxa_dist_type = 'lincka'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0)), f'Sim should be close to 1.0 but got {sim=}'

cxa_dist_type = 'opd'
X: torch.Tensor = torch.distributions.Normal(loc=0.0, scale=1.0).sample((B, Din))
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} ({cxa_dist_type=})')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0, tolerance=1e-2)}')
assert(approx_equal(sim, 1.0, tolerance=1e-2)), f'Sim should be close to 1.0 but got {sim=}'

#%%
"""
Reproducing: How many data points: https://github.com/google/svcca/blob/master/tutorials/001_Introduction.ipynb

As n increases, the cca sim should decrease until it converges to the true max linear correlation in the data.
This is because when D is small it's easy to correlate via Xw, Yw since there are less equations (m data) than unknown (D features). 
Similarly, the similarity decreases because the more data there is, the more variation has to be captured and thus the less
correlation there will be.
This is correct because 1/4*E[|| Xw - Yw||^2]^2 is proportional the pearson's correlation (assuming Xw, Yw is standardized).

"""
from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

import uutils.plot as uulot

print('\n--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 '
      '(even if nets/matrices are different)')
B: int = 10
Dout: int = 100
mdl1: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(B, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

# - get sim for B << D e.g. [B=10, D=300] easy to "fit", to many degrees of freedom
X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
# mdl1(X) : [B, Dout] = [B, B] [B, Dout]
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=} (since we have many features to match the two Xw1, Yw2).')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0))

print('\n-- Santity: just makes sure that when low data is present sim is high and afterwards (as n->infty) sim (CCA) '
      'converges to the "true" cca value (eventually)')
# data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000]
data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000, 10_000]
# data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000, 10_000, 50_000, 100_000]
# data_sizes: list[int] = [10, 25, 50, 100, 200, 500, 1_000, 2_000, 5_000, 10_000]
sims: list[float] = []
for b in data_sizes:
    X: torch.Tensor = uutils.torch_uu.get_identity_data(b)
    mdl1: nn.Module = get_named_one_layer_random_linear_model(b, Dout)
    mdl2: nn.Module = get_named_one_layer_random_linear_model(b, Dout)
    # print(f'{b=}')
    sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
    # print(f'{sim=}')
    sims.append(sim)

print(f'{sims=}')
uulot.plot(x=data_sizes, y=sims, xlabel='number of data points (n)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='ndata_vs_svcca_sim', title='Features (D) vs Sim (SVCCA)', x_hline=Dout, x_hline_label=f'B=D={Dout}')

#%%

from pathlib import Path
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import uutils
from uutils.torch_uu import cxa_sim, approx_equal
from uutils.torch_uu.models import get_named_one_layer_random_linear_model

from uutils.plot import plot, save_to_desktop
import uutils.plot as uuplot

B: int = 10  # [101, 200, 500, 1000, 2000, 5000, 10000]
Din: int = B
Dout: int = 300
mdl1: nn.Module = get_named_one_layer_random_linear_model(Din, Dout)
mdl2: nn.Module = get_named_one_layer_random_linear_model(Din, Dout)
layer_name = 'fc0'
# cxa_dist_type = 'pwcca'
cxa_dist_type = 'svcca'

X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
print(f'Should be very very close to 1.0: {sim=}')
print(f'Is it close to 1.0? {approx_equal(sim, 1.0)}')
assert(approx_equal(sim, 1.0))

# data_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000, 50_000]
B: int = 100
D_feature_sizes: list[int] = [10, 25, 50, 100, 101, 200, 500, 1_000, 2_000, 5_000, 10_000]
sims: list[float] = []
for d in D_feature_sizes:
    X: torch.Tensor = uutils.torch_uu.get_identity_data(B)
    mdl1: nn.Module = get_named_one_layer_random_linear_model(B, d)
    mdl2: nn.Module = get_named_one_layer_random_linear_model(B, d)
    sim: float = cxa_sim(mdl1, mdl2, X, layer_name, downsample_size=None, iters=1, cxa_dist_type=cxa_dist_type)
    # print(f'{d=}, {sim=}')
    sims.append(sim)

print(f'{sims=}')
uuplot.plot(x=D_feature_sizes, y=sims, xlabel='number of features/size of dimension (D)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='D_vs_sim_svcca', title='Features (D) vs Sim (SVCCA)', x_hline=B, x_hline_label=f'B=D={B}')
# uuplot.plot(x=D_feature_sizes, y=sims, xlabel='number of features/size of dimension (D)', ylabel='similarity (svcca)', show=True, save_plot=True, plot_filename='D_vs_sim', title='Features (D) vs Sim (SVCCA)')

@brando90
Copy link
Contributor

@brando90
Copy link
Contributor

Note, it's better to divide by centered data. The accuracy of OPD increases dramatically. Comparing the same matrix twice finally gives 1.0 up to 1e-4 instead of 1e-2

@brando90
Copy link
Contributor

def _matrix_normalize_using_centered_data(X: Tensor, dim: int = 1) -> Tensor:
    """
    Normalize matrix of size wrt to the data dimension according to the similarity preprocessing standard.
    Assumption is that X is of size [n, d].
    Otherwise, specify which simension to normalize with dim.

    ref: https://stats.stackexchange.com/questions/544812/how-should-one-normalize-activations-of-batches-before-passing-them-through-a-si
    """
    from torch.linalg import norm
    X_centered: Tensor = _zero_mean(X, dim=dim)
    X_star: Tensor = X_centered / norm(X_centered, "fro")
    return X_star

results:

Connected to pydev debugger (build 212.5080.64)
--- Sanity check: sCCA = 1.0 when using same net twice with same input. --
Should be very very close to 1.0: sim=1.0000004768371582 (cxa_dist_type='svcca')
Is it close to 1.0? True
Should be very very close to 1.0: sim=1.000000238418579 (cxa_dist_type='pwcca')
Is it close to 1.0? True
Should be very very close to 1.0: sim=1.0 (cxa_dist_type='lincka')
Is it close to 1.0? True
Should be very very close to 1.0: sim=1.0000001192092896 (cxa_dist_type='opd')
Is it close to 1.0? True
--- Sanity check: when number of data points B is smaller than D, then it should be trivial to make similiarty 1.0 (even if nets/matrices are different)
Should be very very close to 1.0: sim=0.9999998807907104 (since we have many features to match the two Xw1, Yw2).
Is it close to 1.0? True
-- Santity: just makes sure that when low data is present sim is high and afterwards (as n->infty) sim (CCA) converges to the "true" cca value (eventually)
sims=[1.0000001192092896, 1.0, 0.9794895648956299, 0.9188864231109619, 0.6179077625274658, 0.3843235969543457, 0.2695028781890869, 0.18886375427246094, 0.11978656053543091, 0.0842815637588501]
Should be very very close to 1.0: sim=1.0
Is it close to 1.0? True
sims=[0.24919700622558594, 0.43115103244781494, 0.6279942393302917, 0.9188255667686462, 0.9206753969192505, 0.9731308817863464, 0.9901297688484192, 0.9902339577674866, 0.990931510925293, 0.9907766580581665, 0.9902600049972534]
import sys; print('Python %s on %s' % (sys.version, sys.platform))

Comment on lines +229 to +232
x = _zero_mean(x, dim=0)
x /= frobenius_norm(x)
y = _zero_mean(y, dim=0)
y /= frobenius_norm(y)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get your point, do you mention this part?

Copy link
Contributor

@brando90 brando90 Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but that is not the code you are using for CCA or CKA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@moskomule

also at some point anatome was using:

def _matrix_normalize(input: Tensor,
                      dim: int
                      ) -> Tensor:
    from torch.linalg import norm
    return input - input.mean(dim=dim, keepdim=True) / norm(input, 'fro')

which doesn't center correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find where I got that code...I am 100% I didn't write it...

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually don't write from torch.linalg import norm...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting! who knows, something weird might have happened in the a merge attempt I did in my fork.

But I think the message is clear, if we normalize making sure centering first and then using that centered data is important (especially if using it for all metrics).

I decided I will only normalize for OPD and only center for CCA.
Undecided for CKA.

Comment on lines 221 to 224
x = _zero_mean(x, dim=0) / frobenius_norm(x)
y = _zero_mean(y, dim=0) / frobenius_norm(x)
x = frobenius_norm(x) ** 2
y = frobenius_norm(y) ** 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@moskomule I think this is one source of the bugs...it's better to wrap one function and have all metrics use that.

Comment on lines +229 to +232
x = _zero_mean(x, dim=0)
x /= frobenius_norm(x)
y = _zero_mean(y, dim=0)
y /= frobenius_norm(y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@moskomule

also at some point anatome was using:

def _matrix_normalize(input: Tensor,
                      dim: int
                      ) -> Tensor:
    from torch.linalg import norm
    return input - input.mean(dim=dim, keepdim=True) / norm(input, 'fro')

which doesn't center correctly.

@brando90
Copy link
Contributor

brando90 commented Oct 27, 2021

@moskomule I am curious. What is the final conclusion for you for normalizing the matrices before computing the distances.
Do you plan to divide by forbenius norm (of the centered matrix) for:

  1. Only for OPD?
  2. CCA?
  3. CKA?

My hunch is that OPD is the only one that needs it and only centering is enough for the other two.

@moskomule
Copy link
Owner Author

I agree with it and if I remember correctly, I implemented so.

@brando90
Copy link
Contributor

brando90 commented Oct 27, 2021

I agree with it and if I remember correctly, I implemented so.

in the risk of being redudant I do want to note that that is not what the authors of the OPD paper do [see here] (js-d/sim_metric#4 (comment)) (they normalize all the time) but with my sanity checks I doubt the difference will be large and I will do what you do and just center for CCA and CKA but only normalize for OPD.

Thanks for discussions! :)

@moskomule moskomule merged commit ebadb51 into master Oct 29, 2021
@moskomule moskomule deleted the dev branch October 29, 2021 19:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants