In [None]:
import os
import numpy as np
import json
from collections import defaultdict
import matplotlib.pyplot as plt

class TreeData(): 
    def __init__(self, tree_index, leaf_cutoff):
        self.tree_index, self.leaf_cutoff = tree_index, leaf_cutoff
        self._open_tree_data()

        scales = [s for a,(s,p) in self.node_address.items()]
        self.max_scale = np.max(scales)
        self.min_scale = np.min(scales)
        self._node_heatmap, self._heatmap_map = None, None
    
    def _open_tree_data(self):
        tree_dir = "/localdata/sorel/covertrees/"
        with open(os.path.join(tree_dir, f"tree_{self.leaf_cutoff}_{self.tree_index}_child_parent.json")) as f:
            self.child_parent = {int(k):int(v) for k,v in json.load(f).items()}
        with open(os.path.join(tree_dir, f"tree_{self.leaf_cutoff}_{self.tree_index}_node_address.json")) as f:
            self.node_address = {int(k):v for k,v in json.load(f).items()}
        self.root = list(self.node_address.keys())[0]
        while self.root in self.child_parent:
            self.root = self.child_parent[self.root]
        self.tree_layout = defaultdict(list)
        for (child,parent) in self.child_parent.items():
            self.tree_layout[int(parent)].append(child)

    def _value_heatmap(self, vals):
        if self._node_heatmap is None:
            self._node_heatmap, self._heatmap_map, _ = self.create_heatmap(self.root)

        final_heatmap = np.zeros((self.node_heatmap.shape[0], self.node_heatmap.shape[1]))
        for ad, val in vals.items():
            ha = self.heatmap_map[ad]
            for h in ha:
                final_heatmap[h[0],h[1]] = val
        return np.array(final_heatmap)

    def create_heatmap(self, root):
        def square_array(array):
            max_width = 0
            for a in array:
                if len(a) > max_width:
                    max_width = len(a)
            for a in array:
                while len(a) < max_width:
                    a.append(0)
            return array
        
        node_heatmap = [[] for i in range(self.max_scale - self.min_scale + 1)]
        heatmap_map = defaultdict(list)
        included_nodes = set()
        unvisited_nodes = [root]
        while 0 < len(unvisited_nodes):
            node = unvisited_nodes.pop()
            included_nodes.add(node)
            if node in self.tree_layout:
                for child in self.tree_layout[node]:
                    unvisited_nodes.append(child)
            else:
                scale_index = max(self.node_address[node][0], self.min_scale)
                heatmap_map[node].append([self.max_scale - scale_index, len(node_heatmap[self.max_scale - scale_index])])
                node_heatmap[self.max_scale - scale_index].append(0)
                while node in self.child_parent:
                    node = self.child_parent[node]
                    scale_index = max(self.node_address[node][0], self.min_scale)
                    heatmap_map[node].append([self.max_scale - scale_index, len(node_heatmap[self.max_scale - scale_index])])
                    node_heatmap[self.max_scale - scale_index].append(0)
                node_heatmap = square_array(node_heatmap)
        return np.array(node_heatmap), heatmap_map, included_nodes

    def path_to_node(self, node):
        path = [node]
        while node in self.child_parent:
            node = self.child_parent[node]
            path.insert(0, node)
        return path
    
    def baseline(self, sample_rate):
        return BayesianBaseline(self, sample_rate)


