In [None]:
import random
from collections import defaultdict
from pathlib import Path

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.utils.data
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE
from sklearn.mixture import GaussianMixture
from torch.utils.data import TensorDataset

from datasets.lawschool import LawschoolDataset
from generative.gmm import GMM
from real_nvp_encoder import FlowEncoder

In [None]:
sns.set_theme()

In [None]:
PROJECT_ROOT = Path('.').absolute().parent

In [None]:
dataset = 'lawschool'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gamma = 1.0
alpha = 0.05
# lr = 1e-2
# weight_decay = 1e-4
# kl_start = 0
# kl_end = 50
# protected_att = None
n_blocks = 4
batch_size = 128
# dec_epochs = 100
# prior_epochs = 150
# n_epochs = 100
# adv_epochs = 100
# prior = 'gmm'
gmm_comps1 = 10
gmm_comps2 = 10
# out_file = None
# n_flows = 1
seed = 100
# train_dec = True
# log_epochs = 10
quantiles = False
p_test = 0.2
p_val = 0.2
# with_test = False
# fair_criterion = 'stat_parity'
# no_early_stop = False
# load_enc = False

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
model_dir = PROJECT_ROOT / 'code' / dataset / f'gamma_{gamma}'
plots_dir = PROJECT_ROOT / 'plots' / dataset
plots_dir.mkdir(parents=True, exist_ok=True)

In [None]:
class Args:
    def __init__(self, quantiles):
        self.quantiles = quantiles

In [None]:
args = Args(quantiles=quantiles)

train_dataset = LawschoolDataset('train', args, p_test=p_test, p_val=p_val)
valid_dataset = LawschoolDataset('validation', args, p_test=p_test, p_val=p_val)
test_dataset = LawschoolDataset('test', args, p_test=p_test, p_val=p_val)

In [None]:
train_all = train_dataset.features
valid_all = valid_dataset.features
test_all = test_dataset.features

train_prot = train_dataset.protected
valid_prot = valid_dataset.protected
test_prot = test_dataset.protected

train_targets = train_dataset.labels
valid_targets = valid_dataset.labels
test_targets = test_dataset.labels


In [None]:
train_college = train_all[:, 6:30].max(dim=1)[1]
valid_college = valid_all[:, 6:30].max(dim=1)[1]
test_college = test_all[:, 6:30].max(dim=1)[1]

c1_cnt = np.bincount(train_college[train_targets == 0].detach().cpu().numpy())
c1_cnt = c1_cnt / np.sum(c1_cnt)

college_rnk = list(range(c1_cnt.shape[0]))
college_rnk.sort(key=lambda i: c1_cnt[i])

new_train_college = train_college.detach().clone()
new_valid_college = valid_college.detach().clone()
new_test_college = test_college.detach().clone()

for i, college in enumerate(college_rnk):
    new_train_college = torch.where(train_college == college, i, new_train_college)
    new_valid_college = torch.where(valid_college == college, i, new_valid_college)
    new_test_college = torch.where(test_college == college, i, new_test_college)

train_all = torch.cat([train_all[:, :2], new_train_college.unsqueeze(1)], dim=1).float()
valid_all = torch.cat([valid_all[:, :2], new_valid_college.unsqueeze(1)], dim=1).float()
test_all = torch.cat([test_all[:, :2], new_test_college.unsqueeze(1)], dim=1).float()

In [None]:
def compute_quants(train_all):
    quants = []
    for i in range(train_all.shape[1]):
        x = np.sort(train_all[:, i].detach().cpu().numpy())
        min_quant = 1000.0
        for j in range(x.shape[0] - 1):
            if x[j+1] - x[j] < 1e-4:
                continue
            min_quant = min(min_quant, x[j+1] - x[j])
        quants += [min_quant]
    return quants


In [None]:
quants = compute_quants(train_all)
quants[1] = 0

In [None]:
column_ids = ['lsat', 'gpa', 'college']

In [None]:
lb, ub = dict(), dict()

for idx in range(train_all.shape[1]):
    lb[idx], ub[idx] = train_all[:, idx].min(), train_all[:, idx].max()


