In [35]:
import numpy as np
from scipy.special import softmax
from sklearn.metrics import confusion_matrix
from collections import Counter
from prettytable import PrettyTable
from scipy.spatial.distance import cdist

In [68]:
prob_1 = 'fairface_train_probs.npy'
target_1 = 'fairface_train_targets.npy'
probs = np.load(prob_1)
probs = softmax(probs, axis=1)
targets = np.load(target_1)
preds = np.argmax(probs, axis=1)
pred_probs = np.max(probs, axis=1)

confusion = confusion_matrix(targets, preds)
confusion_rowwise = confusion / confusion.sum(axis=1, keepdims=True)
confusion_colwise = confusion / confusion.sum(axis=0, keepdims=True)


# calculate correction matrix
correction_matrix = np.zeros((7,7), dtype=np.float32)
for cls_id in range(7):

    cls_real_mask = targets == cls_id
    cls_pred_mask = preds == cls_id
    cls_real_pred_mask = cls_real_mask & cls_pred_mask
    
    precision = np.sum(cls_real_pred_mask.astype(np.float32)) / np.sum(cls_pred_mask.astype(np.float32))
    recall = np.sum(cls_real_pred_mask.astype(np.float32)) / np.sum(cls_real_mask.astype(np.float32))

    cls_real_probs = probs[cls_real_mask]
    cls_real_probs_mean = np.mean(cls_real_probs, axis=0)
    cls_real_probs_std  = np.std(cls_real_probs, axis=0)
    
    cls_pred_probs = probs[cls_pred_mask]
    cls_pred_probs_mean = np.mean(cls_pred_probs, axis=0)
    cls_pred_probs_std  = np.std(cls_pred_probs, axis=0)
    
    print(f"class {cls_id}")
#     print(f"precision:\t{precision}\nrecall:\t\t{recall}")
#     st = ', '.join([f'{x:.5f}' for x in confusion_rowwise[cls_id, :].tolist()])
#     print(f"### pred probs:\n{st}")
#     st = ', '.join([f'{x:.5f}' for x in cls_real_probs_mean.tolist()])
#     print(f"    pred avg probs:\n{st}")
#     st = ', '.join([f'{x:.5f}' for x in cls_real_probs_std.tolist()])
#     print(f"    pred std probs:\n{st}")
#     st = ', '.join([f'{x:.5f}' for x in confusion_colwise[:, cls_id].tolist()])
#     print(f"### real probs:\n{st}")
    st = ', '.join([f'{x:.5f}' for x in cls_pred_probs_mean.tolist()])
    print(f"    real avg probs:\n{st}")
#     st = ', '.join([f'{x:.5f}' for x in cls_pred_probs_std.tolist()])
#     print(f"    real std probs:\n{st}")
#     print()
    
    for j in range(7):
    
        cls_2_mask = targets == j
        cls_2_mispred_mask = cls_pred_mask & cls_2_mask
        cls_2_probs = probs[cls_2_mispred_mask]
        sorted_indices = cls_2_probs.argsort(axis=1)[:, ::-1]
        print([Counter(sorted_indices[:, k].tolist()).most_common(1) for k in range(7)])
        print("### target cls ", j)
        print(cls_2_probs.mean(axis=0))
        correction_matrix[cls_id,j] = cls_2_probs.mean(axis=0)[j] - cls_2_probs.std(axis=0)[j]
        print(cls_2_probs.std(axis=0))
    print()

corrected_preds = np.zeros_like(preds)
second_corrected_preds = np.zeros_like(preds)
for cls_id in range(7):
    cls_pred_mask = preds == cls_id
    cls_pred_probs = probs[cls_pred_mask]
    
    cls_other_pred_probs = cls_pred_probs.copy()
    cls_other_pred_probs[:, cls_id] = 0.
    
    top1_cls = cls_other_pred_probs.argmax(axis=1)
    top1_cls_prob = cls_other_pred_probs.max(axis=1)
    top1_stand = correction_matrix[cls_id, top1_cls]
    labels = np.where(top1_cls_prob>=top1_stand, top1_cls, cls_id)
    corrected_preds[cls_pred_mask] = labels
    
    cls_other_pred_probs[:, top1_cls] = 0.
    top2_cls = cls_other_pred_probs.argmax(axis=1)
    top2_cls_prob = cls_other_pred_probs.max(axis=1)
    top2_stand = correction_matrix[cls_id, top2_cls]
    labels = np.where(top2_cls_prob>=top2_stand, top2_cls, cls_id)
    second_corrected_preds[cls_pred_mask] = labels

for cls_id in range(7):

    cls_real_mask = targets == cls_id
    
    
    cls_pred_mask = preds == cls_id
    cls_real_pred_mask = cls_real_mask & cls_pred_mask
    
    old_precision = np.sum(cls_real_pred_mask.astype(np.float32)) / np.sum(cls_pred_mask.astype(np.float32))
    old_recall = np.sum(cls_real_pred_mask.astype(np.float32)) / np.sum(cls_real_mask.astype(np.float32))
    
    
    cls_pred_mask = corrected_preds == cls_id
#     cls_second_pred_mask = second_corrected_preds == cls_id
#     cls_pred_mask = cls_pred_mask | cls_second_pred_mask
    cls_real_pred_mask = cls_real_mask & cls_pred_mask
    
    precision = np.sum(cls_real_pred_mask.astype(np.float32)) / np.sum(cls_pred_mask.astype(np.float32))
    recall = np.sum(cls_real_pred_mask.astype(np.float32)) / np.sum(cls_real_mask.astype(np.float32))
    print(cls_id)
    print(old_precision, precision)
    print(old_recall, recall)
    print()

class 0
    real avg probs:
