# Learning Z/nZ x Z/nZ group actions
This notebook is adapted from the `modular arithmetic` notebook, replacing `Z/nZ` group action with `Z/nZ x Z/nZ` group action 

In [None]:
import numpy as np
import random
import torch
import os
import torch.nn as nn
import torch.optim as optim
import shutil
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.animation import FuncAnimation
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import FuncFormatter
from matplotlib.ticker import MaxNLocator

import importlib
import pickle

import group_agf.binary_action_learning.models as models
import group_agf.binary_action_learning.datasets as datasets
import group_agf.binary_action_learning.power as power
import group_agf.binary_action_learning.train as train
import group_agf.binary_action_learning.plot as plot


# Define Dataset and Visualize

In [None]:
from group_agf.binary_action_learning.default_config import verbose_interval
import os

# TEST_MODE: Set to reduce epochs for automated testing
TEST_MODE = os.environ.get("NOTEBOOK_TEST_MODE", "0") == "1"

p = 3 if TEST_MODE else 5  # Reduced in test mode
mnist_digit = 4
dataset_fraction = 0.1 if TEST_MODE else 0.2  # Reduced in test mode
template_type = 'mnist'
seed = 47
batch_size = 32 if TEST_MODE else 128  # Reduced in test mode
hidden_size = 32 if TEST_MODE else 128  # Reduced in test mode
lr = 0.001
mom = 0.9
init_scale = 1e-2
epochs = 2 if TEST_MODE else 1000
verbose_interval = max(1, epochs // 10)

model_save_path = (
    f"/tmp/adele/model_"
    f"p{p}_"
    f"digit{mnist_digit}_"
    f"frac{dataset_fraction}_"
    f"type{template_type}_"
    f"seed{seed}.pkl"
)

In [None]:
template = datasets.choose_template(p, template_type, mnist_digit)
group = 'cnxcn'

top_frequency_plot = plot.plot_top_template_components(group, template, p)

In [None]:
X, Y, translations = datasets.load_modular_addition_dataset_2d(p, template, fraction=dataset_fraction, random_state=seed, template_type=template_type)

X, Y, device = datasets.move_dataset_to_device_and_flatten(X, Y, p, device=None)

dataset = TensorDataset(X, Y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Define Model and Train

In [None]:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if using GPU

model = models.TwoLayerNet(p=p, hidden_size=hidden_size, nonlinearity='square', init_scale=init_scale, output_scale=1e0)
model = model.to(device)
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(mom, 0.999))

loss_history, accuracy_history, param_history = train.train(
    model,
    dataloader,
    loss,
    optimizer,
    epochs=epochs,
    verbose_interval=verbose_interval,
    model_save_path=model_save_path
)

# Plot loss, power, and model output

In [None]:
loss_plot = plot.plot_loss_curve(loss_history, template)

In [None]:
template_2d = template.reshape((p, p))
power_over_training_plot = plot.plot_training_power_over_time(template_2d, model, device, param_history, X, p, save_path=None, show=False)    

In [None]:
neuron_indices = list(range(20))
group= 'cnxcn'
print(neuron_indices)
neuron_weights_plot = plot.plot_neuron_weights(group, model, p, neuron_indices=neuron_indices, show=True)

In [None]:
idx = 13
plot.plot_model_outputs(p, model, X, Y, idx)