In [None]:
def normalize(x, a, b):
    return 0.5 + (1 - alpha) * ((x - a) / (b - a) - 0.5)

def denormalize(z, a, b):
    return ((z - 0.5) / (1 - alpha) + 0.5) * (b - a) + a

def normalize_data(data):
    for idx in range(data.shape[1]):
        # clamping has no effect on training data
        data[:, idx] = torch.clamp(data[:, idx], lb[idx], ub[idx] + quants[idx])
        data[:, idx] = normalize(data[:, idx], lb[idx], ub[idx] + quants[idx])
    return data

def denormalize_data(data):
    for idx in range(data.shape[1]):
        data[:, idx] = denormalize(data[:, idx], lb[idx], ub[idx] + quants[idx])
    return data

In [None]:
train_all = normalize_data(train_all)
valid_all = normalize_data(valid_all)
test_all = normalize_data(test_all)

q = torch.tensor(compute_quants(train_all)).float().unsqueeze(0).to(device)
q[0, 1] = 0


In [None]:
train1, train2 = train_all[train_prot == 1], train_all[train_prot == 0]
train1 = torch.clamp(train1 + q * torch.rand(train1.shape).to(device), alpha / 2, 1 - alpha / 2).logit()
train2 = torch.clamp(train2 + q * torch.rand(train2.shape).to(device), alpha / 2, 1 - alpha / 2).logit()
targets1, targets2 = train_targets[train_prot == 1].long(), train_targets[train_prot == 0].long()
train1_loader = torch.utils.data.DataLoader(TensorDataset(train1, targets1), batch_size=batch_size, shuffle=True)
train2_loader = torch.utils.data.DataLoader(TensorDataset(train2, targets2), batch_size=batch_size, shuffle=True)

valid1, valid2 = valid_all[valid_prot == 1], valid_all[valid_prot == 0]
valid1 = torch.clamp(valid1 + q * torch.rand(valid1.shape).to(device), alpha / 2, 1 - alpha / 2).logit()
valid2 = torch.clamp(valid2 + q * torch.rand(valid2.shape).to(device), alpha / 2, 1 - alpha / 2).logit()
v_targets1, v_targets2 = valid_targets[valid_prot == 1].long(), valid_targets[valid_prot == 0].long()
valid1_loader = torch.utils.data.DataLoader(TensorDataset(valid1, v_targets1), batch_size=batch_size)
valid2_loader = torch.utils.data.DataLoader(TensorDataset(valid2, v_targets2), batch_size=batch_size)

test1, test2 = test_all[test_prot == 1], test_all[test_prot == 0]
test1 = torch.clamp(test1 + q * torch.rand(test1.shape).to(device), alpha / 2, 1 - alpha / 2).logit()
test2 = torch.clamp(test2 + q * torch.rand(test2.shape).to(device), alpha / 2, 1 - alpha / 2).logit()
t_targets1, t_targets2 = test_targets[test_prot == 1].long(), test_targets[test_prot == 0].long()
test1_loader = torch.utils.data.DataLoader(TensorDataset(test1, t_targets1), batch_size=batch_size)
test2_loader = torch.utils.data.DataLoader(TensorDataset(test2, t_targets2), batch_size=batch_size)

In [None]:
gaussian_mixture1 = GaussianMixture(
    n_components=gmm_comps1, n_init=1, covariance_type='full'
)
gaussian_mixture2 = GaussianMixture(
    n_components=gmm_comps2, n_init=1, covariance_type='full'
)

gaussian_mixture1.weights_ = np.load(model_dir / 'prior1_weights.npy')
gaussian_mixture1.means_ = np.load(model_dir / 'prior1_means.npy')
gaussian_mixture1.covariances_ = np.load(model_dir / 'prior1_covs.npy')

gaussian_mixture2.weights_ = np.load(model_dir / 'prior2_weights.npy')
gaussian_mixture2.means_ = np.load(model_dir / 'prior2_means.npy')
gaussian_mixture2.covariances_ = np.load(model_dir / 'prior2_covs.npy')

prior1 = GMM(gaussian_mixture1, device=device)
prior2 = GMM(gaussian_mixture2, device=device)


