In [1]:
import pickle
from collections import defaultdict
import numpy as np
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization 
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.utils import to_categorical


In [5]:
x_train, y_train, x_test, y_test = pickle.load(open('/mnt/c/Users/weiwya/Desktop/cifar_resnet50_embed.p', 'rb'))
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
train_tmp = defaultdict(list)
test_tmp = defaultdict(list)

for x, l in zip(x_train, y_train):
    train_tmp[l[0]].append(x)
for x, l in zip(x_test, y_test):
    test_tmp[l[0]].append(x)

train_dict, test_dict  = {}, {}
for k, v in train_tmp.items():
    train_dict[k] = np.array(v)
for k, v in test_tmp.items():
    test_dict[k] = np.array(v)
print(len(train_dict), len(test_dict))
del train_tmp, test_tmp


(50000, 1000) (50000, 1) (10000, 1000) (10000, 1)
100 100


In [41]:
#this function inaddition to return all data specified it can also return 
#randomly sampled data from all non specified as negative examples
def get_data(data_dict, classes, return_negative=False):
    data = []
    labels = []
    counts = 0
    for c in classes:
        size = len(data_dict[c])
        data.append (data_dict[c])
        labels += [c]*size
        counts += size
        
    
    if return_negative:
        other = []
        for k, v in data_dict.items():
            if k not in classes:
                other.append(v)
        other = np.vstack(other)
        sample_per_task =int( counts / len(classes))
        sample_idx = np.random.choice(other.shape[0], sample_per_task, replace=False)
        other = other[sample_idx]
        data.append(other)
        counts += sample_per_task
        labels+= [-1] * len(other)
    
    #shuffle data for good measure
    data = np.vstack(data)
    labels = np.array(labels)

    idx = np.random.choice(np.arange(counts), size=counts, replace=False)
    return data[idx], np.array(labels)[idx]

In [43]:
#basic MLP 
#TODO:Swap out for more  sophisticated models 

def fit_basic_model(train_data, train_labels, epochs=50, verbose=False):
    input_dim = train_data.shape[1]
    n_classes = len(np.unique(train_labels))
    model = Sequential()
    model.add(Dense(256, activation='relu', input_shape=(input_dim, )))
    model.add(Dropout(.25))
#     model.add(BatchNormalization())
    model.add(Dense(n_classes, activation='softmax'))
    model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

    model.fit(train_data, to_categorical(train_labels), 
                        epochs = epochs, 
                        verbose = verbose,
                        validation_split = 0.25,
                        shuffle=True)
    return model

    
def get_accuracy(predict, acutual):
    return np.sum(predict==acutual)/len(acutual)

In [73]:
#TODO: put into Will's PL framework

def gen_task_classes(total_classes, n_tasks=10, shuffle=False):
    class_labels = np.arange(total_classes)
    if shuffle:
        np.random.shuffle(class_labels)
    return np.array_split(class_labels, n_tasks)


def fit_all_task_models(train_dict, n_tasks=10, n_label_repeats=1,shuffle_labels=True, train_neg_exmaples=True, verbose=False):
    
    total_labels = len(train_dict)
    print('total_labels %s' %total_labels)

    labels_for_tasks = []
    for i in range(n_label_repeats):
        labels_for_tasks.append(gen_task_classes(total_labels, n_tasks, shuffle=shuffle_labels))


    labels_for_tasks = np.vstack(labels_for_tasks)
    total_tasks = len(labels_for_tasks)
    print("total number of tasks %s" %total_tasks)
                         
    transformers = {}
    task_data = {}
    task_label_lookup = {}
    voters = {}
    #this part tranins transformers for all tasks
    for task in range(total_tasks):
        actual_label_to_task_labels = {}
        task_labels_to_actual_labels = {} 
        
        task_labels = labels_for_tasks[task]        
        data, label = get_data(train_dict, task_labels, return_negative=train_neg_exmaples)
        
        if train_neg_exmaples:
            task_labels = np.append(task_labels, -1) 
        
        task_labels = np.sort(task_labels) 
        
        for i, l in enumerate(task_labels):
            actual_label_to_task_labels[l] = i
            task_labels_to_actual_labels[i] = l

        label = [actual_label_to_task_labels[x] for x in label]
        
        model = fit_basic_model(data, label, verbose=verbose)
       
        #store for later stage
        transformers[task] = model
        task_label_lookup[task] = task_labels_to_actual_labels
        task_data[task] = (data,label)
                
        print('done_transformer %s' %task)
    print('training deciders')
    
    #this part trains voter, using transformers previously trained
    for task in range(total_tasks):
        d, l = task_data[task]
        transformed = []
        #partition current task data using all  models 
        for m in transformers.values():
            mm = Model(inputs=m.inputs, outputs=m.layers[-2].output)
            transformed.append(mm.predict(d))
        transformed = np.hstack(transformed)
        voter = fit_basic_model(transformed, l, verbose=verbose)
        voters[task] = voter
        print('done_voter %s' %task)
    
    return transformers, voters, task_label_lookup, task_data


