In [None]:
import jax
from jax import numpy as jnp
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText

In [None]:
from jax import make_jaxpr

In [None]:
path = "./20-models/"
targets = np.load(path + "targets-multiclass-model(1)-labels-(mode:random)-(threshold:300.0)-(sigma-gnmax:40.0)-(sigma-threshold:200.0)-(budget:16.00)-transfer-.npy").astype(float)
# raw_queries = np.load(path + "raw-queries-multiclass-model(1)-labels-(mode:random)-(threshold:300.0)-(sigma-gnmax:40.0)-(sigma-threshold:200.0)-(budget:16.00)-transfer-.npy")
raw_votes = np.load(path + "model(1)-raw-votes-mode-random-vote-type-discrete.npy").astype(float)
# raw_votes_attacker = np.load(path + "model(1)-raw-votes-(mode-random)-dataset-colormnist-attacker-.npy")
# aggreagted_labels = np.load(path + "aggregated-labels-multiclass-model(1)-labels-(mode:random)-(threshold:300.0)-(sigma-gnmax:40.0)-(sigma-threshold:200.0)-(budget:16.00)-transfer-.npy")

In [None]:
def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

In [None]:
key = jax.random.PRNGKey(0)

In [None]:
def query(sigma_threshold, sigma_gnmax, threshold, subkey1, subkey2, func):
    # sigma_threshold = 50
    # sigma_gnmax = 5.0
    # threshold = 2.0
    num_classes = 10
    num_samples = raw_votes.shape[0]
    votes=raw_votes

    # Threshold mechanism
    noise_threshold = sigma_threshold * jax.random.normal(subkey1, [num_samples])
    vote_counts = jnp.max(votes, axis=1)
    # print(vote_counts)
    noisy_vote_counts = vote_counts + noise_threshold
    # print(max(noisy_vote_counts), min(noisy_vote_counts))
    answered =  jax.vmap(lambda vote_count: jax.lax.cond(vote_count > threshold, threshold, func, threshold, lambda x: 0.0))(noisy_vote_counts)
    # print(answered)
    # return answered
    # return jnp.sum(answered).astype(float)
    # GNMax mechanism
    noise_gnmax = sigma_gnmax * jax.random.normal(subkey2, [num_samples, num_classes])
    preds = (votes + noise_gnmax).argmax(axis=1).astype(float)

    # return jnp.sqrt(jnp.power(answered.T * (preds - targets), 2).sum())
    preds_one_hot = one_hot(preds, 10)
    targets_one_hot = one_hot(targets, 10)
    accuracy = jnp.sum(answered[:, None]* (preds_one_hot * targets_one_hot))/num_samples
    return accuracy
key, subkey1, subkey2 = jax.random.split(key, 3)
query(50.0, 5.0, 2.0, subkey1, subkey2, lambda x: 1.0)

In [None]:
query_ = lambda t: query(50.0, 5.0, t, subkey1, subkey2, lambda x: 1.0)
query_jitted = jax.jit(query_)
eps=1e-0

In [None]:
def finite_diff(func, x, eps=1e-4):
    return (func(x + eps/2) - func(x - eps/2))/eps

In [None]:
finite_diff(query_, 2.0, eps=1)

In [None]:
eps_list = jnp.linspace(1e-4,1, 10000)
query_fdiff_eps = jax.vmap(lambda eps: finite_diff(query_, 20.0, eps=eps))(eps_list) 

In [None]:
fig, ax = plt.subplots(dpi=150)
sns.lineplot(x=eps_list, y=query_fdiff_eps, ax=ax)
ax.set_ylabel(r"$\frac{f(x_0 + \frac{\epsilon}{2}) - f(x_0 - \frac{\epsilon}{2})}{\epsilon}$")
ax.set_xlabel(r"$\epsilon$")
fig.set_dpi(150)

In [None]:
eps = 1e-4
query_fdif = (query_(2. + eps/2) - query_(2. - eps/2))/eps
query_fdif

In [None]:
query_driv = jax.grad(lambda t: query(50.0, 5.0, t, subkey1, subkey2, lambda x: 1.0))
query_driv(2.0)

