In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Input, Flatten
from tensorflow.keras.models import Model
import numpy as np
import pandas as pd

## Defining class to support Tree-CNN

In [2]:
class CNN (object):
    def __init__(self, num_classes, input_shape):
        input_layer, l = self.create_network_base(num_classes, input_shape)
        self.num_classes = num_classes
        self.input_shape = input_shape
        self.model = Model(input_layer, l)
        self.model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
    
    def create_network_base(self, num_classes, input_shape):
        input_layer = Input(shape=input_shape)
        l = Conv2D(16, 3, activation='relu')(input_layer)
        l = Conv2D(32, 3, activation='relu')(l)
        l = Conv2D(16, 3, activation='relu')(l)
        l = Flatten()(l)
        l = Dense(512, activation='relu')(l)
        l = Dense(num_classes, activation='softmax')(l)
        
        return input_layer, l
    
    def train(self, X, Y):
        self.model.fit(X, Y, batch_size=16, epochs=2)

    def remove_class(self, idx_to_remove):
        input_layer, l = self.create_network_base((self.num_classes-1), self.input_shape)
        new_model = Model(input_layer, l)
        
        for idx in range(len(self.model.layers)-1):
            if len(self.model.layers[idx].get_weights()) == 0 :
                continue
            wi = self.model.layers[idx].get_weights()[0]
            bi = self.model.layers[idx].get_weights()[1]
            new_model.layers[idx].set_weights((wi, bi))
        
        # Copy a already treined part of last layer
        old_w = self.model.layers[-1].get_weights()[0]
        new_w = new_model.layers[-1].get_weights()[0]
        old_bias = self.model.layers[-1].get_weights()[1]
        new_bias = new_model.layers[-1].get_weights()[1]

        for i in range(old_w.shape[0]):
            aux = 0
            for j in range(old_w.shape[1]):
                if j != idx_to_remove:
                    new_w[i][aux] = old_w[i][j]
                    aux = aux + 1
        aux = 0
        for i in range(old_bias.shape[0]):
            if i != idx_to_remove:
                new_bias[aux] = old_bias[i]
                aux = aux + 1

        new_model.layers[-1].set_weights((new_w, new_bias))
        
        self.model = new_model
        self.model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
        self.num_classes = self.num_classes - 1
        
    def add_class(self):
        input_layer, l = self.create_network_base((self.num_classes+1), self.input_shape)
        new_model = Model(input_layer, l)
        
        for idx in range(len(self.model.layers)-1):
            if len(self.model.layers[idx].get_weights()) == 0 :
                continue
            wi = self.model.layers[idx].get_weights()[0]
            bi = self.model.layers[idx].get_weights()[1]
            new_model.layers[idx].set_weights((wi, bi))
        
        # Copy a already treined part of last layer
        old_w = self.model.layers[-1].get_weights()[0]
        new_w = new_model.layers[-1].get_weights()[0]
        old_bias = self.model.layers[-1].get_weights()[1]
        new_bias = new_model.layers[-1].get_weights()[1]

        for i in range(old_w.shape[0]):
            for j in range(old_w.shape[1]):
                new_w[i][j] = old_w[i][j]
        for i in range(old_bias.shape[0]):
            new_bias[i] = old_bias[i]

        new_model.layers[-1].set_weights((new_w, new_bias))
        
        self.model = new_model
        self.model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
        self.num_classes = self.num_classes + 1
        
    def pred(self, img):
        return self.model.predict(img)

In [3]:
class NodeList (object):
    def __init__(self, label, values, nodes, class_position):
        self.label = label
        self.values = values
        self.nodes = nodes
        self.class_position = class_position
    def __str__(self):
        return {'label': self.label, 'values': self.values, 'nodes':self.nodes, 'class_position': self.class_position}
    def __unicode__(self):
        return {'label': self.label, 'values': self.values, 'nodes':self.nodes, 'class_position': self.class_position}
    def __repr__(self):
        return 'label: ' + str(self.label) + ' values: ' + str(self.values) + ' nodes: ' + str(self.nodes)+' class_pos: '+ str(self.class_position)