In [None]:
class Attack():
    def __init__(self, baseline, attack, path):
        self.baseline, self.attack, self.true_path = baseline, attack, path

    def kl_divs_predict(self):
        max_val = 0.0
        max_addr = None
        for k,v in self.attack["node_trackers"].items():
            if self.baseline.kl_div_baseline(k) is None:
                continue

            max_kl_div, std_kl_div = self.baseline.kl_div_baseline(k)
            corrected_kl_div = v["kl_div"] - max_kl_div - self.baseline.kl_cor * std_kl_div  - self.baseline.kl_str
            if corrected_kl_div > max_val:
                #print(f'KL_DIV {corrected_kl_div} = {v["kl_div"]} - {max_kl_div} - {self.baseline.kl_cor} * {std_kl_div}  - {self.baseline.kl_str}')
                max_val = corrected_kl_div
                max_addr = k
        return max_addr

    def mlls_predict(self):
        max_val = 0.0
        max_addr = None
        for k,v in self.attack["node_trackers"].items():
            if self.baseline.mll_baseline(k) is None:
                continue
            min_mll, std_mll = self.baseline.mll_baseline(k)
            corrected_mll = min_mll - v["mll"] - self.baseline.mll_cor * std_mll - self.baseline.mll_str
            if corrected_mll > max_val:
                #print(f'MLL {corrected_mll} = {min_mll} - {v["mll"]} - {self.baseline.mll_cor} * {std_mll}  - {self.baseline.mll_str}')
                max_val = corrected_mll
                max_addr = k
        return max_addr

    def predict_attack(self):
        kl_prediction = self.kl_divs_predict()
        if kl_prediction is None:
            return self.mlls_predict()
        else:
            return kl_prediction
            
class BayesianBaseline():
    def __init__(self, *, sample_rate, tree=None, leaf_cutoff=None, tree_index=None):
        self.leaf_cutoff, self.tree_index, self.sample_rate = leaf_cutoff, tree_index, sample_rate
        self._tree = tree
        baseline_dir = "/localdata/sorel/covertrees/test_set_baselines"
        baseline_file = f"tree_{self.leaf_cutoff}_{self.tree_index}_baseline_{self.sample_rate}.json"
        with open(os.path.join(baseline_dir, baseline_file)) as f:
            self.baseline = json.load(f)
        self.baseline["node_baselines"] = {int(k): v for k,v in self.baseline["node_baselines"].items()}
        self._loo_violators = None
    
    @property
    def tree(self):
        if self._tree is None:
            self._tree = TreeData(self.tree_index, self.leaf_cutoff)

    def attacks(self, attack_rate, model_indices=range(5), attack_indices=range(10)):
        attacks = []
        attack_dir = "/localdata/sorel/covertrees/test_set_attack_results"
        for model_index in model_indices:
            for attack_index in attack_indices:
                try:
                    attack_filename = f"model_{model_index}_tree_{self.leaf_cutoff}_{self.tree_index}_attack_{attack_index}_{attack_rate}_{self.sample_rate}.json"
                    path_filename = f"model_{model_index}_tree_{self.leaf_cutoff}_{self.tree_index}_attack_{attack_index}_attack_path.json"
                    with open(os.path.join(attack_dir, attack_filename)) as f:
                        attack = json.load(f)
                    attack["node_trackers"] = {int(k): v for k,v in attack["node_trackers"].items()}
                    with open(os.path.join(attack_dir, path_filename)) as f:
                        path = [int(i) for i in json.load(f)]
                    attacks.append(Attack(self, attack, path))
                except:
                    pass
        return attacks

    def set_prediction_hypers(self,*,kl_str, kl_cor, mll_str, mll_cor):
        self.kl_str, self.kl_cor, self.mll_str, self.mll_cor = kl_str, kl_cor, mll_str, mll_cor

    def mll_baseline(self, address=None):
        if address is None:
            return self.baseline["overall_baseline"]
        else:
            if address in self.baseline["node_baselines"]:
                baseline = self.baseline["node_baselines"][address]
                return baseline["min_mll"], baseline["std_mll"]
            else:
                return None

    def kl_div_baseline(self, address=None):
        if address is None:
            return self.baseline["overall_baseline"]
        else:
            if address in self.baseline["node_baselines"]:
                baseline = self.baseline["node_baselines"][address]
                return baseline["max_kl_div"], baseline["std_kl_div"]
            else:
                return None


