In [None]:
import numpy as np
from scipy.optimize import linear_sum_assignment
import random

###############################################################################################################################
# Optimally balanced entropy based sampling ###################################################################################
###############################################################################################################################
        
def OBE(X_train, Y_train, predY, cnum):
    if len(Y_train) < 1:
        return 0
    H = 0
    for i in range(cnum):
        if i == predY:
            Ni = X_train[Y_train == i,:].shape[0] + 1
        else:
            Ni = X_train[Y_train == i,:].shape[0]
        if Ni == 0:
            tmp = 0
        else:
            tmp = ((Ni/X_train.shape[0]) * np.log(Ni/X_train.shape[0]))
        H = H + tmp
    return -H

def OBEBS(X_train, Y_train, X_test, Y_test, probs, labels, labels_train, cnum,tt=0):
    rownum = probs.shape[0]
    maxH = -np.Infinity
    maxIdx = 0
    for i in range(rownum):
        H = 0
        for j in range(cnum):
            H = H + (OBE(X_train, Y_train, j, cnum) * probs[i,j])
        if (H > maxH):
            maxH = H
            maxIdx = i
    if len(X_train) > 0:
        X_train = np.insert(X_train, X_train.shape[0], X_test[maxIdx,:], axis=0)
        Y_train = np.insert(Y_train, len(Y_train), Y_test[maxIdx])
        labels_train = np.insert(labels_train, len(labels_train), labels[maxIdx])
    else:
        X_train = np.array([X_test[maxIdx,:]])
        Y_train = np.array([Y_test[maxIdx]])
        labels_train = np.array([labels[maxIdx]])
    X_test = np.delete(X_test, maxIdx, axis=0)
    Y_test = np.delete(Y_test, maxIdx)
    labels = np.delete(labels, maxIdx)
    return X_train, Y_train, X_test, Y_test, labels, labels_train, maxIdx, maxH

def get_costMX(Y_pool, labels, cnum):
    costMX = np.zeros((cnum,cnum))
    for i in range(len(Y_pool)):
        costMX[labels[i], Y_pool[i]] += 1
    return costMX

###############################################################################################################################
# Hungarian algorithm for best assignment
###############################################################################################################################

def hun_algo(costMX):
    ri, ci = linear_sum_assignment(-costMX)
    A = np.zeros((len(ri),len(ri)))
    for i,ind in enumerate(ci):
        A[i,ind] = 1
    return A

###############################################################################################################################
# LBE = (SUM(Diff(t)) - (k-1)*T ) / (k*T)   t=1 to T*k
###############################################################################################################################

def Diff_t_(cnum, X_train, Y_train):
    buckets=np.zeros(cnum)
    for i in range(cnum):
        buckets[i] = X_train[Y_train == i,:].shape[0]
    return np.max(buckets)-np.min(buckets)

###############################################################################################################################
# REL
###############################################################################################################################

def calculate_entropy(X_train, Y_train, cnum):
    if len(Y_train) < 1:
        return 0
    H = 0
    for i in range(cnum):
        Ni = X_train[Y_train == i,:].shape[0]
        if Ni == 0:
            tmp = 0
        else:
            tmp = ((Ni/X_train.shape[0]) * np.log(Ni/X_train.shape[0]))
        H = H + tmp
    return -H

def calculate_optimal_entropy(X_train, Y_train, cnum):
    if len(Y_train) < 1:
        return 0
    sample_num = X_train.shape[0]
    full_num = sample_num // cnum
    remainder_num = sample_num % cnum
    H = 0
    for i in range(cnum):
        if i < remainder_num:
            Ni = full_num+1
        else:
            Ni = full_num
        if Ni == 0:
            tmp = 0
        else:
            tmp = ((Ni/X_train.shape[0]) * np.log(Ni/X_train.shape[0]))
        H = H + tmp
    return -H

###############################################################################################################################
# Adaptive assignment algo.
###############################################################################################################################

