In [29]:
from dataset_slide import *
import torch
import torch.nn as nn

import numpy as np

from scipy.stats import uniform, randint

from sklearn.metrics import auc, accuracy_score, confusion_matrix, mean_squared_error
import xgboost as xgb

In [30]:
person_order = {'F1_Interaction_1': {'P2': 1, 'P1': 1, 'P3': 2},
 'F1_Interaction_2': {'P2': 1, 'P1': 1, 'P3': 2},
 'F2_Interaction_1': {'P4': 1, 'P5': 3},
 'F2_Interaction_2': {'P4': 1},
 'F3_Interaction_1': {'P8': 3, 'P6': 1, 'P7': 1},
 'F3_Interaction_2': {'P6': 1, 'P7': 1},
 'F4_Interaction_1': {'P14': 2,
  'P12': 1,
  'P11': 1,
  'P10': 1,
  'P9': 1,
  'P13': 3},
 'F4_Interaction_2': {'P12': 1,
  'P11': 1,
  'P10': 1,
  'P9': 1,
  'P13': 3},
 'F5_Interaction_1': {'P16': 2, 'P15': 1},
 'F5_Interaction_2': {'P16': 2, 'P15': 1},
 'F6_Interaction_1': {'P19': 3, 'P18': 1, 'P17': 1},
 'F6_Interaction_2': {'P19': 3, 'P18': 1, 'P17': 1},
 'F7_Interaction_1': {'P22': 3,
  'P20': 1,
  'P21': 1,
  'P23': 2},
 'F8_Interaction_1': {'P24': 1, 'P25': 3},
 'F8_Interaction_2': {'P24': 1, 'P25': 3},
 'F8_Interaction_3': {'P24': 1, 'P25': 3},
 'F10_Interaction_1': {'P27': 1, 'P28': 1},
 'F11_Interaction_1': {'P29': 1, 'P30': 2},
 'F11_Interaction_2': {'P29': 1, 'P30': 2},
 'F13_Interaction_1': {'P32': 1, 'P33': 2},
 'F17_Interaction_1': {'P37': 1, 'P38': 2},
 'F17_Interaction_2': {'P37': 1, 'P38': 2}}


group_nums = {1: ['F2_Interaction_2'],
 2: ['F2_Interaction_1',
  'F3_Interaction_2',
  'F5_Interaction_1',
  'F5_Interaction_2',
  'F8_Interaction_1',
  'F8_Interaction_2',
  'F8_Interaction_3',
  'F10_Interaction_1',
  'F11_Interaction_1',
  'F11_Interaction_2',
  'F13_Interaction_1',
  'F17_Interaction_1',
  'F17_Interaction_2'],
 3: ['F1_Interaction_1',
  'F1_Interaction_2',
  'F3_Interaction_1',
  'F6_Interaction_1',
  'F6_Interaction_2'],
 4: ['F7_Interaction_1'],
 5: ['F4_Interaction_2'],
 6: ['F4_Interaction_1']}

group_all_dataset = []
group_ids = group_nums[3]
for group_id in group_ids:
    group_specific_dataset = SpeedDatingDS(group_id = group_id, social_rel = person_order[group_id])
    group_all_dataset.append(group_specific_dataset)

SD = torch.utils.data.ConcatDataset(group_all_dataset)

########################################################################
#Dataloader
########################################################################
train_len = len(SD) - len(SD)//5
test_len = len(SD)//5

train, test = torch.utils.data.random_split(SD, (train_len, test_len), generator=torch.Generator().manual_seed(0))

batch_size = 32
trainloader = DataLoader(train, batch_size = train_len, shuffle = True, num_workers = 8)
testloader = DataLoader(test, batch_size = test_len, shuffle = True, num_workers = 8)

In [31]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for idx, batch in enumerate(trainloader):

    x_train, vb_output = batch['context'], batch['vb_output']

    labels = vb_output.sum(2).to(device).flatten(start_dim =1)
    index_labels = torch.zeros(x_train.shape[0]).long().to(device)
    index_labels[labels.nonzero()[:,0]] = labels.nonzero()[:,1] + 1 
    y_train = index_labels

for idx, batch in enumerate(testloader):
    x_test, vb_output = batch['context'], batch['vb_output']

    labels = vb_output.sum(2).to(device).flatten(start_dim =1)
    index_labels = torch.zeros(x_test.shape[0]).long().to(device)
    index_labels[labels.nonzero()[:,0]] = labels.nonzero()[:,1] + 1 
    y_test = index_labels




In [32]:
X = x_train.flatten(start_dim =1).cpu().numpy()
y = y_train.cpu().numpy()

