In [66]:
from scipy.stats import hmean
from sklearn.metrics import roc_curve, auc

y_pred = np.random.rand(20,5)
y_true = np.random.randint(low=0,high=2,size=(20,5),dtype='int')


y_pred[:,0] *= 0
y_true[:,0] *= 0

y_true[0,0] = 1
y_pred[0,0] = 0.7
y_pred[1,0] = 0.9


def safe_div(x,y):
    if y == 0.0 and x == 0.0:
        return 0.0
    else:
        return x/y
    
def round_list(a,p=2):
    res = []
    for x in a:
        res.append(round(x,p))
    return res

def confuse_2d(y_true,y_pred):
    no_classes = 2
    confuse =  np.zeros((no_classes,no_classes),dtype=int)

    N = len(y_true)
    for i in xrange(N):
        confuse[y_true[i],y_pred[i]] += 1
        
    prec = safe_div(float(confuse[1,1]),float(confuse[:,1].sum()))
    recall = safe_div(float(confuse[1,1]),float(confuse[1,:].sum()))
    f1 = hmean([prec,recall])
    return prec,recall,f1,confuse

def multi_eval(y_pred,y_true,verbose=False):
    yp_shape = y_pred.shape
    yt_shape = y_true.shape 
    assert yp_shape == yt_shape
    runs, classes = yt_shape
    y_pred_binary = y_pred.copy()
    
    for r in xrange(runs):
        for c in xrange(classes):
            if y_pred_binary[r,c] > 0.5:
                y_pred_binary[r,c] = 1.0
            else:
                y_pred_binary[r,c] = 0.0
                
    y_pred = y_pred.astype(float)
    y_true = y_true.astype(int)
    y_pred_binary = y_pred_binary.astype(int)
    
    
    res = np.zeros((classes,4))
    confusion_matrices = []
    for i in xrange(classes):
        prec,recall,f1,confuse = confuse_2d(y_true[:,i],y_pred_binary[:,i])
        fpr, tpr, _ = roc_curve(y_true[:,i],y_pred[:,i])
        roc_auc = auc(fpr, tpr)
        res[i,0] = prec
        res[i,1] = recall
        res[i,2] = f1
        res[i,3] = roc_auc
        confusion_matrices.append(confuse)
        
        
    return res,confusion_matrices
    
multi_eval(y_pred,y_true)

(array([[ 0.5       ,  1.        ,  0.66666667,  0.94736842],
        [ 0.46153846,  0.85714286,  0.6       ,  0.43956044],
        [ 0.27272727,  0.42857143,  0.33333333,  0.52747253],
        [ 0.5       ,  0.33333333,  0.4       ,  0.39583333],
        [ 0.55555556,  0.41666667,  0.47619048,  0.39583333]]),
 [array([[18,  1],
         [ 0,  1]]), array([[6, 7],
         [1, 6]]), array([[5, 8],
         [4, 3]]), array([[4, 4],
         [8, 4]]), array([[4, 4],
         [7, 5]])])

[[0 0 1 0 1]
 [1 0 0 0 1]
 [1 0 0 0 1]
 [1 0 1 0 1]
 [0 0 0 1 1]
 [1 1 0 0 1]
 [0 1 0 1 1]
 [1 1 1 0 1]
 [0 0 1 0 0]
 [0 1 0 1 1]
 [1 1 0 0 0]
 [1 1 0 1 1]
 [0 0 0 1 1]
 [1 0 0 1 1]
 [1 1 0 1 1]
 [1 1 0 0 1]
 [1 1 0 1 0]
 [1 0 0 1 1]
 [0 0 0 0 1]
 [1 1 1 1 1]]
