### Load Modules


In [1]:
import pandas as pd 
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import h5py 
import sklearn

### Load all the ground truth and prediction data  

In [13]:
cls_gt = pd.read_csv('../experiments/outputs/cls_gt.csv', header=None)
cls_pred = pd.read_csv('../experiments/outputs/cls_pred.csv', header=None)
doa_gt = pd.read_csv('../experiments/outputs/doa_gt.csv', header=None)
doa_pred = pd.read_csv('../experiments/outputs/doa_pred.csv', header=None)

In [3]:
cls_pred_cleaned = cls_pred.applymap(lambda x: 1.0 if x > 0.3 else 0.0)
cls_pred_cleaned

Unnamed: 0,0,1,2
0,1.0,0.0,0.0
1,0.0,0.0,1.0
2,1.0,0.0,0.0
3,0.0,0.0,1.0
4,1.0,0.0,0.0
...,...,...,...
22405,1.0,0.0,0.0
22406,1.0,0.0,0.0
22407,0.0,1.0,0.0
22408,0.0,1.0,0.0


In [4]:
identical_rows = (cls_gt == cls_pred_cleaned).all(axis=1)
num_identical_rows = identical_rows.sum()
print(num_identical_rows)

22405


In [35]:
identical_rows[0]

True

In [14]:
for i in range(3):
    doa_gt[f'az{i}'] = np.degrees(np.arctan2(doa_gt[i], doa_gt[i+3]))
    doa_gt[f'az{i}'] = (doa_gt[f'az{i}'] + 360)%360
    
    doa_pred[f'az{i}'] = np.degrees(np.arctan2(doa_pred[i], doa_pred[i+3]))
    doa_pred[f'az{i}'] = (doa_pred[f'az{i}'] + 360)%360

In [16]:
azi_gt = doa_gt[['az0' , 'az1' , 'az2']].copy()
azi_pred = doa_pred[['az0' , 'az1' , 'az2']].copy()

In [22]:
gt_azi = azi_gt.values.tolist()
pred_azi = azi_pred.values.tolist()

In [31]:

mask = [[1 if value !=0 else 0 for value in sublist] for sublist in gt_azi]
pred_arr = np.array(pred_azi)
mask_arr = np.array(mask)
azi_gt_arr = np.array(gt_azi)
masked_pred = list(pred_arr * mask_arr)

In [45]:
mean_abs_diff = np.mean(np.abs(masked_pred - azi_gt_arr))*3
print(mean_abs_diff)

8.40967094076935


In [48]:
np.abs(masked_pred[2] - azi_gt_arr[2])

array([0.11061607, 0.        , 0.        ])

In [47]:
np.sum(np.abs(masked_pred[2] - azi_gt_arr[2])) 

0.11061607045553501

### Location dependent error rate

In [61]:
is_class_correct = identical_rows.values.tolist()
count = 0
check_cls = 0
neg_count = 0
for i in range(len(is_class_correct)):
    if is_class_correct[i]:
        check_cls += 1
        deg_diff = np.sum(np.abs(masked_pred[i] - azi_gt_arr[i]))
        if deg_diff <= 20:
            count += 1
        else:
            neg_count += 1
          
print(check_cls)  
print(count)
print(neg_count)

22405
20390
2015


In [52]:
azi_gt

Unnamed: 0,az0,az1,az2
0,120.0,0.0,0.0
1,0.0,0.0,300.0
2,120.0,0.0,0.0
3,0.0,0.0,240.0
4,300.0,0.0,0.0
...,...,...,...
22405,120.0,0.0,0.0
22406,240.0,0.0,0.0
22407,0.0,120.0,0.0
22408,0.0,120.0,0.0


In [57]:
masked_pred

[array([119.89655843,   0.        ,   0.        ]),
 array([  0.        ,   0.        , 299.87264589]),
 array([119.88938437,   0.        ,   0.        ]),
 array([  0.        ,   0.        , 239.87358191]),
 array([299.81043793,   0.        ,   0.        ]),
 array([0., 0., 0.]),
 array([119.91379837,   0.        ,   0.        ]),
 array([59.37093682,  0.        ,  0.        ]),
 array([  0.        ,   0.        , 180.08531813]),
 array([0., 0., 0.]),
 array([  0.       ,   0.       , 299.8129596]),
 array([  0.        ,   0.        , 299.79695846]),
 array([  0.        ,   0.        , 239.91249588]),
 array([  0.        , 119.64793081,   0.        ]),
 array([299.81689489,   0.        ,   0.        ]),
 array([  0.        , 119.42293595,   0.        ]),
 array([59.36446924,  0.        ,  0.        ]),
 array([180.04391552,   0.        ,   0.        ]),
 array([  0.        , 239.40600514,   0.        ]),
 array([  0.        , 119.53744457,   0.        ]),
 array([119.92052733,   0.   

In [58]:
for i in range(len(masked_pred)):
    print(azi_gt_arr[i])
    print(masked_pred[i])
    print('\n')

[120.00000045   0.           0.        ]
[119.89655843   0.           0.        ]


[  0.           0.         300.00000045]
[  0.           0.         299.87264589]


[120.00000045   0.           0.        ]
[119.88938437   0.           0.        ]


[  0.           0.         239.99999955]
[  0.           0.         239.87358191]


[300.00000045   0.           0.        ]
[299.81043793   0.           0.        ]


[0. 0. 0.]
[0. 0. 0.]


[120.00000045   0.           0.        ]
[119.91379837   0.           0.        ]


[59.99999955  0.          0.        ]
[59.37093682  0.          0.        ]


[  0.   0. 180.]
[  0.           0.         180.08531813]


[0. 0. 0.]
[0. 0. 0.]


[  0.           0.         300.00000045]
[  0.          0.        299.8129596]


[  0.           0.         300.00000045]
[  0.           0.         299.79695846]


[  0.           0.         239.99999955]
[  0.           0.         239.91249588]


[ 0.         59.99999955  0.        ]
[  0.         119.64793