In [None]:
def _plot(columns=["Threshold", "Accuracy"], argnum=2):
    fig, axes = plt.subplots(2, 3, sharex=True, dpi=150, figsize=(20, 5))
    def _subplot(func, ax_id, title, ):
        a_range = jnp.logspace(-7, 2, num=100000)
        in_axes = [0 if _id == argnum else None for _id in range(3)] + [None, None, None]
        defaults = [20.0, 5.0, 2.0]
        _inputs = [a_range if _id == argnum else defaults[_id] for _id in range(3)]
        loss_over_a_range = jax.vmap(query, in_axes=tuple(in_axes))(*_inputs, subkey1, subkey2, func)

        data = pd.DataFrame(jnp.c_[a_range, loss_over_a_range].__array__(), columns=columns)

        sns.lineplot(data=data, y=columns[1], x=columns[0], ax=axes[0, ax_id])

        query_driv = jax.grad(lambda t: query(50.0, 5.0, t, subkey1, subkey2, func))
        
        loss_driv_over_a_range = jax.vmap(query_driv)(a_range)
        g = sns.lineplot(data=pd.DataFrame(jnp.c_[a_range, loss_driv_over_a_range].__array__(), 
                                       columns=[columns[0], f"$grad. {columns[1]}$"]), 
                     x=columns[0], y=f"$grad. {columns[1]}$", ax=axes[1, ax_id])
        at = AnchoredText(f"Max Val (Thr. > 20)={loss_driv_over_a_range[a_range > 20].max()}", prop=dict(size=10), frameon=True, loc='lower right')
        at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
        axes[1, ax_id].add_artist(at)
        axes[0, ax_id].set_title(title)

    _subplot(lambda x: 1.0, 0, "Step")
    _subplot(lambda x: jax.scipy.special.expit(x), 1, "Logit")
    _subplot(lambda x: x, 2, "Relu")
    # return fig

In [None]:
_plot()

In [None]:
_plot(columns=["sigma_1", "Accuracy"], argnum=0)

In [None]:
_plot(columns=["sigma_2", "Accuracy"], argnum=1)

# Query Iter

In [None]:
def query_iter(sigma_threshold, sigma_gnmax, threshold, subkey1, subkey2, func):
    # sigma_threshold = 50
    # sigma_gnmax = 5.0
    # threshold = 2.0
    
    num_classes = 10
    num_samples = raw_votes.shape[0]
    votes=raw_votes
    noise_threshold = sigma_threshold * jax.random.normal(subkey1, [num_samples])
    noise_gnmax = sigma_gnmax * jax.random.normal(subkey2, [num_samples, num_classes])
    # _ids = jnp.arange(num_samples)
    _shape = (1000, 10, 1)
    data = jax.lax.concatenate([jnp.broadcast_to(targets[:, None, None], _shape), 
                                jnp.broadcast_to(votes[:, :, None], _shape), 
                                jnp.broadcast_to(noise_threshold[:, None, None], _shape),
                                jnp.broadcast_to(noise_gnmax[:, :, None], _shape)], 2)

    def _predict(acc, _data):
        # print(_data.shape)
        _target = _data[0, 0]
        _vote = _data[:, 1]
        _noise_threshold = _data[0, 2]
        _noise_gnmax = _data[: ,3]
        vote_count = _vote.max()
        noisy_vote_count = vote_count + _noise_threshold
        answered = jax.lax.cond(noisy_vote_count > threshold, threshold, func, threshold, lambda x: 0.0)
        pred = (_vote + _noise_gnmax).argmax()
        # preds[_id] = answered * pred
        # progress.at[acc + answered * pred
        # progress[1] = num_answered + answered
        acc = acc + answered * (pred==_target).astype(int)
        return acc, answered

    preds = jax.lax.scan(_predict, jnp.zeros((1,)), data, length=len(votes))
    accuracy = preds[0]/num_samples
    return accuracy[0]

In [None]:
query_iter(50.0, 5.0, 0.22, subkey1, subkey2, lambda x: 1.0)

In [None]:
query_iter_driv = jax.grad(lambda t: query_iter(50.0, 5.0, t, subkey1, subkey2, lambda x: 1.0))

In [None]:
query_iter_driv(20.0)

