In [1]:
# I used python kernel 3.10.15 + nvidia RTX 3090 + cuda 12.1 in local machine

# install ralated modules
%pip install torch==2.1.1 torchvision==0.16.1   # compatible version of pytorch and torchvision for mamba-ssm 
%pip install causal-conv1d==1.1.1   # causal dpthwise conv 1d  module in CUDA with pytorch
%pip install mamba-ssm  # Mamba block module

Note: you may need to restart the kernel to use updated packages.
Collecting argparse (from buildtools->causal-conv1d==1.1.1)
  Using cached argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Using cached argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: argparse
Successfully installed argparse-1.4.0
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# import modules

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchsummary import summary

from sklearn.utils import class_weight
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, OrdinalEncoder, StandardScaler,MinMaxScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve
import tqdm

from mamba_ssm import Mamba

import numpy as np
import pandas as pd
from pandas.api.types import is_string_dtype
import matplotlib.pyplot as plt

import copy
from collections import defaultdict

In [3]:
# Check environments
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
torch.cuda.is_available()

torch:  2.1 ; cuda:  cu121


True

In [4]:
# Get original data
!mkdir ./datasets
!mkdir ./datasets/adult
!wget -nc https://archive.ics.uci.edu/static/public/2/adult.zip
!unzip -o ./adult.zip -d ./datasets/adult
!cp -rf ./datasets/adult/adult.data ./datasets/adult/data_processed.csv

mkdir: `./datasets' 디렉터리를 만들 수 없습니다: 파일이 있습니다
mkdir: `./datasets/adult' 디렉터리를 만들 수 없습니다: 파일이 있습니다
‘adult.zip’ 파일이 이미 있습니다. 가져오지 않음.

Archive:  ./adult.zip
  inflating: ./datasets/adult/Index  
  inflating: ./datasets/adult/adult.data  
  inflating: ./datasets/adult/adult.names  
  inflating: ./datasets/adult/adult.test  
  inflating: ./datasets/adult/old.adult.names  


In [5]:
# Set Configuration for MambaTab
config={
    'DATASET_NAME':'adult',
    'SEED':15, # random seed 지정
    'BATCH':100,
    'LR':0.0001,
    'EPOCH':1000,
    'MAMBA_SSM_DIM':32,  # MAMBA model의 dimension 설정 (d_model: Selective Structured State Machine에 담을 최대 Dimension)
    'device':'cuda'}

In [6]:
# data load and preparing

def read_data(dataset_name):
    data=pd.read_csv('./datasets/'+dataset_name+'/data_processed'+'.csv')
    
    # fill null values
    for col in data.columns: 
        #data[col].fillna(data[col].mode()[0], inplace=True)
        data[col] = data[col].fillna(data[col].mode()[0])

    # categorical encoder: 문자열인 경우 소문자로 통일하고, 숫자로 인코딩 처리
    for c in data.columns:
        if is_string_dtype(data[c]):
            data[c]=data[c].str.lower()
            enc=OrdinalEncoder()
            cur_data=np.array(data[c])
            cur_data=np.reshape(cur_data,(cur_data.shape[0],1))
            data[c] = enc.fit_transform(cur_data)

    # 마지막 column을 lable로 추출
    y_data=data[data.columns[-1]]

    # 9번째 컬럼(gender)을 공정성 지표 계산을 위해 별도로 추출
    sensitive_data=data[data.columns[9]]

    # label 컬럼 제거
    x_data = data.drop(labels = [data.columns[-1]],axis = 1)
    
    # 나머지 컬럼 스케일링 처리
    x_data=MinMaxScaler().fit_transform(x_data)
    
    x_data, y_data, sensitive_data = np.array(x_data),np.array(y_data), np.array(sensitive_data)
    
    return x_data, y_data, sensitive_data

In [7]:
# check sensitive column

d = pd.read_csv('./datasets/adult/data_processed.csv')
d[d.columns[9]]

0           Male
1           Male
2           Male
3         Female
4         Female
          ...   
32555     Female
32556       Male
32557     Female
32558       Male
32559     Female
Name:  Male, Length: 32560, dtype: object

In [8]:
# MambaTab Class

