In [14]:
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import numpy as np
import torch
from utils import *
import torch.nn.functional as F
import math
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
from torch.nn.modules.loss import _Loss
import warnings

warnings.filterwarnings("ignore")

class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=True)
        x = self.gc2(x, adj)
        return x


class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        nn.init.xavier_normal_(self.weight.data)
        #nn.init.xavier_normal_(self.bias.data)
        #self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)  # XW
        output = torch.spmm(adj, support)  # AXW
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ')'


class DNN(torch.nn.Module):
    def __init__(self, layers, dropout=0.1):
        super(DNN, self).__init__()

        # parameters
        self.depth = len(layers) - 1

        # set up layer order dict
        self.activation = torch.nn.ReLU

        layer_list = list()
        for i in range(self.depth - 1):
            layer_list.append(
                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i + 1], bias=True))
            )
            layer_list.append(('activation_%d' % i, self.activation()))
            layer_list.append(('dropout_%d' % i, torch.nn.Dropout(dropout)))

        layer_list.append(
            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))
        )

        # deploy layers
        self.layers = torch.nn.Sequential(OrderedDict(layer_list))

    def forward(self, x):
        out = self.layers(x)
        return out  #F.softmax(out, dim=1)

class ParamLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(ParamLayer, self).__init__()
        #self.param = nn.Parameter(torch.randn(out_features, in_features))
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.reset_parameters()
    def reset_parameters(self):
        nn.init.xavier_normal_(self.weight.data)

    def forward(self, x):
        return torch.mm(x,self.weight)

def dot_sim(x, y):
    # Inner product similarity
    ip_sim = torch.mm(x, y)
    return ip_sim


def load_data(path='data/', dataset='tagged', seed=80):
    """Load user network dataset (Tagged only for now)"""
    print('Loading {} dataset...'.format(dataset))
    node_features = pd.read_csv(path + str(dataset) + '_features.csv', header=0)
    if dataset == 'german':
        node_features = node_features.drop(['PurposeOfLoan'], axis=1)
    labels = torch.from_numpy(node_features['label'].to_numpy())  # label tensor
    sensitive_attribute = torch.from_numpy(node_features['sens_2'].to_numpy()).type(torch.LongTensor)
    #sensitive_attribute = torch.from_numpy(node_features['marital_status_indicator'].to_numpy()).type(
    #    torch.LongTensor)  # gender tensor
    #sensitive_attribute_1 = torch.from_numpy(node_features['sens_1'].to_numpy()).type(torch.LongTensor)
    # last three columns are userIds, sensitives and labels
    #node_features = node_features.drop(columns=['slobodny(a)', 'mam vazny vztah', 'zenaty (vydata)', 'rozvedeny(a)'])
    features = node_features[node_features.columns[:-4]].to_numpy()
        
    print(node_features[node_features.columns[:-4]])
    relations = pd.read_csv(path + str(dataset) + '_edges.csv', header=0)
    # build graph
    try:
        adj = sp.coo_matrix((relations['weight'].to_numpy(), (relations['src'].to_numpy(),
                                                              relations['dst'].to_numpy())),
                            shape=(labels.shape[0], labels.shape[0]),
                            dtype=np.float32)
    except KeyError:
        adj = sp.coo_matrix((np.ones(relations.shape[0]), (relations['src'].to_numpy(),
                                                           relations['dst'].to_numpy())),
                            shape=(labels.shape[0], labels.shape[0]),
                            dtype=np.float32)

    # build symmetric adjacency matrix
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

    features = feature_norm(torch.FloatTensor(features)) # normalize feature matrix
    aug_adj = aug_normalized_adjacency(sp.csr_matrix(adj))  # normalize adjacency matrix and add self loop
    aug_adj = sparse_mx_to_torch_sparse_tensor(aug_adj)

    # randomize dataset selection
    random.seed(seed)
    label_idx_0 = np.where(labels == 0)[0]
    label_idx_1 = np.where(labels == 1)[0]

    random.shuffle(label_idx_0)
    random.shuffle(label_idx_1)

    idx_train = np.append(label_idx_0[:int(0.6 * len(label_idx_0))], label_idx_1[:int(0.6 * len(label_idx_1))])
    idx_val = np.append(label_idx_0[int(0.6 * len(label_idx_0)):int(0.8 * len(label_idx_0))],
                        label_idx_1[int(0.6 * len(label_idx_1)):int(0.8 * len(label_idx_1))])
    idx_test = np.append(label_idx_0[int(0.8 * len(label_idx_0)):], label_idx_1[int(0.8 * len(label_idx_1)):])

    train_mask = torch.from_numpy(sample_mask(idx_train, node_features.shape[0]))
    val_mask = torch.from_numpy(sample_mask(idx_val, node_features.shape[0]))
    test_mask = torch.from_numpy(sample_mask(idx_test, node_features.shape[0]))

    return aug_adj, features, labels, sensitive_attribute, train_mask, val_mask, test_mask

