In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.distributions.mixture_same_family import MixtureSameFamily
import torch.distributions as DIST
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal

import sys, os
import pandas as pd

added_path = os.path.join(os.path.abspath(".."))
if added_path not in sys.path:
    sys.path.append(added_path)
print(sys.path)

from dataset.dataset_fair import AdultDataset, COMPAS

['/local/scratch/a/gong123/AUB-CAUB/notebooks', '/local/scratch/a/gong123/anaconda3/envs/AUB-env/lib/python38.zip', '/local/scratch/a/gong123/anaconda3/envs/AUB-env/lib/python3.8', '/local/scratch/a/gong123/anaconda3/envs/AUB-env/lib/python3.8/lib-dynload', '', '/local/scratch/a/gong123/anaconda3/envs/AUB-env/lib/python3.8/site-packages', '/local/scratch/a/gong123/AUB-CAUB']


In [4]:
def conditional_errors(preds, labels, attrs):
    """
    Compute the conditional errors of A = 0/1. All the arguments need to be one-dimensional vectors.
    :param preds: The predicted label given by a model.
    :param labels: The groundtruth label.
    :param attrs: The label of sensitive attribute.
    :return: Overall classification error, error | A = 0, error | A = 1.
    """
    assert preds.shape == labels.shape and labels.shape == attrs.shape
    cls_error = 1 - np.mean(preds == labels)
    idx = attrs == 0
    error_0 = 1 - np.mean(preds[idx] == labels[idx])
    error_1 = 1 - np.mean(preds[~idx] == labels[~idx])
    return cls_error, error_0, error_1

def get_dataset_compas(root):
    compas = pd.read_csv(f"{root}/propublica.csv").values
    num_insts = compas.shape[0]
    # Random shuffle the dataset.
    indices = np.arange(num_insts)
    np.random.shuffle(indices)
    compas = compas[indices]
    # Partition the dataset into train and test split.
    ratio = 0.7
    num_train = int(num_insts * ratio)
    compas_train = COMPAS(compas[:num_train, :])
    compas_test = COMPAS(compas[num_train:, :])
    return compas_train, compas_test