x_test = x_test.flatten(start_dim =1).cpu().numpy()
y_test = y_test.cpu().numpy()

In [41]:
from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from sklearn.metrics import accuracy_score

space={'max_depth': hp.quniform("max_depth", 3, 18, 1),
        'gamma': hp.uniform ('gamma', 1,9),
        'reg_alpha' : hp.quniform('reg_alpha', 40,180,1),
        'reg_lambda' : hp.uniform('reg_lambda', 0,1),
        'colsample_bytree' : hp.uniform('colsample_bytree', 0.5,1),
        'min_child_weight' : hp.quniform('min_child_weight', 0, 10, 1),
        'n_estimators': 180,
        'seed': 0
    }

def objective(space):
    clf=xgb.XGBClassifier(
                    n_estimatoras =space['n_estimators'], max_depth = int(space['max_depth']), gamma = space['gamma'],
                    reg_alpha = int(space['reg_alpha']),min_child_weight=int(space['min_child_weight']),
                    colsample_bytree=int(space['colsample_bytree']))
    
    evaluation = [( X, y)]
    
    clf.fit(X, y,
            eval_set=evaluation, eval_metric="mlogloss",
            early_stopping_rounds=10,verbose=False)
    

    pred = clf.predict(x_test)
    f1_score = (pred == y_test).mean()#sklearn.metrics.f1_score(pred, y_test, average='macro')
    
    print("f1: {}".format(sklearn.metrics.f1_score(pred, y_test, average='macro')))

    return {'loss': -f1_score, 'status': STATUS_OK }

trials = Trials()

best_hyperparams = fmin(fn = objective,
                        space = space,
                        algo = tpe.suggest,
                        max_evals = 100,
                        trials = trials)

