In [1]:
import escnn
import escnn.group
from escnn.group import CyclicGroup, DihedralGroup, DirectProductGroup, Group, Representation
from escnn.nn import FieldType, EquivariantModule, GeometricTensor

from morpho_symm.utils.algebra_utils import gen_permutation_matrix
from morpho_symm.utils.rep_theory_utils import group_rep_from_gens

import torch
import torch.nn as nn
from torch.distributions import Normal

import numpy as np

from typing import Union, List

In [2]:
# symmetry_space = escnn.gspaces.GSpace3D(tuple([False, False, 2]))
# G = symmetry_space.fibergroup
np.set_printoptions(precision=1, suppress=True)

In [3]:
def add_repr_to_gspace(G:escnn.group.Group, 
                       permutation_conf:Union[List[int], np.ndarray],
                       reflection_conf:Union[List[int], np.ndarray],
                       name:str
    ):
    """将指定的变换添加到G这个group中, 供之后生成escnn的对称网络使用

    Args:
        G (escnn.group.Group): 一个escnn的group
        permutation_conf (Union[List[int], np.ndarray]): 位置变换矩阵
        reflection_conf (Union[List[int], np.ndarray]): 由1和-1指定的变换矩阵
        name (str): 这个变换的名字

    Returns:
        _type_: _description_
    """
    # 获取对应的配置文件，保证格式为 (n, 配置长度)
    permutation_conf = np.array(permutation_conf, dtype=int)
    reflection_conf = np.array(reflection_conf, dtype=float)
    if permutation_conf.ndim == 1: permutation_conf = permutation_conf[None]
    if reflection_conf.ndim == 1: reflection_conf = reflection_conf[None]
    
    # 获取配置个数, 长度
    (conf_num, conf_length) = permutation_conf.shape
    print(f"Conf name = {name}, length = {conf_length}")
    
    # 检查: 确保给的配置是正常的
    assert permutation_conf.shape == reflection_conf.shape, len(reflection_conf.shape)==2
    assert conf_num == len(G.generators)
    
    # 开始配置representations
    rep_joints = {G.identity: np.eye(conf_length, dtype=float)}
    for g_gen, perm, refx in zip(G.generators, permutation_conf, reflection_conf):
        refx = np.array(refx, dtype=float)
        rep_joints[g_gen] = gen_permutation_matrix(oneline_notation=perm, reflections=refx)

    # #将dict转化为representation.Representation
    rep_joints = group_rep_from_gens(G, rep_joints) 
    # 配置name
    rep_joints.name = name
    # 输入给G
    G.representations.update(**{name:rep_joints})
    return G

In [11]:
G = CyclicGroup(2)
gspace = escnn.gspaces.no_base_space(G)

add_repr_to_gspace(G, [2,1,0], [-1,1,-1], 't1')
add_repr_to_gspace(G, [1,0,2], [1,1,-1], 't2')

add_repr_to_gspace(G, [0], [1], 'tri')

Conf name = t1, length = 3
Conf name = t2, length = 3
Conf name = tri, length = 1


C2

In [12]:
G.representations

{'irrep_0': C2|[irrep_0]:1,
 'irrep_1': C2|[irrep_1]:1,
 'regular': C2|[regular]:2,
 't1': C2|[t1]:3,
 't2': C2|[t2]:3,
 'tri': C2|[tri]:1}

In [13]:
G.representations['tri'](G.elements[1])  # 这个就是 trivial_representation

array([[1.]])

In [14]:
G.trivial_representation(G.elements[0])

array([[1.]])

In [22]:
class SimpleEMLP(EquivariantModule):
    def __init__(self,
                 in_type: FieldType,
                 out_type: FieldType,
                 hidden_dims = [256, 256, 256],
                 bias: bool = True,
                 actor: bool = True,
                 activation: str = "ReLU"):
        super().__init__()
        self.in_type = in_type
        self.out_type = out_type
        gspace = in_type.gspace
        group = gspace.fibergroup
        
        layer_in_type = in_type
        self.net = escnn.nn.SequentialModule()
        for n in range(len(hidden_dims)):
            layer_out_type = FieldType(gspace, [group.regular_representation] * int((hidden_dims[n] / group.order())))

            self.net.add_module(f"linear_{n}: in={layer_in_type.size}-out={layer_out_type.size}",
                             escnn.nn.Linear(layer_in_type, layer_out_type, bias=bias))
            self.net.add_module(f"act_{n}", self.get_activation(activation, layer_out_type))

            layer_in_type = layer_out_type

        if actor: 
            self.net.add_module(f"linear_{len(hidden_dims)}: in={layer_in_type.size}-out={out_type.size}",
                                escnn.nn.Linear(layer_in_type, out_type, bias=bias))
            self.extra_layer = None
        else:
            num_inv_features = len(layer_in_type.irreps)
            self.extra_layer = torch.nn.Linear(num_inv_features, out_type.size, bias=False)

    def forward(self, x: GeometricTensor) -> GeometricTensor:
        x= self.net(x)
        if self.extra_layer:
            x = self.extra_layer(x.tensor)
        return x

    @staticmethod
    def get_activation(activation: str, hidden_type: FieldType) -> EquivariantModule:
        if activation.lower() == "relu":
            return escnn.nn.ReLU(hidden_type)
        elif activation.lower() == "elu":
            return escnn.nn.ELU(hidden_type)
        elif activation.lower() == "lrelu":
            return escnn.nn.LeakyReLU(hidden_type)
        else:
            raise NotImplementedError

    def evaluate_output_shape(self, input_shape):
        """Returns the output shape of the model given an input shape."""
        batch_size = input_shape[0]
        return batch_size, self.out_type.size

    def export(self):
        """Exports the model to a torch.nn.Sequential instance."""
        sequential = nn.Sequential()
        for name, module in self.net.named_children():
            sequential.add_module(name, module.export())
        return sequential

