Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [BUG] Equivalent irreps yield different output in FullyConnectedTensorProduct #296

Closed
RobDHess opened this issue Aug 24, 2021 · 9 comments
Labels
bug Something isn't working

Comments

@RobDHess
Copy link

Bug

The output of o3.FullyConnectedTensorProduct is dependent on how the irreps are presented. For example:

irreps_in1_a = Irreps("3x0e + 2x1o + 2x2e")
irreps_in2_a = Irreps("1x0e + 1x1o + 1x2e + 1x3o")
irreps_out_a = Irreps("1x0e + 1x1o + 1x2e")

irreps_in1_b = Irreps("1x0e + 1x1o + 1x2e + 1x0e + 1x1o + 1x2e + 1x0e")
irreps_in2_b = Irreps("1x0e + 1x1o + 1x2e + 1x3o")
irreps_out_b = Irreps("1x0e + 1x1o + 1x2e")

Are equivalent, but the same input yields different results. I believe this is due to the normalisation being done "per path", instead of "per unique path". Is there a good reason for them to be different? I have found some mention of this issue here but this is not something .simplify() can solve.

I've included code below to reproduce the issue.

import torch
import torch.nn as nn
from e3nn import o3
from e3nn.o3 import Irreps


def equivalent(irreps_a, irreps_b):
    """ Sorts, simplifies and converts irreps to string for comparison """
    return str(irreps_a.sort().irreps.simplify()) == str(irreps_b.sort().irreps.simplify())


def set_params_to_one(f):
    """ Set parameters to one """
    for p in f.parameters():
        p.data.fill_(1.)


def get_indices(irreps):
    """ Get the indices for locations of values connected to irreps of specific type """
    idx = {}
    c = 0
    for irrep in irreps:
        n = irrep[0]
        l = irrep[1][0]
        p = irrep[1][1]
        dim = 2*l + 1
        name = str(l) + ('o' if p == 1 else 'e')
        if name not in idx:
            idx[name] = torch.arange(c, c+n*dim).long()
        else:
            idx[name] = torch.cat([idx[name], torch.arange(c, c+n*dim).long()])
        c += n*dim
    return idx


def permute(input, irreps_a, irreps_b):
    """ Permute input from a to b """
    assert equivalent(irreps_a, irreps_b), "They aren't equal:" + str(irreps_a) + " " + str(irreps_b)

    idx_a = get_indices(irreps_a)
    idx_b = get_indices(irreps_b)

    permuted_input = torch.zeros_like(input)

    for irrep in idx_a.keys():
        permuted_input[..., idx_b[irrep]] = input[..., idx_a[irrep]]
    return permuted_input


# Create three sets of equivalent irreps
irreps_in1_a = Irreps("3x0e + 2x1o + 2x2e")
irreps_in2_a = Irreps("1x0e + 1x1o + 1x2e + 1x3o")
irreps_out_a = Irreps("1x0e + 1x1o + 1x2e")

irreps_in1_b = Irreps("1x0e + 1x1o + 1x2e + 1x0e + 1x1o + 1x2e + 1x0e")
irreps_in2_b = Irreps("1x0e + 1x1o + 1x2e + 1x3o")
irreps_out_b = Irreps("1x0e + 1x1o + 1x2e")

irreps_in1_c = Irreps("3x0e + 2x1o + 2x2e")  # This one is for sanity checking
irreps_in2_c = Irreps("1x0e + 1x1o + 1x2e + 1x3o")
irreps_out_c = Irreps("1x0e + 1x1o + 1x2e")


# Generate random input and permute according to the other irreps.
in1_a = irreps_in1_a.randn(1, -1)
in2_a = irreps_in2_a.randn(1, -1)
in1_b = permute(in1_a, irreps_in1_a, irreps_in1_b)
in2_b = permute(in2_a, irreps_in2_a, irreps_in2_b)
in1_c = permute(in1_a, irreps_in1_a, irreps_in1_c)
in2_c = permute(in2_a, irreps_in2_a, irreps_in2_c)

# Create functions and set their parameters
normalization = "component"
f_a = o3.FullyConnectedTensorProduct(irreps_in1_a, irreps_in2_a, irreps_out_a, normalization=normalization)
f_b = o3.FullyConnectedTensorProduct(irreps_in1_b, irreps_in2_b, irreps_out_b, normalization=normalization)
f_c = o3.FullyConnectedTensorProduct(irreps_in1_c, irreps_in2_c, irreps_out_c, normalization=normalization)

# Make sure weights aren't affected by permutation
set_params_to_one(f_a)
set_params_to_one(f_b)
set_params_to_one(f_c)

# Get some output
out_a = f_a(in1_a, in2_a)
out_b = f_b(in1_b, in2_b)
out_c = f_c(in1_c, in2_c)

