In [1]:
import os
import sys
import numpy as np
from tqdm import tqdm
from rsq import SSRF
from rsq.samplers import _Sampler, ModelSampler
from functools import partial
from joblib import Parallel, delayed
from time import time

from sklearn.preprocessing import normalize

from sklearn.tree import DecisionTreeClassifier as Tree
from graspologic.cluster import GaussianCluster as GMM
from sklearn.ensemble import RandomForestClassifier
from sklearn import random_projection

from scipy.stats import entropy
  
def semi_supervised_split(X, y, n_test, p_labeled=0.01):
    n_all = len(y)
    
    n_available = n_all - n_test
    p_available = n_available / n_all
    
    unique_y, counts = np.unique(y, return_counts=True)
    priors = counts / np.sum(counts)
    n_per_class = (n_available * priors).astype(int)
    n_supervised_per_class = p_labeled * n_per_class
    n_supervised_per_class = np.array([max([n_spc, 1]) for n_spc in n_supervised_per_class]).astype(int)
    n_semisupervised_per_class = (n_per_class - n_supervised_per_class).astype(int)
        
    idx_by_label = [np.random.choice(np.where(y == c)[0], counts[i], replace=False) for i, c in enumerate(unique_y)]
    supervised = [idx_by_label[i][:n_spc] for i, n_spc in enumerate(n_supervised_per_class)]
    
    supervised = np.concatenate(supervised, axis=0)
    unsupervised = [idx_by_label[i][n_spc:n_semisupervised_per_class[i] + n_spc] 
                                     for (i, n_spc) in enumerate(n_supervised_per_class)]
    unsupervised = np.concatenate(unsupervised)
    
    
    test = np.array([i for i in range(n_all) if i not in supervised and i not in unsupervised])
        
    return supervised, unsupervised, test

def stratified_sample(y, p=0.67, replace=False):
    unique_y, counts = np.unique(y, return_counts=True)
    n_per_class = np.array([int(np.math.floor(p*c)) for c in counts])
    n_per_class = np.array([max([npc, 1]) for npc in n_per_class])
    
    inds = np.array([np.random.choice(np.where(y == unique_y[i])[0], size=npc, replace=replace) for i, npc in enumerate(n_per_class)])
    
    return np.concatenate(inds)

# X = np.load('/home/hhelm/jataware/rsq2/output/feats/crow_resnet50/birdsnap/X.npy')
# y = np.load('/home/hhelm/jataware/rsq2/output/feats/crow_resnet50/birdsnap/y.npy')

# unique_y = np.unique(y)
# y_map = {str_: i for (i,str_) in enumerate(unique_y)}
# y = np.array([y_map[_] for _ in y])

In [2]:
# --
# Helpers

def subset(X, y, n):
    if isinstance(X, list):
        sel = np.random.choice(X[0].shape[0], n, replace=False)
        return [x[sel] for x in X], y[sel]
    
    sel = np.random.choice(X.shape[0], n, replace=False)
    
    return X[sel], y[sel]

def shuffle(X, y):
    """ shuffle X and y """
    
    if isinstance(X, list):
        sel = np.random.permutation(X[0].shape[0])
        return [x[sel] for x in X], y[sel]
    
    sel = np.random.permutation(X.shape[0])
    return X[sel], y[sel]

def adjust_prevalance(X, y, n_pos):
    """ adjust the prevalance of the positive class """
    
    pos_idx = np.where(y == 1)[0]
    neg_idx = np.where(y == 0)[0]
    sel = np.hstack([
        np.random.choice(pos_idx, n_pos, replace=False),
        neg_idx,
    ])
    
    if isinstance(X, list):
        return [x[sel] for x in X], y[sel]
    
    sel = np.random.permutation(sel)
    return X[sel], y[sel]


def process_data(X, y, n_pos_labels, n_all):
    X, y  = shuffle(X, y)

    X = [normalize(x, axis=1, norm='l2') for x in X]

    X, y = subset(X, y, n=n_all)

    target = np.random.choice(np.unique(y)) # !! Uniform sample class, not according to prevalance
    y      = (y == target).astype(np.int)

    X, y = adjust_prevalance(X, y, n_pos_labels)
    
    return X, y
    

