# Label classifier (dSprites): data collection

**Author**: Maleakhi A. Wijaya  
**Description**: This notebook contains code used to collect experimentation data. We compare the performance of methods discussed in Rabanset et al. against our proposed CBSD method.

In [39]:
# Load utilities functions
%run ../../scripts/dsprites_utils.py
%run ../../scripts/shift_applicator.py
%run ../../scripts/shift_dimensionality_reductor.py
%run ../../scripts/constants.py
%run ../../scripts/experiment_utils.py
%run ../../scripts/shift_statistical_test.py

## Load dataset

In [41]:
SEED = 20
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [43]:
path = "../../data/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"
X_train, X_test, y_train, y_test, c_train, c_test = load_dsprites(path, 100000, train_size=0.85, class_index=1)

# Split training into validation set as well 
X_train, X_valid = X_train[:70000], X_train[70000:]
y_train, y_valid = y_train[:70000], y_train[70000:]
c_train, c_valid = c_train[:70000], c_train[70000:]

Training samples: 85000
Testing samples: 15000


In [46]:
# Reshape to appropriate shift input
# It is noteworthy that for efficiency, we represent the images as only 2 dimension
# when we preprocessing (number of instances/ batch size * flatten size).
# When visualising back the image, we need to reshape it back to the original dimension
ORIGINAL_SHAPE = X_test.shape[1:] # constant hold the image original shape
X_test_flatten = deepcopy(X_test.reshape(X_test.shape[0], -1))
X_train_flatten = deepcopy(X_train.reshape(X_train.shape[0], -1))
X_valid_flatten = deepcopy(X_valid.reshape(X_valid.shape[0], -1))

## Dimensionality reduction

We implemented various dimensionality reduction methods, amounting to:
- End to end model (label classifiers/ BBSD)
- Concept bottleneck model (CBSD)
- Trained and untrained autoencoders (TAE and UAE)
- Principal component analysis (PCA)
- Sparse random projection (SRP)

In [48]:
end_to_end_neural_network(3, Dataset.DSPRITES, X_train, y_train,
                         X_valid, y_valid)

TypeError: __init__() missing 1 required positional argument: 'dataset'