class MambaTab(torch.nn.Module):

    def __init__(self,input_features, n_class, intermediate_representation=config['MAMBA_SSM_DIM']):
        super(MambaTab, self).__init__()
        self.linear_layer=torch.nn.Linear(input_features,intermediate_representation)
        self.relu=torch.nn.ReLU()
        self.layer_norm=torch.nn.LayerNorm(intermediate_representation)

        self.mamba=Mamba(d_model=intermediate_representation, d_state=32, d_conv=4, expand=2) # to fine-tuning
        self.output_layer=torch.nn.Linear(intermediate_representation,n_class)
    
    # deault model 참조해서 building
    def forward(self, x):
         x=self.linear_layer(x)
         x=self.layer_norm(x)
         x=self.relu(x)
         x=self.mamba(x)
         x=self.output_layer(x)
         return x

In [9]:
# Training function

def train_model(model, config, dataloader):
    best_model_wts = copy.deepcopy(model.state_dict()) # 최적 가중치 저장
    best_loss = 1e10 # 최적 손실값 초기화
    early_stopping_counter=0 # earlt stopping 값 초기화

    optimizer=torch.optim.Adam(model.parameters(),lr=config['LR'])  # Optimizer setting: Adam
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['EPOCH'], eta_min=0,verbose=False) # learning rate 감소를 위한 scheduler setting
    loss_fn=torch.nn.BCEWithLogitsLoss()    # 이진분류 처리를 위해 BCEWithLogitLoss 함수 사용
  
    # 학습 진행
    for epoch in tqdm.tqdm(range(config['EPOCH'])):
        if early_stopping_counter>=5:
          break
        
        for phase in ['train', 'val']:      
            if phase == 'train':               
                model.train()  
            else:
                model.eval()  
            
            metrics = defaultdict(float)
            epoch_samples = 0
          
            for btch,feed_dict in enumerate(dataloader[phase]):
                inputs=feed_dict[0]
                inputs=inputs.unsqueeze(0)
                labels=feed_dict[1]
                sensitives=feed_dict[2]
                
                inputs = inputs.type(torch.FloatTensor)
                inputs = inputs.to(config['device'])
                labels = labels.type(torch.FloatTensor)
                labels = labels.to(config['device'])
                # sensitives = sensitives.type(torch.FloatTensor)
                # sensitives = sensitives.to(config['device'])

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'): # 학습 단계에서만 gradient 계산

                    outputs = model(inputs)  # 모델 학습
                    outputs=outputs.squeeze()  
                    loss=loss_fn(outputs,labels) # 손실 계산
                    metrics['loss']+=loss.item()
                
                    if phase == 'train':
                        loss.backward() # gradient 계산
                        optimizer.step() # graident 업데이트
                
                epoch_samples += 1 
           
            epoch_loss = metrics['loss'] / epoch_samples # epoch 총손실 계산

            if phase == 'val':
           
                if epoch_loss<best_loss:
                    best_model_wts = copy.deepcopy(model.state_dict()) # 최적 가중치 저장
                    best_loss=epoch_loss
                    early_stopping_counter=0
                else:
                    early_stopping_counter+=1
            
        print(f"Epoch [{epoch+1}/config['EPOCH'], loss of epoch: {metrics['loss']}")

        scheduler.step()           
    
    model.load_state_dict(best_model_wts) # 최적 가중치 불러오기

    print ("training completed")
    
    return model

In [11]:
def test_result(test_model, test_dataloader):

    test_model.eval()
    
    all_probs=[]
    all_labels=[]
    all_sensitives = []

    sig=torch.nn.Sigmoid()  # 이진분류 처리를 위해 BCEWithLogitLoss 함수와 함께 Sigmoid 사용

    for inputs,labels, sensitives in test_dataloader['test']:
        
        inputs = inputs.unsqueeze(0)
        inputs = inputs.type(torch.FloatTensor)
        
        inputs = inputs.to(config['device'])
        labels = labels.to(config['device'])
        sensitives = sensitives.to(config['device'])

        with torch.set_grad_enabled(False):
            outputs = test_model(inputs) # 모델 예측
            outputs=outputs.squeeze()
            
            outputs=sig(outputs)

            # Detach 처리         
            outputs=outputs.cpu().detach().numpy()
            labels=labels.cpu().detach().numpy()
            sensitives=sensitives.cpu().detach().numpy()
            
            # 실제 값, 예측 값, 민감 속성 저장
            for i in range(outputs.shape[0]):
                all_labels.append(labels[i])
                all_probs.append(outputs[i])
                all_sensitives.append(sensitives[i])
    
    print("test completed")

    return all_labels, all_probs, all_sensitives 


In [12]:
# Data loading and data split

