In [2]:
import numpy as np 
import torch 
import math
import matplotlib.pyplot as plt

In [67]:
def kron(X):
    ans = X[0]
    for x in X[1:]: 
        ans = torch.kron(ans.contiguous(), x)
    # print(len(X), [x.shape for x in X])
    return ans 

Adj = lambda x: x.transpose(-1, -2).conj()

class PfSystem:
    def __init__(self, p, n):
        self.p = p 
        self.n = n
        self.omega = np.exp(1j * 2 * math.pi / p)
        self.X = torch.zeros([p, p], dtype=torch.complex64)
        self.Z = torch.zeros([p, p], dtype=torch.complex64)
        for j in range(p):
            self.X[j, (j+1)%p] = 1 
            self.Z[j, j] = self.omega ** j
        self.Id = torch.eye(p)

        self.generators = []
        for j in range(n):
            Gammaj = kron(
                [self.Z] * j + [self.X] 
                + [self.Id] * (self.n - 1 - j)
            )
            Deltaj = kron(
                [self.Z] * j + [self.X @ self.Z] 
                + [self.Id] * (self.n - 1 - j)
            )
            print(Gammaj.shape, Deltaj.shape)
            self.generators.append(Gammaj)
            self.generators.append(Deltaj)

    def __call__(self, i):
        if i == 0:
            return kron([self.Id] * self.n)
        return self.generators[i - 1]

In [70]:
S = PfSystem(3, 2)
bracket = lambda a, b: a @ b - b @ a
torch.allclose(
    np.conj(S.omega) * S(1) @ S(2), S(2) @ S(1)
)

torch.Size([9, 9]) torch.Size([9, 9])
torch.Size([9, 9]) torch.Size([9, 9])


True

In [71]:
torch.allclose(
    bracket(
        S(1) @ torch.matrix_power(S(2), 2), 
        S(2) @ torch.matrix_power(S(3), 2)
    ), (1 - S.omega ** 2) * S(1) @ torch.matrix_power(S(3), 2)
)

True

In [None]:
Sp = PfSystem(3, 1)
Sp(1)

torch.Size([3, 3]) torch.Size([3, 3])


tensor([[ 0.0000+0.0000j, -0.5000+0.8660j,  0.0000+0.0000j],
        [ 0.0000+0.0000j,  0.0000+0.0000j, -0.5000-0.8660j],
        [ 1.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j]])