In [23]:
actor_input_transitions = [G.representations['t1'],
                           G.representations['t2']]
actor_output_transitions = [G.representations['t1']]

in_field_type = FieldType(gspace, actor_input_transitions)
out_field_type = FieldType(gspace, actor_output_transitions)

actor = SimpleEMLP(in_field_type, out_field_type,
            hidden_dims = [256, 256, 256], 
            activation = 'elu')

In [24]:
device = 'cuda'

observations = np.arange(12, dtype=np.float32).reshape(2,6)
obs_torch = torch.as_tensor(observations, device=device)
actor = actor.to(device=device)

res1 = actor(in_field_type(obs_torch))

In [33]:
# 

In [8]:
observations

array([[ 0.,  1.,  2.,  3.,  4.,  5.],
       [ 6.,  7.,  8.,  9., 10., 11.]], dtype=float32)

In [9]:
res1

g_tensor([[  0.2529,  -4.7182,  -1.1217],
          [  2.6172, -13.7864,  -3.6580]], device='cuda:0',
         grad_fn=<AddmmBackward0>, [C2: {t1 (x1)}(3)])

In [10]:
def get_symm_tensor(data:torch.Tensor, G:escnn.group.Group, reprs:List[str])->torch.Tensor:
    """将data这个torch.tensor转化为对称后的结构

    Args:
        data (torch.Tensor): (Batch, N)
        G (escnn.group.Group): 一个escnn的群论group
        reprs (List): represetations的列表，会按照顺序对data进行对称计算

    Returns:
        torch.Tensor: 返回的是对称后的结果，与data保持相同的device
    """
    # 要求data数据是torch.tensor
    assert isinstance(data, torch.Tensor), data.ndim <= 2
    # 整理data的shape和 repr要是列表
    data = data[None] if data.ndim == 1 else data
    reprs = [reprs] if not isinstance(reprs, List) else reprs
    # 获取device
    device = data.device
    # 开始转换
    curr_ind = 0
    res = []
    for repr in reprs:
        res.append(
            (torch.as_tensor(G.representations[repr](G.elements[1]), dtype=torch.float32, device=device) \
                @ data.T[curr_ind:curr_ind+G.representations[repr].size]
                ).T 
        )
        curr_ind += G.representations[repr].size
    # 确保reprs和data的维度是吻合的
    assert curr_ind == data.shape[-1]
    # 返回结果
    return torch.concat(res, dim=-1)

In [11]:
o2 = get_symm_tensor(obs_torch, G, ['t1', 't2'])
print(o2.int())

tensor([[ -2,   1,   0,   4,   3,  -5],
        [ -8,   7,  -6,  10,   9, -11]], device='cuda:0', dtype=torch.int32)


In [12]:
res2 = actor(in_field_type(o2))
print(res2)

g_tensor([[  1.1217,  -4.7182,  -0.2529],
          [  3.6580, -13.7865,  -2.6172]], device='cuda:0',
         grad_fn=<AddmmBackward0>, [C2: {t1 (x1)}(3)])


In [34]:
# import numpy as np


# a = np.array([-0.0001, 1, -2,\
#         3, -4, 5,\
#         6, -7, -8,])

# def transform_zycofig_to_code(config:list):
#     a = np.array(config)
#     tmp = np.abs(a).astype(int)
#     permutation = tmp - tmp.min()
#     is_negative = np.where(a>0, 1, -1)
#     return permutation.tolist(), is_negative.tolist()
    
# def transform_all(*configs):
#     for i,config in enumerate(configs):
#         permutation, is_negative = transform_zycofig_to_code(config)
#         print(f"i:{i}, length:{len(permutation)}")
#         print(permutation)
#         print(is_negative)
#         print()

# transform_all([
#     -0.0001, 1, -2,\
#     ],[
#     3, -4, 5,\
#     ],[    
#     6, -7, -8,\
#     ],[
#     15, -16, -17, 18, 19, -20,\
#     9, -10, -11, 12, 13, -14,\
#     -21,\
#     29, -30, -31, 32, -33, 34, -35,\
#     22, -23, -24, 25, -26, 27, -28,\
#     ],[
#     42, -43, -44, 45, 46, -47,\
#     36, -37, -38, 39, 40, -41,\
#     -48,\
#     56, -57, -58, 59, -60, 61, -62,\
#     49, -50, -51, 52, -53, 54, -55,\
#     ],[
#     69, -70, -71, 72, 73, -74,\
#     63, -64, -65, 66, 67, -68,\
#     -75,\
#     83, -84, -85, 86, -87, 88, -89,\
#     76, -77, -78, 79, -80, 81, -82,
#     ],[
#     -90, -91 
#                           ])