In [None]:
in_dim = train_all.shape[1]

masks = []
for i in range(20):
    t = np.array([j % 2 for j in range(in_dim)])
    np.random.shuffle(t)
    masks += [t, 1 - t]

flow1 = FlowEncoder(None, in_dim, [50, 50], n_blocks, masks).to(device)
flow2 = FlowEncoder(None, in_dim, [50, 50], n_blocks, masks).to(device)

flow1.load_state_dict(torch.load(model_dir / 'flow1.pt'))
flow2.load_state_dict(torch.load(model_dir / 'flow2.pt'))

In [None]:
mappings = defaultdict(list)

for (x1, y1), (x2, y2) in zip(train1_loader, train2_loader):

    # clamp has no effect on train data
    x1 = torch.clamp(x1, alpha / 2, 1 - alpha).logit()
    x2 = torch.clamp(x2, alpha / 2, 1 - alpha).logit()

    x1_z1, _ = flow1.inverse(x1)
    x1_x2, _ = flow2.forward(x1_z1)

    mappings['x1_real'].append(x1.sigmoid())
    mappings['x2_fake'].append(x1_x2.sigmoid())
    mappings['z1'].append(x1_z1)
    mappings['y1'].append(y1)

    x2_z2, _ = flow2.inverse(x2)
    x2_x1, _ = flow1.forward(x2_z2)

    mappings['x1_fake'].append(x2_x1.sigmoid())
    mappings['x2_real'].append(x2.sigmoid())
    mappings['z2'].append(x2_z2)
    mappings['y2'].append(y2)

In [None]:
x1_real = denormalize_data(torch.vstack(mappings['x1_real'])).cpu().detach()
x2_real = denormalize_data(torch.vstack(mappings['x2_real'])).cpu().detach()
x2_fake = denormalize_data(torch.vstack(mappings['x2_fake'])).cpu().detach()
x1_fake = denormalize_data(torch.vstack(mappings['x1_fake'])).cpu().detach()

z1 = torch.vstack(mappings['z1']).cpu().detach()
z2 = torch.vstack(mappings['z2']).cpu().detach()
y1 = torch.cat(mappings['y1']).cpu().detach()
y2 = torch.cat(mappings['y2']).cpu().detach()

In [None]:
# undo dequantization
x1_real[:, 0] = torch.floor(x1_real[:, 0])
x1_real[:, 2] = torch.floor(x1_real[:, 2])

x2_real[:, 0] = torch.floor(x2_real[:, 0])
x2_real[:, 2] = torch.floor(x2_real[:, 2])

x1_fake[:, 0] = torch.floor(x1_fake[:, 0])
x1_fake[:, 2] = torch.floor(x1_fake[:, 2])

x2_fake[:, 0] = torch.floor(x2_fake[:, 0])
x2_fake[:, 2] = torch.floor(x2_fake[:, 2])

In [None]:
x1 = torch.cat((x1_real, x1_fake))
x2 = torch.cat((x2_fake, x2_real))
z = torch.cat((z1, z2))
y = torch.cat((y1, y2))

In [None]:
for n_clusters in [4, 6, 8]:
    kmeans_x1 = KMeans(n_clusters=n_clusters, random_state=0).fit(x1)
    kmeans_x2 = KMeans(n_clusters=n_clusters, random_state=0).fit(x2)

    perplexities = [5, 15, 25, 35, 45]
    fig, ax = plt.subplots(
        nrows=len(perplexities), ncols=2, figsize=(10, 5 * len(perplexities))
    )

    for idx, perplexity in enumerate(perplexities):

        x1_t_sne = TSNE(perplexity=perplexity).fit_transform(x1)
        x2_t_sne = TSNE(perplexity=perplexity).fit_transform(x2)

        ax[idx, 0].scatter(x1_t_sne[:, 0], x1_t_sne[:, 1], c=kmeans_x1.labels_, cmap='tab10')
        ax[idx, 1].scatter(x2_t_sne[:, 0], x2_t_sne[:, 1], c=kmeans_x1.labels_, cmap='tab10')

    ax[0, 0].set_title('Non-White')
    ax[0, 1].set_title('White')

    fig.suptitle(f't-SNE with {n_clusters} Clusters')
    fig.tight_layout()
    plt.savefig(plots_dir / f'gamma_{gamma}_n_clusters_{n_clusters}.eps')

    clusters_x1 = pd.DataFrame(columns=column_ids)
    clusters_x2 = pd.DataFrame(columns=column_ids)

    for cluster in range(n_clusters):
        clusters_x1.loc[cluster] = x1[kmeans_x1.labels_ == cluster].mean(axis=0).numpy()
        clusters_x2.loc[cluster] = x2[kmeans_x1.labels_ == cluster].mean(axis=0).numpy()

    clusters_x1.to_csv(
        plots_dir / f'gamma_{gamma}_n_clusters_{n_clusters}_x1.csv', index=False
    )
    clusters_x2.to_csv(
        plots_dir / f'gamma_{gamma}_n_clusters_{n_clusters}_x2.csv', index=False
    )

