In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import jax.numpy as jnp
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf
import numpy as np
import jax.numpy as jnp

In [None]:
digits = datasets.load_digits()
plt.figure(1, figsize=(3, 3))
plt.imshow(digits.images[-1], cmap=plt.cm.gray_r, interpolation="nearest")
plt.show()

In [None]:
num_digits = len(digits.images)
print(f"Number of digits in the dataset: {num_digits}")

unique, counts = np.unique(digits.target, return_counts=True)
class_distribution = dict(zip(unique, counts))
print("Class distribution:", class_distribution)

fig, axs = plt.subplots(1, 5, figsize=(10, 3))
classes_to_view = [0, 1, 2, 3, 4]
for i, cls in enumerate(classes_to_view):
    idx = np.where(digits.target == cls)[0][0]
    axs[i].imshow(digits.images[idx], cmap=plt.cm.gray_r, interpolation="nearest")
    axs[i].set_title(f"Class {cls}")
    axs[i].axis("off")

plt.show()

## Filtering only few classes

In [None]:
classes_to_keep = [0, 1, 3, 4]
indices_to_keep = np.isin(digits.target, classes_to_keep)

images = digits.images

filtered_images = images[indices_to_keep]
filtered_labels = digits.target[indices_to_keep]

# Print the number of images and class distribution in the filtered dataset
num_filtered_images = len(filtered_images)
print(f"Number of filtered images: {num_filtered_images}")

unique_filtered, counts_filtered = np.unique(filtered_labels, return_counts=True)
filtered_class_distribution = dict(zip(unique_filtered, counts_filtered))
print("Filtered class distribution:", filtered_class_distribution)
print(f"filtered images shape: {filtered_images.shape}")
print(f"filtered labels shape: {filtered_labels.shape}")

In [None]:
X = filtered_images.reshape(-1, 8, 8, 1)
label_encoder = LabelEncoder()
Y = label_encoder.fit_transform(filtered_labels)

print("X shape:", X.shape)
print("Y shape:", Y.shape)
print(
    "Class mapping:",
    dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_))),
)

In [None]:
print(X.shape)
print(Y.shape)

In [None]:
X_train, X_test, Y_train, Y_test = train_test_split(
    X,
    Y,
    test_size=0.2,
    random_state=42,
    stratify=Y,
)

print("X_train shape:", X_train.shape)
print("X_test shape:", X_test.shape)
print("Y_train shape:", Y_train.shape)
print("Y_test shape:", Y_test.shape)

## Utilities for rotations

In [None]:
def inverse_grid_number(n, number):
    if 1 <= number <= n**2:
        row_index = (number - 1) // n + 1
        column_index = (number - 1) % n + 1
        return row_index, column_index
    else:
        raise ValueError("Number must be between 1 and n^2 inclusive.")


def grid_number(n, a, b):
    if 1 <= a <= n and 1 <= b <= n:
        return (a - 1) * n + b
    else:
        raise ValueError("Row and column indices must be between 1 and n inclusive.")


def reflection_grid(n, coordinates):
    a, b = coordinates
    reflected_b = n - b + 1
    return a, reflected_b


def rotation_grid(n, coordinates):
    a, b = coordinates
    rotated_a = n - b + 1
    rotated_b = a
    return rotated_a, rotated_b


def rotate(n, number):
    (a, b) = inverse_grid_number(n, number)
    (new_a, new_b) = rotation_grid(n, (a, b))
    return grid_number(n, new_a, new_b)


def reflect(n, number):
    (a, b) = inverse_grid_number(n, number)
    (new_a, new_b) = reflection_grid(n, (a, b))
    return grid_number(n, new_a, new_b)


def generate_rotation_matrix(n):
    # Define the size of the matrix
    matrix_size = n**2

    # Initialize a matrix with zeros
    rotation_matrix = jnp.zeros((matrix_size, matrix_size), dtype=int)

    # Set 1 at the specified positions for each column
    for m in range(1, matrix_size + 1):
        rotated_position = rotate(n, m)
        rotation_matrix = rotation_matrix.at[rotated_position - 1, m - 1].set(
            1
        )  # Adjust for 0-based indexing

    return rotation_matrix


