In [169]:
''' PARSE DATA '''
import pandas as pd
import os

labels_csv = os.path.join('letters_data', 'labels.csv')
df = pd.read_csv(labels_csv, encoding='ISO-8859-8', header=None)
df = df.applymap(str.strip)

# clean data
doubles = df.applymap(lambda x: True if len(x)>1 and x.find('!') == -1 else False)[1]
X       = df.applymap(lambda x: True if x == 'X' else False)[1]

df = df[~doubles & ~X].reset_index(drop=True)

# label fix ש -> ה
df.iloc[3659][1]='ה!'

root = 'letters_data'
file_list = []
label_list = []
for _,y in df.iterrows():
    file_list.append(os.path.join(root,y[0][1:].replace('\\','/')))
    label_list.append(y[1])

l = sorted(list(set(label_list)))
label_to_idx = {}
for i,j in enumerate(l):
    label_to_idx[j] = i
label_list = [label_to_idx[s] for s in label_list]

In [197]:
''' IMAGE STATS '''
import scipy.misc as sm
import numpy as np

imgs = [sm.imread(file,mode='RGB') for file in file_list]

shapes=[o.shape for o in imgs]
print(np.median(shapes, axis=0), np.mean(shapes, axis=0), np.std(shapes, axis=0))
print(np.max(shapes,axis=0),np.min(shapes,axis=0))
print(np.argmax(shapes, axis=0), np.argmin(shapes,axis=0))

[ 192.  167.    3.] [ 208.19724556  168.49743231    3.        ] [ 47.07156859  29.70098797   0.        ]
[346 231   3] [143 111   3]
[947  86   0] [1566  952    0]


In [190]:
ars = np.array([o[1]/o[0] for o in shapes])
import scipy.stats as st
# st.mode(ars)[2]
np.histogram(ars,bins=10)

(array([ 201,  413,  549,  916, 1150,  553,  240,  206,   50,    6]),
 array([ 0.34124629,  0.46726152,  0.59327675,  0.71929198,  0.84530722,
         0.97132245,  1.09733768,  1.22335291,  1.34936814,  1.47538337,
         1.6013986 ]))

In [196]:
ars=sorted(ars)
for j in ars: print(j)

0.341246290801
0.365325077399
0.373134328358
0.374622356495
0.375404530744
0.380952380952
0.382089552239
0.383561643836
0.384146341463
0.386581469649
0.386627906977
0.387096774194
0.387573964497
0.387959866221
0.388535031847
0.389380530973
0.389380530973
0.39222614841
0.39222614841
0.394039735099
0.398119122257
0.398713826367
0.400621118012
0.403846153846
0.404761904762
0.405660377358
0.405693950178
0.405797101449
0.405882352941
0.407624633431
0.407738095238
0.408823529412
0.409395973154
0.41049382716
0.41134751773
0.411764705882
0.41265060241
0.413793103448
0.415625
0.415807560137
0.41592920354
0.416129032258
0.416666666667
0.418238993711
0.418918918919
0.419642857143
0.419672131148
0.419825072886
0.420074349442
0.42071197411
0.421538461538
0.421686746988
0.422442244224
0.422818791946
0.424242424242
0.424354243542
0.424836601307
0.425
0.425219941349
0.425605536332
0.425655976676
0.426035502959
0.427184466019
0.427272727273
0.42750929368
0.427609427609
0.428571428571
0.428571428571
0.4

In [3]:
for m in test_letters['masks']:
    sm.imshow(m*255)
    break

In [20]:
len(test_letters['masks'])

147

In [28]:
''' MAKE IMAGES '''
import scipy.misc as sm
import numpy as np

imgs = [sm.imread(file,mode='RGB') for file in file_list]

# resized_imgs = [sm.imresize(im,(32,32)) for im in imgs]

# resized_imgs = [im[np.newaxis,:,:,:] for im in resized_imgs]
# images = np.vstack(resized_imgs)
# labels = np.array(label_list)