In [4]:
class CnnNode (object):
    def __init__(self, num_classes, labels = [], max_leafes=10):
        self.net = CNN(num_classes, (28,28,1))
        self.num_classes = num_classes
        self.childrens = [label for label in labels]
        self.childrens_leaf = [True for _ in range(num_classes)]
        self.labels = labels
        self.max_leafes = max_leafes
        self.labels_transform = {}
        for nc in range(num_classes):
            self.labels_transform[nc] = []
            self.labels_transform[nc].append(labels[nc])
        
    
    def get_num_leafnodes(self):
        count = 0
        for is_leaf in self.childrens_leaf:
            if is_leaf:
                count = count + 1
        return count
    
    def remove_leaf(self, label):
        childrens = []
        childrens_leaf = []
        labels = []
        del self.labels_transform[label]
        self.num_classes = (self.num_classes - 1)
        position_in_net = -1
        
        for i in range(len(self.labels)):
            if self.labels[i] != label:
                childrens.append(self.childrens[i])
                childrens_leaf.append(self.childrens_leaf[i])
                labels.append(self.labels[i])
            else:
                position_in_net = i
                
        self.childrens = childrens
        self.childrens_leaf = childrens_leaf
        self.labels = labels
        #TODO: REMOVE in Net the neuron position_in_net of top layer
        self.net.remove_class(position_in_net)
        
    def add_leaf(self, label):
        self.childrens.append(label)
        self.childrens_leaf.append(True)
        self.labels.append(label)
        self.labels_transform[label] = []
        self.labels_transform[label].append(label)
        #TODO: ADD in Net more one neuron in top layer
        self.net.add_class()
    
    def predict(self, imgs):
        vector_output = self.net.pred(imgs)
        return vector_output
    
    def inference(self, imgs):
        vector_output = self.net.pred(imgs)
        out = np.array([idx for idx in np.argmax(vector_output, axis=1)])
        output = np.array([-1 for _ in range(imgs.shape[0])])
        for i, o in zip(range(len(out)), out):
            output[i] = self.labels[o]

        return output
    
    def train(self, X, Y):
        #TODO: Use the labels_transform to transform the Y
        self.net.train(X, Y)
        #TODO: Pass the train to subnodes

#### Testing CnnNode with 2 nivels.

In [5]:
samples = np.random.rand(5,28,28,1)
cnnNode = CnnNode(3, labels=[2, 3, 8])
cnnNodeC = CnnNode(2, labels=[4, 1])
cnnNode.childrens.append(cnnNodeC)
cnnNode.predict(samples)

array([[0.3802768 , 0.30769065, 0.31203258],
       [0.38332018, 0.3050918 , 0.311588  ],
       [0.3771458 , 0.32205832, 0.30079588],
       [0.38218996, 0.2913022 , 0.32650778],
       [0.38976935, 0.29702896, 0.31320173]], dtype=float32)

