In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torchvision

import numpy as np
import pandas as pd

import math

import cv2

In [2]:
H_LO = 16
H_HI = 640
W_LO = 16
W_HI = 640

class ScaleToImageSize:
  def __init__(self, patch_size: int = 16, w_lo: int = W_LO, w_hi: int = W_HI, h_lo: int = H_LO, h_hi: int = H_HI, ) -> None:
    assert w_lo <= w_hi and h_lo <= h_hi
    self.w_lo = w_lo
    self.w_hi = w_hi
    self.h_lo = h_lo
    self.h_hi = h_hi
    self.patch_size = patch_size
  
  def __call__(self, img: np.ndarray) -> np.ndarray:
    h, w = img.shape[:2]
    r = h / w
    lo_r = self.h_lo / self.w_hi
    hi_r = self.h_hi / self.w_lo
    assert lo_r <= h / w <= hi_r, f"img ratio h:w {r} not in range [{lo_r}, {hi_r}]"

    scale_r = min(self.h_hi / h, self.w_hi / w)
    if scale_r < 1.0:
      # one of h or w highr that hi, so scale down
      img = cv2.resize(img, None, fx=scale_r, fy=scale_r, interpolation=cv2.INTER_CUBIC)

    scale_r = max(self.h_lo / h, self.w_lo / w)
    if scale_r > 1.0:
      # one of h or w lower that lo, so scale up
      img = cv2.resize(img, None, fx=scale_r, fy=scale_r, interpolation=cv2.INTER_CUBIC)
    
    h, w = img.shape[:2]
    
    # in the rectangle, do not scale
    assert self.h_lo <= h <= self.h_hi and self.w_lo <= w <= self.w_hi
    
    new_h, new_w = img.shape[:2]
    
    new_h = self.patch_size * math.ceil(new_h / self.patch_size)
    new_w = self.patch_size * math.ceil(new_w / self.patch_size)
    img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
    return img


In [3]:
from torchvision.transforms import transforms, Compose

file = 'train_30738'

test_im = cv2.imread('../train/' + file + '.jpg', cv2.IMREAD_GRAYSCALE)

transform = Compose([ScaleToImageSize(patch_size=16), transforms.ToTensor()])

In [4]:
# 1 image batch test

imgs = [transform(test_im)]

In [5]:
import math

h_x = [s.size(1) for s in imgs]
w_x = [s.size(2) for s in imgs]

max_height_x = max(h_x)
max_width_x = max(w_x)

x = torch.zeros(len(h_x), 1, max_height_x, max_width_x)
x_mask = torch.ones(len(h_x), math.floor(max_height_x/16)*16, math.floor(max_width_x/16)*16, dtype=torch.float)

for idx, s_x in enumerate(imgs):
    x[idx, :, :h_x[idx], :w_x[idx]] = s_x
    x_mask[idx, :h_x[idx], :w_x[idx]] = 0

In [6]:
label_dict = { line.strip().split('\t')[0]:line.strip().split('\t')[1]for line in open ('train_ssml_sd.txt').readlines()}

In [7]:
from bivi_vocab import vocab, vocab_full

vocab.print()

{'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 108, '!': 4, '"': 5, '(': 6, ')': 7, '*': 8, '+': 9, ',': 10, '-': 11, '-:[:0]': 12, '-:[:135]': 13, '-:[:15]': 14, '-:[:30]': 15, '-:[:345]': 16, '-:[:45]': 17, '-:[:60]': 18, '-:[:90]': 19, '-[:0]': 20, '-[:105]': 21, '-[:120]': 22, '-[:135]': 23, '-[:150]': 24, '-[:15]': 25, '-[:165]': 26, '-[:180]': 27, '-[:195]': 28, '-[:210]': 29, '-[:225]': 30, '-[:240]': 31, '-[:255]': 32, '-[:270]': 33, '-[:285]': 34, '-[:300]': 35, '-[:30]': 36, '-[:315]': 37, '-[:330]': 38, '-[:345]': 39, '-[:45]': 40, '-[:60]': 41, '-[:75]': 42, '-[:90]': 43, '.': 44, '/': 45, '//': 46, '0': 47, '1': 48, '2': 49, '3': 50, '4': 51, '5': 52, '6': 53, '7': 54, '8': 55, '9': 56, ':': 57, ';': 58, '<': 59, '<:[:0]': 60, '<:[:105]': 61, '<:[:120]': 62, '<:[:135]': 63, '<:[:150]': 64, '<:[:15]': 65, '<:[:165]': 66, '<:[:180]': 67, '<:[:195]': 68, '<:[:210]': 69, '<:[:225]': 70, '<:[:240]': 71, '<:[:255]': 72, '<:[:270]': 73, '<:[:285]': 74, '<:[:300]': 75, '<:[:30]': 7

In [8]:
from bivi_utils import to_bi_tgt_out

seqs_y = [vocab.words2indices(label_dict[file + '.json'].split())]

tgt, out = to_bi_tgt_out(seqs_y, "cpu")

In [9]:
print(x.shape)
print(x_mask.shape)
print(tgt)

torch.Size([1, 1, 320, 320])
torch.Size([1, 320, 320])
tensor([[  1, 260, 331,  20, 223, 305,  20, 306, 136, 235, 333],
        [  2, 333, 235, 136, 306,  20, 305, 223,  20, 331, 260]])


In [10]:
print(label_dict[file + '.json'])

\chemfig { -[:0] C branch( -[:0] branch) =[:90] O }


In [11]:
print(seqs_y)

[[260, 331, 20, 223, 305, 20, 306, 136, 235, 333]]


In [12]:
from bivi_model import CrocsBiVision

model = CrocsBiVision.load_from_checkpoint('../bivision-logs-new2/epoch=49-step=685450-val_loss=0.6076-val_wer=0.2679-359.ckpt')
   
# model = Crocs.load_from_checkpoint('dense-logs-new/epoch=54-step=375595-val_loss=0.2932-val_wer=0.3506-359.ckpt')

In [13]:
# Test on batch of 1

print(x.shape)
print(x_mask.shape)

output = model(x, x_mask, out)

torch.Size([1, 1, 320, 320])
torch.Size([1, 320, 320])


In [14]:
print(output.argmax(dim=2))

tensor([[  1,  20, 305, 305,  20, 235,  43, 235, 333,   2, 331],
        [ 50,   2, 306,  20, 305, 223,  20, 235, 300,   1, 260]])


In [15]:
print(out)

tensor([[260, 331,  20, 223, 305,  20, 306, 136, 235, 333,   2],
        [333, 235, 136, 306,  20, 305, 223,  20, 331, 260,   1]])


In [16]:
torch.eq(out, output.argmax(dim=2))

tensor([[False, False, False, False, False, False, False, False, False, False,
         False],
        [False, False, False, False, False, False, False, False, False, False,
         False]])

In [17]:
from bivi_utils import we_rate

we_rate(output.argmax(dim=2), seqs_y, 'cpu')

tensor(1.)