x_data, y_data, sensitive_data = read_data(dataset_name=config['DATASET_NAME'])

x_train, x_test, y_train, y_test, sensitive_train, sensitive_test = train_test_split(x_data, y_data, sensitive_data, test_size=0.2,random_state=config['SEED'],stratify=y_data,shuffle=True)
val_size = int(len(y_data)*0.1)
x_train, x_val, y_train, y_val, sensitive_train, sensitive_val = train_test_split(x_train, y_train, sensitive_train, test_size=val_size,random_state=config['SEED'],stratify=y_train, shuffle=True)

print("Train:",x_train.shape)
print("Val:",x_val.shape)
print("Test:",x_test.shape)

Train: (22792, 14)
Val: (3256, 14)
Test: (6512, 14)


In [13]:
# convert data from numpy float array to tensor
x_train = torch.FloatTensor(x_train)
x_val = torch.FloatTensor(x_val)
x_test = torch.FloatTensor(x_test)

y_train = torch.FloatTensor(y_train)
y_val = torch.FloatTensor(y_val)
y_test = torch.FloatTensor(y_test)

sensitive_train = torch.FloatTensor(sensitive_train)
sensitive_val = torch.FloatTensor(sensitive_val)
sensitive_test = torch.FloatTensor(sensitive_test)

# dataset grouping
train_set = TensorDataset(x_train, y_train, sensitive_train)
val_set = TensorDataset(x_val, y_val, sensitive_val)
test_set = TensorDataset(x_test, y_test, sensitive_test)

# build data lodaer
dataloader = {
      'train': DataLoader(train_set, batch_size=config['BATCH'], shuffle=True, num_workers=4),
      'val': DataLoader(val_set, batch_size=config['BATCH'], shuffle=False, num_workers=4),
      'test': DataLoader(test_set, batch_size=config['BATCH'], shuffle=False, num_workers=4)
   }

# Get the model: "n_class=1 is to use a single output logit strategy,  where n_class does not refer to the number of classes and is sufficient for binary classification"
model=MambaTab(input_features=x_train.shape[1], n_class=1)
model=model.to(config['device'])


In [14]:
for batch_data in dataloader['test']:

    inputs, labels, sensitive = batch_data
    print("Batch inputs: ", inputs)
    print("Batch labels: ", labels)
    print("Batch sensitive features: ", sensitive)


Batch inputs:  tensor([[0.0000, 0.5000, 0.0696,  ..., 0.0000, 0.1939, 0.9512],
        [0.5205, 0.5000, 0.1006,  ..., 0.0000, 0.3980, 0.9512],
        [0.1096, 0.5000, 0.0752,  ..., 0.0000, 0.3980, 0.9512],
        ...,
        [0.2192, 0.5000, 0.1678,  ..., 0.0000, 0.4082, 0.9512],
        [0.2877, 0.5000, 0.1005,  ..., 0.4708, 0.4388, 0.9512],
        [0.1507, 0.5000, 0.0274,  ..., 0.0000, 0.5000, 0.9512]])
Batch labels:  tensor([0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1.,
        0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 0., 1.])