[ 192.  167.    3.] [ 47.07156859  29.70098797   0.        ]


In [29]:
''' IMAGE STATS '''
shapes=[o.shape for o in imgs]
print(np.median(shapes, axis=0), np.mean(shapes, axis=0), np.std(shapes, axis=0))
print(np.max(shapes,axis=0),np.min(shapes,axis=0))
print(np.argmax(shapes, axis=0), np.argmin(shapes,axis=0))

[ 192.  167.    3.] [ 208.19724556  168.49743231    3.        ] [ 47.07156859  29.70098797   0.        ]


In [74]:
''' stitch together 4-model adaptive rotation results '''

root = '/home/lioruzan/pixel-cnn/data/letters_data/checkpoints'
runs= [[] for j in range(4)]
for i in range(4):
    r = os.path.join(root,str(i))
    for j in range(10):
        p = os.path.join(r,'results_{}.pkl'.format(j))
        with open(p,'rb') as f:
            runs[i].append(pkl.load(f))
            

runss=[]
for j in range(10):
    samp=np.zeros((0,32,32,3))
    data=samp.copy()
    mask=data.copy()
    for i in range(4):
        for sample,(x,m) in runs[i][j]:
            sample = np.rot90(sample, k=-i, axes=(1,2))
            x = np.rot90(x, k=-i, axes=(1,2))
            m = np.rot90(m, k=-i, axes=(1,2))
            samp = np.vstack([samp,sample])
            data = np.vstack([data,x])
            mask = np.vstack([mask,m])
    runss.append((samp,data,mask))

''' calculate mean average psnr (+- mean average std)'''
average_psnrs, std_psnrs = [], []
for o, data, _ in runss:
    psnrs=[]
    for i in range(o.shape[0]):
        #change to 0..255
        x = 127.5 * o[i] + 127.5
        y = data[i]
        #mse
        mse = np.sum( np.power(x-y,2) ) / np.prod( x.shape )
        #psnr
        psnr = 20 * ( np.log10(255) - np.log10(np.sqrt(mse)) )
        psnrs.append(psnr)
        
    psnr_avg, psnr_std = np.mean(psnrs), np.std(psnrs)
    average_psnrs.append(psnr_avg)
    std_psnrs.append(psnr_std)
print('{:.5} +-{:.5}'.format(np.mean(average_psnrs), np.mean(std_psnrs)))

''' visualize results '''
p=np.random.randint(140)
plt.imshow(runss[0][1][p]/127.5-1)
plt.show()
plt.imshow((runss[0][1][p]/127.5-1)*runss[0][2][p])
plt.show()
plt.imshow(runss[0][0][p])
plt.show()

In [166]:
psnrs

[8.0094482437804793,
 11.576742746439566,
 14.397416623109148,
 16.050003389271453,
 13.547408969241141,
 23.056076520934205,
 12.633913553358219,
 10.246024123511299,
 15.334448796610154,
 11.782192016462133,
 17.510175958815275,
 14.843173007889474,
 16.122954277595923,
 9.2200377348149019,
 11.28626104247334,
 13.515986057497367,
 12.148469563477926,
 9.8219308970910735,
 8.7616479031172201,
 9.9355380158090778,
 19.154202360878038,
 9.3201170342204023,
 16.84945659447223,
 8.0179964062435083,
 45.277232672125834,
 16.739734812845398,
 17.787932509869016,
 14.943535805970427,
 17.171969258429161,
 16.404846435637506,
 9.6479973378444903,
 18.378914848942021,
 15.494049563983769,
 10.581258431914232,
 10.126872069824966,
 16.304898685291889,
 10.819779646181988,
 14.685991622401652,
 15.172440514593362,
 12.363724994144203,
 18.359321153140389,
 20.18213067097788,
 14.180420893095897,
 8.9782805376706776,
 16.63958832636947,
 12.905960815126329,
 13.885445578248623,
 9.83530878940037