# TP DT

Save the notebook as either PDF or HTML and make sure all the results are saved correctly (I won't run them and the original format does not save the results automatically), **and put your name in the filename**.

<div class="alert alert-success"> 
<b>Questions are in green boxes.</b>
The maximum time you should spend on each question is given as indication only. If you take more time than that, then you should come see me.
</div>
<div class="alert alert-info" role="alert"><b>Analyzes are in blue boxes.</b> You should comment on your results in theses boxes (Is it good? Is it expected? Why do we get such result? Why is it different from the previous one? etc)
</div>

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import gzip
import pickle
import matplotlib.pyplot as plt
import time

For this lab, we will use the bluebell dataset. It consists of $64\times 64$ color images, which we will have to flatten into $12k$ dimensional vectors. The code for the dataset comes with several train/val/test splits, but in this notebook, we will use the first split and do our own cross-validation routines.

In [None]:
# Load the dataset
from bluebell import Bluebell
X_train_ds = Bluebell('bluebell_64', 'train', split=0)
X_val_ds = Bluebell('bluebell_64', 'val', split=0)
X_train = np.array([img.flatten()/127.5 - 1. for img, lab in X_train_ds])
y_train = np.array([lab for img, lab in X_train_ds])
X_val = np.array([img.flatten()/127.5 - 1. for img, lab in X_val_ds])
y_val = np.array([lab for img, lab in X_val_ds])
plt.imshow(X_train[0].reshape(64, 64, 3)/2+0.5)
print(y_train[0])

Next, we want to reduce the number of dimensions that will be search through with the decision trees. Since our images are $64\times 64\times 3$ values, this leads to a very high dimensional space that has to be searched at each step. However, dimensions where all images have the same value, or very close values, will never be selected in the tree because they do not provide a good gain.

We will thus select only the 2048 dimension with the highest variance to perform our analysis.

In [None]:
dim = jnp.argsort(X_train.std(axis=0), descending=True)[0:2048]
X_train = jnp.array(X_train[:, dim])
X_val = jnp.array(X_val[:, dim])
print(X_train.shape, X_val.shape)

## Implementing a randomized Decision Tree

<div class="alert alert-success"> 
    <b>Q1.</b> Implement the code of a function that finds an optimal threshold along a given dimension, using the $0-1$ loss with specified example weights and test it on the 150th dimension. To speed-up things, we will only consider 8 thresholds between the minimum and maximum value (use 'linspace'). Compare it to assigning a unique label to all samples. You should get a significant decrease of loss from ~0.92 (random 1/12 chance) to ~0.84. <i>(Indicative time: 30 minutes for a slow version, but take the extra 30 minutes to have a parallel version testing all thresholds at once that runs in under 1s, it is worth it for the next questions.)</i>
</div>

In [None]:
'''
takes arguements 
y_pred: prediction
y_true: true labels
weights: weights for each example
'''
@jax.jit
def zeroOneLoss(y_pred, y_true, weights):
    return (weights * (y_pred != y_true)).sum(axis=0)/(1e-12+weights.sum(axis=0))

In [None]:
'''
takes arguments
X: training samples
y: training labels
dim: dimension to use
w: weight associated to each example (can be True/False or 1/0 to remove some examples)
returns the gain and the threshold
'''
@jax.jit
def findBestTh(X, y, dim, w):
    
    return gain, threshold

In [None]:
%%time 
G, th = findBestTh(X_train, y_train, 150, jnp.ones(len(y_train)))

<div class="alert alert-info" role="alert"><b>Analyze your results in this box.</b>  Answer
</div>

We can vectorize over the dimension by using vmap. The batched function can now operate on a array of dimensions.

In [None]:
batched_findBestTh = jax.vmap(findBestTh, in_axes=(None, None, 0, None), out_axes=0)

<div class="alert alert-success"> 
    <b>Q2.</b> Wrap the batched function in a function that test all dimensions to find the best combination of component and threshold. Use blocks of 256 dimensions to process at a time, as we found it a good setup with respect to speed (You can change those values later to optimize for speed).Test it on the entire train set and make sure it obtains the lowest error.<i>(Indicative time: It could take you 15 minutes to an hour to code and should run in less than 5 seconds.)</i>
</div>

In [None]:
@jax.jit
def findBestDTh(X, y, w):
    return best_gain, best_dim, best_th

In [None]:
%%time
g, d, t = findBestDTh(X_train, y_train, jnp.ones(len(X_train)))
print("gain: {} dim: {} th: {}".format(g, d, t))

<div class="alert alert-info" role="alert"><b>Analyze your results in this box.</b>  Answer
</div>

<div class="alert alert-success"> 
    <b>Q3.</b> Implement the code of the Decision Tree class by using the previous functions. To achieve reasonable speed, loop only over dimensions that have variations (there is no threshold if all samples have the same value) in batches of 256 or more, inspired by the previous function. Do not split and slice the data but zero the associated weights instead, it is faster (all functions have the same size of arrays and are thus compiled and optimized only once). Debug it on only 256 dimensions and 256 samples, because using all dimensions/samples takes about 2 minutes. Test it on the full set with a maximum depth of 8 and a leaf size less than 10 to analyze and comment. <i>(Indicative time: It could take you 30 minutes to an hour to code and debug since it involves recursion.)</i>
</div>

In [None]:
class RandomizedDT():
    '''
    percent_dimension: percent of dimensions to use (random selection)
    '''
    def __init__(self, percent_dimension=1.0, max_depth=8, max_size=20, verbose=False, space=0):
        self.percent_dimension = percent_dimension
        self.max_depth = max_depth
    
    '''
    train this decision tree on a random subset of dimensions (columns) of X with a maximum depth
    '''
    def fit(self, X, y, w=None):
        
        return       
    
    '''
    predict the set of samples
    '''
    def predict(self, X):
        if self.label is not None:
            return self.label * jnp.ones(len(X))
        return jnp.concatenate([self.left.predict([x]) if x[self.dim] < self.th else self.right.predict([x]) for x in X])

We first try on the training set reduced to digits 0 and 1 with all dimensions to check that our code works. 

In [None]:
%%time
x_train_01 = X_train
y_train_01 = y_train
x_val_01 = X_val
y_val_01 = y_val


x_train_01 = x_train_01[0:256, 0:256]
y_train_01 = y_train_01[0:256]
x_val_01 = x_val_01[0:100, 0:256]
y_val_01 = y_val_01[0:100]


<div class="alert alert-info" role="alert"><b>Analyze your results in this box.</b>  Answer
</div>

<div class="alert alert-success"> 
    <b>Q4.</b> Use cross-validation on the full digit dataset (0-9) to select a reasonnable depth between 2 and 8, using random splits of half the training set to save on training time. <i>(Indicative time: maximum 10 minutes to code, about 20 minutes to run)</i>

In [None]:
%%time


<div class="alert alert-info" role="alert"><b>Analyze your results in this box.</b>  Answer
</div>

## Random Forest

Next, we want to mitigate the tendancy of decision trees to overfit when the depth is too high and to underfit when the depth is too small by implementing random forests.

<div class="alert alert-success"> 
    <b>Q5.</b> Code a Random Forest of decision trees, each trained on a subset of the training set. Perform a corase cross-validation to set a reasonnable number of trees (25, 50, 75), percent of training data used (0.5, 0.75), percent of dimensions used (0.5, 0.75) and depth 3. <i>(Indicative time: less than 20 minutes to code, takes more than 20 minutes to run)</i>
</div>

In [None]:
class RandomForest():
    def __init__(self, nb_trees, percent_dataset=1., percent_dimension=1., max_depth=8):
        self.nb_trees = nb_trees
        self.percent_dataset = percent_dataset
        self.percent_dimension = percent_dimension
        self.max_depth = max_depth
        self.trees = []
        
    def fit(self, X, y):
        
        return
    
    def predict(self, X):
        y = []
        for dt in self.trees:
            y.append(dt.predict(X))
        y = jax.nn.one_hot(jnp.array(y), num_classes=12)
        return y.sum(axis=0).argmax(axis=1)

<div class="alert alert-info" role="alert"><b>Analyze your results in this box.</b>  Answer
</div>

## Boosting

To have a more efficient training procedure, we will remove the independance between the trees by using boosting

<div class="alert alert-warning"> 
    <b>Q6.</b> Code the BoostingClassifier that obtains a combination of Randomized Trees using AdaBoost. Each tree is trained using the weighted $0-1$ loss. To allow the tree combination, convert the output of each tree to a one-hot encoded vector. The output of the boosted trees is then the weighted sum of these one-hot vectors and the predicted class is the argmax. Test with the same parameters as the best Random Forest. <i>(Indicative time: about 30 minutes to code, runs about as fast as a random forest)</i>
</div>

In [None]:
class BoostedTrees():
    def __init__(self, nb_trees, percent_dataset=1., percent_dimension=1., max_depth=8):
        self.nb_tress = nb_trees
        self.percent_dataset = percent_dataset
        self.percent_dimension = percent_dimension
        self.max_depth = max_depth
        
    def fit(self, X, y):
        return
    
    def predict(self, X):
        return

<div class="alert alert-info" role="alert"><b>Analyze your results in this box.</b>  Answer
</div>

## Visualization

In order to visualize the decision, we can produce an image the contains only the relevant information with respect to the decisions taken by a tree.

<div class="alert alert-warning"> 
    <b>Q7.</b> For a trained tree, select a leaf and build an image that has a value of 1 for each pixel in the decision path that should be above the threshold, 0 for each pixels in the decision path that should be below the threshold and 0.5 everywhere else. For all classes, show an average all such images for each leaf corresponding to that class. <i>(Indicative time: about one hour to code)</i>
</div>

<div class="alert alert-info" role="alert"><b>Analyze your results in this box.</b>  Answer
</div>