Batch sensitive features:  tensor([0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 

In [15]:
summary(model, x_train.shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1            [-1, 22792, 32]             480
         LayerNorm-2            [-1, 22792, 32]              64
              ReLU-3            [-1, 22792, 32]               0
            Conv1d-4            [-1, 64, 22795]             320
              SiLU-5            [-1, 64, 22792]               0
            Linear-6                   [-1, 66]           4,224
            Linear-7            [-1, 22792, 32]           2,048
             Mamba-8            [-1, 22792, 32]               0
            Linear-9             [-1, 22792, 1]              33
Total params: 7,169
Trainable params: 7,169
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.22
Forward/backward pass size (MB): 50.26
Params size (MB): 0.03
Estimated Total Size (MB): 51.50
---------------------------------------------

In [16]:
# Train-validate the model
model=train_model(model, config, dataloader)




  0%|          | 1/1000 [00:01<19:55,  1.20s/it]

Epoch [1/config['EPOCH'], loss of epoch: 17.769164711236954


  0%|          | 2/1000 [00:02<19:17,  1.16s/it]

Epoch [2/config['EPOCH'], loss of epoch: 15.35576856136322


  0%|          | 3/1000 [00:03<18:59,  1.14s/it]

Epoch [3/config['EPOCH'], loss of epoch: 14.282690972089767


  0%|          | 4/1000 [00:04<18:39,  1.12s/it]

Epoch [4/config['EPOCH'], loss of epoch: 13.630553483963013


  0%|          | 5/1000 [00:05<18:42,  1.13s/it]

Epoch [5/config['EPOCH'], loss of epoch: 13.332520961761475


  1%|          | 6/1000 [00:06<18:39,  1.13s/it]

Epoch [6/config['EPOCH'], loss of epoch: 12.920447587966919


  1%|          | 7/1000 [00:07<18:43,  1.13s/it]

Epoch [7/config['EPOCH'], loss of epoch: 12.691672950983047


  1%|          | 8/1000 [00:09<18:36,  1.13s/it]

Epoch [8/config['EPOCH'], loss of epoch: 12.493073672056198


  1%|          | 9/1000 [00:10<18:32,  1.12s/it]

Epoch [9/config['EPOCH'], loss of epoch: 12.345839977264404


  1%|          | 10/1000 [00:11<18:36,  1.13s/it]

Epoch [10/config['EPOCH'], loss of epoch: 12.190184563398361


  1%|          | 11/1000 [00:12<18:36,  1.13s/it]

Epoch [11/config['EPOCH'], loss of epoch: 12.235721856355667


  1%|          | 12/1000 [00:13<18:42,  1.14s/it]

Epoch [12/config['EPOCH'], loss of epoch: 12.102530807256699


  1%|▏         | 13/1000 [00:14<18:54,  1.15s/it]

Epoch [13/config['EPOCH'], loss of epoch: 11.892494916915894


  1%|▏         | 14/1000 [00:15<18:54,  1.15s/it]

Epoch [14/config['EPOCH'], loss of epoch: 11.896155446767807


  2%|▏         | 15/1000 [00:17<18:44,  1.14s/it]

Epoch [15/config['EPOCH'], loss of epoch: 11.786630615592003


  2%|▏         | 16/1000 [00:18<18:40,  1.14s/it]

Epoch [16/config['EPOCH'], loss of epoch: 11.756490260362625


  2%|▏         | 17/1000 [00:19<18:36,  1.14s/it]

Epoch [17/config['EPOCH'], loss of epoch: 11.64432530105114


  2%|▏         | 18/1000 [00:20<18:26,  1.13s/it]

Epoch [18/config['EPOCH'], loss of epoch: 11.643199846148491


  2%|▏         | 19/1000 [00:21<18:24,  1.13s/it]

Epoch [19/config['EPOCH'], loss of epoch: 11.542551413178444


  2%|▏         | 20/1000 [00:22<18:29,  1.13s/it]

Epoch [20/config['EPOCH'], loss of epoch: 11.492808148264885


  2%|▏         | 21/1000 [00:23<18:26,  1.13s/it]

Epoch [21/config['EPOCH'], loss of epoch: 11.448875144124031


  2%|▏         | 22/1000 [00:24<18:11,  1.12s/it]

Epoch [22/config['EPOCH'], loss of epoch: 11.466156601905823


  2%|▏         | 23/1000 [00:26<18:11,  1.12s/it]

Epoch [23/config['EPOCH'], loss of epoch: 11.40346945822239


  2%|▏         | 24/1000 [00:27<18:06,  1.11s/it]

Epoch [24/config['EPOCH'], loss of epoch: 11.413935884833336


  2%|▎         | 25/1000 [00:28<18:03,  1.11s/it]

Epoch [25/config['EPOCH'], loss of epoch: 11.404731273651123


  3%|▎         | 26/1000 [00:29<18:04,  1.11s/it]

Epoch [26/config['EPOCH'], loss of epoch: 11.34216271340847


  3%|▎         | 27/1000 [00:30<18:01,  1.11s/it]

Epoch [27/config['EPOCH'], loss of epoch: 11.350034102797508


  3%|▎         | 28/1000 [00:31<17:57,  1.11s/it]

Epoch [28/config['EPOCH'], loss of epoch: 11.279611125588417


  3%|▎         | 29/1000 [00:32<17:58,  1.11s/it]

Epoch [29/config['EPOCH'], loss of epoch: 11.318166300654411


  3%|▎         | 30/1000 [00:33<17:48,  1.10s/it]

Epoch [30/config['EPOCH'], loss of epoch: 11.256501272320747


  3%|▎         | 31/1000 [00:34<17:59,  1.11s/it]

Epoch [31/config['EPOCH'], loss of epoch: 11.234853848814964


  3%|▎         | 32/1000 [00:36<18:05,  1.12s/it]

Epoch [32/config['EPOCH'], loss of epoch: 11.209640473127365


  3%|▎         | 33/1000 [00:37<18:10,  1.13s/it]

Epoch [33/config['EPOCH'], loss of epoch: 11.188303589820862


  3%|▎         | 34/1000 [00:38<18:00,  1.12s/it]

Epoch [34/config['EPOCH'], loss of epoch: 11.224361717700958


  4%|▎         | 35/1000 [00:39<17:56,  1.12s/it]

Epoch [35/config['EPOCH'], loss of epoch: 11.257421866059303


  4%|▎         | 36/1000 [00:40<18:02,  1.12s/it]

Epoch [36/config['EPOCH'], loss of epoch: 11.144161328673363


  4%|▎         | 37/1000 [00:41<17:55,  1.12s/it]

Epoch [37/config['EPOCH'], loss of epoch: 11.213214978575706


  4%|▍         | 38/1000 [00:42<17:47,  1.11s/it]

Epoch [38/config['EPOCH'], loss of epoch: 11.113940000534058


  4%|▍         | 39/1000 [00:43<17:46,  1.11s/it]

Epoch [39/config['EPOCH'], loss of epoch: 11.191370666027069


  4%|▍         | 40/1000 [00:44<17:53,  1.12s/it]

Epoch [40/config['EPOCH'], loss of epoch: 11.135790959000587


  4%|▍         | 41/1000 [00:46<17:53,  1.12s/it]

Epoch [41/config['EPOCH'], loss of epoch: 11.122004956007004


  4%|▍         | 42/1000 [00:47<17:48,  1.12s/it]

Epoch [42/config['EPOCH'], loss of epoch: 11.073901668190956


  4%|▍         | 43/1000 [00:48<17:40,  1.11s/it]

Epoch [43/config['EPOCH'], loss of epoch: 11.18762369453907


  4%|▍         | 44/1000 [00:49<17:40,  1.11s/it]

Epoch [44/config['EPOCH'], loss of epoch: 11.216801807284355


  4%|▍         | 45/1000 [00:50<17:39,  1.11s/it]

Epoch [45/config['EPOCH'], loss of epoch: 11.046562999486923


  5%|▍         | 46/1000 [00:51<17:41,  1.11s/it]

Epoch [46/config['EPOCH'], loss of epoch: 11.191536903381348


  5%|▍         | 47/1000 [00:52<17:41,  1.11s/it]

Epoch [47/config['EPOCH'], loss of epoch: 11.057764664292336


  5%|▍         | 48/1000 [00:53<17:33,  1.11s/it]

Epoch [48/config['EPOCH'], loss of epoch: 11.058096393942833


  5%|▍         | 49/1000 [00:54<17:36,  1.11s/it]

Epoch [49/config['EPOCH'], loss of epoch: 11.037525907158852


  5%|▌         | 50/1000 [00:56<17:28,  1.10s/it]

Epoch [50/config['EPOCH'], loss of epoch: 11.007569193840027


  5%|▌         | 51/1000 [00:57<17:40,  1.12s/it]

Epoch [51/config['EPOCH'], loss of epoch: 11.017319217324257


  5%|▌         | 52/1000 [00:58<17:51,  1.13s/it]

Epoch [52/config['EPOCH'], loss of epoch: 11.039473488926888


  5%|▌         | 53/1000 [00:59<18:05,  1.15s/it]

Epoch [53/config['EPOCH'], loss of epoch: 10.984064921736717


  5%|▌         | 54/1000 [01:00<18:02,  1.14s/it]

Epoch [54/config['EPOCH'], loss of epoch: 10.99089914560318


  6%|▌         | 55/1000 [01:01<18:01,  1.14s/it]

Epoch [55/config['EPOCH'], loss of epoch: 10.987623170018196


  6%|▌         | 56/1000 [01:03<18:11,  1.16s/it]

Epoch [56/config['EPOCH'], loss of epoch: 11.147830426692963


  6%|▌         | 57/1000 [01:04<18:08,  1.15s/it]

Epoch [57/config['EPOCH'], loss of epoch: 10.993157342076302


  6%|▌         | 58/1000 [01:05<18:07,  1.15s/it]

Epoch [58/config['EPOCH'], loss of epoch: 10.977584674954414


  6%|▌         | 59/1000 [01:06<18:07,  1.16s/it]

Epoch [59/config['EPOCH'], loss of epoch: 10.981714591383934


  6%|▌         | 60/1000 [01:07<18:13,  1.16s/it]

Epoch [60/config['EPOCH'], loss of epoch: 10.95949113368988


  6%|▌         | 61/1000 [01:08<18:09,  1.16s/it]

Epoch [61/config['EPOCH'], loss of epoch: 11.026782363653183


  6%|▌         | 62/1000 [01:09<18:10,  1.16s/it]

Epoch [62/config['EPOCH'], loss of epoch: 10.935862004756927


  6%|▋         | 63/1000 [01:11<18:04,  1.16s/it]

Epoch [63/config['EPOCH'], loss of epoch: 10.987715765833855


  6%|▋         | 64/1000 [01:12<18:08,  1.16s/it]

Epoch [64/config['EPOCH'], loss of epoch: 10.989186331629753


  6%|▋         | 65/1000 [01:13<18:09,  1.17s/it]

Epoch [65/config['EPOCH'], loss of epoch: 10.900382995605469


  7%|▋         | 66/1000 [01:14<18:09,  1.17s/it]

Epoch [66/config['EPOCH'], loss of epoch: 10.92583404481411


  7%|▋         | 67/1000 [01:15<18:13,  1.17s/it]

Epoch [67/config['EPOCH'], loss of epoch: 10.951724514365196


  7%|▋         | 68/1000 [01:16<18:10,  1.17s/it]

Epoch [68/config['EPOCH'], loss of epoch: 10.93062537908554


  7%|▋         | 69/1000 [01:18<18:00,  1.16s/it]

Epoch [69/config['EPOCH'], loss of epoch: 10.92515940964222


  7%|▋         | 70/1000 [01:19<17:32,  1.13s/it]

Epoch [70/config['EPOCH'], loss of epoch: 10.950438067317009
training completed





In [17]:
# Test model
all_labels, all_probs, all_sensitives = test_result(model, dataloader)

test completed


In [None]:
def calculate_fairness_metrics(y_true, y_pred, sensitive_attributes):
    
    metrics = {}
    
    # Demographic Parity: P(Y_pred=1 | A=0) = P(Y_pred=1 | A=1)
    dp_0 = np.mean(y_pred[sensitive_attributes == 0])
    dp_1 = np.mean(y_pred[sensitive_attributes == 1])
    
    metrics['Demographic Parity'] = abs(dp_0 - dp_1)
    
    # Equal Opportunity: P(Y_pred=1 | Y_true=1, A=0) = P(Y_pred=1 | Y_true=1, A=1)
    eo_0 = np.mean(y_pred[(y_true == 1) & (sensitive_attributes == 0)])
    eo_1 = np.mean(y_pred[(y_true == 1) & (sensitive_attributes == 1)])
    
    metrics['Equal Opportunity'] = abs(eo_0 - eo_1)
    
    # Equality of Odds: Same true positive and false positive rates for both groups
    tp_0 = np.mean(y_pred[(y_true == 1) & (sensitive_attributes == 0)])
    fp_0 = np.mean(y_pred[(y_true == 0) & (sensitive_attributes == 0)])
    tp_1 = np.mean(y_pred[(y_true == 1) & (sensitive_attributes == 1)])
    fp_1 = np.mean(y_pred[(y_true == 0) & (sensitive_attributes == 1)])
    
    metrics['Equality of Odds'] = abs(tp_0 - tp_1) + abs(fp_0 - fp_1)
    
    return metrics

In [32]:
 # AUROC 계산        
auroc_score = roc_auc_score(all_labels, all_probs)
print("\nAUROC score: ", auroc_score)
  
# 공성정 지표 계산
fairness_scores = calculate_fairness_metrics(np.array(all_labels), np.array(all_probs), np.array(all_sensitives))

print("\n공정성 지표\n")
for metric, score in fairness_scores.items():
    print(f"{metric}: {score}")


AUROC score:  0.9022063695016843

공정성 지표

Demographic Parity: 0.19244429469108582
Equal Opportunity: 0.049435317516326904
Equality of Odds: 0.16788774728775024


list

In [None]:
#ROC Curve visualization

fpr, tpr, thresholds = roc_curve(all_labels, all_probs)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f'ROC curve (AUROC = {auroc_score:.2f})')
plt.plot([0, 1], [0, 1], 'r--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.tight_layout()
plt.show()
