## Convert training data

Take the `.npz` files and convert to `.memmap` for easy use with memory-limited machines

In [1]:
import numpy as np
import os.path

In [2]:
training_dir = "/lustre_scratch/duncanwp/combined_v3_typed_new_composite"

In [3]:
if os.path.isfile(os.path.join(training_dir, 'data.npz')):
    all_data = np.load(os.path.join(training_dir, 'data.npz'))['arr_0']
elif os.path.isfile(os.path.join(training_dir, 'data.npy')):
    all_data = np.load(os.path.join(training_dir, 'data.npy'))
else:
    raise ValueError("No training data found")

all_labels = np.load(os.path.join(training_dir, 'labels.npz'))['arr_0']

# Shuffle the data in-place since the original training datasets are roughly ordered
# Set a fixed seed for reproducibility
R_SEED = 12345
rstate = np.random.RandomState(R_SEED)
rstate.shuffle(all_data)
rstate = np.random.RandomState(R_SEED)  # Be sure to shuffle the labels using the same seed
rstate.shuffle(all_labels)

In [4]:
all_labels.dtype

dtype('uint8')

In [5]:
# See https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
filename = os.path.join(training_dir, 'data.memmap')
fp = np.memmap(filename, dtype='uint8', mode='w+', shape=all_data.shape)
fp[:] = all_data[:]
fp.flush()

In [6]:
with open(os.path.join(training_dir, 'data.memmap.meta'), 'w+') as f:
    f.writelines([str(all_data.shape), '\n', 'uint8'])

In [7]:
# See https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
filename = os.path.join(training_dir, 'labels.memmap')
fp = np.memmap(filename, dtype='uint8', mode='w+', shape=all_labels.shape)
fp[:] = all_labels[:]
fp.flush()

In [8]:
with open(os.path.join(training_dir, 'labels.memmap.meta'), 'w+') as f:
    f.writelines([str(all_data.shape), '\n', 'uint8'])