In [6]:
class TreeCNN (object):
    def __init__(self, num_class_initial, initial_labels, alpha=0.1, beta=0.1, max_leafnodes=1000):
        self.root = CnnNode(num_class_initial, labels=initial_labels)
        self.alpha = alpha
        self.beta = beta
        self.max_leafnodes = max_leafnodes
        
    
    def addTasks(self, imgs_of_classes=[], labels=[]):
        self.growTreeCNN(self.root, imgs_of_classes, labels)
        
    def train(self, X, Y):
        self.root.train(X, Y)
        
    def inference(self, X):
        return self.root.inference(X)
        
        
    def growTreeCNN(self, operation_node, imgs_of_classes=[], labels=[]):
        def get_Oavg_matrix(node, imgs_of_classes_, labels_):
            Oavg = np.zeros(shape=(node.num_classes, 0))
            for imgs, label in zip(imgs_of_classes, labels_):
                net_out = node.predict(imgs)
                Oavg_i = np.average(net_out, axis=0)
                Oavg = np.concatenate(( Oavg, Oavg_i.reshape((Oavg_i.shape[0], 1)) ), axis=1)
            return Oavg
        
        def get_loglikelihood_matrix(Oavg):
            return (np.power(np.e, Oavg) / np.sum(np.power(np.e, Oavg), axis=0))
        
        def generate_listS(llh, labels_in):
            listS = []
            for i in range(llh.shape[1]):
                label = labels_in[i]
                values = []
                nodes = []

                col = llh[:,i].copy()
                for _ in range(3):
                    max_idx = np.argmax(col)
                    values.append(col[max_idx])
                    nodes.append(max_idx)
                    col[max_idx] = -100

                listS.append(NodeList(label, values, nodes, i))

            # Sort List S by value of S[i].values[0]
            listS.sort(key=lambda node_list: node_list.values[0])
            
            return listS

        llh = get_loglikelihood_matrix(get_Oavg_matrix(operation_node, imgs_of_classes, labels))
        
        listS = generate_listS(llh, labels)
        new_labels = labels
        
        branches_dest = {}
        while len(listS) > 0:
            nodeList = listS[0]
            rows_to_remove_in_llh = []
            if nodeList.values[0] - nodeList.values[1] > self.alpha:
                if operation_node.childrens_leaf[nodeList.nodes[0]]:
                    operation_node.childrens_leaf[nodeList.nodes[0]] = False
                    old_label = operation_node.labels[nodeList.nodes[0]]
                    new_label = nodeList.label
                    branch_node = CnnNode(2, labels=[old_label, new_label])
                    operation_node.childrens[nodeList.nodes[0]] = branch_node
                else:
                    if nodeList.nodes[0] not in branches_dest:
                        branches_dest[nodeList.nodes[0]] = []
                    branches_dest[nodeList.nodes[0]].append(nodeList.label)
                
                operation_node.labels_transform[nodeList.nodes[0]].append(nodeList.label)
            elif nodeList.values[1] - nodeList.values[2] > self.beta:
                left_is_leafnode = operation_node.childrens_leaf[nodeList.nodes[0]]
                right_is_leafnode = operation_node.childrens_leaf[nodeList.nodes[1]]
                has_space_in_left = left_is_leafnode or (operation_node.childrens[nodeList.nodes[0]].get_num_leafnodes() < (self.max_leafnodes - 1))
                
                if right_is_leafnode and has_space_in_left: # if Merge
                    if operation_node.childrens_leaf[nodeList.nodes[0]]: # if left is a leaf
                        operation_node.childrens_leaf[nodeList.nodes[0]] = False
                        old_label = operation_node.labels[nodeList.nodes[0]]
                        new_label = operation_node.labels[nodeList.nodes[1]]
                        branch_node = CnnNode(2, labels=[old_label, new_label])
                        operation_node.childrens[nodeList.nodes[0]] = branch_node
                    else:
                        operation_node.childrens[nodeList.nodes[0]].add_leaf(operation_node.labels[nodeList.nodes[1]])
                        if nodeList.nodes[0] not in branches_dest:
                            branches_dest[nodeList.nodes[0]] = []
                        branches_dest[nodeList.nodes[0]].append(nodeList.label)
                    
                    operation_node.labels_transform[nodeList.nodes[0]].append(operation_node.labels[nodeList.nodes[1]])
                    operation_node.remove_leaf(operation_node.labels[nodeList.nodes[1]])
                    rows_to_remove_in_llh.append(nodeList.nodes[1])
                else:
                    if left_is_leafnode:
                        operation_node.childrens_leaf[nodeList.nodes[0]] = False
                        old_label = operation_node.labels[nodeList.nodes[0]]
                        new_label = nodeList.label
                        branch_node = CnnNode(2, labels=[old_label, new_label])
                        operation_node.childrens[nodeList.nodes[0]] = branch_node
                    elif right_is_leafnode:
                        operation_node.childrens_leaf[nodeList.nodes[1]] = False
                        old_label = operation_node.labels[nodeList.nodes[1]]
                        new_label = nodeList.label
                        branch_node = CnnNode(2, labels=[old_label, new_label])
                        operation_node.childrens[nodeList.nodes[1]] = branch_node
                    else:
                        if operation_node.childrens[nodeList.nodes[0]].get_num_leafnodes() < operation_node.childrens[nodeList.nodes[1]].get_num_leafnodes():
                            if nodeList.nodes[0] not in branches_dest:
                                branches_dest[nodeList.nodes[0]] = []
                            branches_dest[nodeList.nodes[0]].append(nodeList.label)
                        else:
                            if nodeList.nodes[1] not in branches_dest:
                                branches_dest[nodeList.nodes[1]] = []
                                print(nodeList)
                            branches_dest[nodeList.nodes[1]].append(nodeList.label)
                    
            else:
                operation_node.add_leaf(nodeList.label)
                
            
            # Clean likelihood matrix and labels to recreate listS to next iteration
            # Delete column already inserted
            llh = np.delete(llh, nodeList.class_position, axis=1)
            # Delete rows that was merged
            for r in rows_to_remove_in_llh:
                llh = np.delete(llh, r, axis=0)
            # Delete rows that represent full nodes (just branch/child nodes)
            deleted = 0
            for i in range(len(operation_node.childrens_leaf)):
                if not operation_node.childrens_leaf[i] and (operation_node.childrens[i].get_num_leafnodes() >= self.max_leafnodes):
                    llh = np.delete(llh, (i - deleted), axis=0)
                    deleted = deleted + 1
            # Update labels
            del new_labels[nodeList.class_position]
            listS = generate_listS(llh, new_labels)
        
        # Send to sub-nivels
        for k, v in branches_dest.items():
            imgs_to_send = []
            labels_to_send = []
            for idx, label in zip(range(len(labels)), labels):
                if label in v:
                    imgs_to_send.append(imgs_of_classes[idx])
                    labels_to_send.append(labels[idx])

            print(idx_to_send)
            self.growTreeCNN(operation_node.childrens[k], imgs_to_send, labels_to_send)
        