def run_exp(X, y, model_idx, label_batchsize, n_label_rounds, n_queries, samplers):
    qinds       = np.random.choice(np.where(y == 1)[0], size=n_queries, replace=False)
    start_labs       = np.zeros(X[model_idx].shape[0]) - 1
    start_labs[qinds] = 1
    
    prop_pos = np.zeros((len(samplers), n_label_rounds))
    timer = np.zeros((len(samplers), n_label_rounds))
    
    for i, (sampler_name, sampler) in enumerate(samplers.items()):
        t       = time()
        if sampler_name == "MIPSampler":
            sampler = sampler(fts=X, labels=start_labs.copy())
        else:    
            sampler = sampler(fts=X[model_idx], labels=start_labs.copy())
        for j in range(n_label_rounds):

            priority  = sampler.get_priority()            # get priority of unlabeled instances
            next_idxs = priority[:label_batchsize]        # get next `query_bs` instances to label
            next_labs = y[next_idxs]                      # get labels for next instances
            sampler.set_label(next_idxs, next_labs)       # tell the sampler about the new labels
            
            prop_pos[i, j] = len(sampler.pos_idxs)
            timer[i,j] = float(time() - t)
        
    return prop_pos, timer    
            
    
def experiment(process_args, experiment_args):
    X, y = process_data(*process_args)
        
    return run_exp(X, y, *experiment_args)

In [35]:
class MALF(_Sampler):
    def __init__(self, fts, labels, 
                 n_trees_entropy=2, n_trees_labeledness=10,
                 max_depth_entropy=2, max_depth_labeledness=10,
                 tree_split_entropy=0.5, tree_split_labeledness=1,
                 induce_entropy=True, induce_class_entropy=GMM, induce_kwargs_entropy={'min_components':100, 'max_components':100},
                 induce_labeledness=True, induce_class_labeledness=GMM, induce_kwargs_labeledness={'min_components':100, 'max_components':100},
                 random_projection_entropy=True, projection_kwargs_entropy={'n_components': 128},
                 random_projection_labeledness=True, projection_kwargs_labeledness={'n_components': 128}
                ):
        self.fts      = fts
        self.labels   = labels
        
        self.n_trees_entropy=n_trees_entropy
        self.max_depth_entropy=max_depth_entropy
        self.tree_split_entropy = tree_split_entropy
        
        self.induce_entropy=induce_entropy
        self.induce_class_entropy=induce_class_entropy
        self.induce_kwargs_entropy=induce_kwargs_entropy
        
        self.random_projection_entropy=random_projection_entropy
        self.projection_kwargs_entropy=projection_kwargs_entropy
        
        self.entropy_forest = SSRF(self.n_trees_entropy, self.max_depth_entropy, self.tree_split_entropy, 
                                   self.induce_entropy, self.induce_class_entropy, self.induce_kwargs_entropy,
                                  self.random_projection_entropy, self.projection_kwargs_entropy)
        
        
        self.n_trees_labeledness=n_trees_labeledness
        self.max_depth_labeledness=max_depth_labeledness
        self.tree_split_labeledness=tree_split_labeledness
        
        self.induce_labeledness=induce_labeledness
        self.induce_class_labeledness=induce_class_labeledness
        self.induce_kwargs_labeledness=induce_kwargs_labeledness
        
        self.random_projection_labeledness=random_projection_labeledness
        self.projection_kwargs_labeledness=projection_kwargs_labeledness
        
        self.labeledness_forest = SSRF(self.n_trees_labeledness, self.max_depth_labeledness, self.tree_split_labeledness, 
                                   self.induce_labeledness, self.induce_class_labeledness, self.induce_kwargs_labeledness,
                                  self.random_projection_labeledness, self.projection_kwargs_labeledness)
        
        self.fitted=False
        
    def _fit_entropy(self, n_cores=-1):
        self.entropy_forest.fit(self.fts, self.labels, None, n_cores)
        
    def _fit_labeledness(self, n_cores=-1):
        temp_labels = np.zeros(len(self.labels))
        temp_labels[np.where(self.labels != -1)[0]] = 1


                
        self.labeledness_forest.fit(self.fts, temp_labels, None, n_cores)

        
    def fit(self, n_cores=-1):
        print("start timer")
        start_time = time()
        self._fit_entropy(n_cores=n_cores)
        print("fitted entropy", time() - start_time)
        int_time = time()
        self._fit_labeledness(n_cores=n_cores)
        print("fitted labeledness", time() - int_time)
        
    def get_priority(self, n_cores=-1):
        if not self.fitted:
            self.fit(n_cores)
            
        pos_idxs = self.pos_idxs
        neg_idxs = self.neg_idxs
        mis_idxs = self.mis_idxs
                    
        posteriors_entropy = self.entropy_forest.predict_proba(self.fts[mis_idxs])
        entropy_score = entropy(posteriors_entropy, axis=0)
        
        posteriors_labeledness = self.labeledness_forest.predict_proba(self.fts[mis_idxs])
        
        print(entropy_score.shape, posteriors_labeledness.shape)
        
        scores = np.multiply(entropy_score, posteriors_labeledness[:, 1])
        
        mis_ranks = mis_idxs[np.argsort(scores)]
        
        return mis_ranks
        
    
    def set_label(self, indices, labels):  
        pos_idxs = self.pos_idxs
        for idx, yy in zip(indices, labels):
            self.labels[idx] = yy
            
            assert yy in [0, 1]
                        
            if yy == 1:
                self.labels[idx] = 1
            else:
                self.labels[idx] = 0
                            
        self.update_posteriors(indices)
    
    def update_posteriors(self, indices):
        # can parallelize
        for tree in self.entropy_forest.forest:
            new_labeled_decision_paths = [tree.decision_paths[_] for _ in indices]
            nodes_with_new_labeled_data = np.unique(np.concatenate(new_labeled_decision_paths))
            
            updated_leaf_nodes = []

            for dp in tree.decision_paths:
                leaf_node = dp[-1]
                temp_intersection=set(nodes_with_new_labeled_data).intersection(set(dp))
                if len(temp_intersection) == 0:
                    continue
                
                if leaf_node in updated_leaf_nodes:
                    continue

                for i in range(1, len(dp)+1):
                    temp_node = dp[-i]

                    if temp_node in nodes_with_new_labeled_data:        
                        temp_counts = np.zeros(len(tree.classes))

                        tree_distance_to_labeled_data = np.zeros(len(tree.labeled_decision_paths) + len(indices))
                        for j, labeled_dp in enumerate(np.concatenate((tree.labeled_decision_paths, new_labeled_decision_paths))):
                            if temp_node in labeled_dp:
                                tree_distance_to_labeled_data[j] = i + len(labeled_dp) - 1 - np.where(labeled_dp == temp_node)[0][0] - 1
                            else:
                                tree_distance_to_labeled_data[j] = 100

                        min_tree_distance = np.min(tree_distance_to_labeled_data)
                        argmins = np.where(tree_distance_to_labeled_data == min_tree_distance)[0]

                        for index in argmins.astype(int):
                            temp_counts[int(self.labels[index])] += 1

                        tree.mapping[leaf_node] = temp_counts / np.sum(temp_counts)
                        break
            tree.labeled_decision_paths = np.concatenate((tree.labeled_decision_paths, new_labeled_decision_paths))
            tree.nodes_with_labeled_data = np.unique(np.concatenate((tree.nodes_with_labeled_data, nodes_with_new_labeled_data)))