In [None]:
def _plot(columns=["Threshold", "Accuracy"], argnum=2):
    fig, axes = plt.subplots(2, 3, sharex=True, dpi=150, figsize=(20, 5))
    def _subplot(func, ax_id, title, ):
        a_range = jnp.logspace(-7, 2, num=10000)
        in_axes = [0 if _id == argnum else None for _id in range(3)] + [None, None, None]
        defaults = [20.0, 5.0, 2.0]
        _inputs = [a_range if _id == argnum else defaults[_id] for _id in range(3)]
        loss_over_a_range = jax.vmap(query_iter, in_axes=tuple(in_axes))(*_inputs, subkey1, subkey2, func)

        data = pd.DataFrame(jnp.c_[a_range, loss_over_a_range].__array__(), columns=columns)

        sns.lineplot(data=data, y=columns[1], x=columns[0], ax=axes[0, ax_id])

        query_driv = jax.grad(lambda t: query_iter(50.0, 5.0, t, subkey1, subkey2, func))
        
        loss_driv_over_a_range = jax.vmap(query_driv)(a_range)
        g = sns.lineplot(data=pd.DataFrame(jnp.c_[a_range, loss_driv_over_a_range].__array__(), 
                                       columns=[columns[0], f"$grad. {columns[1]}$"]), 
                     x=columns[0], y=f"$grad. {columns[1]}$", ax=axes[1, ax_id])
        at = AnchoredText(f"Max Val (Thr. > 20)={loss_driv_over_a_range[a_range > 20].max()}", prop=dict(size=10), frameon=True, loc='lower right')
        at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
        axes[1, ax_id].add_artist(at)
        axes[0, ax_id].set_title(title)

    _subplot(lambda x: 1.0, 0, "Step")
    _subplot(lambda x: jax.scipy.special.expit(x), 1, "Logit")
    _subplot(lambda x: x, 2, "Relu")
    # return fig

In [None]:
_plot()

In [None]:
_plot(columns=["sigma_1", "Accuracy"], argnum=0)

In [None]:
_plot(columns=["sigma_2", "Accuracy"], argnum=1)

# FairPATE

In [None]:
sensitives = np.random.choice(np.arange(3).astype(float), (1000,), p=[0.1, 0.3, 0.6])

In [None]:
from functools import partial
jit_static1 = partial(jax.jit, static_argnums=2)

In [None]:
key, subkey1, subkey2 = jax.random.split(key, 3)

In [None]:
def query_fair_iter(sigma_threshold, sigma_gnmax, threshold, max_fairness_violation, min_group_count, subkey1, subkey2):
    # sigma_threshold = 50
    # sigma_gnmax = 5.0
    # threshold = 2.0
    # max_fairness_violation = 0.2
    # min_group_count = 50
    
    func = lambda x: 1.0

    num_classes = 10
    num_sensitive_attributes = 3
    num_samples = raw_votes.shape[0]
    votes=raw_votes
    noise_threshold = sigma_threshold * jax.random.normal(subkey1, [num_samples])
    noise_gnmax = sigma_gnmax * jax.random.normal(subkey2, [num_samples, num_classes])
    _shape = (1000, 10, 1)

    data = jax.lax.concatenate([jnp.broadcast_to(targets[:, None, None], _shape), 
                                  jnp.broadcast_to(sensitives[:, None, None], _shape), 
                                  jnp.broadcast_to(votes[:, :, None], _shape), 
                                  jnp.broadcast_to(noise_threshold[:, None, None], _shape),
                                  jnp.broadcast_to(noise_gnmax[:, :, None], _shape)], 2)

    def _calculate_gaps(sensitive_group_count, pos_classified_group_count):
        all_members = jnp.sum(sensitive_group_count)
        all_pos_classified_group_count = jnp.sum(pos_classified_group_count)
        dem_parity = jnp.divide(pos_classified_group_count, sensitive_group_count)
        others_count = all_members - sensitive_group_count
        others_pos_classified_group_count = all_pos_classified_group_count - pos_classified_group_count
        dem_parity_others = jnp.divide(others_pos_classified_group_count, others_count)
        gaps = dem_parity - dem_parity_others
        return gaps

    def _apply_fairness_constraint(pred, sensitive, answered, sensitive_group_count, pos_classified_group_count):
        gaps = _calculate_gaps(sensitive_group_count, pos_classified_group_count)
        sensitive_one_hot = (jnp.arange(num_sensitive_attributes) == sensitive).astype(float)
        sensitive_group_count_per_z = sensitive_one_hot.dot(sensitive_group_count)
        pos_classified_group_count_per_z = sensitive_one_hot.dot(pos_classified_group_count)
        answered = jax.lax.cond(sensitive_one_hot.dot(sensitive_group_count) < min_group_count, 
                             (answered, pred, gaps), lambda x: x[0],
                             (answered, pred, gaps), lambda x: jax.lax.cond(x[1] == 0.0, 
                                                                   x, lambda y: y[0],
                                                                   x, lambda y: jax.lax.cond(sensitive_one_hot.dot(y[2]) < max_fairness_violation,
                                                                                                     y, lambda z: z[0],
                                                                                                     y, lambda z: 0.0)
                                                                  )
                           )

        sensitive_group_count = jax.lax.cond(answered == 1.,
                                         sensitive_group_count, lambda x: x+sensitive_one_hot,
                                         sensitive_group_count, lambda x: x)

        pos_classified_group_count = jax.lax.cond(answered == 1.,
                                         (pos_classified_group_count, pred), lambda x: x[0] + sensitive_one_hot * jax.lax.cond(x[1]==1., 1., lambda x: x, 0., lambda x:x), 
                                         (pos_classified_group_count, pred), lambda x: x[0])

        return answered, sensitive_group_count, pos_classified_group_count

    def _predict(output, _data):
        acc, sensitive_group_count, pos_classified_group_count = output
        _target = _data[0, 0]
        _sensitive = _data[0, 1]
        _vote = _data[:, 2]
        _noise_threshold = _data[0, 3]
        _noise_gnmax = _data[:, 4]
        
        vote_count = _vote.max()
        noisy_vote_count = vote_count + _noise_threshold
        answered = jax.lax.cond(noisy_vote_count > threshold, threshold, func, threshold, lambda x: 0.0)
        pred = (_vote + _noise_gnmax).argmax()
        answered, sensitive_group_count, pos_classified_group_count = \
                            _apply_fairness_constraint(pred, _sensitive, answered, sensitive_group_count, pos_classified_group_count)
        acc = acc + answered * (pred==_target).astype(int)
        output = acc, sensitive_group_count, pos_classified_group_count
        return output, answered

    output, answered = jax.lax.scan(_predict, (jnp.zeros((1,)), jnp.zeros((num_sensitive_attributes,)), jnp.zeros((num_sensitive_attributes,))), data, length=len(votes))
    accuracy = output[0]/num_samples
    gaps = _calculate_gaps(*output[1:])
    return accuracy[0]

