In [1]:
from torchvision.transforms import Compose, ToTensor, Lambda
from loguru import logger
from torch.utils.data import DataLoader
import numpy as np

In [2]:
dataset_name='MNIST'
n_dims=50
n_test_samples=100
n_ground_truth=20
output_path='outputs'

In [3]:
from torchvision.datasets import MNIST
dataset_class = eval(dataset_name)
transform = Compose([ToTensor()])
original_train_data = dataset_class(
    f"{output_path}/{dataset_name}_data", download=True, transform=transform, train=True
)
train_dataset_length = len(original_train_data)

In [4]:
loader = DataLoader(original_train_data, batch_size=4096*16, num_workers=2, shuffle=False)
mean = np.zeros((28, 28), dtype=np.float32)
for images, _ in loader:
    images = np.squeeze(images.numpy())
    mean += np.sum(images, axis=0, keepdims=False)
mean /= train_dataset_length
logger.debug(f"train dataset mean: {np.mean(mean)}")
var = np.zeros((28, 28))
original_train_dataset_arr = []
for images, _ in loader:
    images = np.squeeze(images.numpy())
    var += np.sum((images - mean) ** 2, axis=0)
    original_train_dataset_arr.append(images)
original_train_dataset_arr = np.concatenate(original_train_dataset_arr)
var /= train_dataset_length
logger.debug(f"train dataset variance: {np.mean(var)}")
del mean

2019-05-27 15:28:36.130 | DEBUG    | __main__:<module>:7 - train dataset mean: 0.13066266477108002
2019-05-27 15:28:42.938 | DEBUG    | __main__:<module>:16 - train dataset variance: 0.06725081825587355


In [5]:
selected_positions = np.unravel_index(np.argsort(var, axis=None)[-1:-n_dims-1:-1], np.shape(var))
logger.debug(f"selected positions: {selected_positions}")
del var

2019-05-27 15:28:42.949 | DEBUG    | __main__:<module>:2 - selected positions: (array([13, 14, 16, 22, 16, 15, 15, 15, 22, 14,  6,  6, 22,  6, 16,  6, 13,
       22, 17, 17, 20, 13, 14, 13, 22, 20,  9, 19, 21,  8,  8, 15,  9, 22,
       16, 16, 12, 20, 20, 20, 18,  8, 21,  9,  8, 12, 19, 21, 14, 19]), array([14, 14, 13, 11, 14, 14, 17, 13, 12, 17, 15, 16, 14, 14, 16, 17, 17,
       13, 14, 15, 13, 15, 13, 18, 10, 14, 11, 14, 11, 18, 13, 16, 12, 15,
       15, 17, 18, 16, 15, 12, 14, 14, 10, 18, 12, 17, 17, 12, 18, 15]))


In [6]:
train_arr = original_train_dataset_arr[:, selected_positions[0], selected_positions[1]]
logger.debug(f"selected train dataset shape: {np.shape(train_arr)}")
fp = np.memmap(f"{output_path}/train_arr", dtype='float32', mode='w+', shape=np.shape(train_arr))
fp[:] = train_arr[:]
del fp
del original_train_dataset_arr

2019-05-27 15:28:42.995 | DEBUG    | __main__:<module>:2 - selected train dataset shape: (60000, 50)


In [7]:
original_test_data = dataset_class(
    f"{output_path}/{dataset_name}_data", download=True, transform=transform, train=False
)
loader = DataLoader(original_test_data, batch_size=4096*16, num_workers=2, shuffle=False)
original_test_dataset_arr = []
for images, _ in loader:
    images = np.squeeze(images.numpy())
    original_test_dataset_arr.append(images)
original_test_dataset_arr = np.concatenate(original_test_dataset_arr)

test_arr = original_test_dataset_arr[:, selected_positions[0], selected_positions[1]]
test_arr = test_arr[:n_test_samples]
logger.debug(f"selected test dataset shape: {np.shape(test_arr)}")
fp = np.memmap(f"{output_path}/test_arr", dtype='float32', mode='w+', shape=np.shape(test_arr))
fp[:] = test_arr[:]
del fp
del original_test_dataset_arr

2019-05-27 15:28:44.417 | DEBUG    | __main__:<module>:13 - selected test dataset shape: (100, 50)


In [8]:
distances = np.sum(np.square(np.expand_dims(test_arr, 1) - np.expand_dims(train_arr, 0)), axis=-1)
ground_truth = np.argsort(distances, axis=-1)[:, :n_ground_truth]
logger.debug(f'ground truth shape: {np.shape(ground_truth)}')
fp = np.memmap(f"{output_path}/ground_truth", dtype='float32', mode='w+', shape=np.shape(ground_truth))
fp[:] = ground_truth[:]
del fp

2019-05-27 15:28:46.461 | DEBUG    | __main__:<module>:3 - ground truth shape: (100, 20)