In [15]:
dataset = "bail_multi"
lr = 0.0001
weight_decay = 1e-5
dropout = 0.0
epochs = 1000
seed = 10
hidden_dim = 512
path = "data/"

reg_lambda=5e-4

loss_lt= []
loss_adv_lt = []

np.random.seed(seed)
torch.manual_seed(seed)

adj, features, labels, gender, train_mask, val_mask, test_mask= load_data(path=path, dataset=dataset, seed=seed)

train_mask_1 = (gender[train_mask] == 2) | (gender[train_mask]==3)

Loading bail_multi dataset...
       ALCHY  JUNKY  SUPER  MARRIED  FELON  WORKREL  PROPTY  PERSON  PRIORS  \
0          0      0      1        0      0        0       0       0       1   
1          0      0      1        0      0        0       0       0       2   
2          1      0      1        1      0        1       0       0       0   
3          0      0      1        1      0        0       0       0       0   
4          0      0      0        0      0        1       0       0       0   
...      ...    ...    ...      ...    ...      ...     ...     ...     ...   
18871      0      0      1        1      0        1       1       0       0   
18872      0      0      1        1      1        1       0       0       0   
18873      0      1      1        0      0        0       0       0       3   
18874      0      0      1        0      1        1       0       0       3   
18875      0      1      0        0      0        0       0       0       0   

       SCHOOL  RULE  

In [16]:
class_seen = torch.from_numpy(np.load(path+"bail_truth_class_seen.npy"))