def generate_reflection_matrix(n):
    # Define the size of the matrix
    matrix_size = n**2

    # Initialize a matrix with zeros
    reflection_matrix = jnp.zeros((matrix_size, matrix_size), dtype=int)

    # Set 1 at the specified positions for each column
    for m in range(1, matrix_size + 1):
        reflected_position = reflect(n, m)
        reflection_matrix = reflection_matrix.at[reflected_position - 1, m - 1].set(
            1
        )  # Adjust for 0-based indexing

    return reflection_matrix


def generate_d4_matrices(n):
    """Outputs n^2 by n^2 matrices"""
    # Get rotation and reflection matrices
    R = generate_rotation_matrix(n)
    S = generate_reflection_matrix(n)

    # Calculate R^2, R^3, SR, SR^2, SR^3
    R2 = jnp.dot(R, R)
    R3 = jnp.dot(R2, R)
    SR = jnp.dot(S, R)
    SR2 = jnp.dot(S, R2)
    SR3 = jnp.dot(S, R3)

    # Generate D4 matrices
    D4_matrices = [jnp.eye(n**2), R, R2, R3, S, SR, SR2, SR3]

    return D4_matrices

In [None]:
def apply_transformation(image, transformation_matrix, n):
    flat_image = image.flatten()
    transformed_flat_image = jnp.dot(transformation_matrix, flat_image)
    return transformed_flat_image.reshape((n, n))


def apply_transformation_tf(image, matrix, n):
    flat_image = tf.reshape(image, [n**2])
    transformed_flat_image = tf.linalg.matvec(matrix, flat_image)
    transformed_image = tf.reshape(transformed_flat_image, [n, n, 1])
    return transformed_image

## Augmenting the dataset

In [None]:
n = 8
d4_matrices = generate_d4_matrices(n)
tf_d4_matrices = [tf.convert_to_tensor(matrix, dtype=tf.float32) for matrix in d4_matrices]
augmented_images = []
augmented_labels = []

print(len(filtered_images))

for img, lbl in zip(filtered_images, filtered_labels):
    for matrix in d4_matrices:
        transformed_image = apply_transformation(img, matrix, 8)
        augmented_images.append(transformed_image)
        augmented_labels.append(lbl)


augmented_images = np.array(augmented_images)
augmented_labels = np.array(augmented_labels)

print(f"augmented images shape: {augmented_images.shape}")
print(f"aumented labels shape: {augmented_labels.shape}")

num_augmented_images = len(augmented_images)
print(f"Number of augmented images: {num_augmented_images}")

unique_augmented, counts_augmented = np.unique(augmented_labels, return_counts=True)
augmented_class_distribution = dict(zip(unique_augmented, counts_augmented))
print("Augmented class distribution:", augmented_class_distribution)

In [None]:
augmented_images_reshaped = augmented_images.reshape(-1, 8, 8, 1)
print(augmented_images_reshaped.shape)
print(augmented_labels.shape)
X_train_aug, X_test_aug, Y_train_aug, Y_test_aug = train_test_split(
    augmented_images_reshaped,
    augmented_labels,
    test_size=0.2,
    random_state=42,
    stratify=augmented_labels,
)

In [None]:
def apply_transformation_batch(inputs, matrix):
    transformed = tf.map_fn(
        lambda x: apply_transformation_tf(x, matrix, 8),
        inputs,
    )
    return transformed

In [None]:
class CustomConvLayer(layers.Layer):
    def __init__(self, kernel_size):
        super(CustomConvLayer, self).__init__()
        self.kernel_size = kernel_size
        self.kernel = self.add_weight(
            shape=(kernel_size, kernel_size, 1, 1),
            initializer="random_normal",
            trainable=True,
        )

    def call(self, inputs):
        convolved_results = []

        for matrix in tf_d4_matrices:
            transformed_inputs = apply_transformation_batch(inputs, matrix)
            convolved = tf.nn.conv2d(
                transformed_inputs, self.kernel, strides=[1, 1, 1, 1], padding="SAME"
            )
            convolved_results.append(convolved)

        convolved_average = tf.reduce_mean(tf.stack(convolved_results), axis=0)
        return convolved_average