0.78791, 0.01284, 0.05998, 0.02106, 0.02073, 0.02507, 0.07242
[[(0, 15794)], [(6, 11959)], [(2, 8273)], [(4, 4344)], [(4, 5354)], [(5, 4938)], [(1, 7652)]]
### target cls  0
[0.91136426 0.00335261 0.02405048 0.00564277 0.00535132 0.00780015
 0.04244041]
[0.12823327 0.01677827 0.05212643 0.02272741 0.01847489 0.02685139
 0.07003001]
[[(0, 1774)], [(6, 490)], [(2, 523)], [(5, 414)], [(4, 369)], [(4, 708)], [(3, 1031)]]
### target cls  1
[0.65623754 0.08585699 0.08250819 0.01826441 0.03221851 0.05924706
 0.06566697]
[0.20314592 0.10092703 0.08401936 0.03656876 0.04999086 0.07530073
 0.07411156]
[[(0, 8320)], [(2, 3572)], [(2, 3429)], [(5, 3133)], [(4, 3280)], [(1, 2298)], [(3, 3870)]]
### target cls  2
[0.7168818  0.01406391 0.10936074 0.01465214 0.0225999  0.03565217
 0.08678912]
[0.19313134 0.03792744 0.10441803 0.03938369 0.04537409 0.06037714
 0.09059936]
[[(0, 3826)], [(3, 1675)], [(4, 1235)], [(2, 1199)], [(2, 999)], [(5, 1853)], [(1, 2188)]]
### target c

0
0.3738933 0.39576125
0.95564836 0.38754764

1
0.8488126 0.67901945
0.62233305 0.48230198

2
0.39174178 0.30893883
0.13982195 0.3537069

3
0.7449127 0.51963747
0.4617889 0.37796044

4
0.54943407 0.32233837
0.44066697 0.29421028

5
0.63018245 0.4408756
0.55523986 0.3253511

6
0.5122488 0.2836348
0.20193142 0.57508683



In [30]:
# clustering analysis
def print_distances(embeddings_file, labels_file):
    embeddings = np.load(embeddings_file)
    normed_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    labels = np.load(labels_file)[:, 0]

    cls_normed_embeddings = []
    for i in range(7):
        curr_cls_mask = labels==i
        curr_cls_normed_embeddings = normed_embeddings[curr_cls_mask]
        cls_normed_embeddings.append(curr_cls_normed_embeddings)

    cls_clusterings_mean = np.array([curr_cls_normed_embeddings.mean(axis=0) for curr_cls_normed_embeddings in cls_normed_embeddings])
    cls_clusterings_std = np.array([curr_cls_normed_embeddings.std(axis=0) for curr_cls_normed_embeddings in cls_normed_embeddings])

    race_list = ['White', 'Black', 'Latino hispanic', 'East asian', 'Southeast asian', 'Indian', 'Middle eastern']
    np.set_printoptions(precision=2, suppress=True)

    pairwise_distances = np.linalg.norm(cls_clusterings_mean[:, None, :] - cls_clusterings_mean[None, :, :], axis=2)
    weighted_pairwise_distances = pairwise_distances.copy()
    weighted_pairwise_distances[:, 0] *= 4
    mean_distances = weighted_pairwise_distances.sum(axis=1)/6
    
    x = PrettyTable(["", 'avg'] + race_list)
    for i, row in enumerate(pairwise_distances):
        row = ["{0:0.2f}".format(i) for i in row.tolist()]
        mean_dist = "{0:0.2f}".format(mean_distances[i])
        x.add_row([race_list[i], mean_dist] + row)
    print(x)

def print_center_distances(cluster_file):
    embeddings = np.load(cluster_file)
    embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    
    distances = cdist(embeddings, embeddings)
    
    race_list = ['White', 'Black', 'Latino hispanic', 'East asian', 'Southeast asian', 'Indian', 'Middle eastern']
    x = PrettyTable([""] + race_list)
    for i, row in enumerate(pairwise_distances):
        row = ["{0:0.2f}".format(i) for i in row.tolist()]
        mean_dist = "{0:0.2f}".format(mean_distances[i])
        x.add_row([race_list[i], mean_dist] + row)
    print(x)
    

In [34]:
print_distances('cls_fairface-val_embeddings.npy', 'cls_fairface-val_labels.npy')
print_distances('fixmatch_fairface-val_embeddings.npy', 'fixmatch_fairface-val_labels.npy')
print_distances('ssp_fairface-val_embeddings.npy', 'ssp_fairface-val_labels.npy')
print_distances('cluster_fairface-val_embeddings.npy', 'cluster_fairface-val_labels.npy')

+-----------------+------+-------+-------+-----------------+------------+-----------------+--------+----------------+
|                 | avg  | White | Black | Latino hispanic | East asian | Southeast asian | Indian | Middle eastern |
+-----------------+------+-------+-------+-----------------+------------+-----------------+--------+----------------+
|      White      | 0.56 |  0.00 |  0.80 |       0.36      |    0.67    |       0.67      |  0.63  |      0.26      |
|      Black      | 1.04 |  0.80 |  0.00 |       0.57      |    0.76    |       0.58      |  0.40  |      0.72      |
| Latino hispanic | 0.60 |  0.36 |  0.57 |       0.00      |    0.56    |       0.46      |  0.33  |      0.24      |
|    East asian   | 0.94 |  0.67 |  0.76 |       0.56      |    0.00    |       0.24      |  0.69  |      0.69      |
| Southeast asian | 0.85 |  0.67 |  0.58 |       0.46      |    0.24    |       0.00      |  0.53  |      0.64      |
|      Indian     | 0.82 |  0.63 |  0.40 |       0.33   