In [1]:
"""
    DS 300-001: Homework 5
    implementation of model extraction attack
"""

import numpy as np
from numpy.random import randn
from numpy.linalg import solve
from numpy.random import rand

# number of parameters in the LR model
num_param = 20


class LR:
    """
    a simple logistic regression model with randomized parameters

    """
    def __init__(self, num_param):
        self.num_param = num_param
        self.theta = randn(self.num_param)*4

    def query(self, x):
        assert np.array(x).shape == self.theta.shape or np.array(x).shape[1] == self.theta.shape[0]
        return 1./(1 + np.exp(-np.inner(x, self.theta)))

    def check(self, theta):
        assert np.array(theta).shape == self.theta.shape
        return np.linalg.norm(theta - self.theta, np.inf) <= 1e-5



def attack(model, num_param):
    """
    model extraction attack

    input:
        model (LR) -- target LR model, which provides two functions:
            (i) query(x), which returns the prediction for given input x
            (ii) check(theta), which check whether theta agrees with the model's parameters
        num_param (integer) -- number of parameters in the LR model

    output:
        theta (vector) -- the parameters guessed by the attacker
    """

    #!todo: create a set of queries
    a = rand(num_param,num_param)
    #a = np.array(a)
    # hint: use "randn" to create random numbers, put multiple queries in a matrix

    
    #!todo: query the model and get the predictions
    pred = model.query(a)
    # hint: call model.query


    #!todo: convert the predictions to linear equations
    linear = np.log(pred/(1-pred))
    # hint: refer to Page 22 of Lecture 17


    #!todo: solve the linear system to extract the parameters
    theta = solve(a,linear)
    # hint: use "solve": https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.solve.html


    return theta

In [2]:
if __name__ == "__main__":

    # a blackbox logistic regression model
    lr = LR(num_param)
    theta = attack(lr, num_param)

    if lr.check(theta):
        print('Attack succeeds!')
        print('Extracted parameters:', theta)
    else:
        print('Attack failed...')

Attack succeeds!
Extracted parameters: [ 2.46106757 -3.13104154 -0.64170121  2.60549873  2.1509599   5.60650149
  6.32897951 -3.98258144 -6.59689378 -0.71492531 -0.9337038  -5.28650377
 -6.18719492 -0.98616153 -4.88492285  0.62002699 -3.52002845  0.88270897
  1.13055934 -2.38527747]
