# **Set according to environment (e.g. local, Google Colab...)**

In [1]:
project_folder = ''

# **Body**

In [2]:
from custom_libraries.miscellaneous import *
from custom_libraries.image_dataset import *
from custom_libraries.ktree import *
import numpy as np

In [3]:
import gc

# Initialize settings
bs = 256
trials = 10
epochs = 2000
trees_set = [1]

# Load class-dataset list
# classes = np.load(project_folder + 'results/classes.npy', allow_pickle=True)

classes = [[3, 5, 'mnist'],
           [0, 6, 'fmnist'],
           [14, 17, 'emnist'],
           [2, 6, 'kmnist'],
           [3, 5, 'cifar10'],
           [5, 6, 'svhn'],
           [3, 5, 'usps']]

callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_binary_crossentropy', patience=2000),
             tf.keras.callbacks.ModelCheckpoint(filepath="checkpoints/ktree_orig_checkpoint",
                                                monitor='val_binary_crossentropy',
                                                verbose=0,
                                                save_best_only=True,
                                                save_weights_only=True,
                                                )]

history = np.load(project_folder + 'results/ktree_history.npy', allow_pickle=True)
#history = np.zeros((len(classes), trials, len(trees_set), 2))


for j, (t1, t2, ds) in enumerate(classes):

    # escludo i dataset a colori per il momento
    if ds in ['cifar10', 'svhn']:
        continue

    print(f"Dataset: {ds} / Pair: {t1}-{t2}")

    test_ds = ImageDataset(ds, 'test', data_dir=None, shuffle_files=False)
    train_ds = ImageDataset(ds, 'train', data_dir=None, shuffle_files=False)
    test_ds_2 = ImageDataset(ds, 'test', data_dir=None, shuffle_files=False)
    train_ds_2 = ImageDataset(ds, 'train', data_dir=None, shuffle_files=False)

    for x in [train_ds, test_ds, train_ds_2, test_ds_2]:
        x.filter(t1, t2, overwrite=True)
        x.normalize()
        if x.images.shape[1:3] == (28, 28):
            x.pad()

    for x in [train_ds, test_ds]:
        x.vectorize(merge_channels=True, by_row=True)

    for x in [train_ds_2, test_ds_2]:
        x.vectorize(merge_channels=True, by_row=False)

    for (x, y) in [(train_ds, train_ds_2), (test_ds, test_ds_2)]:
        x.images = np.concatenate((x.images, y.images), axis=1)
        x.labels = np.concatenate((x.labels, y.labels), axis=None)
        x.shuffle()

    del train_ds_2, test_ds_2

    for k, trees in enumerate(trees_set):

        print(f"{trees}-tree")

        test_set = tf.data.Dataset.from_tensor_slices((test_ds.images, test_ds.labels)).map(
            lambda x, y: (tf.tile(x, [trees]), y)).batch(bs)

        for i in range(trials):

            #if history[j, i, k, 0] != 0:
            #    continue

            print(f"Trial {i + 1}")

            with tf.device('/device:GPU:0'):

                X_train, y_train, X_valid, y_valid = train_ds.bootstrap(.85, True)

                model = create_model(input_size=X_train.shape[1] * trees, num_trees=trees, use_bias=True)

                train_set = tf.data.Dataset.from_tensor_slices((X_train, y_train)).map(
                    lambda x, y: (tf.tile(x, [trees]), y)).batch(bs)
                valid_set = tf.data.Dataset.from_tensor_slices((X_valid, y_valid)).map(
                    lambda x, y: (tf.tile(x, [trees]), y)).batch(bs)

                fit_history = model.fit(x=train_set, batch_size=bs, epochs=epochs,
                                        validation_data=valid_set, validation_batch_size=bs,
                                        callbacks=callbacks, verbose=1)
                print_fit_history(fit_history, epochs)
                model.load_weights('checkpoints/ktree_orig_checkpoint')

                evaluate_history = model.evaluate(x=test_set, batch_size=bs, verbose=0)
                print_evaluate_history(evaluate_history)

                history[j, i, k] = evaluate_history[1:]

                np.save(project_folder + 'results/ktree_history.npy', history,
                        allow_pickle=True)

                del model, train_set, valid_set, X_train, y_train, X_valid, y_valid
                gc.collect()

Dataset: mnist / Pair: 3-5
1-tree
Trial 1
Epoch 1/2000
Epoch 2/2000
Epoch 3/2000
Epoch 4/2000
Epoch 5/2000
Epoch 6/2000
Epoch 7/2000
Epoch 8/2000
Epoch 9/2000
Epoch 10/2000
Epoch 11/2000
Epoch 12/2000
Epoch 13/2000
Epoch 14/2000
Epoch 15/2000
Epoch 16/2000
Epoch 17/2000

KeyboardInterrupt: 

In [None]:
history = np.load(project_folder + 'results/ktree_history.npy', allow_pickle=True)
print("RESULTS:")
for j, (t1, t2, ds) in enumerate(classes):
    print(f"Dataset: {ds} / Pair: {t1}-{t2}")
    for k, trees in enumerate(trees_set):
        print(f"{trees}-tree")
        print(f"Accuracy: {round(np.mean(history[j, :, k, 1]), 4)} ± {round(np.std(history[j, :, k, 1]), 4)}")