In [None]:
import os
import numpy as np
import json
import math
from collections import defaultdict
import matplotlib.pyplot as plt
   

class TreeData(): 
    def __init__(self, tree_index, leaf_cutoff):
        self.tree_index, self.leaf_cutoff = tree_index, leaf_cutoff
        self._open_tree_data()

        scales = [s for a,(s,p) in self.node_address.items()]
        self.max_scale = np.max(scales)
        self.min_scale = np.min(scales)
        self.node_heatmap, self.heatmap_map, _ = self.create_heatmap(self.root)
        self._open_baselines()

    def _open_tree_data(self):
        tree_dir = "/localdata/sorel/covertrees/"
        with open(os.path.join(tree_dir, f"tree_{self.leaf_cutoff}_{self.tree_index}_child_parent.json")) as f:
            self.child_parent = {int(k):int(v) for k,v in json.load(f).items()}
        with open(os.path.join(tree_dir, f"tree_{self.leaf_cutoff}_{self.tree_index}_node_address.json")) as f:
            self.node_address = {int(k):v for k,v in json.load(f).items()}
        self.root = list(self.node_address.keys())[0]
        while self.root in self.child_parent:
            self.root = self.child_parent[self.root]
        self.tree_layout = defaultdict(list)
        for (child,parent) in self.child_parent.items():
            self.tree_layout[int(parent)].append(child)

    def _open_baselines(self):
        baseline_dir = "/localdata/sorel/covertrees/test_set_baselines/"
        tree_prefix = f"tree_{self.leaf_cutoff}_{self.tree_index}_baseline_"
        filelist = os.listdir(baseline_dir)
        self.baselines = {}
        for fname in filelist:
            if fname.startswith(tree_prefix):
                sample_rate = fname[len(tree_prefix):-len(".json")].split("_")[0]
                with open(os.path.join(baseline_dir, fname)) as f:
                    self.baselines[int(sample_rate)] = json.load(f)

    def _value_heatmap(self, vals):
        final_heatmap = np.zeros((self.node_heatmap.shape[0], self.node_heatmap.shape[1]))
        for ad, val in vals.items():
            ha = self.heatmap_map[ad]
            for h in ha:
                final_heatmap[h[0],h[1]] = val
        return np.array(final_heatmap)

    def create_heatmap(self, root):
        def square_array(array):
            max_width = 0
            for a in array:
                if len(a) > max_width:
                    max_width = len(a)
            for a in array:
                while len(a) < max_width:
                    a.append(0)
            return array
        
        node_heatmap = [[] for i in range(self.max_scale - self.min_scale + 1)]
        heatmap_map = defaultdict(list)
        included_nodes = set()
        unvisited_nodes = [root]
        while 0 < len(unvisited_nodes):
            node = unvisited_nodes.pop()
            included_nodes.add(node)
            if node in self.tree_layout:
                for child in self.tree_layout[node]:
                    unvisited_nodes.append(child)
            else:
                scale_index = max(self.node_address[node][0], self.min_scale)
                heatmap_map[node].append([self.max_scale - scale_index, len(node_heatmap[self.max_scale - scale_index])])
                node_heatmap[self.max_scale - scale_index].append(0)
                while node in self.child_parent:
                    node = self.child_parent[node]
                    scale_index = max(self.node_address[node][0], self.min_scale)
                    heatmap_map[node].append([self.max_scale - scale_index, len(node_heatmap[self.max_scale - scale_index])])
                    node_heatmap[self.max_scale - scale_index].append(0)
                node_heatmap = square_array(node_heatmap)
        return np.array(node_heatmap), heatmap_map, included_nodes

    def attack_set(self, attack_index, model_index):
        return TreeAttackSet(self, attack_index, model_index)

    def baseline(self, sample_rate):
        return Baseline(self, self.baselines[sample_rate])

    def path_to_node(self, node):
        path = [node]
        while node in self.child_parent:
            node = self.child_parent[node]
            path.insert(0, node)
        return path

