In [63]:
import numpy as np
import matplotlib.pyplot as plt
import itertools
import math
from sklearn.model_selection import train_test_split

shape = (100, 16)
n = shape[1]

np.random.seed(10)
X = np.random.randint(-5,6, size=shape) # Integers from -5 to 5
y = np.sum(X, axis=1) # Sum of each row
X = X + np.random.normal(0, .1, shape) # Add noise to input

In [54]:
from sklearn.linear_model import LinearRegression

reg = LinearRegression().fit(X, y)
reg.score(X, y), reg.coef_

(0.9993316174977624,
 array([0.98231329, 1.00293103, 1.01089872, 0.99780244, 1.01754801,
        1.01399901, 1.01117047, 1.0116978 , 0.97996713, 1.0128416 ,
        1.00504136, 1.01933667, 0.98243754, 0.98514289, 1.0177332 ,
        1.00559587]))

In [78]:
def divisive_shap_approx(Xs, y, model, beta, gamma_n, gamma_d, num_splits = 2):
    """Xs is a dictionary:
    {
        party1: X1,
        party2: X2,
        ...
    }
    """
    parties_to_split = {}
    for party, X in Xs.items():
        if X.shape[1] > np.emath.logn(beta, n):
            parties_to_split[party] = X
        else:
            shap = shapley_true(X, y, model)
            shap = shap*gamma_n/gamma_d
            results[party].append([X, shap])

    if len(parties_to_split) > 0:
        Xs_split = partition(parties_to_split, num_splits)
        gamma_n * np.mean([value(X, y, model) for X in parties_to_split.values()])
        gamma_d * np.mean([sum([value(Xi, y, model) for Xi in Xs]) for Xs in Xs_split.values()])

        for i in range(num_splits):
            divisive_shap_approx({party:Xs[i] for party, Xs in Xs_split.items()}, 
                                          y, 
                                          model,
                                          beta,
                                          gamma_n , 
                                          gamma_d)

def value(X, y, model):
    if X.size == 0:
        return 0
    else:
        reg = model.fit(X, y)
        return reg.score(X, y)
        

def findsubsets(s, n):
    return list(itertools.combinations(s, n))


def shapley_true(X, y, model):
    """Returns true shapley value of each feature (numpy array)"""
    N = X.shape[1]
    shap = np.zeros(N)
    for i in range(N):
        other_features = set(range(N))
        other_features.remove(i)
        subsets = [subset for j in range(N) for subset in findsubsets(other_features, j)]
        for subset in subsets:
            coeff = math.factorial(len(subset))*math.factorial(N - len(subset) - 1)/math.factorial(N)
            shap[i] += coeff*(value(X[:,list(subset) + [i]], y, model)-value(X[:, list(subset)], y, model))
    return np.array(shap)

def partition(Xs, num_splits):
    """Returns dict of lists:
    {
        party1: [X1, X2, ...],
        party2: [X1, X2, ...],
        ...
    }
    """

    return {party: [x.T for x in train_test_split(X.T, test_size=0.5)] for party, X in Xs.items()}

def main(Xs, y, model, beta):
    divisive_shap_approx(Xs, y, model, beta, 1, 1)
    #w = sum(shap)
    #vNs = [value(X, y, model) for X in Xs.values()]
    #if w != vN:
    #    if w == 0:
    #        shap = np.array([vN/X.shape[1]]*X.shape[1])
    #    gamma = vN/w
    #    shap = shap*gamma
    #return (X2, y2, shap)
    return results


In [79]:
shape = (10000, 64)
num_parties = 4
n = shape[1]//num_parties

np.random.seed(10)
X = np.random.randint(-5,6, size=shape) # Integers from -5 to 5
y = np.sum(X, axis=1) # Sum of each row
X = X + np.random.normal(0, .1, shape) # Add noise to input

Xs = {}
for i in range(num_parties):
 Xs["p" + str(i)] = X[:,i:i+n]

results = {key:[] for key in Xs.keys()}
main(Xs, y, LinearRegression(), n**(1/np.sqrt(n)))

#plt.hist(shap)
#plt.show()

{'p0': [[array([[ 3.92545364,  4.92380328, -3.84583502, -4.0976428 ],
          [ 0.90066968,  1.03906022,  5.04293242, -4.90046537],
          [ 2.95114941,  4.03105567, -3.81911554, -0.06162437],
          ...,
          [ 3.0000833 ,  5.02487563, -1.89748938, -3.05050071],
          [ 3.8413117 ,  4.97148022,  4.84060538,  2.8601002 ],
          [-1.09229269,  3.98906744,  3.03011106,  2.03591269]]),
   array([0.01472816, 0.01741692, 0.01622382, 0.01608128])],
  [array([[ 4.96104736, -4.87962581,  4.0850175 , -4.96012   ],
          [ 4.91619067, -0.0742365 , -4.07442025,  2.84860076],
          [ 2.07082564, -0.05582734, -0.00551704,  4.9011427 ],
          ...,
          [-2.10378392,  0.01785918, -5.10171298,  0.06495363],
          [ 0.83585879,  3.87813478,  3.71208701,  4.02933417],
          [-2.08912648, -1.16946176,  5.12913283,  0.98067211]]),
   array([0.0133642 , 0.01302671, 0.0155642 , 0.01856826])],
  [array([[ 3.03486127, -1.89619077, -5.01772872,  3.85588302],
      