def get_loaders_compas(compas_train, compas_test, batch_size):
    train_loader = DataLoader(compas_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(compas_test, batch_size=batch_size, shuffle=False)
    return train_loader ,test_loader


def get_dataset_adult(root, target="income", private="sex"):
    adult_train = AdultDataset(root_dir=root, phase='train', tar_attr=target, priv_attr=private)
    adult_test = AdultDataset(root_dir=root, phase='test', tar_attr=target, priv_attr=private)
    return adult_train, adult_test

def get_loaders_adult(adult_train, adult_test, batch_size):
    train_loader = DataLoader(adult_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(adult_test, batch_size=batch_size, shuffle=False)
    return train_loader ,test_loader

def all_pairs_gaussian_kl(mu, sigma, dim_z):
    # mu is [batchsize x dim_z]
    # sigma is [batchsize x dim_z]

    sigma_sq = sigma * sigma + 1e-8
    sigma_sq_inv = 1 / sigma_sq
    # sigma_inv is [batchsize x sizeof(latent_space)]

    # first term
    # dot product of all sigma_inv vectors with sigma
    # is the same as a matrix mult of diag
    first_term = torch.matmul(sigma_sq, sigma_sq_inv.transpose(1, 0))

    # second term (we break the mu_i-mu_j square term)
    # REMEMBER THAT THIS IS SIGMA_1, not SIGMA_0
    sqi = torch.matmul(mu * mu, sigma_sq_inv.transpose(1, 0))
    # sqi is now [batchsize x batchsize] = sum(mu[:,i]**2 / Sigma[j])

    sqj = mu * mu * sigma_sq_inv
    sqj = torch.sum(sqj, dim=1)
    # sqj is now [batchsize, 1] = mu[j]**2 / Sigma[j]

    # squared distance
    # (mu[i] - mu[j])\sigma_inv(mu[i] - mu[j]) = r[i] - 2*mu[i]*mu[j] + r[j]
    # uses broadcasting
    second_term = 2 * torch.matmul(mu, torch.transpose(mu * sigma_sq_inv, 1, 0))
    second_term = sqi + sqj.view(1, -1) - second_term

    # third term

    # log det A = tr log A
    # log \frac{ det \Sigma_1 }{ det \Sigma_0 } =
    #   \tr\log \Sigma_1 - \tr\log \Sigma_0
    # for each sample, we have B comparisons to B other samples...
    #   so this cancels out, but we keep it

    logi = 2 * torch.sum(torch.log(sigma), dim=1)
    logi = torch.reshape(logi, [-1, 1])
    logj = torch.reshape(logi, [1, -1])
    third_term = logi - logj

    # combine and return
    return 0.5 * (first_term + second_term + third_term - dim_z)

class LinearVAUB(nn.Module):
    def __init__(self, pz, cls, input_features=1, latent_features=1, hidden_features=3, learnable_loc=False, learnable_var=False, pair_kl=False):
        super(LinearVAUB, self).__init__()

        # encoder & decoder
        self.enc1 = nn.Sequential(
            nn.Linear(in_features=input_features, out_features=hidden_features),
            nn.ReLU(),
            # nn.Linear(in_features=hidden_features, out_features=hidden_features),
            # nn.ReLU(),
            nn.Linear(in_features=hidden_features, out_features=latent_features*2),
        )

        # self.enc2 = nn.Sequential(
        #     nn.Linear(in_features=input_features, out_features=hidden_features),
        #     nn.ReLU(),
        #     nn.Linear(in_features=hidden_features, out_features=hidden_features),
        #     nn.ReLU(),
        #     nn.Linear(in_features=hidden_features, out_features=latent_features*2),
        # )

        self.dec1 = nn.Sequential(
            nn.Linear(in_features=latent_features, out_features=hidden_features),
            nn.ReLU(),
            nn.Linear(in_features=hidden_features, out_features=hidden_features),
            nn.ReLU(),
            nn.Linear(in_features=hidden_features, out_features=input_features),
        )

        self.dec2 = nn.Sequential(
            nn.Linear(in_features=latent_features, out_features=hidden_features),
            nn.ReLU(),
            nn.Linear(in_features=hidden_features, out_features=hidden_features),
            nn.ReLU(),
            nn.Linear(in_features=hidden_features, out_features=input_features),
        )

        self.E_arr = [self.enc1, self.enc1]
        self.D_arr = [self.dec1, self.dec2]

        # shared distribution
        '''
        change this line
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))
        to
        '''
        self.log_scale_arr = [nn.Parameter(torch.Tensor([0.0])) for _ in range(2)]

        self.latent_loc = nn.Parameter(torch.Tensor([0.0])) if learnable_loc else torch.Tensor([0.0])
        self.latent_log_var = nn.Parameter(torch.Tensor([0.0])) if learnable_var else torch.Tensor([0.0])

        self.Recon_Loss = [[],[]]

        self.mu = None
        self.log_var = None
        self.elbo = None
        self.criterion = nn.MSELoss()

        self.pz = pz
        self.classifier = cls
        self.pair_kl = pair_kl

    # def simulate_output(self, mean, idx):
    #
    #     noise = torch.distributions.Normal(0, 1)
    #     return mean + noise.sample(mean.size())*0.05

    def reparameterize(self, mu, log_var):

        std = torch.exp(log_var/2).float()
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        return z

    def forward(self, x):
        x_encoded = self.E_arr[0](x)
        return self.classifier(x_encoded)

    def train_iter(self, X_arr, y_arr, scale_recon=1, lambda_cls=1):

        self.recon_loss = 0
        self.vaub_loss = 0
        self.elbo_loss = 0
        self.cls_loss = 0
        self.kl_loss = 0

        self.z_arr = []
        self.x_hat_arr = []
        self.x_flipped_arr = []

        criterion_cls = nn.CrossEntropyLoss()

        for idx in range(2):
            x_original = X_arr[idx]
            x = X_arr[idx]
            # encoding
            x = self.E_arr[idx](x)

            # get `mu` and `log_var`
            self.mu, self.log_var = x.chunk(2, dim=-1)
            # self.mu = x[:, 0, None] # the first feature values as mean
            # self.log_var = x[:, 1, None] # the other feature values as variance

            # get the latent vector through reparameterization
            z = self.reparameterize(self.mu, self.log_var)
            self.z_arr.append(z)

            # decoding
            x_hat = self.D_arr[idx](z)
            # self.x_hat_arr.append(self.simulate_output(x_hat, idx))

            # get flipped results
            x_flipped = self.D_arr[(idx+1)%2](z)
            # self.x_flipped_arr.append(self.simulate_output(x_flipped,(idx+1)%2))

            # reconstruction loss
            recon_loss = -1*self.gaussian_likelihood(x_hat, x_original, self.log_scale_arr[idx])
            self.Recon_Loss[idx].append(recon_loss.mean().item())

            # kl loss
            if self.pair_kl:
                kl_loss = all_pairs_gaussian_kl(self.mu, torch.exp(self.log_var/2), self.mu.size(1))
            else:
                kl_loss = self.kl_divergence_with_pz(z, self.mu, self.log_var)

            # elbo
            elbo = kl_loss + scale_recon*recon_loss
            vaub = kl_loss + recon_loss

            self.kl_loss += kl_loss.mean()
            self.recon_loss += recon_loss.mean()
            self.elbo_loss += elbo.mean()
            self.vaub_loss += vaub.mean()
            self.cls_loss += criterion_cls(self.classifier(z), y_arr[idx])

        return self.elbo_loss + lambda_cls*self.cls_loss

    def get_loss_dict(self):
        loss_dict = {
            "recon_loss": self.recon_loss.item(),
            "kl_loss": self.kl_loss.item(),
            "vaub_loss": self.vaub_loss.item(),
            "elbo_loss": self.elbo_loss.item(),
            "cls_loss": self.cls_loss.item(),
            "overall_loss": (self.elbo_loss + lambda_cls*self.cls_loss).item()
        }
        return loss_dict

    def kl_divergence_with_pz(self, z, mu, log_var):

        # mu = self.mu
        std = torch.exp(log_var/2)

        # p = torch.distributions.Normal(torch.zeros_like(mu)+self.latent_loc, torch.ones_like(std)*torch.exp(self.latent_log_var/2))
        # p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std.to(mu.device))

        log_qzx = q.log_prob(z)
        # log_pz = p.log_prob(z)
        log_pz = self.pz.log_prob(z)

        kl = (log_qzx.sum(-1) - log_pz)

        # sum over last dim to go from single dim distribution to multi-dim
        return kl


    # def kl_divergence(self, z, mu, log_var):
    #
    #     # mu = self.mu
    #     std = torch.exp(log_var/2)
    #
    #     p = torch.distributions.Normal(torch.zeros_like(mu)+self.latent_loc, torch.ones_like(std)*torch.exp(self.latent_log_var/2))
    #     # p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    #     q = torch.distributions.Normal(mu, std)
    #
    #     log_qzx = q.log_prob(z).sum(-1)
    #     log_pz = p.log_prob(z)
    #
    #     kl = (log_qzx - log_pz)
    #
    #     # sum over last dim to go from single dim distribution to multi-dim
    #     kl = kl
    #     return kl


    def gaussian_likelihood(self, x_hat, x, learnable_scale):
        scale = torch.exp(learnable_scale)
        mean = x_hat
        # print(mean.size(), scale.size())
        dist = torch.distributions.Normal(mean, scale.to(device))

        # measure prob of seeing image under p(x|z)
        # print(mean.size(), x.size())
        log_pxz = dist.log_prob(x)
        # print(log_pxz.size(), log_pxz.sum(-1).size())
        return log_pxz.sum(-1)


def train_VAUB_fair(model, train_data_loader, test_dataset, n_epochs, lr, scale_recon, lambda_cls, device):

    opt = torch.optim.Adam(params=model.parameters(), lr=lr)
    model.to(device)
    Loss_dict = {
        "recon_loss": [],
        "kl_loss": [],
        "vaub_loss": [],
        "elbo_loss": [],
        "cls_loss": [],
        "overall_loss": []
    }
    target_insts = torch.from_numpy(test_dataset.X).float().to(device)
    target_labels = np.argmax(test_dataset.Y, axis=1)
    target_attrs = np.argmax(test_dataset.A, axis=1)
    target_insts = torch.from_numpy(test_dataset.X).float().to(device)
    test_idx = target_attrs == 0
    conditional_idx = target_labels == 0

    for epoch in range(n_epochs):
        model.train()
        for batch_idx, batch in enumerate(train_data_loader):
            opt.zero_grad()
            xs, ys, attrs = batch
            X_arr = [xs[attrs==0].to(device), xs[attrs==1].to(device)]
            y_arr = [ys[attrs==0].to(device), ys[attrs==1].to(device)]
            loss = model.train_iter(X_arr, y_arr, scale_recon, lambda_cls)
            loss.backward()
            opt.step()

            for k,v in model.get_loss_dict().items():
                Loss_dict[k].append(v)

        model.eval()
        preds_labels = torch.max(model(target_insts), 1)[1].cpu().numpy()
        cls_error, error_0, error_1 = conditional_errors(preds_labels, target_labels, target_attrs)
        if epoch % 5 ==0:
            print(f"Epoch {epoch}/{n_epochs}: Loss {loss.item():.4f}  " +\
                  ", ".join([f"{k}: {Loss_dict[k][-1]:.2f}" for k in Loss_dict])+\
                  f"\nOverall predicted error = {cls_error:.2f}, Err|A=0 = {error_0:.2f}, Err|A=1 = {error_1:.2f}")


    pred_0, pred_1 = np.mean(preds_labels[test_idx]), np.mean(preds_labels[~test_idx])
    preds_labels = torch.max(model(target_insts), 1)[1].cpu().numpy()
    cond_00 = np.mean(preds_labels[np.logical_and(test_idx, conditional_idx)])
    cond_10 = np.mean(preds_labels[np.logical_and(~test_idx, conditional_idx)])
    cond_01 = np.mean(preds_labels[np.logical_and(test_idx, ~conditional_idx)])
    cond_11 = np.mean(preds_labels[np.logical_and(~test_idx, ~conditional_idx)])
    cls_error, _, _ = conditional_errors(preds_labels, target_labels, target_attrs)
    print(f"Overall Error: {cls_error:.4f}")
    print("Joint Error: |Err|A=0 + Err|A=1| = {}".format(error_0 + error_1))
    print("Error Gap: |Err|A=0 - Err|A=1| = {}".format(np.abs(error_0 - error_1)))
    print(f"DP Gap: |Pred=1|A=0 - Pred=1|A=1| = {np.abs(pred_0 - pred_1):.4e}")
    print("Equalized Odds Y = 0: |Pred = 1|A = 0, Y = 0 - Pred = 1|A = 1, Y = 0| = {}".format(
        np.abs(cond_00 - cond_10)))
    print("Equalized Odds Y = 1: |Pred = 1|A = 0, Y = 1 - Pred = 1|A = 1, Y = 1| = {}".format(
    np.abs(cond_01 - cond_11)))

    return Loss_dict

class MoGNN(nn.Module):

    def __init__(self, n_components, input_dim, loc_init=None, scale_init=None, weight_init=None):

        super(MoGNN, self).__init__()

        if loc_init is None:
            self.loc = nn.Parameter((torch.rand(n_components, input_dim)*2-1))
        else:
            self.loc = nn.Parameter(loc_init)

        if scale_init is None:
            self.log_scale = nn.Parameter(torch.zeros(n_components, input_dim))
        else:
            self.log_scale = nn.Parameter(torch.log(scale_init))

        if weight_init is None:
            self.raw_weight = nn.Parameter(torch.ones(n_components))
        else:
            self.raw_weight = nn.Parameter(torch.log(weight_init/(1-weight_init)))

    def log_prob(self, Z):

        self.loc = self.loc.to(Z.device)
        self.log_scale = self.log_scale.to(Z.device)
        self.raw_weight = self.raw_weight.to(Z.device)

        # print(torch.sigmoid(self.raw_weight))
        mix = torch.distributions.Categorical(torch.sigmoid(self.raw_weight))
        comp = torch.distributions.Independent(torch.distributions.Normal(self.loc, torch.exp(self.log_scale)), 1)
        gmm = torch.distributions.mixture_same_family.MixtureSameFamily(mix, comp)

        return gmm.log_prob(Z)

class FixedGaussian(nn.Module):

    def __init__(self, input_dim, scale=1.0, device='cpu'):
        super(FixedGaussian, self).__init__()
        self.dist = torch.distributions.MultivariateNormal(torch.zeros(input_dim).to(device), scale_tril=torch.diag(torch.ones(input_dim)*scale).to(device))

    def log_prob(self, x):
        return self.dist.log_prob(x)

    def sample(self, n_sample):
        return self.dist.sample((n_sample,))

def gmm(weight, loc, scale):
    mix = DIST.Categorical(weight)
    comp = DIST.Normal(loc, scale)
    return MixtureSameFamily(mix, comp)

class Classifier(nn.Module):

    def __init__(self, input_dim=32, hidden_dim=16):

        super(Classifier, self).__init__()
        self.fc_1 = nn.Linear(input_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, 2)

    def forward(self, x):
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x


In [6]:
############################################
#####     Training hyperparameters     #####
############################################
data_dir = "/local/scratch/a/gong123/AUB-CAUB/ICLR2020-CFair/data"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 512
n_epochs = 100
lr = 1e-3
input_features = 114
hidden_features = 84
latent_features = 60
scale_recon = 1
lambda_cls = 1e-15

kwargs_pz = {
    "n_components": 10,
    "input_dim": latent_features,
    }

############################################
#####         Training models          #####
############################################
pz = MoGNN(**kwargs_pz)
cls = Classifier(input_dim=latent_features)
VAUB = LinearVAUB(
    pz = pz,
    cls = cls,
    input_features=input_features,
    hidden_features=hidden_features,
    latent_features=latent_features,
)

############################################
#####          Training data           #####
############################################
train_dataset, test_dataset = get_dataset_adult(data_dir)
train_loader, test_loader = get_loaders_adult(train_dataset, test_dataset, batch_size)

############################################
#####        Training script           #####
############################################
# for name, p in VAUB.named_parameters():
#     if p.requires_grad == True:
#         print(name)
Loss_dict = train_VAUB_fair(VAUB, train_loader, test_dataset, n_epochs, lr, scale_recon, lambda_cls, device=device)

Epoch 0/100: Loss 315.8213  recon_loss: 309.22, kl_loss: 6.60, vaub_loss: 315.82, elbo_loss: 315.82, cls_loss: 1.39, overall_loss: 315.82
Overall predicted error = 0.25, Err|A=0 = 0.31, Err|A=1 = 0.11
Epoch 5/100: Loss 309.1332  recon_loss: 305.75, kl_loss: 3.38, vaub_loss: 309.13, elbo_loss: 309.13, cls_loss: 1.40, overall_loss: 309.13
Overall predicted error = 0.29, Err|A=0 = 0.34, Err|A=1 = 0.19
Epoch 10/100: Loss 319.1642  recon_loss: 315.77, kl_loss: 3.39, vaub_loss: 319.16, elbo_loss: 319.16, cls_loss: 1.41, overall_loss: 319.16
Overall predicted error = 0.39, Err|A=0 = 0.44, Err|A=1 = 0.31
Epoch 15/100: Loss 322.8605  recon_loss: 318.10, kl_loss: 4.76, vaub_loss: 322.86, elbo_loss: 322.86, cls_loss: 1.42, overall_loss: 322.86
Overall predicted error = 0.57, Err|A=0 = 0.59, Err|A=1 = 0.53
Epoch 20/100: Loss 302.9802  recon_loss: 297.83, kl_loss: 5.15, vaub_loss: 302.98, elbo_loss: 302.98, cls_loss: 1.39, overall_loss: 302.98
Overall predicted error = 0.59, Err|A=0 = 0.61, Err|A=1

In [28]:
############################################
#####     Training hyperparameters     #####
############################################
data_dir = "/local/scratch/a/gong123/AUB-CAUB/ICLR2020-CFair/data"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 1024
n_epochs = 200
lr = 1e-3
input_features = 114
hidden_features = 84
latent_features = 60
scale_recon = 1
lambda_cls = 1e-15

# kwargs_pz = {
#     "n_components": 10,
#     "input_dim": latent_features,
#     }

############################################
#####         Training models          #####
############################################
pz = FixedGaussian(input_dim=latent_features, device=device)
cls = Classifier(input_dim=latent_features)
VAUB = LinearVAUB(
    pz = pz,
    cls = cls,
    input_features=input_features,
    hidden_features=hidden_features,
    latent_features=latent_features,
    pair_kl=True,
)

############################################
#####          Training data           #####
############################################
train_dataset, test_dataset = get_dataset_adult(data_dir)
train_loader, test_loader = get_loaders_adult(train_dataset, test_dataset, batch_size)

############################################
#####        Training script           #####
############################################
# for name, p in VAUB.named_parameters():
#     if p.requires_grad == True:
#         print(name)
Loss_dict = train_VAUB_fair(VAUB, train_loader, test_dataset, n_epochs, lr, scale_recon, lambda_cls, device=device)

Epoch 0/200: Loss 302.8466  recon_loss: 301.62, kl_loss: 1.23, vaub_loss: 302.85, elbo_loss: 302.85, cls_loss: 1.30, overall_loss: 302.85
Overall predicted error = 0.25, Err|A=0 = 0.31, Err|A=1 = 0.11
Epoch 5/200: Loss 310.2605  recon_loss: 310.07, kl_loss: 0.19, vaub_loss: 310.26, elbo_loss: 310.26, cls_loss: 1.28, overall_loss: 310.26
Overall predicted error = 0.25, Err|A=0 = 0.31, Err|A=1 = 0.11
Epoch 10/200: Loss 341.1550  recon_loss: 341.04, kl_loss: 0.11, vaub_loss: 341.15, elbo_loss: 341.15, cls_loss: 1.30, overall_loss: 341.15
Overall predicted error = 0.25, Err|A=0 = 0.31, Err|A=1 = 0.11
Epoch 15/200: Loss 339.1794  recon_loss: 339.11, kl_loss: 0.07, vaub_loss: 339.18, elbo_loss: 339.18, cls_loss: 1.29, overall_loss: 339.18
Overall predicted error = 0.25, Err|A=0 = 0.31, Err|A=1 = 0.11
Epoch 20/200: Loss 309.2220  recon_loss: 309.18, kl_loss: 0.04, vaub_loss: 309.22, elbo_loss: 309.22, cls_loss: 1.25, overall_loss: 309.22
Overall predicted error = 0.25, Err|A=0 = 0.31, Err|A=1

In [29]:
def sliced_wasserstein_distance(x, y, num_projections=10000, p=1):
    # Generate random projections
    projections = torch.randn(x.shape[1], num_projections).to(x.device)
    # projections = projections / torch.std(projections, dim=0, keepdim=True)

    # Compute sliced distances for x and y
    x_projected = torch.matmul(x, projections)
    y_projected = torch.matmul(y, projections)

    # Sort the projected samples
    x_sorted, _ = torch.sort(x_projected, dim=0)
    y_sorted, _ = torch.sort(y_projected, dim=0)

    # Compute the sliced Wasserstein distance
    sliced_dist = torch.mean(torch.abs(x_sorted - y_sorted) ** p) ** (1 / p)

    return sliced_dist


# Example usage
print(VAUB.z_arr[0].shape)
print(VAUB.z_arr[1].shape)

indices = torch.randperm(VAUB.z_arr[0].shape[0])[:VAUB.z_arr[1].shape[0]]
distance = sliced_wasserstein_distance(VAUB.z_arr[0][indices], VAUB.z_arr[1])
print("Sliced Wasserstein Distance:", distance.item())

torch.Size([328, 60])
torch.Size([138, 60])
Sliced Wasserstein Distance: 1.8631256818771362


In [30]:
def sliced_wasserstein_distance_whitening(model, test_loader, num_projections=1000, p=1):

    from sklearn.preprocessing import StandardScaler
    from sklearn.decomposition import PCA

    dist_list = []

    for _, (x_test, _, attrs) in enumerate(test_loader):

        # x_test = torch.from_numpy(x_test).float().to(device)
        x_test = x_test.float().to(device)
        z_test = model.E_arr[0](x_test).cpu().detach().numpy()
        # print(x_test.shape, z_test.shape)

        # centering, unit variance
        scaler = StandardScaler() # zero mean (feature-wise), unit-var
        scaler.fit(z_test)

        z_test_zm = scaler.transform(z_test)

        # whitening
        pca = PCA(whiten = True) # X_enrollment already unit variance
        pca.fit(z_test_zm)
        z_test_white = pca.transform(z_test_zm)

        # Generate random projections
        projections = torch.randn(z_test_white.shape[1], num_projections)
        # projections = projections / torch.std(projections, dim=0, keepdim=True)

        # Compute sliced distances for x and y

        x, y = z_test_white[attrs==0], z_test_white[attrs==1]
        indices = torch.randperm(x.shape[0])[:y.shape[0]]
        x_truncated, y = torch.from_numpy(x[indices]), torch.from_numpy(y)

        x_projected = torch.matmul(x_truncated, projections)
        y_projected = torch.matmul(y, projections)

        # Sort the projected samples
        x_sorted, _ = torch.sort(x_projected, dim=0)
        y_sorted, _ = torch.sort(y_projected, dim=0)

        # Compute the sliced Wasserstein distance
        dist_list.append(torch.mean(torch.abs(x_sorted - y_sorted) ** p) ** (1 / p))

    return torch.tensor(dist_list).mean()

distance = sliced_wasserstein_distance_whitening(VAUB, test_loader)
print("Sliced Wasserstein Distance:", distance.item())

Sliced Wasserstein Distance: 9.713080406188965


In [31]:
def svm_acc(model, test_dataset):

    from sklearn import svm, datasets
    from sklearn.model_selection import GridSearchCV, train_test_split
    from sklearn.metrics import accuracy_score
    from sklearn.preprocessing import StandardScaler

    test_loader_whole = DataLoader(test_dataset, batch_size=len(test_dataset))
    _, (X, y, a) = next(enumerate(test_loader_whole))

    X_train, X_test, a_train, a_test = train_test_split(X, a, test_size=0.2, random_state=42)

    X_train, X_test = X_train.float().to(device), X_test.float().to(device)
    Z_train = model.E_arr[0](X_train).cpu().detach().numpy()
    Z_test = model.E_arr[0](X_test).cpu().detach().numpy()
    # print(x_test.shape, z_test.shape)

    # centering, unit variance
    scaler = StandardScaler() # zero mean (feature-wise), unit-var
    scaler.fit(Z_train)
    Z_train_zm = scaler.transform(Z_train)
    Z_test_zm = scaler.transform(Z_test)

    svm_classifier = svm.SVC()
    parameters = {'C': np.logspace(-1, 1, 3), 'gamma': np.logspace(-1, 1, 3)}
    grid_search = GridSearchCV(svm_classifier, parameters)

    grid_search.fit(Z_train_zm, a_train)

    # Print the best parameters and the corresponding score
    print("Best parameters: ", grid_search.best_params_)
    print("Best score: ", grid_search.best_score_)

    # Evaluate the model on the test set
    best_model = grid_search.best_estimator_
    a_pred = best_model.predict(Z_test_zm)
    accuracy = accuracy_score(a_test, a_pred)
    print("Test accuracy: ", accuracy)

svm_acc(VAUB, test_dataset)

Best parameters:  {'C': 10.0, 'gamma': 0.1}
Best score:  0.9958498645294529
Test accuracy:  0.997675962815405
