## Imports

In [None]:
import os

import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.keras import backend as tfk
from tensorflow.compat.v1.keras.initializers import RandomNormal, TruncatedNormal
from tensorflow.compat.v1.keras.layers import (Input, Dense, Activation, Layer, Lambda,
                                     Concatenate)
from tensorflow.compat.v1.keras.models import Model, load_model
from tensorflow.compat.v1.keras.optimizers import Adam

from models import ConvNet, SoftBinaryDecisionTree
from models.utils import brand_new_tfsession, draw_tree
from tensorflow.keras.callbacks import EarlyStopping, Callback

tf.disable_v2_behavior()


sess = brand_new_tfsession()

## Zad. 1 
Podział danych na zbiory treningowe, walidacyjne i testowe.

In [None]:
# load MNIST data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# add channel dim
x_train, x_test = x_train[..., np.newaxis], x_test[..., np.newaxis]

# hold out last 10000 training samples for validation
# x_valid, y_valid =  # TODO
# x_train, y_train =  # TODO

print(x_train.shape, y_train.shape, x_valid.shape, y_valid.shape, x_test.shape, y_test.shape)
# (50000, 28, 28, 1) (50000,) (10000, 28, 28, 1) (10000,) (10000, 28, 28, 1) (10000,)

In [None]:
# retrieve image and label shapes from training data
img_rows, img_cols, img_chans = x_train.shape[1:]
n_classes = np.unique(y_train).shape[0]

print(img_rows, img_cols, img_chans, n_classes)

In [None]:
# convert labels to 1-hot vectors
y_train = tf.keras.utils.to_categorical(y_train, n_classes)
y_valid = tf.keras.utils.to_categorical(y_valid, n_classes)
y_test = tf.keras.utils.to_categorical(y_test, n_classes)

print(y_train.shape, y_valid.shape, y_test.shape)

In [None]:
# normalize inputs and cast to float
x_train = (x_train / np.max(x_train)).astype(np.float32)
x_valid = (x_valid / np.max(x_valid)).astype(np.float32)
x_test = (x_test / np.max(x_test)).astype(np.float32)

### Neural Network
##### Sieci neuronowe jako model nauczyciela

In [None]:
nn = ConvNet(img_rows, img_cols, img_chans, n_classes)

In [None]:
nn = ConvNet(img_rows, img_cols, img_chans, n_classes)
nn.maybe_train(data_train=(x_train, y_train),
               data_valid=(x_valid, y_valid),
               batch_size=16, epochs=12, model_name='nn-model')
# NOTE: if the model doesn't load properly try model_name='nn-model-alternative'
nn.evaluate(x_train, y_train)

In [None]:
nn.evaluate(x_valid, y_valid)
nn.evaluate(x_test, y_test)

## Zad. 2 Extraction of soft labels for distillation
Wyekstrahować `soft labels` potrzebne do destylacji. Trzeba wykorzystać metodę klasy `ConvNet`.

In [None]:
# y_train_soft =  # TODO
y_train_soft.shape # (50000, 10)

## Zad. 3 Binary Soft Decision Tree
Należy wypłaszczyć zbiór danych.

In [None]:
# x_train_flat =  # TODO
# x_valid_flat =  # TODO
# x_test_flat =  # TODO

import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(x_test_flat.reshape((x_test_flat.shape[0], img_rows, img_cols))[1])

x_train_flat.shape, x_valid_flat.shape, x_test_flat.shape  # ((50000, 784), (10000, 784), (10000, 784))