In [36]:
# --
# Experiment parameters & initial data loading

np.random.seed(1)


data_files = ['../output/feats/crow_resnet50/birdsnap/',
#               '../output/feats/crow_wide_resnet101_2/birdsnap/',
#               '../output/feats/crow_resnext101_32x8d/birdsnap/',
#               '../output/feats/crow_vgg19/birdsnap/'
             ]
models = [s.split('/')[-3] for s in data_files]


X = [np.load(os.path.join(f, 'X.npy')).astype(np.float64) for f in data_files]
y = np.load(os.path.join(data_files[0], 'y.npy'))
n_pos_labels=3
n_all=5000

label_batchsize=1
n_label_rounds=100
n_queries=1
samplers = {
#     "MIPSampler"        : partial(MIPSampler, initial_max_seconds=60, update_max_seconds=30, warm_start='barycenter'),
    "MALForestSampler"  : MALF,
    "SVCSampler"        : partial(ModelSampler, model='svc'),
#     "LogisticSampler"   : partial(ModelSampler, model='logistic'),
    
#     "LASSampler"        : LASSampler,
    
#     "NaiveMeanSampler"  : partial(PoolingSampler, pool_fn='mean', score_fn='pos'),
#     "RatioMeanSampler"  : partial(PoolingSampler, pool_fn='mean'),
#     "RatioMaxSampler"   : partial(PoolingSampler, pool_fn='max'),
#     "GEMSampler1"       : partial(PoolingSampler, pool_fn='gem', gem_p=1),
#     "GEMSampler4"       : partial(PoolingSampler, pool_fn='gem', gem_p=4),
#     "GEMSampler16"      : partial(PoolingSampler, pool_fn='gem', gem_p=16)
}