In [None]:
transformers, voters, task_label_lookup, task_data  = \
        fit_all_task_models(train_dict, n_tasks=10, n_label_repeats=5, train_neg_exmaples=True)

total_labels 100
total number of tasks 50


In [71]:
#This is the decider part of task-unaware, 
#currenly it does argmax of sums
#TODO: exp w argmax of argmax, 
def vote_on_data(data, transforms, voters, task_label_lookup, total_classes, contains_neg=True):
    n_tasks = len(voters)
    n_classes = total_classes
    n_data = len(data)

    transformed = []
    for n in range(n_tasks):
        t = transformers[n]
        m = Model(inputs=t.inputs, outputs=t.layers[-2].output)
        transformed.append(m.predict(data))
    transformed  = np.hstack(transformed)

    all_prob = np.zeros((n_data, n_classes+1))
    for n in range(n_tasks):
        label = task_label_lookup[n]
        v = voters[n]
        vv = v.predict(transformed)
        for i, p in enumerate(vv):
            for j, pp in enumerate(p):
                all_prob[i][label[j]] += pp
    #cutout last class, because it is neg examples in every task
    return all_prob[:, :-1].argmax(axis=1)

    
    

In [75]:
#check how many times an label occurred in all tasks
occur = defaultdict(int)
for v in task_label_lookup.values():
    for vv in v.values():
        occur[vv] += 1
        
total_correct = 0
for c_class in range(100):
    data = np.array(test_dict[c_class])
    predict = vote_on_data(data, transformers, voters, task_label_lookup, 100, contains_neg=False)
    total_correct += np.sum( predict == [c_class]*len(data)) 
    print(c_class,  occur[c_class], get_accuracy(predict, [c_class]*len(data)))

print(total_correct / (100 * 100))

0 2 0.56
1 2 0.62
2 2 0.21
3 2 0.28
4 2 0.18
5 2 0.15
6 2 0.62
7 2 0.41
8 2 0.62
9 2 0.7
10 2 0.14
11 2 0.3
12 2 0.43
13 2 0.64
14 2 0.27
15 2 0.6
16 2 0.65
17 2 0.53
18 2 0.39
19 2 0.24
20 2 0.74
21 2 0.68
22 2 0.57
23 2 0.46
24 2 0.71
25 2 0.19
26 2 0.08
27 2 0.16
28 2 0.49
29 2 0.4
30 2 0.14
31 2 0.71
32 2 0.16
33 2 0.3
34 2 0.47
35 2 0.44
36 2 0.59
37 2 0.47
38 2 0.54
39 2 0.7
40 2 0.58
41 2 0.69
42 2 0.47
43 2 0.68
44 2 0.41
45 2 0.25
46 2 0.15
47 2 0.14
48 2 0.43
49 2 0.59
50 2 0.13
51 2 0.31
52 2 0.55
53 2 0.75
54 2 0.46
55 2 0.32
56 2 0.57
57 2 0.47
58 2 0.26
59 2 0.17
60 2 0.65
61 2 0.45
62 2 0.37
63 2 0.18
64 2 0.09
65 2 0.31
66 2 0.24
67 2 0.52
68 2 0.45
69 2 0.58
70 2 0.24
71 2 0.28
72 2 0.17
73 2 0.28
74 2 0.28
75 2 0.69
76 2 0.56
77 2 0.57
78 2 0.42
79 2 0.52
80 2 0.43
81 2 0.11
82 2 0.33
83 2 0.49
84 2 0.27
85 2 0.42
86 2 0.27
87 2 0.61
88 2 0.59
89 2 0.75
90 2 0.18
91 2 0.24
92 2 0.2
93 2 0.21
94 2 0.54
95 2 0.71
96 2 0.35
97 2 0.62
98 2 0.09
99 2 0.69
0.4187
