In [1]:
%%capture
cd ..

In [2]:
import os
import numpy as np
import pandas as pd
import pickle
from torchvision.transforms import Resize
import torch
from utils.make_dfu import make_dfu
from mccullen.vae import ConvVAE

vae = ConvVAE(init_channels = 16, # initial number of filters
              final_channels = 64,
              latent_dim = 64, # latent dimension for sampling
              hidden_dim = 128)


print('# parameters:',sum([np.prod(m.shape) for n,m in vae.named_parameters()]))
vae.load_state_dict(torch.load('models/vae_overfit.pt',map_location='cpu'))

# parameters: 431569


  from .autonotebook import tqdm as notebook_tqdm


<All keys matched successfully>

In [3]:
if os.path.exists('data/dfu_v2.pkl'):
    dfu = pd.read_pickle('data/dfu_v2.pkl')
else:
    dfu = make_dfu()
    dfu.to_pickle('data/dfu_v2.pkl')

In [4]:
# load (or create) the processed image data for the vae
scripts = dfu.script.to_list()

if os.path.exists('data/PS_v2_32.pkl'):
    with open('data/PS_v2_32.pkl','rb') as f:
        PS = pickle.load(f)
else:
    rs = Resize((32,32))
    PS = dfu.apply(lambda D: 
    (rs(torch.Tensor([D.picture])),D.rep)
              ,axis=1).tolist()
    with open('data/PS_v2_32.pkl','wb') as f:
        pickle.dump(PS,f)


PS = [(1-p[0],p[1]) for p in PS]
ls = [p[1] for p in PS] # letters

In [5]:
%%time
# create vectors for all characters
bs=256
dl = torch.utils.data.DataLoader(dataset=PS, 
                                 batch_size=bs, 
                                 shuffle=False)

hs=[]
for b,q in dl:
    h,_,_ = vae.encoder(b)
    hs.append(h)
    
hs = torch.cat(hs).detach().numpy() # "hidden" states- vector rep for each character
# normalize vectors to length 1 for easy computation of cosine
vs = (hs.T/np.sqrt(np.sum(hs**2,axis=1))).T


CPU times: user 20.9 s, sys: 4.19 s, total: 25.1 s
Wall time: 7.99 s


In [7]:
def top_matches(im,script):
    """
    returns top matches for image, restricted to a given script
    remember to invert the image so that the background is 0.
    returns: list of character-cos pairs
    """
    v_script = np.array([v for s,v in zip(scripts,vs) 
                         if s==script])

    ls_script = np.array([l for s,l in zip(scripts,ls) 
                         if s==script])



    rs = Resize((32,32))
    im=rs(torch.Tensor([[im]]))
    with torch.no_grad():
        h,_,_ = vae.encoder(im)

    v = h.numpy()[0]
    v = v/np.sqrt(np.sum(v**2))

    coss = np.inner(v,v_script)

    return [(ls_script[k],coss[k]) 
            for k in np.argsort(-coss)][:10]

im = 1-dfu.picture.iloc[808]
top_matches(im,'LATIN')

[('Ŧ', 0.8736037),
 ('Ĭ', 0.87193763),
 ('T', 0.86511827),
 ('Ī', 0.85592043),
 ('Ǐ', 0.84747434),
 ('Ĩ', 0.8457887),
 ('Ｔ', 0.8435838),
 ('Ï', 0.83890235),
 ('Ṫ', 0.82552445),
 ('ꞁ', 0.82385993)]