## Demo of Secure Multi-Party Linear Regression

Source: Section 2 of *Secure multi-party linear regression at plaintext speed*

https://github.com/jbloom22/DASH

In [None]:
import numpy as np
import pandas as pd
from scipy.stats import t
from statsmodels.api import OLS

### SIMULATE data

In [None]:
np.random.seed(0)

K = 10

# Alice's data
N1 = 1000
y1 = np.random.randn(N1)
C1 = np.random.randn(N1, K)

# Bob's data
N2 = 2000
y2 = np.random.randn(N2)
C2 = np.random.randn(N2, K)

# Carla's data
N3 = 1500
y3 = np.random.randn(N3)
C3 = np.random.randn(N3, K)

### PRIVATE COMPUTATION - Compress

In [None]:
# Alice
yy1 = y1.T @ y1
Cty1 = C1.T @ y1
CtC1 = C1.T @ C1

# Bob
yy2 = y2.T @ y2
Cty2 = C2.T @ y2
CtC2 = C2.T @ C2

# Carla
yy3 = y3.T @ y3
Cty3 = C3.T @ y3
CtC3 = C3.T @ C3

### MULTI-PARTY COMPUTATION - Combine

Computation is now independent of the sample sizes. Practical security follows from non-invertibility of compression. For theoretical guarantees, do the following one, two, or three cells with secure multi-party computation.

In [None]:
D = N1 + N2 + N3 - K
yty = yy1 + yy2 + yy3
Cty = Cty1 + Cty2 + Cty3
CtC = CtC1 + CtC2 + CtC3

To limit secure computation to addition, parties could choose to share `D`, `yty`, `Cty`, and `CtC`, and compute the remaining cells privately. Note that this leaks info beyond the per-coefficient statistics and p-values.

In [None]:
invCtC = np.linalg.inv(CtC)
beta = invCtC @ Cty
residual_sq = (yty - beta @ CtC @ beta) / D
sigma_sq = np.diag(invCtC) * residual_sq

To reduce secure computation, parties could choose to share `beta` and `sigma_sq` and compute the remaining cell privately.

In [None]:
sigma = np.sqrt(sigma_sq)
tstat = beta / sigma
pval = 2 * t.cdf(-abs(tstat), D)

In [None]:
df = pd.DataFrame({'beta': beta,
                   'sigma': sigma, 
                   'tstat': tstat,
                   'pval': pval})

df

### VERIFY correctness

Computing results using OLS model from statsmodel API:

In [None]:
y = np.concatenate([y1 ,y2, y3])
C = np.concatenate([C1, C2, C3])

res = OLS(y, C, hasconst=False).fit()

In [None]:
df2 = pd.DataFrame({'beta': res.params,
                    'sigma': res.bse, 
                    'tstat': res.tvalues, 
                    'pval': res.pvalues})

df2

Verify agreement up to 10 digits after the decimal point:

In [None]:
df = df.apply(lambda x: round(x, 10))
df2 = df2.apply(lambda x: round(x, 10))
np.array(df == df2).all() # Returns TRUE