In [None]:
train1 = denormalize_data(train1.sigmoid())
train2 = denormalize_data(train2.sigmoid())

# undo dequantization
train1[:, 0] = torch.floor(train1[:, 0])
train1[:, 2] = torch.floor(train1[:, 2])
train2[:, 0] = torch.floor(train2[:, 0])
train2[:, 2] = torch.floor(train2[:, 2])

print(train1.min(0)[0])
print(train2.min(0)[0])

print(train1.mean(0))
print(train2.mean(0))

print(train1.max(0)[0])
print(train2.max(0)[0])

In [None]:
clf = LogisticRegression(random_state=0).fit(z, y)
y_hat = clf.predict(z)
clf.score(z, y)

In [None]:
mappings = defaultdict(list)

for x1_i, z_i in zip(x1[y_hat == 0], z[y_hat == 0]):
    z_i_nn = z[y_hat == 1][torch.norm(z_i - z[y_hat == 1], dim=1).argmin()]

    for beta in np.linspace(0, 1, 11):
        z_new = torch.unsqueeze(z_i + beta * (z_i_nn - z_i), 0)

        if clf.predict(z_new):
            break

    x1_i_new, _ = flow1.forward(z_new.to(device))

    mappings['x1_old'].append(x1_i.to(device))
    mappings['x1_new'].append(x1_i_new.sigmoid().to(device))

for x2_i, z_i in zip(x2[y_hat == 0], z[y_hat == 0]):
    z_i_nn = z[y_hat == 1][torch.norm(z_i - z[y_hat == 1], dim=1).argmin()]

    for beta in np.linspace(0, 1, 11):
        z_new = torch.unsqueeze(z_i + beta * (z_i_nn - z_i), 0)

        if clf.predict(z_new):
            break

    x2_i_new, _ = flow2.forward(z_new.to(device))

    mappings['x2_old'].append(x2_i.to(device))
    mappings['x2_new'].append(x2_i_new.sigmoid().to(device))

In [None]:
x1_old = torch.vstack(mappings['x1_old'])
x2_old = torch.vstack(mappings['x2_old'])

# undo flow normalization
x1_new = denormalize_data(torch.vstack(mappings['x1_new']))
x2_new = denormalize_data(torch.vstack(mappings['x2_new']))
# undo dequantization
x1_new[:, 0] = torch.floor(x1_new[:, 0])
x2_new[:, 0] = torch.floor(x2_new[:, 0])
x1_new[:, 2] = torch.floor(x1_new[:, 2])
x2_new[:, 2] = torch.floor(x2_new[:, 2])

In [None]:
x1_diff = x1_new - x1_old
x2_diff = x2_new - x2_old

x1_diff_without_college = x1_diff[x1_diff[:, 2] == 0]
x2_diff_without_college = x2_diff[x2_diff[:, 2] == 0]

avg_recourse = pd.DataFrame(
    [x1_diff_without_college.mean(0).cpu().detach().numpy(),
     x2_diff_without_college.mean(0).cpu().detach().numpy()],
    columns=column_ids, index=['Non-White', 'White']
)
avg_recourse.to_csv(plots_dir / f'recourse_gamma_{gamma}.csv')