In [7]:
tree = TreeCNN(3, [0,1,2,])

## Test with MNIST

In [8]:
from tensorflow.keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [9]:
old_classes_X = x_train[y_train<3].astype(float)
old_classes_X = old_classes_X.reshape((old_classes_X.shape[0], old_classes_X.shape[1], old_classes_X.shape[2], 1))
old_classes_Y = pd.get_dummies(pd.Series(y_train[y_train<3])).values

new_classes_X = [x_train[y_train==i].reshape((x_train[y_train==i].shape[0], x_train[y_train==i].shape[1], x_train[y_train==i].shape[2], 1 )).astype(float) for i in range(4, 7)]
new_classes_Y = [pd.get_dummies(pd.Series(y_train[y_train==i])) for i in range(4, 7)]
# pd.get_dummies(pd.Series(old_classes_Y))

In [10]:
Y_hat = tree.inference(old_classes_X)
Y_true = y_train[y_train<3]
Y_hat

array([0, 0, 0, ..., 0, 1, 0])

In [11]:
np.sum(Y_true==Y_hat)/len(Y_true)

0.3761477742576384

In [12]:
tree.train(old_classes_X, old_classes_Y)

Train on 18623 samples
Epoch 1/2
Epoch 2/2


In [13]:
Y_hat = tree.inference(old_classes_X)
Y_true = y_train[y_train<3]
Y_hat

array([0, 1, 2, ..., 0, 2, 1])

In [14]:
np.sum(Y_true==Y_hat)/len(Y_true)

0.9959190248617301

### Growing network

In [15]:
tree.addTasks(new_classes_X, [i for i in range(4, 7)])

In [21]:
tree.root.net.model.layers[-1].output_shape

(None, 4)

In [24]:
tree.root.childrens[0].childrens_leaf

[True, True]