In [3]:
import json,numpy
import matplotlib
from matplotlib import pyplot as plt
plt.rcParams['text.usetex'] = True 

for mode,artifact,prefix in [('truck','watermark',''),('fish','human','non_')]:

    res = json.load(open('%s.json'%mode))
    
    #print(res.keys())

    for key in ['r50-clip','simclr-rn50','r50-barlowtwins','r50-sup'] + \
       ([] if mode == 'fish' else ['r50-clip_fix_filter_%d'%d for d in [1,2,3,4,5,10,15,20,30]]):

        net = res[key]

        c0 = numpy.array(net['confusion_matrix'])
        ch = numpy.array(net['%s_confusion_matrix'%artifact])
        cn = numpy.array(net['%s%s_confusion_matrix'%(prefix,artifact)])

        #print(c0.shape)
        
        if mode == 'fish':
            ind = numpy.argsort(ch.sum(axis=1))[::-1]
            classes = numpy.array(res['class_names'])[ind]
            
            def balance(d):
                dnew = d / numpy.maximum(d.sum(axis=1,keepdims=True),10)
                return dnew / dnew.sum() * d.sum()
            
            c0  =         c0[ind][:,ind]
            ch0 =         ch[ind][:,ind]
            ch  = balance(ch[ind][:,ind])
            cn0 =         cn[ind][:,ind]
            cn  = balance(cn[ind][:,ind])

        if mode == 'truck':
            classes = numpy.array(net['class_names'])

        

        acc1 = c0.diagonal().sum()/c0.sum()*100
        acc2 = ch.diagonal().sum()/ch.sum()*100
        acc3 = cn.diagonal().sum()/cn.sum()*100

        L = len(classes)

        print('%8s %24s  |  %.1f  %.1f '%(mode,key,acc1,acc2) + (' %.1f'%acc3 if mode == 'fish' else ''))

        if mode == 'fish':
            
            for d,name in [(ch0,'tot'),(ch,artifact)]:

                plt.figure(figsize=(1,3))
                plt.subplots_adjust(left=0.02,top=0.8,bottom=0.02,right=0.98)
                plt.barh(numpy.arange(L)[::-1],-d.sum(axis=1),color='#7a71b3')
                ax = plt.gca()
                ax.spines['bottom'].set_color('white')
                ax.spines['top'].set_color('white') 
                ax.spines['right'].set_color('white')
                ax.spines['left'].set_color('white')
                ax.yaxis.tick_right()
                plt.xlim(-50,0)
                plt.ylim(-0.5,L-0.5)
                plt.xticks([])
                plt.yticks([])
                plt.rcParams['figure.facecolor'] = 'white'
                plt.savefig('confusions-repr/%s-%s-%s-hist.png'%(mode,key,name),dpi=400);
                plt.close()

        d = numpy.diag(ch.diagonal())

        plt.figure(figsize=(3,3) if mode == 'fish' else (1.75,1.75))
        plt.subplots_adjust(left=0.2,top=0.8,bottom=0.02,right=0.98)
        ax = plt.gca()
        ax.spines['bottom'].set_color('white')
        ax.spines['top'].set_color('white') 
        ax.spines['right'].set_color('white')
        ax.spines['left'].set_color('white')
        ax.xaxis.tick_top()
        plt.imshow((ch-d),cmap='seismic',alpha=1,vmin=-300/L,vmax=300/L)
        plt.xticks(numpy.arange(L),[a[:3 if mode == 'fish' else 2] for a in classes],rotation=90)
        plt.yticks(numpy.arange(L),[a[:3 if mode == 'fish' else 2] for a in classes])
        
        for i in range(L):
            plt.plot([i-0.25,i+0.25],[i-0.25,i+0.25],color='black',lw=1)
            plt.plot([i-0.25,i+0.25],[i+0.25,i-0.25],color='black',lw=1)
        
        if artifact == 'human':
            p = (ch-d)[:,:6].sum()/(ch-d).sum()
            m = L-1
            s = 0.75
            plt.plot([0-s,5+s,5+s,0-s,0-s],[0-s,0-s,m+s,m+s,0-s],color='#999999',lw=1)
            plt.text(-0.25,m+0.5,r'%.1f\%%'%(p*100),horizontalalignment='left',verticalalignment='bottom',
                     fontsize=(p*100)**.33*4,color='black')
            
        if artifact == 'watermark':
            m = L-1
            s = 0.75
            d = numpy.diag(ch.diagonal())
            p = (ch-d)[:,1].sum() / (ch-d).sum()
            plt.plot([1-s,1+s,1+s,1-s,1-s],[0-s,0-s,m+s,m+s,0-s],color='#999999',lw=1)
            plt.text(0.75,L-0.5,r'%.1f\%%'%(p*100),horizontalalignment='left',verticalalignment='bottom',
                     fontsize=(p*100)**.33*4,color='black')#
        plt.rcParams['figure.facecolor'] = 'white'
        plt.savefig('confusions-repr/%s-%s-%s.png'%(mode,key,artifact),dpi=400);
        
        plt.close()
    
    q = 2 if mode == 'truck' else 3
    print(", ".join("[%s]%s"%(cl[:q],cl[q:].replace("_","\,")) for cl in classes))
    print('')
        

   truck                 r50-clip  |  85.0  80.5 
   truck              simclr-rn50  |  74.8  74.5 
   truck          r50-barlowtwins  |  80.2  80.2 
   truck                  r50-sup  |  83.8  83.2 
   truck    r50-clip_fix_filter_1  |  85.0  83.8 
   truck    r50-clip_fix_filter_2  |  84.8  84.0 
   truck    r50-clip_fix_filter_3  |  84.2  83.8 
   truck    r50-clip_fix_filter_4  |  85.5  84.8 
   truck    r50-clip_fix_filter_5  |  85.2  85.0 
   truck   r50-clip_fix_filter_10  |  85.0  85.0 
   truck   r50-clip_fix_filter_15  |  81.5  81.2 
   truck   r50-clip_fix_filter_20  |  81.2  81.2 
   truck   r50-clip_fix_filter_30  |  81.0  82.2 
[fi]re\,engine, [ga]rbage\,truck, [po]lice\,van, [tr]ailer\,truck, [to]w\,truck, [mo]ving\,van, [pi]ckup, [mi]nivan

    fish                 r50-clip  |  86.5  83.8  82.5
    fish              simclr-rn50  |  82.2  78.6  78.4
    fish          r50-barlowtwins  |  83.1  75.6  81.1
    fish                  r50-sup  |  86.2  84.2  81.9
[bar]racouta,