In [17]:
#lt_alpha = [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]
lt_alpha = [0,0.001,0.010,0.10]
#lt_alpha = [0.5,0.7,0.9]
#lt_alpha = [0]
#lt_beta = [0]
lt_beta = [0,0.5,1,2,5,10,50]
layers = [128,128,8]
seeds = [10]
for seed in seeds:
    for beta in lt_beta:
        loss_ce_class = []
        loss_ce_seen = []
        loss_ce_sens = []
        loss_ce_dis_1 = []
        loss_ce_dis_2 = []
        total_loss_lt = []
        for alpha in lt_alpha:
            np.random.seed(seed)
            torch.manual_seed(seed)
            model_encoder = GCN(features.shape[1], hidden_dim, 128, dropout=dropout)
            model_encoder_dnn = DNN(layers=layers,dropout=0.0)
            model_param_layer = ParamLayer(4,4)
            optimizer = torch.optim.Adam(model_encoder.parameters(), lr=lr, weight_decay=weight_decay)
            optimizer_dnn = torch.optim.Adam(model_encoder_dnn.parameters(), lr=lr, weight_decay=weight_decay)
            criterion = nn.CrossEntropyLoss()
            #criterion_unseen = nn.CrossEntropyLoss(weight=torch.tensor([1.0,100.0]))
            
            model_class = DNN(layers=[4,128, 128,2],dropout=0.0)
            optimizer_class = torch.optim.Adam(model_class.parameters(), lr=lr, weight_decay=weight_decay)
            
            
            model_unseen = DNN(layers=[4,128,128,2],dropout=0.0)
            optimizer_unseen = torch.optim.Adam(model_unseen.parameters(), lr=lr, weight_decay=weight_decay)
            optimizer_sens = torch.optim.Adam(model_param_layer.parameters(), lr=lr, weight_decay=weight_decay)
            
            for epoch in range(epochs):
                model_encoder.train()
                optimizer.zero_grad()
                optimizer_dnn.zero_grad()
                optimizer_class.zero_grad()
                optimizer_unseen.zero_grad()
                optimizer_sens.zero_grad()
                
                #encoder
                temp_emb = model_encoder(features,adj)
                emb = model_encoder_dnn(temp_emb)
            
                
                #sensitive attribute val prediction
                
                #decoder
                #feat_pred = model_decoder_feat(emb)
                #adj_pred = model_decoder_adj(emb)
                
                
                #mse_loss_feat = torch.sum((feat_pred[train_mask]-features[train_mask])**2)
                #mse_loss_adj = torch.sum((adj_pred.to_sparse()-adj)**2)
                #loss =  mse_loss_adj + mse_loss_feat
                
                #classifier
                
                class_emb = F.softmax(model_class(emb[:,4:].detach()),dim=1)
                loss_class =  criterion(class_emb[train_mask][train_mask_1],labels[train_mask][train_mask_1])
                loss_class.backward()
                optimizer_class.step()
            
                
                #sensitive attribute prediction unseen & seen
                #param_emb_mul = model_param_layer(emb[:,:4].detach())
                unseen_emb = F.softmax(model_unseen(emb[:,:4].detach()),dim=1)
                loss_unseen = criterion(unseen_emb,class_seen)
                loss_unseen.backward()
                optimizer_unseen.step()
                optimizer_sens.step()
                #corr
                #corr_att = torch.cat([torch.flatten(emb[:,2:][train_mask]).unsqueeze(1),torch.flatten(emb[:,:2][train_mask]).unsqueeze(1)],dim=1)
                #corr_matrix = torch.corrcoef(corr_att)
                #corr_v = abs(corr_matrix[0,1])
                #print('here')
                #different losses
                class_emb = F.softmax(model_class(emb[:,4:]),dim=1)
                loss_class =  criterion(class_emb[train_mask][train_mask_1],labels[train_mask][train_mask_1])
                loss_ce_class.append(loss_class.item())
                
                #param_emb_mul = model_param_layer(emb[:,:4].detach())
                unseen_emb = F.softmax(model_unseen(emb[:,:4]),dim=1)
                loss_sens = criterion(unseen_emb[train_mask],class_seen[train_mask])
                loss_ce_seen.append(loss_sens.item())
                
                #sensitive attribute val prediction
                #sens_val_emb =  F.softmax(emb[:,:4])
                #loss_sens_val =  criterion(sens_val_emb[train_mask][train_mask_1],gender[train_mask][train_mask_1])
                #loss_ce_sens.append(loss_sens_val.item())
                
                #class_emb_neg = F.softmax(model_class(emb[:,:4]),dim=1)
                #loss_class_neg = criterion(class_emb_neg[train_mask],labels[train_mask])
                
                #unseen_emb_neg =  F.softmax(model_unseen(emb[:,4:]),dim=1)
                #loss_sens_neg =  criterion(unseen_emb_neg[train_mask],class_seen[train_mask])
    
                #loss_ce_dis_1.append(loss_class_neg.item())
                #loss_ce_dis_2.append(loss_sens_neg.item())
                
                #preds_class = class_emb.max(1)[1].type_as(labels)
                
                idx_s0 = gender== 0
                idx_s1 = gender == 1
                idx_s2 = class_seen == 0
            
                #parity = abs(sum(F.sigmoid(class_emb[idx_s0][:,1])*3) / sum(labels[idx_s0]) - sum(F.sigmoid(class_emb[idx_s1][:,1])*3) / sum(labels[idx_s1]))
                #parity_1 = abs((sum(F.sigmoid(class_emb[idx_s0][:,1])*3) / sum(idx_s0)) - (sum(F.sigmoid(class_emb[idx_s1][:,1])*3) / sum(idx_s1)))
                parity_2 = abs((sum(F.sigmoid(class_emb[idx_s1][:,1])*3) / sum(idx_s1)) - (sum(F.sigmoid(class_emb[idx_s2][:,1])*3) / sum(idx_s2)))
                parity_3 = abs((sum(F.sigmoid(class_emb[idx_s0][:,1])*3) / sum(idx_s0)) - (sum(F.sigmoid(class_emb[idx_s2][:,1])*3) / sum(idx_s2)))
                #parity=max(parity_1,parity_2,parity_3)
                parity=max(parity_2,parity_3)
                
                #vector_class = emb[:,4:].reshape(-1)
                #vector_unseen = emb[:,:4].reshape(-1)
                #vector_class= vector_class.view(vector_class.shape[0], 1)
                #vector_unseen= vector_unseen.view(vector_unseen.shape[0], 1)
                #print(vector_class.shape)
                #print(vector_unseen.shape)
                #vector_both = torch.cat((vector_unseen.T,vector_class.T),dim=0)
                #print(vector_both.shape)
                #coef = torch.abs(torch.corrcoef(vector_both)[0][1])
                coef = torch.abs(torch.corrcoef(emb.T))
                #print(coef)
                #print(coef[:4,4:])
                coef = torch.max(coef[:4,4:])
                #print(coef)
                
                
                total_loss = (1-alpha)*(loss_class+loss_sens)+alpha*(coef) + beta*parity
                #total_loss = (1-alpha)*(loss_class+beta*loss_sens+loss_sens_val)-alpha*(loss_class_neg+ loss_sens_neg) + beta*parity
                
                total_loss.backward()
                optimizer.step()
                optimizer_dnn.step()
                optimizer_unseen.step()
                optimizer_sens.step()
                optimizer_class.step()
                
                if epoch % 200 == 0:
                    #print("Epoch: ", epoch,"loss: ", total_loss.item() )
                    preds_1 = class_emb.max(1)[1].type_as(labels)
                    acc = acc_measurements_multi(class_emb[train_mask], labels[train_mask], gender[train_mask])
                    acc_unseen = acc_measurements_multi(unseen_emb[train_mask], class_seen[train_mask], gender[train_mask])
                    print("Epoch: ", epoch,"loss: ", total_loss.item(),  "acc :",acc[0] ,"seen/unseen acc: ", acc_unseen[0])
                    print("Confusion Matrix class: \n",confusion_matrix(labels[train_mask], preds_1[train_mask]))
                    preds_1 = unseen_emb.max(1)[1].type_as(labels)
                    print("Confusion Matrix sens: \n",confusion_matrix(class_seen[train_mask], preds_1[train_mask]))
                    print("Loss class: ",loss_class.item(),"Loss unseen: ",loss_sens.item(), "Loss Parity: " ,parity.item())
                    
    
            temp_emb = model_encoder(features, adj)
            emb = model_encoder_dnn(temp_emb)
            class_emb = F.softmax(model_class(emb[:, 4:]), dim=1)
            unseen_emb = F.softmax(model_unseen(emb[:, :4]), dim=1)
            
            preds_class = class_emb.max(1)[1].type_as(labels)
            preds_sens = unseen_emb.max(1)[1].type_as(labels)
            
            acc_test = acc_measurements(class_emb[test_mask], labels[test_mask], gender[test_mask])
            auc_roc_test, auc_m, auc_f = auc_measurements(class_emb[test_mask], labels[test_mask],
                                                          gender[test_mask])
            
            parity, equality = fair_metric(preds_class[test_mask].numpy(), labels[test_mask].numpy(),
                                           gender[test_mask].numpy())
            print('---------------EVALUATION------Alpha: ',alpha,' Beta: ',beta,' seed: ', seed,'--------------------------------')
            lt = [acc_test[0], auc_roc_test, auc_m, auc_f, parity, equality]
            print(
                f'|*|Test: acc : {acc_test[0]} || Auc layer: {auc_roc_test} || AUC Male: {auc_m} || AUC Female: {auc_f} || SP: {parity} || EQ: {equality}|*|')
            print('-----------------------------------------------------')
            try:
                acc_test = acc_measurements_multi(class_emb[test_mask], labels[test_mask], gender[test_mask])
                auc_m, auc_f = auc_measurements_multi(class_emb[test_mask], labels[test_mask], gender[test_mask])
                index, parity, equality = fair_metric_multi(preds_class[test_mask].numpy(), labels[test_mask].numpy(),
                                                            gender[test_mask])
                
                print('-----------------------------------------------------')
                lt = [acc_test[0], auc_roc_test, auc_m, auc_f, parity, equality]
                print(
                    f'|*|Test: acc : {acc_test[0]} || Auc layer: {auc_roc_test} || AUC Max: {auc_m} || AUC min: {auc_f} || SP: {parity} || EQ: {equality}|*|')
                
                acc_test = acc_measurements_multi(unseen_emb[test_mask], class_seen[test_mask], gender[test_mask])
                auc_roc_test, auc_m, auc_f = auc_measurements(unseen_emb[test_mask], class_seen[test_mask],
                                                          gender[test_mask])
                print(
                    f'|*|Test sens: acc : {acc_test[0]} || Auc layer: {auc_roc_test} || AUC Max: {auc_m} || AUC min: {auc_f}')
                print("correlation",coef)
                print('-----------------------------------------------------')
            except:
                print("Error with calculating fairness metrics")
            
            str_model = "param_saved/bail_december_2_3_only/bail_model_encoder_alpha_" + str(int(alpha*10000))+"_beta_"+str(int(beta*10))+"_seed_"+str(int(seed))+".pth"
            torch.save(model_encoder.state_dict(), str_model)
            str_model = "param_saved/bail_december_2_3_only/bail_model_encoder_dnn_alpha_" + str(int(alpha*10000))+"_beta_"+str(int(beta*10))+"_seed_"+str(int(seed))+".pth"
            torch.save(model_encoder_dnn.state_dict(), str_model)
            str_model = "param_saved/bail_december_2_3_only/bail_model_class_alpha_" + str(int(alpha*10000))+"_beta_"+str(int(beta*10))+"_seed_"+str(int(seed))+".pth"
            torch.save(model_class.state_dict(), str_model)
            str_model = "param_saved/bail_december_2_3_only/bail_model_unseen_alpha_" + str(int(alpha*10000))+"_beta_"+str(int(beta*10))+"_seed_"+str(int(seed))+".pth"
            torch.save(model_unseen.state_dict(), str_model)