In [35]:
a2 = SimpleEMLP(in_field_type, out_field_type,
            hidden_dims = [256, 256, 256], 
            activation = 'elu')
a2 = a2.to('cpu')
a2.check_equivariance(1e-5, 1e-5)

0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
1[2pi/2]
1[2pi/2] 2.3841858e-07 8.443991e-08 4.8356383e-15
1[2pi/2]
1[2pi/2] 2.3841858e-07 8.443991e-08 4.8356383e-15
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
1[2pi/2]
1[2pi/2] 2.3841858e-07 8.443991e-08 4.8356383e-15
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
1[2pi/2]
1[2pi/2] 2.3841858e-07 8.443991e-08 4.8356383e-15
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
1[2pi/2]
1[2pi/2] 2.3841858e-07 8.443991e-08 4.8356383e-15
1[2pi/2]
1[2pi/2] 2.3841858e-07 8.443991e-08 4.8356383e-15
1[2pi/2]
1[2pi/2] 2.3841858e-07 8.443991e-08 4.8356383e-15
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0
0[2pi/2]
0[2pi/2] 0.0 0.0 0.0


[(0[2pi/2], 0.0),
 (0[2pi/2], 0.0),
 (0[2pi/2], 0.0),
 (0[2pi/2], 0.0),
 (0[2pi/2], 0.0),
 (0[2pi/2], 0.0),
 (0[2pi/2], 0.0),
 (1[2pi/2], 8.443991e-08),
 (1[2pi/2], 8.443991e-08),
 (0[2pi/2], 0.0),
 (1[2pi/2], 8.443991e-08),
 (0[2pi/2], 0.0),
 (0[2pi/2], 0.0),
 (1[2pi/2], 8.443991e-08),
 (0[2pi/2], 0.0),
 (1[2pi/2], 8.443991e-08),
 (1[2pi/2], 8.443991e-08),
 (1[2pi/2], 8.443991e-08),
 (0[2pi/2], 0.0),
 (0[2pi/2], 0.0)]

In [39]:
print(a2)

SimpleEMLP(
  (net): SequentialModule(
    (linear_0: in=6-out=256): Linear(
      (_basisexpansion): BlocksBasisExpansion(
        (block_expansion_('t1', 'regular')): SingleBlockBasisExpansion()
        (block_expansion_('t2', 'regular')): SingleBlockBasisExpansion()
      )
    )
    (act_0): ELU(alpha=1.0, inplace=False, type=[C2: {regular (x128)}(256)])
    (linear_1: in=256-out=256): Linear(
      (_basisexpansion): BlocksBasisExpansion(
        (block_expansion_('regular', 'regular')): SingleBlockBasisExpansion()
      )
    )
    (act_1): ELU(alpha=1.0, inplace=False, type=[C2: {regular (x128)}(256)])
    (linear_2: in=256-out=256): Linear(
      (_basisexpansion): BlocksBasisExpansion(
        (block_expansion_('regular', 'regular')): SingleBlockBasisExpansion()
      )
    )
    (act_2): ELU(alpha=1.0, inplace=False, type=[C2: {regular (x128)}(256)])
    (linear_3: in=256-out=3): Linear(
      (_basisexpansion): BlocksBasisExpansion(
        (block_expansion_('regular', 't1')

In [42]:
list(a2.state_dict().keys())

['net.linear_0: in=6-out=256.bias',
 'net.linear_0: in=6-out=256.weights',
 'net.linear_0: in=6-out=256.bias_expansion',
 'net.linear_0: in=6-out=256.expanded_bias',
 'net.linear_0: in=6-out=256.matrix',
 "net.linear_0: in=6-out=256._basisexpansion.block_expansion_('t1', 'regular').sampled_basis",
 "net.linear_0: in=6-out=256._basisexpansion.block_expansion_('t2', 'regular').sampled_basis",
 'net.linear_1: in=256-out=256.bias',
 'net.linear_1: in=256-out=256.weights',
 'net.linear_1: in=256-out=256.bias_expansion',
 'net.linear_1: in=256-out=256.expanded_bias',
 'net.linear_1: in=256-out=256.matrix',
 "net.linear_1: in=256-out=256._basisexpansion.block_expansion_('regular', 'regular').sampled_basis",
 'net.linear_2: in=256-out=256.bias',
 'net.linear_2: in=256-out=256.weights',
 'net.linear_2: in=256-out=256.bias_expansion',
 'net.linear_2: in=256-out=256.expanded_bias',
 'net.linear_2: in=256-out=256.matrix',
 "net.linear_2: in=256-out=256._basisexpansion.block_expansion_('regular', '

In [43]:
a2.load_state_dict(a2.state_dict())

<All keys matched successfully>