# Notebook Initialization

In [None]:
from google.colab import drive
from tqdm.auto import tqdm
import os
import shutil
import json
import torch
import matplotlib.pyplot as plt

drive.mount('/content/drive')
DRIVE_ROOT_DIR_PATH = 'MyDrive/nesy'
DRIVE_DATASETS_DIR = 'prepared_datasets'

In [None]:
shutil.copy(f'/content/drive/{DRIVE_ROOT_DIR_PATH}/visudo_scripts.zip', '/content/')
shutil.unpack_archive('/content/visudo_scripts.zip', 'visudo_scripts', 'zip')
os.remove('/content/visudo_scripts.zip')
print('The `visudo` scripts were unpacked.')

# Data Collection

## Constants

In [4]:
SOURCE_MAP = {
    'mnist': lambda label: int(str(label).split('_')[-1]),
    'emnist': lambda label: int(str(label).split('_')[-1]) - 10,
    'kmnist': lambda label: int(str(label).split('_')[-1]),
    'fmnist': lambda label: int(str(label).split('_')[-1])
}
BOARD_DIM_LIST = [4, 9]
SPLIT_LIST = list(split + 1 for split in range(11))

## Tools

In [5]:
def collect(source_name:str, board_dim:int, split:int, verbose:bool) -> None:

    # Initialization
    if source_name not in SOURCE_MAP:
        raise ValueError(f'Unknown source: {source_name}')
    if board_dim not in BOARD_DIM_LIST:
        raise ValueError(f'Unknown board dimension: {board_dim}')
    if split not in SPLIT_LIST:
        raise ValueError(f'Unknown split: {split}')
    if verbose:
        print(f'Source: {source_name}, Board Dimension: {board_dim}, Split: {split}')
    dataset_name = f'{source_name}_{board_dim}x{board_dim}_split_{split:02}'
    output_dir_path = f'/content/data/{dataset_name}'

    # Download
    if verbose:
        print('  Downloading...')
    if not os.path.exists('/content/data/'):
        os.makedirs('/content/data/')
    !python /content/visudo_scripts/generate-split.py --dataset {source_name} --dimension {board_dim} --split {split} --out-dir {output_dir_path} > /dev/null 2>&1

    # Investigation
    if verbose:
        print('  Investigating...')
    current_dir_path = output_dir_path
    while True:
        child_list = os.listdir(current_dir_path)
        if len(child_list) == 1 and child_list[0].endswith(''):
            current_dir_path = os.path.join(current_dir_path, child_list[0])
            continue
        break

    # Processing
    if verbose:
        print('  Processing...')
    data_dict = {
        'train': dict[str, torch.Tensor](),
        'val': dict[str, torch.Tensor](),
        'test': dict[str, torch.Tensor]()
    }
    for alias, subset_name in zip(['train', 'valid', 'test'], data_dict):
        with open(os.path.join(current_dir_path, f'{alias}_puzzle_pixels.txt'), 'r') as text_file:
            symbols = torch.tensor([json.loads('[' + line.replace('\t', ',') + ']') for line in text_file.readlines()])
            data_dict[subset_name]['symbols'] = (symbols.reshape(200, board_dim, board_dim, 28, 28) * 255).to(torch.uint8)
        with open(os.path.join(current_dir_path, f'{alias}_cell_labels.txt'), 'r') as text_file:
            digits = torch.tensor([[SOURCE_MAP[source_name](label) for label in line.strip().split('\t')] for line in text_file.readlines()])
            data_dict[subset_name]['digits'] = digits.reshape(200, board_dim, board_dim).to(torch.uint8)
        with open(os.path.join(current_dir_path, f'{alias}_puzzle_labels.txt'), 'r') as text_file:
            labels = torch.tensor([int(line.strip().split('\t')[0]) for line in text_file.readlines()])
        data_dict[subset_name]['labels'] = labels.to(torch.bool)

    # Storage
    if verbose:
        print('  Storing...')
    dataset_path = f'/content/data/{dataset_name}.pt'
    if os.path.exists(dataset_path):
        os.remove(dataset_path)
    torch.save(data_dict, dataset_path)

    # Cache Removal
    if verbose:
        print('  Removing Cache...')
    shutil.rmtree(output_dir_path)


def visualize(source_name:str, board_dim:int, split:int, n_boards:int) -> None:

    # Initialization
    dataset_path = f'/content/data/{source_name}_{board_dim}x{board_dim}_split_{split:02}.pt'
    if not os.path.exists(dataset_path):
        raise ValueError(f'Dataset not found: {dataset_path}')
    data_dict:dict[str, dict[str, torch.Tensor]] = torch.load(dataset_path)

    # Visualization
    subset = ['train', 'val', 'test'][torch.randint(0, 3, (1,)).item()]
    symbols = data_dict[subset]['symbols']
    digits = data_dict[subset]['digits']
    labels = data_dict[subset]['labels']
    for l in torch.randperm(symbols.shape[0])[:n_boards]:
        fig, axs = plt.subplots(nrows=board_dim, ncols=board_dim, figsize=(3, 3))
        fig.suptitle(f'Board `{l}` of Subset `{subset}` with Label `{labels[l]}`', fontsize=8)
        for i in range(board_dim):
            for j in range(board_dim):
                ax:plt.Axes = axs[i, j]
                ax.imshow(symbols[l, i, j, :, :], cmap='gray')
                ax.set_title(f'{digits[l, i, j]}', fontsize=6)
                ax.axis('off')
        fig.tight_layout()

## Collection

In [None]:
collect(
    source_name='mnist',
    board_dim=4,
    split=2,
    verbose=True
)

In [None]:
visualize(
    source_name='mnist',
    board_dim=4,
    split=2,
    n_boards=1
)

## Transfering to Drive

In [8]:
for file_name in os.listdir('/content/data/'):
    if file_name.endswith('.pt'):
        drive_dataset_dir_path = f'/content/drive/{DRIVE_ROOT_DIR_PATH}/{DRIVE_DATASETS_DIR}/'
        if not os.path.exists(drive_dataset_dir_path):
            os.makedirs(drive_dataset_dir_path)
        shutil.copy(f'/content/data/{file_name}', f'/content/drive/{DRIVE_ROOT_DIR_PATH}/{DRIVE_DATASETS_DIR}/{file_name}')