In [None]:
import os

In [None]:
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt

In [None]:
##

In [None]:
path_to_svhn_dataset =  ""

In [None]:
def open_svhn(mat_path: str, train=True):
    loaded_mat = sio.loadmat(mat_path)

    images = loaded_mat["X"]
    print("X.shape", images.shape)
    print("X.dtype", images.dtype)

    labels = loaded_mat["y"].squeeze()
    print("y.shape", loaded_mat["y"].shape)
    print("y.dtype", loaded_mat["y"].dtype)

    np.place(labels, labels == 10, 0)
    images = np.transpose(images, (3, 0, 1, 2))

    assert (
        images.shape == ((73257 if train else 26032), 32, 32, 3)
        and images.dtype == np.uint8
    )
    assert labels.shape == ((73257 if train else 26032),) and labels.dtype == np.uint8
    assert np.min(images) == 0 and np.max(images) == 255
    assert np.min(labels) == 0 and np.max(labels) == 9

    return images, labels, loaded_mat["X"], loaded_mat["y"]

In [None]:
train_images, train_labels, train_X, train_y = open_svhn(
    os.path.join(path_to_svhn_dataset, "train_32x32.mat"), train=True
)

In [None]:
test_images, test_labels, test_X, test_y = open_svhn(
    os.path.join(path_to_svhn_dataset, "test_32x32.mat"), train=False
)

In [None]:
##

In [None]:
np.random.seed(0)

all_images = np.concatenate([train_images, test_images])
all_labels = np.concatenate([train_labels, test_labels])

all_idx = np.arange(len(all_images))
new_train_idx = list()
new_test_idx = list()

new_train_labels = list()
new_test_labels = list()

new_train_images = list()
new_test_images = list()

for class_idx in list(range(10)):
    all_labels_idx = all_labels == class_idx
    train_labels_idx = train_labels == class_idx
    test_labels_idx = test_labels == class_idx

    n_of_test = int(test_labels_idx.sum())
    n_of_train = int(train_labels_idx.sum())

    new_class_test_indices = np.random.choice(
        n_of_test + n_of_train, n_of_test, replace=False
    )
    new_class_train_indices = np.array(
        [i for i in range(n_of_test + n_of_train) if i not in new_class_test_indices]
    )

    new_train_labels.extend(all_labels[all_labels_idx][new_class_train_indices])
    new_train_images.extend(all_images[all_labels_idx][new_class_train_indices])
    new_train_idx.extend(all_idx[all_labels_idx][new_class_train_indices])

    new_test_labels.extend(all_labels[all_labels_idx][new_class_test_indices])
    new_test_images.extend(all_images[all_labels_idx][new_class_test_indices])
    new_test_idx.extend(all_idx[all_labels_idx][new_class_test_indices])


train_idx_shuffle = np.random.permutation(len(new_train_labels))
test_idx_shuffle = np.random.permutation(len(new_test_labels))

new_train_labels = np.array(new_train_labels)[train_idx_shuffle].squeeze()
new_test_labels = np.array(new_test_labels)[test_idx_shuffle].squeeze()

new_train_idx = np.array(new_train_idx)[train_idx_shuffle].squeeze()
new_test_idx = np.array(new_test_idx)[test_idx_shuffle].squeeze()

new_train_images = np.array(new_train_images)[train_idx_shuffle].squeeze()
new_test_images = np.array(new_test_images)[test_idx_shuffle].squeeze()

assert len(new_train_labels) == len(train_labels)
assert len(new_test_labels) == len(test_labels)

for i in range(10):
    assert (new_train_labels == i).sum() == (train_labels == i).sum()

for i in range(10):
    assert (new_test_labels == i).sum() == (test_labels == i).sum()

In [None]:
##

In [None]:
train_mat = {
    "X": np.transpose(new_train_images, (1, 2, 3, 0)),
    "y": new_train_labels.reshape(-1, 1),
}
test_mat = {
    "X": np.transpose(new_test_images, (1, 2, 3, 0)),
    "y": new_test_labels.reshape(-1, 1),
}

In [None]:
sio.savemat("../dataset/train_32x32_remix.mat", train_mat)
sio.savemat("../dataset/test_32x32_remix.mat", test_mat)

In [None]:
np.save("../index/train_index_remix.npy", new_train_idx)
np.save("../index/test_index_remix.npy", new_test_idx)

In [None]:
##