<a id='hyperparameters'></a>
### Hyperparameters
* `tree_depth`: as denoted in the [[paper](https://arxiv.org/pdf/1711.09784.pdf)], depth is in terms of inner nodes (excluding leaves / indexing depth from `0`)
* `penalty_strength`: regularization penalty strength
* `penalty_decay`: regularization penalty decay: paper authors found 0.5 optimal (note that $2^{-d} = 0.5^d$ as we use it)
* `ema_win_size`: scaling factor to the "default size of the window" used to calculate moving averages (growing exponentially with depth) of node and path probabilities
* `inv_temp`: scale logits of inner nodes to "avoid very soft decisions" [[paper](https://arxiv.org/pdf/1711.09784.pdf)]
    * pass `0` to indicate that this should be a learned parameter (single scalar learned to apply to all nodes in the tree)
* `learning_rate`: hopefully no need to explain, but let's be cool and use [Karpathy constant](https://www.urbandictionary.com/define.php?term=Karpathy%20Constant) ([source](https://twitter.com/karpathy/status/801621764144971776)) :D as default in `tree.__init__()`
* `batch_size`: we use a small one, because with increasing depth and thus amount of leaf bigots, larger batch sizes cause their loss terms to be scaled down too much by averaging, which results in poor optimization properties

In [None]:
n_features = img_rows * img_cols * img_chans
tree_depth = 4
penalty_strength = 1e+1
penalty_decay = 0.25
ema_win_size = 1000
inv_temp = 0.01
learning_rate = 5e-03
batch_size = 4

### Regular training with hard labels

In [None]:
sess = brand_new_tfsession(sess)

tree = SoftBinaryDecisionTree(tree_depth, n_features, n_classes,
    penalty_strength=penalty_strength, penalty_decay=penalty_decay,
    inv_temp=inv_temp, ema_win_size=ema_win_size, learning_rate=learning_rate)
tree.build_model()

In [None]:
epochs = 40

es = EarlyStopping(monitor='val_acc', patience=20, verbose=1)

tree.maybe_train(
    sess=sess, data_train=(x_train_flat, y_train), data_valid=(x_valid_flat, y_valid),
    batch_size=batch_size, epochs=epochs, callbacks=[es], distill=False)

In [None]:
tree.evaluate(x=x_valid_flat, y=y_valid, batch_size=batch_size)
tree.evaluate(x=x_test_flat, y=y_test, batch_size=batch_size)

### Distillation: training with soft labels

In [None]:
sess = brand_new_tfsession(sess)

tree = SoftBinaryDecisionTree(tree_depth, n_features, n_classes,
    penalty_strength=penalty_strength, penalty_decay=penalty_decay,
    inv_temp=inv_temp, ema_win_size=ema_win_size, learning_rate=learning_rate)
tree.build_model()

## Zad. 4
Wytrenuj drzewo z wykorzystaniem destylacji oraz `soft labels`.

In [None]:
epochs = 50

es = EarlyStopping(monitor='val_acc', patience=20, verbose=1)

# TODO: wykorzystaj tej samej metody co w przypadku zwykłego drzewa (powyżej). Wykorzystaj również `soft labels`

In [None]:
tree.evaluate(x=x_valid_flat, y=y_valid, batch_size=batch_size)
tree.evaluate(x=x_test_flat, y=y_test, batch_size=batch_size)

<a id='Wizualizacja'></a>
Zwizualizujmy sobie teraz jak drzewo rozpoznaje różne cyfry:

In [None]:
np.random.seed(0)

for digit in range(10):
    sample_index = np.random.choice(np.where(np.argmax(y_test, axis=1)==digit)[0])
    input_img = x_test[sample_index]
    draw_tree(sess, tree, img_rows, img_cols, img_chans, input_img=input_img)

## Pytanie kontrolne
##### Które z drzew daje lepsze rezultaty i o ile?

## Zad. 5
Wytrenuj drzewa o głębokości `tree_depth=[1, 2, 3, 5]` (dla `soft labels` oraz `hard labels`) i porównaj otrzymane rezultaty
z wcześniejszymi wynikami dla głębokości `tree_depth=4`. 
Można użyć funkcji [`draw_tree`](#Wizualizacja) (z zadania 4) dla każdego drzewa w celu zwizualizowania jego węzłów.

#### Zadanie dodatkowe:
Znaleźć optymalne wartości parametrów dla drzew decyzyjnych w przypadku zarówno `soft labels` jak i `hard labels` (modyfikując wartości parametrów opisanych w sekcji [Hyperparameters](#hyperparameters)). Przetestować drzewa dla większych wartości `tree_depth`.