experiment_args = (label_batchsize, n_label_rounds, n_queries, samplers)

In [37]:
np.random.seed(1)

n_cores=30
n_mc=1

prop_pos_1 = np.zeros((len(models), len(samplers), n_mc, n_label_rounds))
timer_1 = np.zeros((len(models), len(samplers), n_mc, n_label_rounds))

for i, XX in enumerate(tqdm(X)):
    process_args = (X, y, n_pos_labels, n_all)
    
    samplers_ = samplers.copy()
    if i > 0:
        samplers_.pop('MIPSampler')

    experiment_args = (i, label_batchsize, n_label_rounds, n_queries, samplers_)
    
    condensed_experiment = lambda x: experiment(process_args, experiment_args)
    
    props_and_timers = Parallel(n_jobs=1)(delayed(condensed_experiment)(_) for _ in np.zeros(1))
    temp_props = np.array([p for (p, t) in props_and_timers]) # (n_mc, len(samplers), n_label_rounds)
    temp_timers = np.array([t for (p, t) in props_and_timers])
    
    prop_pos_1[i] = temp_props.transpose((1,0,2))
    timer_1[i] = temp_props.transpose((1,0,2))

  0%|          | 0/1 [00:00<?, ?it/s]

start timer
fitting forest
building forest


  0%|          | 0/1 [00:33<?, ?it/s]


IndexError: index 1 is out of bounds for axis 0 with size 1

In [32]:
class SSRF:
    def __init__(self, n_trees=100, max_depth=10, tree_split=None,
                 induce=True, induce_class=GMM, induce_kwargs={'min_components':10, 'max_components':10},
                 random_projection=False, projection_kwargs={}
                ):
        
        self.n_trees=n_trees
        self.max_depth=max_depth
        self.tree_split = tree_split
        
        self.induce=induce
        self.induce_class=induce_class
        self.induce_kwargs=induce_kwargs
        
        self.random_projection=random_projection
        self.projection_kwargs=projection_kwargs
        
        self.forest = []
        
    def induce_labels(self, X, GMM_kwargs={'min_clusters': 10, 'max_clusters': 20, 'covariance':'tied'}):
        self.induced_labels = GMM(**GMM_kwargs).fit_predict(X)
        
    def fit(self, X, y, y_induced, n_cores=-1):
        print("fitting forest")
        self.classes = np.array([i for i in np.unique(y) if i != -1])
        n = len(y)
        
        labeled_indices = np.where(y != -1)[0]
        
        if self.tree_split is None:
            self.tree_split = len(np.where(y != -1)[0]) / n
                    
        n_supervised_trees = int(self.n_trees * self.tree_split)
        n_semi_supervised_trees = self.n_trees - n_supervised_trees
        
        condensed_func = lambda x: self._build_tree(X, y, y_induced, x, stratified=True)
        func_tuples = np.concatenate((np.ones(n_supervised_trees), np.zeros(n_semi_supervised_trees))).astype(int)
        
        print("building forest")
        self.forest = Parallel(n_jobs=n_cores)(delayed(condensed_func)(tuple_) for tuple_ in func_tuples)
                                
    def predict_proba(self, X):
        posteriors = np.zeros((X.shape[0], len(self.classes)))
                
        for i, tree in enumerate(self.forest):
            temp = tree.predict_proba(X)
            posteriors += temp
            
        return posteriors / len(self.forest)
    
    
    def predict(self, X):
        return np.argmax(self.predict_proba(X), axis=1)
            
    def _build_tree(self, X, y, y_induced, supervised=True, stratified=True):
        def worker(tree, X, y, y_induced=None):
            if y_induced is None:
                tree.fit(X, y)
            else:
                tree.fit(X, y, y_induced)
            return tree
        
        labeled_indices = np.where(y != -1)[0]
        unlabeled_indices = np.where(y == -1)[0]
        
        print(supervised, X.shape, len(y))
        
        if len(labeled_indices) == len(y):
            all_supervised=True
        else:
            all_supervised=False
            
        if supervised:
            bag_inds = stratified_sample(y[labeled_indices], p=0.67, replace=False)
            tree = Tree(max_depth=self.max_depth)
            tree.fit(X[labeled_indices[bag_inds]], y[labeled_indices[bag_inds]])
            
        else:
            sbag_inds = stratified_sample(y[labeled_indices], p=0.67, replace=False)
            if all_supervised:
                bag_inds = sbag_inds
            else:
                ssbag_inds = np.random.choice(len(unlabeled_indices), size=int(X.shape[0]*0.67), replace=True)
                bag_inds = np.concatenate((labeled_indices[sbag_inds], ssbag_inds))
            
            tree = SemiSupervisedTreeClassifier(max_depth=self.max_depth, 
                                                induce=self.induce, induce_class=self.induce_class, induce_kwargs=self.induce_kwargs,
                                               random_projection=self.random_projection, projection_kwargs=self.projection_kwargs)
            if y_induced is None:
                tree.fit(X[bag_inds], y[bag_inds])
            else:
                tree.fit(X[bag_inds], y[bag_inds], y_induced[bag_inds])
                
        return tree