def adaptive_assignment(cnum, labels, labels_train, X_pool, X_train, Y_train, Q, costMX, MID):
    strong_belongness = list()
    for iii in range(costMX.shape[0]):
        strong_belongness.append(np.argmax(costMX[iii,:]))
    max_ratio = -np.Inf
    chosen_group = 0
    fn1 = 0
    fn2 = 0
    fn1i = 0
    fn2i = 0
    for iii in range(cnum):
        n1 = 0
        n2 = 0
        n1i = 0
        n2i = 0
        for jjj in range(cnum):
            temp1 = np.zeros(len(Y_train))
            temp1[Y_train == jjj] = 2
            temp2 = np.ones(len(Y_train))
            temp2[labels_train == iii] = 2
            temp3 = X_train[temp1 == temp2,:].shape[0]
            if temp3 > n1:
                if n1 != 0:
                    n2 = n1
                    n2i = n1i
                n1 = temp3
                n1i = jjj
            elif temp3 > n2:
                n2 = temp3
                n2i = jjj
        if n1 != 0:
            if n2/n1 > max_ratio:
                max_ratio = n2/n1
                chosen_group = iii
                fn1 = n1
                fn2 = n2
                fn1i = n1i
                fn2i = n2i
    if (fn2 != 0) and (MID is not None):
        ## Divide
        temp1 = np.zeros(len(Y_train))
        temp1[Y_train == fn1i] = 2
        temp2 = np.ones(len(Y_train))
        temp2[labels_train == chosen_group] = 2
        temp3 = np.zeros(len(Y_train))
        temp3[Y_train == fn2i] = 2
        X_train_svm = np.zeros((2,X_train.shape[1]))
        Y_train_svm = np.zeros(2)
        X_train_svm1 = np.array(X_train[temp1 == temp2,:])
        X_train_svm2 = np.array(X_train[temp2 == temp3,:])
        X_train_svm[0,:] = X_train_svm1[0,:]
        X_train_svm[1,:] = X_train_svm2[0,:]
        Y_train_svm[0] = chosen_group
        Y_train_svm[1] = MID
        clf = SVC(gamma='auto', probability=True)
        clf.fit(X_train_svm, Y_train_svm)
        labels_train[temp1 == temp2] = chosen_group
        labels_train[temp2 == temp3] = MID
        if X_train[labels_train == chosen_group,:].shape[0] > (fn1+fn2):
            temp4 = np.zeros(len(Y_train))
            temp4[Y_train != fn1i] = 2
            temp5 = np.ones(len(Y_train))
            temp5[Y_train != fn2i] = 2
            temp6 = np.zeros(len(Y_train))
            temp6[temp4==temp5] = 2
            svm_predictions = clf.predict(X_train[temp6==temp2])
            svm_idx = np.where(temp6 == temp2)[0]
            for iii,jjj in enumerate(svm_idx):
                labels_train[jjj] = svm_predictions[iii]
        svm_predictions = clf.predict(X_pool[labels==chosen_group])
        svm_probabilities = clf.predict_proba(X_pool)
        svm_idx = np.where(labels==chosen_group)[0]
        for iii,jjj in enumerate(svm_idx):
            labels[jjj] = svm_predictions[iii]
        ### Update Q
        Q_temp = Q.copy()
        for iii in range(Q.shape[0]):
            for jjj in range(cnum):
                if jjj == chosen_group:
                    Q[iii,jjj] = Q_temp[iii,chosen_group]*svm_probabilities[iii,1]
                elif jjj == MID:
                    Q[iii,jjj] = Q_temp[iii,chosen_group]*svm_probabilities[iii,0]
                else:
                    Q[iii,jjj] = Q_temp[iii,jjj]
        for iii in range(Q.shape[0]):
            Q[iii,:] = (Q[iii,:] - Q[iii,:].min()) / (Q[iii,:] - Q[iii,:].min()).sum()
        MID = None
    elif (fn2 != 0) and (MID is None):
        ## Divide
        temp1 = np.zeros(len(Y_train))
        temp1[Y_train == fn1i] = 2
        temp2 = np.ones(len(Y_train))
        temp2[labels_train == chosen_group] = 2
        temp3 = np.zeros(len(Y_train))
        temp3[Y_train == fn2i] = 2
        X_train_svm = np.zeros((2,X_train.shape[1]))
        Y_train_svm = np.zeros(2)
        X_train_svm1 = np.array(X_train[temp1 == temp2,:])
        X_train_svm2 = np.array(X_train[temp2 == temp3,:])
        X_train_svm[0,:] = X_train_svm1[0,:]
        X_train_svm[1,:] = X_train_svm2[0,:]
        Y_train_svm[0] = chosen_group
        Y_train_svm[1] = cnum
        clf = SVC(gamma='auto', probability=True)
        clf.fit(X_train_svm, Y_train_svm)
        labels_train[temp1 == temp2] = chosen_group
        labels_train[temp2 == temp3] = cnum
        if X_train[labels_train == chosen_group,:].shape[0] > (fn1+fn2):
            temp4 = np.zeros(len(Y_train))
            temp4[Y_train != fn1i] = 2
            temp5 = np.ones(len(Y_train))
            temp5[Y_train != fn2i] = 2
            temp6 = np.zeros(len(Y_train))
            temp6[temp4==temp5] = 2
            svm_predictions = clf.predict(X_train[temp6==temp2])
            svm_idx = np.where(temp6 == temp2)[0]
            for iii,jjj in enumerate(svm_idx):
                labels_train[jjj] = svm_predictions[iii]
        if len(X_pool[labels==chosen_group]) > 0:
            svm_predictions = clf.predict(X_pool[labels==chosen_group])
            svm_probabilities = clf.predict_proba(X_pool)
            svm_idx = np.where(labels==chosen_group)[0]
            for iii,jjj in enumerate(svm_idx):
                labels[jjj] = svm_predictions[iii]
            ### Update Q
            Q_temp = np.zeros((Q.shape[0],cnum+1))
            for iii in range(Q.shape[0]):
                for jjj in range(cnum+1):
                    if jjj == chosen_group:
                        Q_temp[iii,jjj] = Q[iii,chosen_group]*svm_probabilities[iii,1]
                    elif jjj == cnum:
                        Q_temp[iii,jjj] = Q[iii,chosen_group]*svm_probabilities[iii,0]
                    else:
                        Q_temp[iii,jjj] = Q[iii,jjj]
            ## Merge 
            strong_belongness[chosen_group] = fn1i
            strong_belongness.append(fn2i)
            max_elements = 0
            gti1 = 0
            gti2 = 0
            for iii in range(len(strong_belongness)-1):
                for jjj in range(iii+1,len(strong_belongness)):
                    if strong_belongness[iii] == strong_belongness[jjj]:
                        if len(Y_train[Y_train == strong_belongness[iii]]) > max_elements:
                            max_elements = len(Y_train[Y_train == strong_belongness[iii]])
                            gti1 = iii
                            gti2 = jjj
            labels_train[labels_train == gti2] = gti1
            labels[labels == gti2] = gti1
            if gti2 != cnum:
                labels_train[labels_train == cnum] = gti2
                labels[labels == cnum] = gti2
            ### Update Q
            for iii in range(Q.shape[0]):
                for jjj in range(cnum+1):
                    if jjj == gti1: 
                        Q[iii,gti1] = Q_temp[iii,gti1] + Q_temp[iii,gti2]
                    elif (jjj == gti2) and (gti2 != cnum):
                        Q[iii,gti2] = Q_temp[iii,cnum]
                    elif jjj != cnum:
                        Q[iii,jjj] = Q_temp[iii,jjj]
            for iii in range(Q.shape[0]):
                Q[iii,:] = (Q[iii,:] - Q[iii,:].min()) / (Q[iii,:] - Q[iii,:].min()).sum()
        else:
            ## Simple merge 
            print(strong_belongness)
            for iii in range(len(strong_belongness)):
                if strong_belongness[iii] == fn2i:
                    print(iii)
                    labels_train[labels_train == cnum] = iii
                    break
    elif (fn2 == 0) and (MID is None):
        ## Merge 
        max_elements = 0
        gti1 = 0
        gti2 = 0
        for iii in range(len(strong_belongness)-1):
            for jjj in range(iii+1,len(strong_belongness)):
                if strong_belongness[iii] == strong_belongness[jjj]:
                    if len(Y_train[Y_train == strong_belongness[iii]]) > max_elements:
                        max_elements = len(Y_train[Y_train == strong_belongness[iii]])
                        gti1 = iii
                        gti2 = jjj
        labels_train[labels_train == gti2] = gti1
        labels[labels == gti2] = gti1
        ### Update Q
        Q_temp = Q.copy()
        for iii in range(Q.shape[0]):
            for jjj in range(cnum):
                if jjj == gti1: 
                    Q[iii,gti1] = Q_temp[iii,gti1] + Q_temp[iii,gti2]
                elif jjj == gti2:
                    Q[iii,gti2] = 0
                else:
                    Q[iii,jjj] = Q_temp[iii,jjj]
        MID = gti2
    elif (fn2 == 0) and (MID is not None):
        return labels, labels_train, Q, MID
    return labels, labels_train, Q, MID