In [6]:
import numpy as np
import pandas as pd

from svm_proxy import *
from sklearn.metrics import accuracy_score

In [None]:
class test_proxy:

    def __init__(self, proxy, true_weights, true_bias, X_proj = None, 
                 y_proj = None, X_valid = None, y_valid = None, logs):
        """
        inputs   : proxy = the incremental model
                   true_weights, true_intercept = true DB params
                   X_proj, y_proj = dataset for projection metric
                   X_valid, y_valid = dataset for accuracy metric
                   log = list of (string) metrics (min_dist_DB, ..)
        """

        self.log = log
        self.proxy = proxy
        self.true_weights_vec = np.insert(true_weights, 0, true_bias)

        if 'min_dist_DB' in logs:
            self.min_dist_DB = 1e10

        # stored datasets
        if isinstance(X_proj, np.ndarray):
            self.avg_L2_proj = None
            self.X_proj, self.y_proj  = np.c_[np.ones((X_proj.shape[0], 1)), X_proj], y_proj

        if isinstance(X_valid, np.ndarray):
            self.X_valid, self.y_valid = np.c_[np.ones((X_valid.shape[0], 1)), X_valid], y_valid


    def dist_true_DB(self, x):
        fx = abs(np.dot(self.true_weights, x) + self.intercept)
        return fx / np.linalg.norm(self.true_weights)


    def steal(self, attacker, max_queries): 

        for query in range(max_queries):
            x, y, attacker_acc = attacker.query_fit() # query and improve fit; return the query point and accuracy
            proxy.update(x, y)

            # 1. minimum distance from hyperplane
            if 'min_dist_DB' in self.logs:
                self.min_dist_DB = min(self.min_dist_DB, dist_true_DB(x))
                # log in wandb

            # 2. accuracy
            if isinstance(self.X_valid, np.ndarray): 
                self.proxy_acc = accuracy_score(self.proxy.predict(X_val), y_val)
                # log in wandb

            # 3. projection metric
            if isinstance(self.X_proj, np.ndarray):
                proxy_weight_vec = np.insert(self.proxy.weights, 0, self.proxy.bias)                  
                X_proxy_proj = self.X_proj - (self.X_proj @ proxy_weight_vec) @ proxy_weight_vec.T
                X_true_proj  = self.X_proj - (self.X_proj @ self.true_weight_vec) @ self.true_weight_vec.T

                self.avg_L2_proj  = np.mean(np.linalg.norm(X_proxy_proj - X_true_proj, axis = 1)
                # log in wandb

            # 4. other metrics
                