In [33]:
class SemiSupervisedTreeClassifier:
    def __init__(self, max_depth=10, induce=True, 
                 induce_class=GMM, induce_kwargs={'min_components':100, 'max_components':100},
                 random_projection=False,
                 projection_kwargs={}):
        
        self.max_depth=max_depth
        self.fitted=False
        
        self.induce=induce
        self.induce_class=induce_class
        self.induce_kwargs=induce_kwargs
        
        self.random_projection=random_projection
        self.projection_kwargs=projection_kwargs
        self.projector=None
        
    def fit(self, X, y, y_induced=None):
        self.classes = np.array([i for i in np.unique(y) if i != -1])
        
        if y_induced is None and self.induce:
            if self.random_projection:
                self.projector = random_projection.GaussianRandomProjection(**self.projection_kwargs)
                self.projector.fit(X)
                X = self.projector.transform(X.copy())
                
            y_induced = self.induce_class(**self.induce_kwargs).fit_predict(X)
            
        self.tree = Tree(max_depth=self.max_depth).fit(X, y_induced)
        decision_paths = self.tree.decision_path(X)
        self.decision_paths = [dp.nonzero()[1] for dp in decision_paths]
        
        labeled_indices = np.where(y != -1)[0]
        self.labeled_decision_paths = [self.decision_paths[i] for i in labeled_indices]
        
        self.labels = y[labeled_indices]
            
        self.nodes_with_labeled_data = np.unique(np.concatenate(self.labeled_decision_paths))
        self.projection_matrix = None
        
        self._get_mapping()
        
        self.fitted=True
    
    def _get_mapping(self):
        self.mapping = {}
                
        for dp in self.decision_paths:
            leaf_node = dp[-1]
                            
            if leaf_node in list(self.mapping.keys()):
                continue
            
            for i in range(1, len(dp)+1):
                temp_node = dp[-i]
                
                if temp_node in self.nodes_with_labeled_data:        
                    temp_counts = np.zeros(len(self.classes))

                    tree_distance_to_labeled_data = np.zeros(len(self.labeled_decision_paths))
                    for j, labeled_dp in enumerate(self.labeled_decision_paths):
                        if temp_node in labeled_dp:
                            tree_distance_to_labeled_data[j] = i + len(labeled_dp) - 1 - np.where(labeled_dp == temp_node)[0][0] - 1
                        else:
                            tree_distance_to_labeled_data[j] = 100
                                                        
                    min_tree_distance = np.min(tree_distance_to_labeled_data)
                    argmins = np.where(tree_distance_to_labeled_data == min_tree_distance)[0]

                    for index in argmins.astype(int):
                        temp_counts[int(self.labels[index])] += 1
                    
                    self.mapping[leaf_node] = temp_counts / np.sum(temp_counts)
                    break
    
    def predict_proba(self, X):
        if not self.fitted:
            raise ValueError('Not fitted')
            
        if self.random_projection:
            X = self.projector.transform(X.copy())
            
        leaf_nodes = self.tree.apply(X)
        posteriors = np.zeros((len(leaf_nodes), len(self.classes)))
        
        for i, leaf_node in enumerate(leaf_nodes):
            posteriors[i] = self.mapping[leaf_node]
            
        return posteriors
        
    def predict(self, X):
        return np.argmax(self.predict_proba(X), axis=1)