class TreeAttackSet():
    def __init__(self, tree, attack_index, model_index):
        self.tree = tree
        self.attack_index, self.model_index = attack_index, model_index
        attack_dir = "/localdata/sorel/covertrees/test_set_attack_results/"
        attack_prefix = f"model_{model_index}_tree_{self.tree.leaf_cutoff}_{self.tree.tree_index}_attack_{attack_index}_"
        filelist = os.listdir(attack_dir)
        self.attack_set = defaultdict(dict)
        for fname in filelist:
            if fname.startswith(attack_prefix) and "path" not in fname:
                attack_rate, sample_rate = fname[len(attack_prefix):-len(".json")].split("_")
                with open(os.path.join(attack_dir, fname)) as f:
                    self.attack_set[float(attack_rate)][int(sample_rate)] = json.load(f)

        with open(os.path.join(attack_dir, attack_prefix + "attack_path.json")) as f:
            self.path = [int(i) for i in json.load(f)]

        self.attack_rates = list(self.attack_set.keys())
        self.attack_rates = sorted(self.attack_rates)

        self.sample_rates = list(self.attack_set[self.attack_rates[0]].keys())
        self.sample_rates = sorted(self.sample_rates)

    def tree_plot(self, val_type):
        fig, ax = plt.subplots(len(self.attack_set),len(self.attack_set[0]))
        for i,ar in enumerate(self.attack_rates):
            for j,sr in enumerate(self.sample_rates):
                all_results = self.attack(ar,sr).val_map(val_type)
                pos = ax[i,j].imshow(self.tree._value_heatmap(all_results), aspect="auto")
                ax[i,j].set_xlabel(f"{ar}, {sr}")
                fig.colorbar(pos, ax=ax[i,j])
        fig.set_size_inches(30, 40)
        plt.show()

    def corrected_tree_plot(self, val_type):
        fig, ax = plt.subplots(len(self.attack_set),len(self.attack_set[0]))
        for i,ar in enumerate(self.attack_rates):
            for j,sr in enumerate(self.sample_rates):
                all_results = self.attack(ar,sr).corrected_val_map(val_type)
                pos = ax[i,j].imshow(self.tree._value_heatmap(all_results), aspect="auto")
                ax[i,j].set_xlabel(f"{ar}, {sr}")
                fig.colorbar(pos, ax=ax[i,j])
        fig.set_size_inches(50, 40)
        plt.show()

    def attack(self, attack_rate, sample_rate):
        results = self.attack_set[attack_rate][sample_rate]
        return TreeAttack(self, self.tree, results, self.tree.baseline(sample_rate))

    def true_attack_path(self):
        return self.path

class TreeAttack():
    def __init__(self, attack, tree, results, baseline):
        self.tree, self.attack, self.results, self.baseline = tree, attack, results, baseline

    def mll(self):
        self.results["overall_tracker"]["mll"]

    def kl_div(self):
        self.results["overall_tracker"]["kl_div"]

    def tree_plot(self, val_type):
        fig, ax = plt.subplots(1)
        all_results = self.val_map(val_type)
        pos = ax.imshow(self.tree._value_heatmap(all_results), aspect="auto")
        fig.colorbar(pos, ax=ax)
        fig.set_size_inches(50, 20)
        plt.show()
    
    def corrected_tree_plot(self, val_type):
        fig, ax = plt.subplots(1)
        all_results = self.corrected_val_map(val_type)
        pos = ax.imshow(self.tree._value_heatmap(all_results), aspect="auto")
        fig.colorbar(pos, ax=ax)
        fig.set_size_inches(50, 20)
        plt.show()

    def corrected_val_map(self, val_type, correction_factor = 0):
        return {int(k): self.baseline(int(k), val_type, r[val_type], correction_factor) for k,r in self.results["node_trackers"].items()}
    
    def val_map(self, val_type):
        return {int(k): r[val_type] for k,r in self.results["node_trackers"].items()}

    def true_attack_path(self):
        return self.attack.path

class Baseline():
    def __init__(self, tree, baseline):
        self.node_baselines = {int(k):v for k,v in baseline["node_baselines"].items()}
        self.overall_baseline = baseline["overall_baseline"]
        self.tree, self.baseline = tree, baseline
    
    def __call__(self, address, val_type, value, correction_factor = 0):
        if address in self.node_baselines:
            base_mean = self.node_baselines[address][f"mean_{val_type}"]
            base_max = self.node_baselines[address][f"max_{val_type}"]
            return value - base_max - correction_factor*(base_max - base_mean)
        else:
            return value

    def overall(self, val_type):
        return self.overall_baseline[f"mean_{val_type}"], self.overall_baseline[f"max_{val_type}"]

class AllData():
    def __init__(self, limit=48):
        tree_dir = "/localdata/sorel/covertrees/"
        attack_dir = "/localdata/sorel/covertrees/test_set_attack_results/"

        self.trees = {}
        self.attack_sets = {}
        self.attacks = defaultdict(lambda: defaultdict(list))
        for tree_name in os.listdir(tree_dir):
            if "test" in tree_name:
                continue
            p = tree_name.split(".")[0].split("_")
            leaf_cutoff, tree_index = int(p[1]), int(p[2])
            if tree_index < limit:
                self.trees[(tree_index, leaf_cutoff)] = TreeData(tree_index, leaf_cutoff)
        
        for attack_name in os.listdir(attack_dir):
            if "path" in attack_name:
                continue
            p = attack_name[:-len(".json")].split("_")
            model_index, leaf_cutoff, tree_index, attack_index, attack_rate, sample_rate = int(p[1]), int(p[3]), int(p[4]), int(p[6]), float(p[7]), int(p[8])
            if (tree_index, leaf_cutoff) not in self.trees:
                continue
            if (tree_index, leaf_cutoff, model_index, attack_index) in self.attack_sets:
                attack_set = self.attack_sets[(tree_index, leaf_cutoff, model_index, attack_index)]
            else:
                attack_set = self.trees[(tree_index, leaf_cutoff)].attack_set(attack_index, model_index)
                self.attack_sets[(tree_index, leaf_cutoff, model_index, attack_index)] = attack_set
            self.attacks[attack_rate][sample_rate].append(attack_set.attack(attack_rate, sample_rate))



