In [None]:
%pylab inline

In [None]:
from importlib import reload
import os, sys, re, glob, time, pickle, IPython, logging
import scipy.ndimage as ndi
from itertools import islice
import torch
from torch import nn, optim
from torch.nn import functional as F
#from torchmore import layers, flex
#import torchtrainers as tt
from torch.utils.data import DataLoader
from webdataset import WebDataset, WebLoader
from ocropus4train import ocrhelpers as helpers
from ocropus4train.ocrhelpers import *
from ocropus4train import ocrmodels2 as models
import scipy
import scipy.ndimage
import ocrodeg
import imageio.v2 as imageio
import braceexpand

RUN("date")
RUN("hostname")
RUN("whoami")
RUN("nvidia-smi -L")

os.environ["GOPEN_VERBOSE"] = "0"
os.environ["WDS_VERBOSE_CACHE"] = "0"
if not "CUDA_VISIBLE_DEVICES" in os.environ:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
training_urls = list(braceexpand.braceexpand("gs://ocro-iaa/words/books-{000001..000653}-words.tar"))
training_urls += list(braceexpand.braceexpand("gs://ocro-iaa/lines/books-{000001..000653}-lines.tar"))
testing_urls = "gs://ocro-iaa/words/books-000000-words.tar"

batch_size = 4

#training_urls = "data/words-simple-training.tar"
#testing_urls = "data/words-simple-testing.tar"

In [None]:
chars = [chr(i) for i in range(32, 127)]
charset = DefaultCharset(chars)

In [None]:
# augmentations:
# - autocrop
# - shift, scale
# - threshold
# - noise, offset, contrast

from ocropus4train.ocraugment import maybe, aniso, distort, normalize, height_normalize, autoinvert, make_noise, threshold, noisify

def random_padding(a, target):
    if a.shape[0] < target-1:
        # for smaller images, add random padding to the top
        d = random.randint(0, int(target - a.shape[0])-1)
        a = np.pad(a, ((d, 0), (0, 0)))
    else:
        # for larger images, add up to 10% random padding to the top
        d = random.randint(0, int(a.shape[0] * 0.1)-1)
        a = np.pad(a, ((d, 0), (0, 0)))
    return a

def preprocess(a, target=48.0):
    assert isinstance(a, np.ndarray)
    assert a.ndim == 2
    a = normalize(a)
    a = autoinvert(a)
    a = random_padding(a, target)
    if maybe(0.5):
        a = noisify(a)
    if maybe(0.5):
        a = distort(a)
    if maybe(0.5):
        a = aniso(a)
    if maybe(0.1):
        sigma = 10**random.uniform(-0.3, 0.3)
        a = scipy.ndimage.gaussian_filter(a, sigma)
        a = normalize(a)
    if maybe(0.1):
        a = threshold(a)
    if target is not None:
        a = height_normalize(a, target*(1.0 + random.uniform(-0.2, 0.0)))
    assert a.ndim == 2
    return a


figsize(24, 12)
testimg = imageio.imread("samples/word.jpg")
for i in range(36):
    subplot(6, 6, i+1)
    imshow(preprocess(testimg))
plt.show()

In [None]:
def good(sample):
    img, txt = sample
    if img.shape[-1] < 10 or img.shape[-2] < 10 or img.shape[-1] > 2500 or img.shape[-2] > 150:
        # print("bad image size", img.shape)
        return None
    return img, txt

def usm_image(img):
    img = img - ndi.gaussian_filter(img, 16.0, mode="nearest")
    return img

def img_tensor(img):
    assert img.ndim == 2, img.shape
    assert img.dtype == np.float32, img.dtype
    assert np.amax(img) < 10.0  # make sure it already got normalied somewhere
    return torch.tensor(img).unsqueeze(0)

def str_tensor(s):
    assert isinstance(s, str)
    return torch.tensor(charset.encode(s)).long()

def pipeline(ds):
    return ds.decode("l8").to_tuple("jpg;jpeg;ppm;png txt").map(good).map_tuple(preprocess).map_tuple(usm_image).map_tuple(img_tensor, str_tensor)

training = pipeline(WebDataset(training_urls, resampled=True).shuffle(20000))
testing = pipeline(WebDataset(testing_urls))
training_dl = WebLoader(training, batch_size=batch_size, collate_fn=helpers.collate4trans, num_workers=8).with_epoch(100000//batch_size)
testing_dl = WebLoader(testing, batch_size=batch_size, collate_fn=helpers.collate4trans, num_workers=4)
images, sequences = next(iter(training_dl))
assert images.shape[-2] <= 48, images.shape
print(images.shape)

In [None]:
from ocropus4train import ocrmodels2
from importlib import reload
reload(ocrmodels2)
mname = "tf_v1"
model = ocrmodels2.make(mname, noutput=len(charset))
# ensure it can be jitted
torch.jit.script(model);
model

In [None]:
import time
errors = [(-1, time.time())]


for trial in range(10):
    model = models.make(mname, noutput=len(charset))
    trainer = helpers.LineTrainer(model, charset=charset, lr=1e-5, mode="tf")
    trainer.save_jit = False  # FIXME
    trainer.load_best()
    try:
        trainer.train(training_dl, 10, every=15, learning_rates=[1e-4]*3 + [1e-5]*200)
    except helpers.NanError:
        errors += [(trial, time.time())]
        print("NaN Error during training, restarting and reloading from last checkpoint")
        time.sleep(10)