class CustomPoolingLayer(layers.Layer):
    def __init__(self, pool_size):
        super(CustomPoolingLayer, self).__init__()
        self.pool_size = pool_size

    def call(self, inputs):
        res = tf.nn.avg_pool(
            inputs,
            ksize=[1, self.pool_size[0], self.pool_size[1], 1],
            strides=[1, self.pool_size[0], self.pool_size[1], 1],
            padding="VALID",
        )
        return res

In [None]:
kernel_size = 7

model = models.Sequential(
    [
        CustomConvLayer(kernel_size=kernel_size),
        CustomPoolingLayer(pool_size=(4, 4)),
        layers.Flatten(),
        layers.Activation("softmax"),
    ]
)

model.compile(
    optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)

model.fit(X_train, Y_train, epochs=100, batch_size=32)
print("Model training complete.")
test_loss, test_acc = model.evaluate(X_test, Y_test)
print("Test accuracy:", test_acc)

aug_model = models.Sequential(
    [
        CustomConvLayer(kernel_size=kernel_size),
        CustomPoolingLayer(pool_size=(4, 4)),
        layers.Flatten(),
        layers.Dense(4, activation="softmax"),
    ]
)

In [None]:
def create_circulant_matrix(kernel, image_size=8):
    kernel_size = kernel.shape[0]
    pad = kernel_size // 2

    circ_matrix_size = image_size * image_size
    circ_matrix = np.zeros((circ_matrix_size, circ_matrix_size))

    for i in range(image_size):
        for j in range(image_size):
            row = np.zeros((image_size, image_size))
            for ki in range(kernel_size):
                for kj in range(kernel_size):
                    ii = i + ki - pad
                    jj = j + kj - pad
                    if 0 <= ii < image_size and 0 <= jj < image_size:
                        row[ii, jj] = kernel[ki, kj]
            circ_matrix[i * image_size + j, :] = row.flatten()

    return circ_matrix

In [None]:
model.layers[0].kernel.shape

In [None]:
def create_averaged_kernel(kernel):
    rotated_kernels = []
    for matrix in d4_matrices:
        rotated_kernel = kernel @ matrix
        rotated_kernels.append(rotated_kernel)
    averaged_kernel = tf.reduce_mean(tf.stack(rotated_kernels), axis=0)
    return averaged_kernel

In [None]:
test_kernel = np.random.randn(64, 64)
averaged_test_kernel = create_averaged_kernel(test_kernel)

In [None]:
kernel_matrix = tf.reshape(model.layers[0].kernel, [kernel_size, kernel_size])
kernel_circulant_matrix = create_circulant_matrix(kernel_matrix)

rotated_kernels = []
for matrix in tf_d4_matrices:
    rotated_kernel = kernel_circulant_matrix @ matrix
    rotated_kernels.append(rotated_kernel)

averaged_circulant_kernel = tf.reduce_mean(tf.stack(rotated_kernels), axis=0)

In [None]:
averaged_circulant_kernel.shape

In [None]:
def visualize_combined_matrix(matrix, title="Combined Transformation Matrix", width=20, height=16, title_fontsize=18, axis_fontsize=14, cbar_label="Color Bar Label", cbar_tick_fontsize=14):
    plt.figure(figsize=(width, height))
    ax = sns.heatmap(matrix, annot=False, fmt=".2f", cmap="viridis")
    plt.title(title, fontsize=title_fontsize)
    plt.tick_params(axis='x', labelsize=axis_fontsize)
    plt.tick_params(axis='y', labelsize=axis_fontsize)

    colorbar = ax.collections[0].colorbar

    colorbar.ax.tick_params(labelsize=cbar_tick_fontsize)
    plt.show()

In [None]:
visualize_combined_matrix(
    kernel_matrix, title="Convolution Kernel", title_fontsize=60, axis_fontsize=50, cbar_tick_fontsize=50, height=32, width=40
)

In [None]:
visualize_combined_matrix(
    kernel_circulant_matrix, title="Pre-Transformed Circulant Kernel", title_fontsize=60, axis_fontsize=20, cbar_tick_fontsize=50, height=32, width=40
)