[0.5304600082884376, 0.5632732797916215, 0.5793614067561315, 0.6431902985074627]
[0.0, 0.0, 1.0, 1.0]
Epoch:  0 loss:  1.392905354499817 acc : 0.5744812362030906 seen/unseen acc:  0.3801324503311258
Confusion Matrix class: 
 [[6119  944]
 [3875  387]]
Confusion Matrix sens: 
 [[4305    0]
 [7020    0]]
Loss class:  0.6929842233657837 Loss unseen:  0.6999211311340332 Loss Parity:  0.0003268718719482422
[0.7612929962702031, 0.8484914260907315, 0.8514576584914392, 0.8325559701492538]
[0.9912971404890178, 0.9350987627523334, 0.8496066635816751, 0.992070895522388]
Epoch:  200 loss:  0.867132306098938 acc : 0.8274613686534217 seen/unseen acc:  0.9415452538631347
Confusion Matrix class: 
 [[7013   50]
 [1904 2358]]
Confusion Matrix sens: 
 [[3963  342]
 [ 320 6700]]
Loss class:  0.4948713481426239 Loss unseen:  0.3722609579563141 Loss Parity:  0.07431697845458984
[0.7592208868628264, 0.8903841979596266, 0.9361406756131421, 0.9132462686567164]
[0.9908827186075425, 0.9546342522248752, 0.8958815

In [11]:
def fair_metric_multi(pred, labels, sens):
    parity_lt = []
    eq_lt = []
    index = []
    for i in torch.unique(sens):
        idx_s0 = sens == i
        # idx_s0_y1 = np.bitwise_and(idx_s0, labels == 1)
        tn, fp, fn, tp = confusion_matrix(labels[idx_s0], pred[idx_s0]).ravel()
        sp_i = sum(pred[idx_s0]) / sum(idx_s0)
        eq_i = tp / sum(labels[idx_s0])
        for j in torch.unique(sens):
            if i != j:
                idx_s1 = sens == j
                tn_j, fp_j, fn_j, tp_j = confusion_matrix(labels[idx_s1], pred[idx_s1]).ravel()
                sp_j = sum(pred[idx_s1]) / sum(idx_s1)
                eq_j = tp_j / sum(labels[idx_s1])

                parity = abs(sp_i - sp_j)
                equality = abs(eq_i - eq_j)

                index.append([i, j])
                parity_lt.append(parity.item())
                eq_lt.append(equality.item())
    max_ = max(parity_lt)
    parity_index = parity_lt.count(max_)
    max_eq = max(eq_lt)
    eq_index = eq_lt.count(max_eq)
    print('index of sp: ', index[parity_index])
    print('index of eq: ', index[eq_index])
    # print(parity_lt)
    # print(eq_lt)
    return index, max_, max_eq

In [None]:
from sklearn.metrics import f1_score
f1_lt = []
auc_lt = []
sp_lt =[]
eo_lt = []
model_encoder = GCN(features.shape[1], hidden_dim, 128, dropout=dropout)
model_encoder_dnn = DNN(layers=layers,dropout=0.0)
model_class = DNN(layers=[4,128, 128,2],dropout=0.0)
model_unseen = DNN(layers=[4,128,128,2],dropout=0.0)
lt_alpha= [1.0]
for alpha in lt_alpha:
    str_model = "param_saved/tagged_eo_reg_alldata/tagged_model_encoder_alpha_" + str(int(alpha*10))+".pth"
    model_encoder.load_state_dict(torch.load(str_model))
    str_model = "param_saved/tagged_eo_reg_alldata/tagged_model_encoder_dnn_alpha_" + str(int(alpha*10))+".pth"
    model_encoder_dnn.load_state_dict(torch.load(str_model))
    str_model = "param_saved/tagged_eo_reg_alldata/tagged_model_class_alpha_" + str(int(alpha*10))+".pth"
    model_class.load_state_dict(torch.load(str_model))
    str_model = "param_saved/tagged_eo_reg_alldata/tagged_model_unseen_alpha_" + str(int(alpha*10))+".pth"
    model_unseen.load_state_dict(torch.load(str_model))
    
    temp_emb = model_encoder(features, adj)
    emb = model_encoder_dnn(temp_emb)
    class_emb = F.softmax(model_class(emb[:, 4:]), dim=1)
    unseen_emb = F.softmax(model_unseen(emb[:, :4]), dim=1)
    
    preds_class = class_emb.max(1)[1].type_as(labels)
    preds_sens = unseen_emb.max(1)[1].type_as(labels)
    
    acc_test = acc_measurements(class_emb[test_mask], labels[test_mask], gender[test_mask])
    auc_roc_test, auc_m, auc_f = auc_measurements(class_emb[test_mask], labels[test_mask],
                                                  gender[test_mask])
    
    parity, equality = fair_metric(preds_class[test_mask].numpy(), labels[test_mask].numpy(),
                                   gender[test_mask].numpy())
    print('---------------EVALUATION----------------Alpha: ',alpha,'-------------------')
    lt = [acc_test[0], auc_roc_test, auc_m, auc_f, parity, equality]
    #print(
    #    f'|*|Test: acc : {acc_test[0]} || Auc layer: {auc_roc_test} || AUC Male: {auc_m} || AUC Female: {auc_f} || SP: {parity} || EQ: {equality}|*|')
    print('-----------------------------------------------------')
    acc_test = acc_measurements_multi(class_emb[test_mask], labels[test_mask], gender[test_mask])
    auc_m, auc_f = auc_measurements_multi(class_emb[test_mask], labels[test_mask], gender[test_mask])
    index, parity, equality = fair_metric_multi(preds_class[test_mask].numpy(), labels[test_mask].numpy(),
                                                gender[test_mask])
    f1_s = f1_score(labels[test_mask].numpy(), preds_class[test_mask].detach().numpy())
    
    f1_lt.append(f1_s)
    auc_lt.append(auc_roc_test)
    sp_lt.append(parity)
    eo_lt.append(equality)
    
    print('-----------------------------------------------------')
    lt = [acc_test[0], auc_roc_test, auc_m, auc_f, parity, equality]
    print(
        f'|*|Test: acc : {acc_test[0]} || f1-score:{f1_s} Auc layer: {auc_roc_test} || AUC Max: {auc_m} || AUC min: {auc_f} || SP: {parity} || EQ: {equality}|*|')
    print('-----------------------------------------------------')