f1: 0.2325020112630732                                                       
  1%|      | 1/100 [00:00<00:14,  7.02trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
  3%|▏     | 3/100 [00:00<00:12,  7.92trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
  5%|▎     | 5/100 [00:00<00:11,  8.23trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
  7%|▍     | 7/100 [00:00<00:11,  8.33trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
  9%|▌     | 9/100 [00:01<00:10,  8.44trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 11%|▌    | 11/100 [00:01<00:11,  7.74trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 13%|▋    | 13/100 [00:01<00:11,  7.89trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 15%|▊    | 15/100 [00:01<00:11,  7.57trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 17%|▊    | 17/100 [00:02<00:10,  8.17trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 19%|▉    | 19/100 [00:02<00:09,  8.35trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 21%|█    | 21/100 [00:02<00:10,  7.65trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 23%|█▏   | 23/100 [00:02<00:09,  7.77trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 25%|█▎   | 25/100 [00:03<00:09,  7.68trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 27%|█▎   | 27/100 [00:03<00:09,  7.39trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 29%|█▍   | 29/100 [00:03<00:09,  7.32trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 31%|█▌   | 31/100 [00:03<00:09,  7.41trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 33%|█▋   | 33/100 [00:04<00:08,  7.48trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 35%|█▊   | 35/100 [00:04<00:09,  7.21trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 37%|█▊   | 37/100 [00:04<00:09,  6.97trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 39%|█▉   | 39/100 [00:05<00:08,  7.16trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 41%|██   | 41/100 [00:05<00:08,  6.93trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 43%|██▏  | 43/100 [00:05<00:08,  7.12trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 45%|██▎  | 45/100 [00:05<00:07,  7.16trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 47%|██▎  | 47/100 [00:06<00:07,  7.33trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 49%|██▍  | 49/100 [00:06<00:07,  7.21trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 51%|██▌  | 51/100 [00:06<00:06,  7.25trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 53%|██▋  | 53/100 [00:07<00:06,  7.02trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 55%|██▊  | 55/100 [00:07<00:06,  6.80trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 57%|██▊  | 57/100 [00:07<00:06,  7.07trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 59%|██▉  | 59/100 [00:07<00:05,  7.14trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 61%|███  | 61/100 [00:08<00:05,  7.20trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 63%|███▏ | 63/100 [00:08<00:05,  7.13trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 65%|███▎ | 65/100 [00:08<00:04,  7.07trial/s, best loss: -0.869172932330827]




f1: 0.2325020112630732                                                       
 66%|███▎ | 66/100 [00:08<00:05,  6.06trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 68%|███▍ | 68/100 [00:09<00:05,  6.39trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 70%|███▌ | 70/100 [00:09<00:04,  6.54trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 72%|███▌ | 72/100 [00:09<00:04,  6.75trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 74%|███▋ | 74/100 [00:10<00:03,  7.04trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 76%|███▊ | 76/100 [00:10<00:03,  7.13trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 78%|███▉ | 78/100 [00:10<00:03,  7.23trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 80%|████ | 80/100 [00:10<00:02,  7.24trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 82%|████ | 82/100 [00:11<00:02,  7.05trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 84%|████▏| 84/100 [00:11<00:02,  6.80trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 86%|████▎| 86/100 [00:11<00:01,  7.02trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 88%|████▍| 88/100 [00:12<00:01,  6.96trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 90%|████▌| 90/100 [00:12<00:01,  6.90trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 92%|████▌| 92/100 [00:12<00:01,  7.05trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 94%|████▋| 94/100 [00:12<00:00,  6.93trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 96%|████▊| 96/100 [00:13<00:00,  6.81trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
 98%|████▉| 98/100 [00:13<00:00,  6.76trial/s, best loss: -0.869172932330827]





f1: 0.2325020112630732                                                       
f1: 0.2325020112630732                                                       
100%|████| 100/100 [00:13<00:00,  7.20trial/s, best loss: -0.869172932330827]





In [42]:
best_hyperparams

{'colsample_bytree': 0.7620468041420051,
 'gamma': 3.898103311872073,
 'max_depth': 4.0,
 'min_child_weight': 3.0,
 'reg_alpha': 142.0,
 'reg_lambda': 0.7741583696916022}

In [43]:
clf = xgb.XGBClassifier(best_hyperparams)

clf.fit(X, y)
y_true, y_pred = y_test, clf.predict(x_test)
print("f1: {}".format(sklearn.metrics.f1_score(y_pred, y_true, average='macro')))
print("weighted_f1: {}".format(sklearn.metrics.f1_score(y_pred, y_true, average='weighted')))
print("acc: {}".format((y_pred == y_true).mean()))
print(confusion_matrix(y_test, y_pred))



f1: 0.7714938302800296
f1: 0.9231741935038688
acc: 0.9233082706766917
[[552   9   9   8]
 [  4  15   2   0]
 [ 18   0  31   0]
 [  1   0   0  16]]


In [50]:
confusion_matrix(y_test, y_pred).ravel()

array([552,   9,   9,   8,   4,  15,   2,   0,  18,   0,  31,   0,   1,
         0,   0,  16])

In [57]:
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred, labels=[0,1,2,3]))

              precision    recall  f1-score   support

           0       0.96      0.96      0.96       578
           1       0.62      0.71      0.67        21
           2       0.74      0.63      0.68        49
           3       0.67      0.94      0.78        17

    accuracy                           0.92       665
   macro avg       0.75      0.81      0.77       665
weighted avg       0.93      0.92      0.92       665



In [87]:
from sklearn.metrics import confusion_matrix
import numpy as np

cm = confusion_matrix(y_test, y_pred)
# recall = np.diag(cm) / np.sum(cm, axis = 1)
# precision = np.diag(cm) / np.sum(cm, axis = 0)
cm

array([[552,   9,   9,   8],
       [  4,  15,   2,   0],
       [ 18,   0,  31,   0],
       [  1,   0,   0,  16]])

In [95]:
cm  = np.array([[552,   9 + 9+ 8+ 2],
       [  4 + 18 + 1,  15+31+16,]])
recall = np.diag(cm) / np.sum(cm, axis = 1)
precision = np.diag(cm) / np.sum(cm, axis = 0)

In [96]:
recall

array([0.95172414, 0.72941176])

In [97]:
precision

array([0.96      , 0.68888889])

In [91]:
from sklearn.metrics import confusion_matrix
import numpy as np


cm = np.array([[525.,  13.,  27.,  13.],
               [  2.,  19.,   0.,   0.],
               [  0.,   1.,  47.,   1.],
               [  0.,   0.,   0.,  17.]])
recall = np.diag(cm) / np.sum(cm, axis = 1)
precision = np.diag(cm) / np.sum(cm, axis = 0)

In [92]:
cm  = np.array([[525,  13+27+13+1 ],
       [  2+1,  19+47+17]])
recall = np.diag(cm) / np.sum(cm, axis = 1)
precision = np.diag(cm) / np.sum(cm, axis = 0)

In [93]:
precision

array([0.99431818, 0.60583942])

In [94]:
recall

array([0.90673575, 0.96511628])

In [66]:
cm

array([[552,   9,   9,   8],
       [  4,  15,   2,   0],
       [ 18,   0,  31,   0],
       [  1,   0,   0,  16]])