In [None]:
def stats_for_sample_rate(sample_rate, *, kl_str, kl_cor, mll_str, mll_cor):
    attack_rates = [0, 0.0001, 0.001, 0.01, 0.1, 1]
    successes = {ar: 0 for ar in attack_rates}
    penetration = {ar: 0 for ar in attack_rates}
    failures = {ar: 0 for ar in attack_rates}
    total = {ar: 0 for ar in attack_rates}

    overall_kl_div = {ar: [] for ar in attack_rates}
    overall_mll = {ar: [] for ar in attack_rates}
    for tree_index in range(48):
        baseline = BayesianBaseline(sample_rate=sample_rate, leaf_cutoff=500, tree_index=tree_index)
        baseline.set_prediction_hypers(kl_str=kl_str,kl_cor=kl_cor,mll_str=mll_str, mll_cor=mll_cor)
        for attack_rate in attack_rates:
            for attack in baseline.attacks(attack_rate):
                overall_kl_div[attack_rate].append(attack.attack["overall_tracker"]["kl_div"])
                overall_mll[attack_rate].append(attack.attack["overall_tracker"]["mll"])
                pred = attack.predict_attack()
                if pred in attack.true_path:
                    penetration[attack_rate] += float(attack.true_path.index(pred))/len(attack.true_path)
                    successes[attack_rate] += 1
                elif pred is not None:
                    failures[attack_rate] += 1
                total[attack_rate] += 1
    return successes, penetration, failures, total, overall_kl_div, overall_mll

In [None]:
stats_100000 = stats_for_sample_rate(100000,kl_str=40,kl_cor=10,mll_str=50, mll_cor=1.7)
stats_100000

In [None]:
stats_10000 = stats_for_sample_rate(10000, kl_str=6.5,kl_cor=10,mll_str=20, mll_cor=1.3)
stats_10000

In [None]:
stats_1000 = stats_for_sample_rate(1000, kl_str=8,kl_cor=7,mll_str=20, mll_cor=1.3)
stats_1000

In [None]:
def attack_stats_to_table(attack_stats):
    successes, penetration, failures, total, overall_kl_div, overall_mll = attack_stats
    attack_table = {}
    kl_div_table = {}
    mll_table = {}
    zero_false_positives = float(successes[0] + failures[0])/total[0]
    zero_true_positives = 0
    zero_mean_kl_div = np.mean(overall_kl_div[0])
    zero_std_kl_div = np.sqrt(np.var(overall_kl_div[0]))
    zero_mean_mll = np.mean(overall_mll[0])
    zero_std_mll = np.sqrt(np.var(overall_mll[0]))
    attack_table[0] = (zero_true_positives, zero_false_positives)
    kl_div_table[0] = (0, 1)
    mll_table[0] = (0, 1)
    for ar in successes.keys():
        if ar == 0:
            continue
        if successes[ar] > 0:
            mean_penetration = float(penetration[ar])/successes[ar]
        else:
            mean_penetration = None
        true_positives = float(successes[ar])/total[ar]
        false_positives = float(failures[ar])/total[ar]
        mean_kl_div = np.mean(overall_kl_div[ar])
        std_kl_div = np.sqrt(np.var(overall_kl_div[ar]))
        mean_mll = np.mean(overall_mll[ar])
        std_mll = np.sqrt(np.var(overall_mll[ar]))
        attack_table[ar] = (true_positives, false_positives, mean_penetration)
        kl_div_table[ar] = ((mean_kl_div-zero_mean_kl_div)/zero_mean_kl_div, std_kl_div/zero_std_kl_div)
        mll_table[ar] = ((mean_mll-zero_mean_mll)/zero_std_mll, std_mll/zero_std_mll)
    return attack_table, kl_div_table, mll_table
complete_attack_table, complete_kl_div_table, complete_mll_table = {}, {}, {}
complete_attack_table[1000], complete_kl_div_table[1000], complete_mll_table[1000] = attack_stats_to_table(stats_1000)
complete_attack_table[10000], complete_kl_div_table[10000], complete_mll_table[10000] = attack_stats_to_table(stats_10000)
complete_attack_table[100000], complete_kl_div_table[100000], complete_mll_table[100000] = attack_stats_to_table(stats_100000)

In [None]:
complete_attack_table

In [None]:
complete_kl_div_table

In [None]:
complete_mll_table