In [54]:
import numpy as np
import os
import matplotlib.pyplot as plt
import scipy.misc as sm
data_dir = '/home/lioruzan/pixel-cnn/data/letters_data'
import pickle as pkl

with open(os.path.join(data_dir,'letters_test.pkl'),'rb') as f:
    test_letters= pkl.load(f)

''' update labels to be correct rotations for test time '''
from scipy.ndimage import measurements as me
rotations = np.zeros(len(test_letters['masks']))
c=0
for m in test_letters['masks']:
    m = m[:,:,0]
    xm,ym = me.center_of_mass(m)
    xm,ym=int(xm)//16,int(ym)//16
    if xm==0 and ym==0:
        rotations[c] = 0
    if xm==0 and ym==1:
        rotations[c] = 1
    if xm==1 and ym==1:
        rotations[c] = 2
    if xm==1 and ym==0:
        rotations[c] = 3
    c += 1

test_letters['labels'] = rotations
with open(os.path.join(data_dir,'letters_test.pkl'),'wb') as f:
    pkl.dump(test_letters, f)

''' figure out batch size that minimizes loss of test samples '''
print((rotations==0).sum()%32)
print((rotations==1).sum()%20)
print((rotations==2).sum()%4)
print((rotations==3).sum()%12)


In [3]:
for m in test_letters['masks']:
    sm.imshow(m*255)
    break

In [20]:
len(test_letters['masks'])

147

In [28]:
''' MAKE IMAGES '''
import scipy.misc as sm
import numpy as np

imgs = [sm.imread(file,mode='RGB') for file in file_list]

# resized_imgs = [sm.imresize(im,(32,32)) for im in imgs]

# resized_imgs = [im[np.newaxis,:,:,:] for im in resized_imgs]
# images = np.vstack(resized_imgs)
# labels = np.array(label_list)

[ 192.  167.    3.] [ 47.07156859  29.70098797   0.        ]


In [29]:
''' IMAGE STATS '''
shapes=[o.shape for o in imgs]
print(np.median(shapes, axis=0), np.mean(shapes, axis=0), np.std(shapes, axis=0))
print(np.max(shapes,axis=0),np.min(shapes,axis=0))
print(np.argmax(shapes, axis=0), np.argmin(shapes,axis=0))

[ 192.  167.    3.] [ 208.19724556  168.49743231    3.        ] [ 47.07156859  29.70098797   0.        ]


In [74]:
''' stitch together 4-model adaptive rotation results '''

root = '/home/lioruzan/pixel-cnn/data/letters_data/checkpoints'
runs= [[] for j in range(4)]
for i in range(4):
    r = os.path.join(root,str(i))
    for j in range(10):
        p = os.path.join(r,'results_{}.pkl'.format(j))
        with open(p,'rb') as f:
            runs[i].append(pkl.load(f))
            

runss=[]
for j in range(10):
    samp=np.zeros((0,32,32,3))
    data=samp.copy()
    mask=data.copy()
    for i in range(4):
        for sample,(x,m) in runs[i][j]:
            sample = np.rot90(sample, k=-i, axes=(1,2))
            x = np.rot90(x, k=-i, axes=(1,2))
            m = np.rot90(m, k=-i, axes=(1,2))
            samp = np.vstack([samp,sample])
            data = np.vstack([data,x])
            mask = np.vstack([mask,m])
    runss.append((samp,data,mask))

''' calculate mean average psnr (+- mean average std)'''
average_psnrs, std_psnrs = [], []
for o, data, _ in runss:
    psnrs=[]
    for i in range(o.shape[0]):
        #change to 0..255
        x = 127.5 * o[i] + 127.5
        y = data[i]
        #mse
        mse = np.sum( np.power(x-y,2) ) / np.prod( x.shape )
        #psnr
        psnr = 20 * ( np.log10(255) - np.log10(np.sqrt(mse)) )
        psnrs.append(psnr)
        
    psnr_avg, psnr_std = np.mean(psnrs), np.std(psnrs)
    average_psnrs.append(psnr_avg)
    std_psnrs.append(psnr_std)
print('{:.5} +-{:.5}'.format(np.mean(average_psnrs), np.mean(std_psnrs)))

''' visualize results '''
p=np.random.randint(140)
plt.imshow(runss[0][1][p]/127.5-1)
plt.show()
plt.imshow((runss[0][1][p]/127.5-1)*runss[0][2][p])
plt.show()
plt.imshow(runss[0][0][p])
plt.show()

In [166]:
psnrs

[8.0094482437804793,
 11.576742746439566,
 14.397416623109148,
 16.050003389271453,
 13.547408969241141,
 23.056076520934205,
 12.633913553358219,
 10.246024123511299,
 15.334448796610154,
 11.782192016462133,
 17.510175958815275,
 14.843173007889474,
 16.122954277595923,
 9.2200377348149019,
 11.28626104247334,
 13.515986057497367,
 12.148469563477926,
 9.8219308970910735,
 8.7616479031172201,
 9.9355380158090778,
 19.154202360878038,
 9.3201170342204023,
 16.84945659447223,
 8.0179964062435083,
 45.277232672125834,
 16.739734812845398,
 17.787932509869016,
 14.943535805970427,
 17.171969258429161,
 16.404846435637506,
 9.6479973378444903,
 18.378914848942021,
 15.494049563983769,
 10.581258431914232,
 10.126872069824966,
 16.304898685291889,
 10.819779646181988,
 14.685991622401652,
 15.172440514593362,
 12.363724994144203,
 18.359321153140389,
 20.18213067097788,
 14.180420893095897,
 8.9782805376706776,
 16.63958832636947,
 12.905960815126329,
 13.885445578248623,
 9.83530878940037