# Let's do some sanity checks:
# Is the permutation invertible, as it should be?
permuted_in1_a = permute(in1_a, irreps_in1_a, irreps_in1_b)
assert torch.equal(in1_a, permute(permuted_in1_a, irreps_in1_b, irreps_in1_a)), "permutation is not invertible"

# Are all parameters equal to one?
for p_a, p_b, p_c in zip(f_a.parameters(), f_b.parameters(), f_c.parameters()):
    assert torch.equal(p_a, torch.ones_like(p_a)), "problems with a"
    assert torch.equal(p_b, torch.ones_like(p_b)), "problems with b"
    assert torch.equal(p_c, torch.ones_like(p_c)), "problems with c"

# Are the outputs the same, after being permuted?
assert torch.equal(out_a, out_c), "The unpermuted outputs are not equal"
assert torch.equal(out_a, permute(out_b, irreps_out_b, irreps_out_a)), "The permuted outputs are not equal"
@RobDHess RobDHess added the bug Something isn't working label Aug 24, 2021
@mariogeiger
Copy link
Member

Hi @RobDHess,
This is a conception choice.
I'm not against changing that or give the option to the user.
The user could in principle counteract this normalization and impose his own using path_weight in the instructions but it will be non-trivial.
So I agree that we need to provide a simple way to the user to chose the normalization scheme.

@mariogeiger mariogeiger added enhancement New feature or request bug Something isn't working and removed bug Something isn't working enhancement New feature or request labels Aug 24, 2021
@mariogeiger
Copy link
Member

The current normalization works like that:

output = N[
  sum(
    N[
      sum_{uv} N[tensor_product[i]_{uv}]
    ] 
    for i in instructions if i.i_out == this_output
  )
]

where N[...] stands for normalization: basically dividing by the square root of the number of component summed.
Here are lines in the code corresponding to each normalization
outer N[...]
middle N[...]
inner N[...]

A way to normalize that would solve your problem is to remove the middle normalization and treat all path equal

output = N[
  sum(
    # N[ remove this one
      sum_{uv} N[tensor_product[i]_{uv}]
    # ] 
    for i in instructions if i.i_out == this_output
  )
]

I choose the first option because I thought is someone splits its irreps in two (10x1e -> 3x1e + 7x1e) it would mean that the two packets means something different and maybe that the user would like that each packet contribute equally.

@mariogeiger
Copy link
Member

mariogeiger commented Aug 24, 2021

A simpler example that shows the difference is (btw thanks for the clean code to make tests)

irreps_in1_a = Irreps("10x0e")
irreps_in2_a = Irreps("1x0e")
irreps_out_a = Irreps("1x0e")

irreps_in1_b = Irreps("3x0e + 7x0e")
irreps_in2_b = Irreps("1x0e")
irreps_out_b = Irreps("1x0e")

sidenote: I replaced torch.equal by torch.allclose otherwise 5x0e + 5x0e was raising the assert even if it should not

@mariogeiger
Copy link
Member

I guess what you propose is this one:

output = N[
  sum(
    N[
      sum(
        sum_{uv} N[tensor_product[i]_{uv}]
        for i in instructions if i.i_out == this_output and irreps_in1[i.i_in1].ir == ir1 and irreps_in2[i.i_in2].ir == ir2
      )
    ]
    for ir1, ir2 in all possibilities
  )
]

It sounds good to me

@RobDHess
Copy link
Author

Hi,

I agree with your suggestion, it might make it so that the code acts more predictably. However, both ways of normalisation have some merit to them.

By the way, there is a slight difference between your example irreps and mine. Your example can be "fixed" by applying a .simplify(), since their order is irrelevant. However, it is more difficult to change "1x0e + 1x1o + 1x0e + 1x1o" into "2x0e + 2x1o", since the order is relevant—even though they are equivalent. The latter arises when e.g. concatenating two steerable vectors. With the current normalisation, I would either need to set path_weights, or permute the resultant vector to sort the irreps. With the suggested normalisation, it would not make a difference.

Thank you for the quick replies!

@mariogeiger
Copy link
Member

I created a PR
@Linux-cpp-lisp @blondegeek @simonbatzner do you have any thoughts about that?

@Linux-cpp-lisp
Copy link
Contributor

Hm, unless I misunderstand, this desired behavior can be achieved by just setting out_var or in1_var, right?

I think by default it makes most sense to understand separated things in irreps as distinct inputs with distinct normalization.

Though I'm not sure if I'm actually correctly understanding the 1 vs 2 distinction.

@mariogeiger
Copy link
Member

It's not clear to me if var_in1 and var_out are enough. Don't you need .path_weight as well?

@mariogeiger
Copy link
Member

  1. All instructions are equal in normalization
  2. All products ir x ir -> ir are equal in normalization
  3. Is like 0 but paths of similar kind are merged

0 is not wanted
2 fixes 0 with minimal braking changes
1 breaks more things but is much simpler to think about

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants