# STAFF - CoSTCo

In [1]:
import sys
sys.path.append('../src')
from read import *
from model import *
from train import *
from metric import *
device = 'cuda:0'

### 1. Configuration of data, model, and augmentation.

In [2]:
data = 'lastfm_time'
data_path = '../data'

cfg = DotMap()
cfg.name = data
cfg.dpath = data_path
cfg.opath = '../output'
cfg.unfair = 0.05
cfg.bs = 1024
cfg.random = 1
cfg.device = device
tensor = TensorDataset(cfg=cfg, path=cfg.dpath, name=cfg.name)
cfg.sizes = tensor.sizes

***********************************************************************************
[1] Read lastfm_time self...
[2] Read metadata...
[3] No normalization; values are already binary...
[4] Split the tensor into training/validation/test
 [4 - 1] Sparsify the minority group to make it more unfair
[5] Make statistics of group information
[6] Change the date type into torch
[7] Read lastfm_time tensor done...!
Tensor      || ['user', 'artist', 'time']; value
NNZ         || [861, 3066, 1586]; 76727 | 14311 | 14311
Sens. Attr  || user, gender: maj(['M']) min(['F'])
Entity      || Majority: 493 Minority: 368
NNZ         || Majority: [74740] Minority: [1987]
***********************************************************************************


### 2. Configuration of augmentation.

In [3]:
cfg.rank = 10
cfg.lr = 0.001
cfg.wd = 0.001
cfg.n_iters = 10000
verbose = True

In [4]:
cfg.nc = 64
cfg.tf = 'costco'
cfg.aug_tf = 'costco'
cfg.sampling = 'knn'
cfg.aug_modes ="0"
cfg.K = 3
cfg.gamma = 0.9
cfg.wd2 = 0.01
cfg.aug_training = True
tensor.load_data()

In [5]:
read_augment(tensor, cfg)

***********************************************************************************
Augment entities with fair K-NN graph 
Augmentation for the 'user' mode 
Save file as [../output/lastfm_time/sampling/0.05_costco_0.9_3_1_dist.csv]
Save file as [../output/lastfm_time/sampling/0.05_costco_0.9_3_1_graph.csv]
Save file as [../output/lastfm_time/sampling/0.05_costco_0.9_3_1_df.csv]


### 3. Building a model

In [6]:
model = CoSTCo(cfg).to(cfg.device)
model

CoSTCo(
  (factors): ParameterList(
      (0): Parameter containing: [torch.float32 of size 1722x10 (GPU 0)]
      (1): Parameter containing: [torch.float32 of size 3066x10 (GPU 0)]
      (2): Parameter containing: [torch.float32 of size 1586x10 (GPU 0)]
  )
  (conv1): Conv2d(1, 64, kernel_size=(1, 3), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(10, 1), stride=(1, 1))
  (fc1): Linear(in_features=64, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=1, bias=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
  (last_act): ReLU()
)

### 4. Train a model

In [7]:
trainer = Trainer(model, tensor, cfg, wandb=None)
trainer.train()

Iters:   1 || training loss: 14.24993	fair loss: 277010.84058	Train RMSE: 0.33100 Valid RMSE: 0.34379	
Iters:   2 || training loss: 11.41565	fair loss: 159649.60327	Train RMSE: 0.32834 Valid RMSE: 0.34139	
Iters:   3 || training loss: 11.23882	fair loss: 91306.07678	Train RMSE: 0.32526 Valid RMSE: 0.33872	
Iters:   4 || training loss: 10.96116	fair loss: 51359.21594	Train RMSE: 0.32046 Valid RMSE: 0.33442	
Iters:   5 || training loss: 10.54917	fair loss: 28267.64645	Train RMSE: 0.31213 Valid RMSE: 0.32752	
Iters:   6 || training loss: 9.87865	fair loss: 15167.12035	Train RMSE: 0.29734 Valid RMSE: 0.31638	
Iters:   7 || training loss: 8.75933	fair loss: 7908.57309	Train RMSE: 0.27299 Valid RMSE: 0.30271	
Iters:   8 || training loss: 7.59871	fair loss: 3995.84687	Train RMSE: 0.25567 Valid RMSE: 0.29722	
Iters:   9 || training loss: 6.96648	fair loss: 1950.89920	Train RMSE: 0.24854 Valid RMSE: 0.29572	
Iters:  10 || training loss: 6.77588	fair loss: 917.99455	Train RMSE: 0.24572 Valid RMS

### 5. Evaluate fairness and accuracy of model for tensor completion

In [8]:
res = evaluate_model(model, tensor)
print(f"MSE : {res['test_rmse'] * res['test_rmse']:.4f}")
print(f"MAD: {abs(res['MAD_Error']):.5f} Error1 : {res['Group0_Error']:.5f} Error2: {res['Group1_Error']:.5f}")

Test NRE: 0.2717 Test RMSE: 0.2521
***********************************************************************************
Calculate group fairness...
***********************************************************************************
MSE : 0.0635
MAD: 0.05190 Error1 : 0.04511 Error2: 0.09701