In [None]:
attack_data = AllData(5)

In [None]:
import pandas as pd
def predict_attack(attack_vals, correction):
    max_val = 0.0
    max_addr = None
    for k,v in attack_vals.items():
        if v - correction > max_val:
            max_val = v
            max_addr = k
    return max_addr

def prediction_efficacy(value_type, prediction_correction_factor, baseline_correction_factor):
    attack_rates = [0.0, 0.0001, 0.001, 0.01, 0.1, 1.0]
    sample_rates = [1000, 10000, 100000]

    efficacy = defaultdict(lambda: defaultdict(list))
    failures = defaultdict(lambda: defaultdict(list))
    misses = defaultdict(lambda: defaultdict(list))
    for ar in attack_rates:
        for sr in sample_rates:
            success = 0
            failure = 0
            total = 0
            for attack in attack_data.attacks[ar][sr]:
                predicted_attack = predict_attack(attack.corrected_val_map(value_type, baseline_correction_factor), prediction_correction_factor)
                if predicted_attack in attack.true_attack_path():
                    success += 1
                elif predicted_attack is not None:
                    failure += 1
                    failures[ar][sr].append(attack)
                else:
                    misses[ar][sr].append(attack)
                total += 1
            efficacy[ar][sr] = (float(success)/ total, float(failure)/total)

    return pd.DataFrame(efficacy)

In [None]:
def paired_prediction_efficacy(hyper):
    attack_rates = [0.0, 0.0001, 0.001, 0.01, 0.1, 1.0]
    sample_rates = [1000, 10000, 100000]

    efficacy = defaultdict(lambda: defaultdict(list))
    for ar in attack_rates:
        for sr in sample_rates:
            success = 0
            failure = 0
            total = 0
            for attack in attack_data.attacks[ar][sr]:
                kl_predicted_attack = predict_attack(attack.corrected_val_map("kl_div", hyper[sr]["kl_core"]), hyper[sr]["kl_str"])
                if kl_predicted_attack in attack.true_attack_path():
                    success += 1
                elif kl_predicted_attack is not None:
                    failure += 1
                else:
                    mll_predicted_attack = predict_attack(attack.corrected_val_map("mll", hyper[sr]["mll_core"]), hyper[sr]["mll_str"])
                    if mll_predicted_attack in attack.true_attack_path():
                        success += 1
                    elif mll_predicted_attack is not None:
                        failure += 1
                total += 1
            efficacy[ar][sr] = (float(success)/ total, float(failure)/total)

    return pd.DataFrame(efficacy)
hyper = {
    1000:  {"kl_str": 3, "kl_core":4.5, "mll_str":11, "mll_core":4.5},
    10000: {"kl_str": 5, "kl_core":3, "mll_str":14, "mll_core":3.5},
    100000: {"kl_str": 6, "kl_core":3, "mll_str":17, "mll_core":2.5}}
paired_prediction_efficacy(hyper)

In [None]:
from sklearn.linear_model import LinearRegression
def ml_kl_scatter(ar,sr):
    X = []
    Y = []
    A = []
    B = []
    fix, (ax1, ax2) = plt.subplots(2)
    for attack in attack_data.attacks[ar][sr]:
        for a, v in attack.val_map("kl_div").items():
            if a in attack.baseline.node_baselines and v - attack.baseline.node_baselines[a]["max_kl_div"] > 0:
                baseline = attack.baseline.node_baselines[a]
                Y.append(v - baseline["max_kl_div"])
                X.append(baseline["max_kl_div"] - baseline["mean_kl_div"])
        for a, v in attack.val_map("mll").items():
            if a in attack.baseline.node_baselines and v - attack.baseline.node_baselines[a]["max_mll"] > 0:
                baseline = attack.baseline.node_baselines[a]
                B.append(v - baseline["max_mll"])
                A.append(baseline["max_mll"] - baseline["mean_mll"])
                
    X = np.array(X)
    Y = np.array(Y)
    A = np.array(A)
    B = np.array(B)
    X = X.flatten()
    ax1.scatter(A, B, color="orange")
    ax2.scatter(X, Y)
    plt.show()


In [None]:
ml_kl_scatter(0.0,1000)

In [None]:
ml_kl_scatter(0.0,10000)

In [None]:
ml_kl_scatter(0.0,100000)