In [None]:
visualize_combined_matrix(
    averaged_circulant_kernel, title="Transformed Circulant Kernel", title_fontsize=60, axis_fontsize=20, cbar_tick_fontsize=50, height=32, width=40
)

## STEP 2 - COB

We have the averaged_circulant_kernel. Now we get the change of basis matrix:

In [None]:
from sympy.parsing.sympy_parser import parse_expr

with open('8x8COB.txt', 'r') as f:
    Q = [next(f) for _ in range(129)] # f.read()
Q = [parse_expr(s.strip('\n')) for s in Q if s != '\n']
Q = Q[1:] # Remove constant
Q1, Q2, Q3, Q4, Q56, Q78 = Q[:10], Q[10:16], Q[16:22], Q[22:32], Q[32:48], Q[48:]
Q1, Q2, Q3, Q4, Q56, Q78 = [elem[0,0] for elem in Q1], [elem[1,0] for elem in Q2], [elem[2,0] for elem in Q3], [elem[3,0] for elem in Q4], [elem[(4,5),0].tolist() for elem in Q56], [elem[(6,7),0].tolist() for elem in Q78]

In [None]:
import re
Q1_Str = [[elem.replace('x','') for elem in l] for l in [str(q).split(' + ') for q in Q1]]
Q2_str = [str(q).replace(' - ',' + -') for q in Q2]
Q2_Str = [[elem.replace('x','') for elem in l] for l in [str(q).split(' + ') for q in Q2_str]]
Q3_str = [str(q).replace(' - ',' + -') for q in Q3]
Q3_Str = [[elem.replace('x','') for elem in l] for l in [str(q).split(' + ') for q in Q3_str]]
Q4_str = [str(q).replace(' - ',' + -') for q in Q4]
Q4_Str = [[elem.replace('x','') for elem in l] for l in [str(q).split(' + ') for q in Q4_str]]
Q56_str = [(str(q[0]).replace(' - ',' + -'),str(q[1]).replace(' - ',' + -')) for q in Q56]
Q56_Str = [[[elem.replace('x','').replace('[','').replace(']','') for elem in l0],[elem.replace('x','').replace('[','').replace(']','') for elem in l1]] for l0,l1 in [(q[0].split(' + '),q[1].split(' + ')) for q in Q56_str]]

In [None]:
from math import copysign

X_matrix = jnp.zeros((64,64))

Q1_to_4 = Q1_Str + Q2_Str + Q3_Str + Q4_Str

for i in range(32):
  for elem in Q1_to_4[i]:
    X_matrix = X_matrix.at[i,abs(int(elem))-1].set(np.sign(int(elem))*1)

for i in range(16):
  for elem in Q56_Str[i][0]:
    X_matrix = X_matrix.at[i+32,abs(int(elem))-1].set(np.sign(int(elem))*1)
  for elem in Q56_Str[i][1]:
    X_matrix = X_matrix.at[i+48,abs(int(elem))-1].set(np.sign(int(elem))*1)

In [None]:
B = X_matrix.copy()
plt.imshow(B)

In [None]:
visualize_combined_matrix(
    B, title="Change of Basis Matrix", title_fontsize=60, axis_fontsize=20, cbar_tick_fontsize=50, height=32, width=40
)

In [None]:
B = np.array(B)
B_inv = np.linalg.inv(B)
M = averaged_circulant_kernel @ B_inv

In [None]:
visualize_combined_matrix(
    B_inv, title="Inverse Change of Basis Matrix", title_fontsize=60, axis_fontsize=20, cbar_tick_fontsize=50, height=32, width=40
)

In [None]:
visualize_combined_matrix(
    M, title="Linear Map on Fundamental Invariants", title_fontsize=60, axis_fontsize=20, cbar_tick_fontsize=50, height=32, width=40
)
visualize_combined_matrix(
    averaged_circulant_kernel, title="Transformed Circulant Kernel", title_fontsize=50, axis_fontsize=20, cbar_tick_fontsize=20, height=32, width=40
)

In [None]:
visualize_combined_matrix(averaged_test_kernel @ B_inv)