In [8]:
import torch
import torch.optim as optim
import yaml

import dataset as DS
import finetune as FT

import logging
logging.basicConfig(level=logging.WARNING)


In [9]:


with open("./config_files/fintuning_config.yaml", 'r', encoding="utf8") as f:
    general_config = yaml.safe_load(f)

with open(general_config["data_config"], "r", encoding="utf8") as f:
    data_config = yaml.safe_load(f)
    
with open(general_config["model_config"], "r", encoding="utf8") as f:
    model_config = yaml.safe_load(f)

with open(general_config["training_config"], "r", encoding="utf8") as f:
    training_config = yaml.safe_load(f)

In [10]:
FT.LOGGER.setLevel(logging.INFO)
DS.LOGGER.setLevel(logging.INFO)

In [11]:
model, converter = FT.get_easyocr_recognizer_and_training_converter(["ch_tra"])


In [12]:
freeze_FeatureFxtraction = True
freeze_SequenceModeling = False

if freeze_FeatureFxtraction:
    for param in model.module.FeatureExtraction.parameters():
        param.requires_grad = False
if freeze_SequenceModeling:
    for param in model.module.SequenceModeling.parameters():
        param.requires_grad = False

In [13]:
# define loss
DEVICE= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.CTCLoss(zero_infinity=True).to(DEVICE)
# loss_avg = Averager()

In [14]:
# define optimizer 
lr = 1.
rho = 0.95
eps = 1e-8
filtered_parameters = [p for p in filter(lambda p:p.requires_grad, model.parameters())]
optimizer = optim.Adadelta(filtered_parameters, lr=lr, rho=rho, eps=eps)

In [15]:
# setup dataset
character = ''.join(converter.character[1:])
# print(character)

train_loader = DS.load_dataset("./all_data/en_train", character=character)

validation_set_roots = ["./all_data/en_val"]
val_loader = DS.load_dataset(*validation_set_roots, character=character)

INFO:dataset:dataset: ./all_data/en_train
    filename                                words
64    44.jpg  (895261) Greenery {Wemyss-Islamist}
402  454.jpg  Tuktamysheva (resin) Technologies !
427  490.jpg  Fourteenth . Naiads injurious_Issue
498  571.jpg  Equalization LIGURIA carbohydrate [
781  833.jpg  Buys-Horwood misinterpreting Twitch
INFO:dataset:dataset: ./all_data/en_val
    filename                                words
64    44.jpg  (895261) Greenery {Wemyss-Islamist}
402  454.jpg  Tuktamysheva (resin) Technologies !
427  490.jpg  Fourteenth . Naiads injurious_Issue
498  571.jpg  Equalization LIGURIA carbohydrate [
781  833.jpg  Buys-Horwood misinterpreting Twitch


In [16]:
for epoch in range(10):
    result = FT.finetune_epoch(model, criterion, converter, optimizer, training_set_loader=train_loader)
    # print(epoch, result.mean(), result.std())
    val_result = FT.validation(model, criterion, converter, val_loader)
    # print(epoch, val_result)
    torch.save(model.state_dict(), f'./saved_models/OvO/epoch_{epoch+1}.pth')



training phase:  36%|███▌      | 10/28 [00:06<00:09,  1.97it/s]