In [1]:
PARTITIONS = {
    1: (1),
    2: [(2), (1, 1)],
    3: [(3), (2, 1), (1, 1, 1)],
    4: [(4), (3, 1), (2, 2), (2, 1, 1), (1, 1, 1, 1)],
    5: [(5), (4, 1), (3, 2), (3, 1, 1), (2, 1, 1, 1), (1, 1, 1, 1, 1)],
    6: [(6), (5, 1), (4, 2), (4, 1, 1), (3, 3), (3, 2, 1), (3, 1, 1, 1), (2, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1)]
}


def _partitions(n):
    if n == 0:
        return None
    elif n == 1:
        return (1)
    elif n in PARTITIONS:
        return PARTITIONS[n]
    else:
        ps = [(n)]
        for i in range(1, n):
            k = n - i

            ps.append([k, i])
            i_parts = _partitions(i)
        
            
    return ps
    


In [31]:
from itertools import permutations, combinations, product



In [32]:
class Permutation:
    
    def __init__(self, sigma):
        self.sigma = tuple(sigma)
        self.base = list(range(len(sigma)))
        self._cycle_rep = None
    
    def __len__(self):
        return len(self.sigma)

    def __hash__(self):
        return hash(self.sigma)
    
    def __repr__(self):
        return f'Permutation on {len(self)} elements: {self.sigma}'
    
    def __call__(self, x):
        if len(x) != len(self):
            raise ValueError(f'Permutation of length {len(self)} is ill-defined for given sequence of length {len(x)}')
        if isinstance(x, Permutation):
            sequence = x.sigma
            new_sigma = [sequence[self.sigma[i]] for i in self.base]
            return Permutation(new_sigma)
        else:
            return [x[self.sigma[i]] for i in self.base]
    
    @property
    def cycle_rep(self):
        if self._cycle_rep is None:
            elems = set(self.sigma)
            base = list(range(len(self)))
            cycles = []
            i = 0
            while len(elems) > 0:
                this_cycle = []
                curr = min(elems)
                while curr not in this_cycle:
                    this_cycle.append(curr)
                    curr = base[self.sigma[curr]]
                cycles.append(this_cycle)
                elems = elems - set(this_cycle)
                i += 1
            self._cycle_rep = cycles
        return self._cycle_rep
    
    def congruency_class(self):
        cycle_lens = [len(c) for c in self.cycle_rep]
        return tuple(sorted(cycle_lens))
    
    @property
    def parity(self):
        odd_cycles = [c for c in self.cycle_rep if (len(c) % 2 == 0)]
        return len(odd_cycles) % 2

In [35]:
def make_permutation_dataset(n: int):
    mult_table = []
    perms = []
    index = {}
    for i, seq in enumerate(permutations(list(range(n)))):
        #print(seq)
        p = Permutation(seq)
        perms.append(p)
        index[seq] = i
    for perm1, perm2 in product(perms, repeat=2):        
        q = perm1(perm2)
        mult_table.append((index[perm1.sigma], index[perm2.sigma], index[q.sigma]))
    return perms, mult_table


In [36]:
perms, table = make_permutation_dataset(5)

In [39]:
len(table)

14400

In [35]:
perm.congruency_class()

(1, 1, 2, 3)

In [36]:
perm.parity

1

In [37]:
ident = Permutation([0, 1, 2, 3, 4])
ident.parity

0

In [38]:
perm1 = Permutation([1, 0, 2, 3, 4])
perm1.parity

1

In [39]:
perm1.cycle_rep

[[0, 1], [2], [3], [4]]

In [40]:
perm1.parity

1

In [41]:
perm2 = Permutation([0, 1, 3, 4, 2])
perm2.parity

0

In [43]:
perm3 = perm1(perm2)
perm3.parity

1

In [40]:
import torch


In [48]:
all_data = torch.tensor(table)

In [60]:
torch.hsplit(all_data, 3)

(tensor([[  0],
         [  0],
         [  0],
         ...,
         [119],
         [119],
         [119]]),
 tensor([[  0],
         [  1],
         [  2],
         ...,
         [117],
         [118],
         [119]]),
 tensor([[ 0],
         [ 1],
         [ 2],
         ...,
         [ 6],
         [24],
         [ 0]]))

In [62]:
import math

math.factorial(5)


120