In [None]:
query_fair_iter(50.0, 5.0, 0.22, 0.2, 50, subkey1, subkey2)

In [None]:
# query_fair_iter_jit = jax.jit(lambda t: query_fair_iter(t, 5.0, 0.22, 0.2, 50, subkey1, subkey2))
# query_fair_iter_jit(50.0)

In [None]:
# query_fair_iter_driv = jax.grad(lambda x: query_fair_iter_jit(x)[0][0])
# query_fair_iter_driv(50.0)

In [None]:
def _plot(columns=["Threshold", "Accuracy"], argnum=2):
    fig, axes = plt.subplots(2, 1, sharex=True, dpi=150, figsize=(5, 5))
    axes = axes.reshape(2, -1)
    def _subplot(ax_id, title):
        a_range = jnp.logspace(-7, 2, num=1000)
        defaults = [50.0, 5.0, 0.22, 0.2, 50]
        in_axes = [0 if _id == argnum else None for _id in range(len(defaults))] + [None, None]
        _inputs = [a_range if _id == argnum else defaults[_id] for _id in range(len(defaults))]
        loss_over_a_range = jax.vmap(query_fair_iter, in_axes=tuple(in_axes))(*_inputs, subkey1, subkey2)
        
        print(loss_over_a_range.shape)
        data = pd.DataFrame(jnp.c_[a_range, loss_over_a_range].__array__(), columns=columns)


        sns.lineplot(data=data, y=columns[1], x=columns[0], ax=axes[0, ax_id])
        
        
        query_driv = lambda t: jax.grad(query_fair_iter, argnums=argnum, allow_int=True)(*([t if _id==argnum else defaults[_id] for _id in range(5)] + [subkey1, subkey2]))
        
        query_driv(2)
        loss_driv_over_a_range = jax.vmap(query_driv)(a_range)
        g = sns.lineplot(data=pd.DataFrame(jnp.c_[a_range, loss_driv_over_a_range].__array__(), 
                                       columns=[columns[0], f"$grad. {columns[1]}$"]), 
                     x=columns[0], y=f"$grad. {columns[1]}$", ax=axes[1, ax_id])
        at = AnchoredText(f"Max Val (Thr. > 20)={loss_driv_over_a_range[a_range > 20].max()}", prop=dict(size=10), frameon=True, loc='lower right')
        at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
        axes[1, ax_id].add_artist(at)

    _subplot(0, "")
    # return fig

In [None]:
_plot(argnum=4)