In [None]:
from torchvision.transforms import Compose

from htr_crnn_ctc.datasource import ParquetDataSource, InMemoryDataSource
from htr_crnn_ctc.dataset import CTCDataset
from htr_crnn_ctc.transforms import Deslant, Rescale, ToRGB, ToTensor, Normalise
from htr_crnn_ctc.dataloader import CTCDataLoader
from htr_crnn_ctc.model import CTCModel
from htr_crnn_ctc.learn import Learner
from htr_crnn_ctc.utils import get_decode_map

from torch import cuda, device as Device

In [None]:
dev = Device("cuda" if cuda.is_available() else "cpu")
print(f"dev: {dev}")

In [None]:
# pds = ParquetDataSource(
#     file="tmp\\dataset\\IAM-line\\data\\train.parquet",
#     map_columns=None
# )

imds = InMemoryDataSource(file_path="tmp\\dataset\\IAM-line\\data\\train.parquet.deslanted.bin")
# imds.from_datasource(pds)
# imds.dump()
imds.load()

In [None]:
pre_trans = [
    Deslant()
]

trans = [
    Rescale(
        output_size=(64, 800),
        random_pad=True,
        border_pad=(10, 40), 
        random_rotation=2,
        random_stretch=1.2
    ),
    ToRGB(), 
    ToTensor(rgb=True),
    Normalise(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
]

In [None]:
ds = CTCDataset(
    data_source=imds, # pds / imds
    char_dict=None,
    transform=Compose(trans) # pre_trans / trans
)

In [None]:
# for i in range(len(ds)):
#     imds[i] = ds[i]

# imds.dump()

In [None]:
dl = CTCDataLoader(
    dataset=ds,
    train_batch_size=120,
    validation_batch_size=240,
    validation_split=0.2,
    shuffle=True,
    seed=42,
    device=dev
)

In [None]:
model = CTCModel(
    chan_in=3,
    time_step=96,
    feature_size=512,
    hidden_size=512,
    output_size=len(ds.char_dict) + 1,
    num_rnn_layers=4,
    rnn_dropout=0
).to(dev)

model.load_pretrained_resnet()

In [None]:
learn = Learner(
    model=model,
    dataloader=dl,
    decode_map=get_decode_map(ds.char_dict)
)

In [None]:
# learn.freeze()
# log, lr = learn.find_lr(start_lr=1e-5, end_lr=1e1, wd=0.1)

In [None]:
learn.freeze()
learn.fit_one_cycle(epochs=5, max_lr=1e-3, base_lr=1e-4, wd=0.1)

learn.unfreeze()
learn.fit_one_cycle(epochs=5, max_lr=1e-3, base_lr=1e-4, wd=0.1)

In [None]:
# learn.predict(trans, img_ndarray=imds[0].image, dev=Device("cuda"))
learn.predict(Compose(pre_trans + trans), img_path="test.png", dev=Device("cuda"))

In [None]:
learn.batch_predict(show_img=True, up_to=10)